use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use crate::coerce::coerce_value;
use crate::schema::response_schema_for;
#[derive(Debug, Clone)]
pub struct Candidate {
pub json: String,
pub source: CandidateSource,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CandidateSource {
Direct,
MarkdownBlock,
Grepped,
Fixed,
}
#[derive(Debug)]
pub struct ParseResult<T> {
pub value: T,
pub source: CandidateSource,
pub candidates_tried: usize,
}
#[derive(Debug)]
pub struct ParseError {
pub candidates: Vec<(Candidate, String)>,
pub raw: String,
}
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Failed to parse into target type. {} candidates tried",
self.candidates.len()
)?;
for (i, (candidate, err)) in self.candidates.iter().enumerate() {
write!(
f,
"\n [{i}] {:?}: {}",
candidate.source,
truncate(err, 100)
)?;
}
Ok(())
}
}
impl std::error::Error for ParseError {}
pub fn parse_flexible<T: DeserializeOwned>(raw: &str) -> Result<ParseResult<T>, ParseError> {
let candidates = collect_candidates(raw);
let mut errors = Vec::new();
for candidate in &candidates {
match serde_json::from_str::<T>(&candidate.json) {
Ok(value) => {
return Ok(ParseResult {
value,
source: candidate.source,
candidates_tried: errors.len() + 1,
});
}
Err(e) => {
errors.push((candidate.clone(), e.to_string()));
}
}
}
Err(ParseError {
candidates: errors,
raw: raw.to_string(),
})
}
pub fn parse_flexible_coerced<T: JsonSchema + DeserializeOwned>(
raw: &str,
) -> Result<ParseResult<T>, ParseError> {
if let Ok(result) = parse_flexible::<T>(raw) {
return Ok(result);
}
let candidates = collect_candidates(raw);
let schema = response_schema_for::<T>();
let mut errors = Vec::new();
for candidate in &candidates {
if let Ok(mut value) = serde_json::from_str::<serde_json::Value>(&candidate.json) {
coerce_value(&mut value, &schema);
match serde_json::from_value::<T>(value) {
Ok(parsed) => {
return Ok(ParseResult {
value: parsed,
source: candidate.source,
candidates_tried: errors.len() + 1,
});
}
Err(e) => {
errors.push((candidate.clone(), format!("coerced: {}", e)));
}
}
} else {
errors.push((candidate.clone(), "invalid JSON even for Value".into()));
}
}
Err(ParseError {
candidates: errors,
raw: raw.to_string(),
})
}
pub fn collect_candidates(raw: &str) -> Vec<Candidate> {
let mut candidates = Vec::new();
let effective = try_unescape_json_string(raw).unwrap_or_else(|| raw.to_string());
let raw = effective.as_str();
if looks_like_json(raw) {
candidates.push(Candidate {
json: raw.to_string(),
source: CandidateSource::Direct,
});
}
for block in extract_markdown_blocks(raw) {
candidates.push(Candidate {
json: block,
source: CandidateSource::MarkdownBlock,
});
}
for json in extract_json_objects(raw) {
if !candidates.iter().any(|c| c.json == json) {
candidates.push(Candidate {
json,
source: CandidateSource::Grepped,
});
}
}
let fixable: Vec<String> = candidates.iter().map(|c| c.json.clone()).collect();
for json in &fixable {
if let Some(fixed) = try_fix_json(json)
&& !candidates.iter().any(|c| c.json == fixed)
{
candidates.push(Candidate {
json: fixed,
source: CandidateSource::Fixed,
});
}
}
if (candidates.is_empty()
|| !candidates
.iter()
.any(|c| c.source == CandidateSource::Direct))
&& let Some(fixed) = try_fix_json(raw)
&& !candidates.iter().any(|c| c.json == fixed)
{
candidates.push(Candidate {
json: fixed,
source: CandidateSource::Fixed,
});
}
for json_source in [raw]
.iter()
.chain(fixable.iter().map(|s| s as &str).collect::<Vec<_>>().iter())
{
for recovered in truncation_recovery_candidates(json_source) {
if !candidates.iter().any(|c| c.json == recovered) {
candidates.push(Candidate {
json: recovered,
source: CandidateSource::Fixed,
});
}
}
}
candidates
}
fn extract_markdown_blocks(text: &str) -> Vec<String> {
let mut blocks = Vec::new();
let mut rest = text;
while let Some(start) = rest.find("```") {
let after_ticks = &rest[start + 3..];
let content_start = if let Some(newline) = after_ticks.find('\n') {
newline + 1
} else {
break;
};
let content = &after_ticks[content_start..];
if let Some(end) = content.find("```") {
let block = content[..end].trim();
if !block.is_empty() && looks_like_json(block) {
blocks.push(block.to_string());
}
rest = &content[end + 3..];
} else {
let block = content.trim();
if !block.is_empty() && looks_like_json(block) {
blocks.push(block.to_string());
}
break;
}
}
blocks
}
fn extract_json_objects(text: &str) -> Vec<String> {
let mut results = Vec::new();
for open in ['{', '['] {
let close = if open == '{' { '}' } else { ']' };
let mut search_from = 0;
while let Some(start) = text[search_from..].find(open) {
let abs_start = search_from + start;
if let Some(end) = find_matching_bracket(text, abs_start, open, close) {
let json = &text[abs_start..=end];
if !results.contains(&json.to_string()) {
results.push(json.to_string());
}
search_from = end + 1;
} else {
search_from = abs_start + 1;
}
}
}
results
}
fn find_matching_bracket(text: &str, start: usize, open: char, close: char) -> Option<usize> {
let bytes = text.as_bytes();
let mut depth = 0i32;
let mut in_string = false;
let mut escape_next = false;
let mut i = start;
while i < bytes.len() {
let ch = bytes[i] as char;
if escape_next {
escape_next = false;
i += 1;
continue;
}
if ch == '\\' && in_string {
escape_next = true;
i += 1;
continue;
}
if ch == '"' {
in_string = !in_string;
i += 1;
continue;
}
if !in_string {
if ch == open {
depth += 1;
} else if ch == close {
depth -= 1;
if depth == 0 {
return Some(i);
}
}
}
i += 1;
}
None
}
fn try_fix_json(raw: &str) -> Option<String> {
let trimmed = raw.trim();
if serde_json::from_str::<serde_json::Value>(trimmed).is_ok() {
return None;
}
let mut fixed = trimmed.to_string();
let mut changed = false;
let re_trailing = strip_trailing_commas(&fixed);
if re_trailing != fixed {
fixed = re_trailing;
changed = true;
}
let closed = close_brackets(&fixed);
if closed != fixed {
fixed = closed;
changed = true;
}
let quoted = fix_single_quotes(&fixed);
if quoted != fixed {
fixed = quoted;
changed = true;
}
let uncommented = strip_comments(&fixed);
if uncommented != fixed {
fixed = uncommented;
changed = true;
}
if changed && serde_json::from_str::<serde_json::Value>(&fixed).is_ok() {
Some(fixed)
} else {
None
}
}
fn strip_trailing_commas(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let chars: Vec<char> = s.chars().collect();
let mut i = 0;
while i < chars.len() {
if chars[i] == '"' {
result.push(chars[i]);
i += 1;
while i < chars.len() {
result.push(chars[i]);
if chars[i] == '\\' && i + 1 < chars.len() {
i += 1;
result.push(chars[i]);
} else if chars[i] == '"' {
break;
}
i += 1;
}
i += 1;
continue;
}
if chars[i] == ',' {
let mut j = i + 1;
while j < chars.len() && chars[j].is_whitespace() {
j += 1;
}
if j < chars.len() && (chars[j] == '}' || chars[j] == ']') {
i += 1;
continue;
}
}
result.push(chars[i]);
i += 1;
}
result
}
fn close_brackets(s: &str) -> String {
let mut stack = Vec::new();
let mut in_string = false;
let mut escape_next = false;
for ch in s.chars() {
if escape_next {
escape_next = false;
continue;
}
if ch == '\\' && in_string {
escape_next = true;
continue;
}
if ch == '"' {
in_string = !in_string;
continue;
}
if !in_string {
match ch {
'{' => stack.push('}'),
'[' => stack.push(']'),
'}' | ']' => {
stack.pop();
}
_ => {}
}
}
}
if stack.is_empty() && !in_string {
return s.to_string();
}
let mut result = s.to_string();
if in_string {
result.push('"');
}
while let Some(close) = stack.pop() {
result.push(close);
}
result
}
fn truncation_recovery_candidates(s: &str) -> Vec<String> {
let mut cut_points = Vec::new();
let mut in_string = false;
let mut escape_next = false;
for (byte_pos, ch) in s.char_indices() {
if escape_next {
escape_next = false;
continue;
}
if ch == '\\' && in_string {
escape_next = true;
continue;
}
if ch == '"' {
in_string = !in_string;
continue;
}
if in_string {
continue;
}
match ch {
',' => cut_points.push(byte_pos),
'}' | ']' => cut_points.push(byte_pos + 1),
_ => {}
}
}
let mut results = Vec::new();
for &cut in cut_points.iter().rev() {
if cut == 0 || cut >= s.len() {
continue;
}
if let Some(candidate) = try_close_at(s, cut)
&& !results.contains(&candidate)
{
results.push(candidate);
}
}
results
}
fn try_close_at(s: &str, pos: usize) -> Option<String> {
let mut truncated = s[..pos].trim_end().to_string();
if truncated.ends_with(',') {
truncated.pop();
}
let mut stack = Vec::new();
let mut in_str = false;
let mut esc = false;
for ch in truncated.chars() {
if esc {
esc = false;
continue;
}
if ch == '\\' && in_str {
esc = true;
continue;
}
if ch == '"' {
in_str = !in_str;
continue;
}
if !in_str {
match ch {
'{' => stack.push('}'),
'[' => stack.push(']'),
'}' | ']' => {
stack.pop();
}
_ => {}
}
}
}
if in_str {
truncated.push('"');
}
while let Some(close) = stack.pop() {
truncated.push(close);
}
if serde_json::from_str::<serde_json::Value>(&truncated).is_ok() {
Some(truncated)
} else {
None
}
}
fn fix_single_quotes(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let mut in_double = false;
let mut escape_next = false;
for ch in s.chars() {
if escape_next {
result.push(ch);
escape_next = false;
continue;
}
if ch == '\\' {
result.push(ch);
if in_double {
escape_next = true;
}
continue;
}
if ch == '"' {
in_double = !in_double;
result.push(ch);
continue;
}
if ch == '\'' && !in_double {
result.push('"');
} else {
result.push(ch);
}
}
result
}
fn strip_comments(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let chars: Vec<char> = s.chars().collect();
let mut i = 0;
let mut in_string = false;
while i < chars.len() {
if in_string {
result.push(chars[i]);
if chars[i] == '\\' && i + 1 < chars.len() {
i += 1;
result.push(chars[i]);
} else if chars[i] == '"' {
in_string = false;
}
i += 1;
continue;
}
if chars[i] == '"' {
in_string = true;
result.push(chars[i]);
i += 1;
continue;
}
if i + 1 < chars.len() && chars[i] == '/' && chars[i + 1] == '/' {
while i < chars.len() && chars[i] != '\n' {
i += 1;
}
continue;
}
if i + 1 < chars.len() && chars[i] == '/' && chars[i + 1] == '*' {
i += 2;
while i + 1 < chars.len() && !(chars[i] == '*' && chars[i + 1] == '/') {
i += 1;
}
i += 2; continue;
}
result.push(chars[i]);
i += 1;
}
result
}
fn try_unescape_json_string(raw: &str) -> Option<String> {
let trimmed = raw.trim();
if !trimmed.starts_with('"') || !trimmed.ends_with('"') || trimmed.len() < 3 {
return None;
}
let inner = &trimmed[1..trimmed.len() - 1];
if !inner.contains("\\\"") {
return None;
}
match serde_json::from_str::<String>(trimmed) {
Ok(unescaped) if looks_like_json(&unescaped) => Some(unescaped),
_ => None,
}
}
fn looks_like_json(s: &str) -> bool {
let trimmed = s.trim();
(trimmed.starts_with('{') && trimmed.ends_with('}'))
|| (trimmed.starts_with('[') && trimmed.ends_with(']'))
|| trimmed == "null"
|| trimmed == "true"
|| trimmed == "false"
|| trimmed.starts_with('"')
}
fn truncate(s: &str, max: usize) -> &str {
if s.len() <= max {
s
} else {
&s[..s.floor_char_boundary(max)]
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
#[derive(Debug, Deserialize, PartialEq)]
struct Answer {
answer: String,
confidence: f64,
}
#[test]
fn parses_clean_json() {
let raw = r#"{"answer": "42", "confidence": 0.95}"#;
let result = parse_flexible::<Answer>(raw).unwrap();
assert_eq!(result.value.answer, "42");
assert_eq!(result.source, CandidateSource::Direct);
}
#[test]
fn parses_from_markdown_block() {
let raw = r#"Here's my answer:
```json
{"answer": "hello", "confidence": 0.8}
```
Hope that helps!"#;
let result = parse_flexible::<Answer>(raw).unwrap();
assert_eq!(result.value.answer, "hello");
assert_eq!(result.source, CandidateSource::MarkdownBlock);
}
#[test]
fn parses_from_unlabeled_markdown_block() {
let raw = r#"Sure:
```
{"answer": "test", "confidence": 0.5}
```"#;
let result = parse_flexible::<Answer>(raw).unwrap();
assert_eq!(result.value.answer, "test");
assert_eq!(result.source, CandidateSource::MarkdownBlock);
}
#[test]
fn extracts_json_from_surrounding_text() {
let raw =
r#"I think the answer is {"answer": "yes", "confidence": 0.9} based on my analysis."#;
let result = parse_flexible::<Answer>(raw).unwrap();
assert_eq!(result.value.answer, "yes");
assert_eq!(result.source, CandidateSource::Grepped);
}
#[test]
fn extracts_json_after_chain_of_thought() {
let raw = r#"Let me think step by step...
First, I need to consider the question carefully.
The answer seems clear.
{"answer": "deep thought", "confidence": 0.99}"#;
let result = parse_flexible::<Answer>(raw).unwrap();
assert_eq!(result.value.answer, "deep thought");
}
#[test]
fn fixes_trailing_comma() {
let raw = r#"{"answer": "fixed", "confidence": 0.7,}"#;
let result = parse_flexible::<Answer>(raw).unwrap();
assert_eq!(result.value.answer, "fixed");
assert_eq!(result.source, CandidateSource::Fixed);
}
#[test]
fn fixes_unclosed_brackets() {
let raw = r#"{"answer": "partial", "confidence": 0.6"#;
let result = parse_flexible::<Answer>(raw).unwrap();
assert_eq!(result.value.answer, "partial");
assert_eq!(result.source, CandidateSource::Fixed);
}
#[test]
fn fixes_single_quotes() {
let raw = r#"{'answer': 'quoted', 'confidence': 0.5}"#;
let result = parse_flexible::<Answer>(raw).unwrap();
assert_eq!(result.value.answer, "quoted");
assert_eq!(result.source, CandidateSource::Fixed);
}
#[test]
fn fixes_js_comments() {
let raw = r#"{
// This is the answer
"answer": "commented",
"confidence": 0.4
}"#;
let result = parse_flexible::<Answer>(raw).unwrap();
assert_eq!(result.value.answer, "commented");
assert_eq!(result.source, CandidateSource::Fixed);
}
#[test]
fn prefers_direct_over_markdown() {
let raw = r#"{"answer": "direct", "confidence": 1.0}"#;
let result = parse_flexible::<Answer>(raw).unwrap();
assert_eq!(result.source, CandidateSource::Direct);
}
#[test]
fn handles_multiple_json_objects_picks_matching() {
#[derive(Debug, Deserialize, PartialEq)]
struct Config {
model: String,
temperature: f64,
}
let raw = r#"Here are two objects:
{"answer": "wrong type", "confidence": 0.5}
{"model": "gemini", "temperature": 0.3}"#;
let result = parse_flexible::<Config>(raw).unwrap();
assert_eq!(result.value.model, "gemini");
}
#[test]
fn error_shows_all_candidates() {
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Impossible {
xyz_field_that_wont_match: i64,
}
let raw = "Just some plain text with no JSON";
let err = parse_flexible::<Impossible>(raw).unwrap_err();
assert!(err.to_string().contains("Failed to parse"));
}
#[test]
fn handles_nested_json() {
#[derive(Debug, Deserialize, PartialEq)]
struct Nested {
outer: Inner,
}
#[derive(Debug, Deserialize, PartialEq)]
struct Inner {
value: String,
}
let raw = r#"{"outer": {"value": "deep"}}"#;
let result = parse_flexible::<Nested>(raw).unwrap();
assert_eq!(result.value.outer.value, "deep");
}
#[test]
fn handles_array_response() {
let raw = r#"```json
[{"answer": "one", "confidence": 0.5}, {"answer": "two", "confidence": 0.8}]
```"#;
let result = parse_flexible::<Vec<Answer>>(raw).unwrap();
assert_eq!(result.value.len(), 2);
assert_eq!(result.value[1].answer, "two");
}
#[test]
fn handles_empty_input() {
let err = parse_flexible::<Answer>("").unwrap_err();
assert!(err.candidates.is_empty() || !err.candidates.is_empty());
}
#[test]
fn handles_unclosed_markdown_block() {
let raw = r#"```json
{"answer": "streaming", "confidence": 0.3}
"#;
let result = parse_flexible::<Answer>(raw).unwrap();
assert_eq!(result.value.answer, "streaming");
}
#[test]
fn strip_trailing_commas_works() {
assert_eq!(strip_trailing_commas(r#"{"a": 1,}"#), r#"{"a": 1}"#);
assert_eq!(strip_trailing_commas(r#"[1, 2,]"#), r#"[1, 2]"#);
assert_eq!(strip_trailing_commas(r#"{"a": "b,"}"#), r#"{"a": "b,"}"#);
}
#[test]
fn close_brackets_works() {
assert_eq!(close_brackets(r#"{"a": 1"#), r#"{"a": 1}"#);
assert_eq!(close_brackets(r#"[1, [2"#), r#"[1, [2]]"#);
assert_eq!(close_brackets(r#"{"a": "hello"#), r#"{"a": "hello"}"#);
}
#[test]
fn truncation_recovery_drops_incomplete_element() {
let raw = r#"{"items":[{"id":1,"name":"ok"},{"id":2,"na"#;
let candidates = truncation_recovery_candidates(raw);
assert!(!candidates.is_empty(), "Should produce recovery candidates");
let has_valid = candidates.iter().any(|c| {
if let Ok(val) = serde_json::from_str::<serde_json::Value>(c) {
val["items"]
.as_array()
.is_some_and(|a| !a.is_empty() && a[0]["id"] == 1)
} else {
false
}
});
assert!(
has_valid,
"At least one candidate should have first complete element"
);
}
#[test]
fn truncation_recovery_streaming_action() {
#[derive(Debug, Deserialize)]
struct Step {
situation: String,
actions: Vec<serde_json::Value>,
}
let raw = r#"{"situation":"working","actions":[{"tool":"read","path":"a.rs"},{"tool":"edit","path":"b.rs","old"#;
let result = parse_flexible::<Step>(raw);
assert!(result.is_ok(), "Should recover from truncated streaming");
let step = result.unwrap().value;
assert_eq!(step.situation, "working");
assert!(!step.actions.is_empty());
}
#[test]
fn unescape_double_wrapped_json() {
#[derive(Debug, Deserialize)]
struct Simple {
msg: String,
}
let raw = r#""{\"msg\": \"hello world\"}""#;
let result = parse_flexible::<Simple>(raw);
assert!(result.is_ok(), "Should unescape double-wrapped JSON");
assert_eq!(result.unwrap().value.msg, "hello world");
}
#[test]
fn unescape_ignores_normal_strings() {
let result = try_unescape_json_string("\"just a normal string\"");
assert!(result.is_none());
}
#[test]
fn fix_single_quotes_works() {
assert_eq!(fix_single_quotes("{'a': 'b'}"), r#"{"a": "b"}"#);
assert_eq!(
fix_single_quotes(r#"{"it's": "fine"}"#),
r#"{"it's": "fine"}"#
);
}
#[test]
fn strip_comments_works() {
assert_eq!(
strip_comments("{\n// comment\n\"a\": 1\n}"),
"{\n\n\"a\": 1\n}"
);
assert_eq!(strip_comments("{/* block */\"a\": 1}"), "{\"a\": 1}");
}
#[test]
fn extract_markdown_blocks_multiple() {
let raw = r#"First:
```json
{"a": 1}
```
Second:
```json
{"b": 2}
```"#;
let blocks = extract_markdown_blocks(raw);
assert_eq!(blocks.len(), 2);
}
#[test]
fn extract_json_objects_finds_multiple() {
let raw = r#"text {"a": 1} middle {"b": 2} end"#;
let objects = extract_json_objects(raw);
assert_eq!(objects.len(), 2);
}
#[test]
fn extract_json_objects_nested_returns_outer() {
let raw = r#"text {"outer": {"inner": 1}} more text"#;
let objects = extract_json_objects(raw);
assert_eq!(objects.len(), 1);
assert!(objects[0].contains("outer"));
}
#[test]
fn collect_candidates_deduplicates() {
let raw = r#"{"answer": "test", "confidence": 0.5}"#;
let candidates = collect_candidates(raw);
let jsons: Vec<&str> = candidates.iter().map(|c| c.json.as_str()).collect();
let unique: std::collections::HashSet<&&str> = jsons.iter().collect();
assert_eq!(jsons.len(), unique.len());
}
}