1use serde::de::DeserializeOwned;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ProviderFamily {
24 OpenAI,
25 Anthropic,
26 Gemini,
27 Groq,
28 Cohere,
29 XAI,
30 DeepSeek,
31 Ollama,
32 Unknown,
33}
34
35impl ProviderFamily {
36 pub fn from_model_name(model: &str) -> Self {
41 let lower = model.to_lowercase();
42 if lower.starts_with("gpt-")
43 || lower.starts_with("o1-")
44 || lower.starts_with("o3-")
45 || lower.starts_with("o4-")
46 || lower.contains("openai")
47 {
48 ProviderFamily::OpenAI
49 } else if lower.starts_with("claude") || lower.contains("anthropic") {
50 ProviderFamily::Anthropic
51 } else if lower.starts_with("gemini") || lower.contains("google") {
52 ProviderFamily::Gemini
53 } else if lower.contains("groq")
54 || lower.starts_with("llama")
55 || lower.starts_with("mixtral")
56 {
57 ProviderFamily::Groq
58 } else if lower.starts_with("command") || lower.contains("cohere") {
59 ProviderFamily::Cohere
60 } else if lower.starts_with("grok") || lower.contains("xai") {
61 ProviderFamily::XAI
62 } else if lower.starts_with("deepseek") {
63 ProviderFamily::DeepSeek
64 } else if lower.contains("ollama") {
65 ProviderFamily::Ollama
66 } else {
67 ProviderFamily::Unknown
68 }
69 }
70}
71
72impl std::fmt::Display for ProviderFamily {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 match self {
75 ProviderFamily::OpenAI => write!(f, "openai"),
76 ProviderFamily::Anthropic => write!(f, "anthropic"),
77 ProviderFamily::Gemini => write!(f, "gemini"),
78 ProviderFamily::Groq => write!(f, "groq"),
79 ProviderFamily::Cohere => write!(f, "cohere"),
80 ProviderFamily::XAI => write!(f, "xai"),
81 ProviderFamily::DeepSeek => write!(f, "deepseek"),
82 ProviderFamily::Ollama => write!(f, "ollama"),
83 ProviderFamily::Unknown => write!(f, "unknown"),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum ExtractionMethod {
91 FencedJson,
93 GenericFence,
95 DirectJson,
97 EmbeddedJson,
99}
100
101impl std::fmt::Display for ExtractionMethod {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 match self {
104 ExtractionMethod::FencedJson => write!(f, "fenced_json"),
105 ExtractionMethod::GenericFence => write!(f, "generic_fence"),
106 ExtractionMethod::DirectJson => write!(f, "direct_json"),
107 ExtractionMethod::EmbeddedJson => write!(f, "embedded_json"),
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct NormalizedOutput {
115 pub json_body: String,
117 pub method: ExtractionMethod,
119}
120
121#[derive(Debug, Clone)]
123pub struct NormalizationError {
124 pub reason: String,
126 pub input_len: usize,
128}
129
130impl std::fmt::Display for NormalizationError {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 write!(
133 f,
134 "normalization failed (input {} bytes): {}",
135 self.input_len, self.reason
136 )
137 }
138}
139
140impl std::error::Error for NormalizationError {}
141
142pub fn extract_json(raw: &str) -> Result<NormalizedOutput, NormalizationError> {
153 let trimmed = raw.trim();
154
155 if trimmed.is_empty() {
156 return Err(NormalizationError {
157 reason: "empty input".to_string(),
158 input_len: 0,
159 });
160 }
161
162 if let Some(body) = extract_fenced_json(trimmed) {
164 return Ok(NormalizedOutput {
165 json_body: body,
166 method: ExtractionMethod::FencedJson,
167 });
168 }
169
170 if let Some(body) = extract_generic_fence_json(trimmed) {
172 return Ok(NormalizedOutput {
173 json_body: body,
174 method: ExtractionMethod::GenericFence,
175 });
176 }
177
178 if trimmed.starts_with('{') || trimmed.starts_with('[') {
180 return Ok(NormalizedOutput {
181 json_body: trimmed.to_string(),
182 method: ExtractionMethod::DirectJson,
183 });
184 }
185
186 if let Some(body) = extract_embedded_json(trimmed) {
188 return Ok(NormalizedOutput {
189 json_body: body,
190 method: ExtractionMethod::EmbeddedJson,
191 });
192 }
193
194 Err(NormalizationError {
195 reason: "no JSON object or array found in response".to_string(),
196 input_len: raw.len(),
197 })
198}
199
200pub fn extract_and_deserialize<T: DeserializeOwned>(
202 raw: &str,
203) -> Result<(T, ExtractionMethod), NormalizationError> {
204 let output = extract_json(raw)?;
205 match serde_json::from_str::<T>(&output.json_body) {
206 Ok(value) => Ok((value, output.method)),
207 Err(e) => Err(NormalizationError {
208 reason: format!(
209 "JSON extracted via {} but deserialization failed: {}",
210 output.method, e
211 ),
212 input_len: raw.len(),
213 }),
214 }
215}
216
217fn extract_fenced_json(input: &str) -> Option<String> {
223 let marker = "```json";
224 let start_idx = input.find(marker)?;
225 let body_start = start_idx + marker.len();
226
227 let remaining = &input[body_start..];
229 let remaining = remaining.strip_prefix('\n').unwrap_or(remaining);
230
231 let end_offset = remaining.find("```")?;
232 let body = remaining[..end_offset].trim();
233 if body.is_empty() {
234 return None;
235 }
236 Some(body.to_string())
237}
238
239fn extract_generic_fence_json(input: &str) -> Option<String> {
241 let marker = "```";
242 let start_idx = input.find(marker)?;
243 let after_marker = start_idx + marker.len();
244
245 let remaining = &input[after_marker..];
247 let body_start = remaining.find('\n').map(|n| n + 1).unwrap_or(0);
248 let remaining = &remaining[body_start..];
249
250 let end_offset = remaining.find("```")?;
251 let body = remaining[..end_offset].trim();
252
253 if body.starts_with('{') || body.starts_with('[') {
255 Some(body.to_string())
256 } else {
257 None
258 }
259}
260
261fn extract_embedded_json(input: &str) -> Option<String> {
264 let open = input.find('{')?;
265 let mut depth = 0i32;
267 let mut in_string = false;
268 let mut escape_next = false;
269 let mut close = None;
270
271 for (i, ch) in input[open..].char_indices() {
272 if escape_next {
273 escape_next = false;
274 continue;
275 }
276 match ch {
277 '\\' if in_string => {
278 escape_next = true;
279 }
280 '"' => {
281 in_string = !in_string;
282 }
283 '{' if !in_string => {
284 depth += 1;
285 }
286 '}' if !in_string => {
287 depth -= 1;
288 if depth == 0 {
289 close = Some(open + i);
290 break;
291 }
292 }
293 _ => {}
294 }
295 }
296
297 let close = close?;
298 let body = &input[open..=close];
299 Some(body.to_string())
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
309 fn test_direct_json_object() {
310 let raw = r#"{"tasks": [{"id": "1"}]}"#;
311 let out = extract_json(raw).unwrap();
312 assert_eq!(out.method, ExtractionMethod::DirectJson);
313 assert_eq!(out.json_body, raw);
314 }
315
316 #[test]
317 fn test_direct_json_array() {
318 let raw = r#"[{"id": 1}]"#;
319 let out = extract_json(raw).unwrap();
320 assert_eq!(out.method, ExtractionMethod::DirectJson);
321 }
322
323 #[test]
324 fn test_fenced_json() {
325 let raw = "Here is the plan:\n```json\n{\"tasks\": []}\n```\nDone.";
326 let out = extract_json(raw).unwrap();
327 assert_eq!(out.method, ExtractionMethod::FencedJson);
328 assert_eq!(out.json_body, "{\"tasks\": []}");
329 }
330
331 #[test]
332 fn test_generic_fence_with_json() {
333 let raw = "Result:\n```\n{\"artifacts\": []}\n```";
334 let out = extract_json(raw).unwrap();
335 assert_eq!(out.method, ExtractionMethod::GenericFence);
336 assert_eq!(out.json_body, "{\"artifacts\": []}");
337 }
338
339 #[test]
340 fn test_generic_fence_with_language_hint() {
341 let raw = "```rust\nfn main() {}\n```";
342 let result = extract_json(raw);
346 if let Ok(out) = &result {
349 assert_ne!(out.method, ExtractionMethod::GenericFence);
350 }
351 }
352
353 #[test]
354 fn test_embedded_json_with_wrapper_text() {
355 let raw = "Sure! Here is the bundle:\n{\"artifacts\": [{\"path\": \"main.rs\", \"operation\": \"write\", \"content\": \"fn main() {}\"}]}\nLet me know if you need changes.";
356 let out = extract_json(raw).unwrap();
357 assert_eq!(out.method, ExtractionMethod::EmbeddedJson);
358 assert!(out.json_body.starts_with('{'));
359 assert!(out.json_body.ends_with('}'));
360 }
361
362 #[test]
363 fn test_embedded_json_with_nested_braces() {
364 let raw = "Plan: {\"a\": {\"b\": {\"c\": 1}}} end";
365 let out = extract_json(raw).unwrap();
366 assert_eq!(out.method, ExtractionMethod::EmbeddedJson);
367 assert_eq!(out.json_body, "{\"a\": {\"b\": {\"c\": 1}}}");
368 }
369
370 #[test]
371 fn test_embedded_json_with_strings_containing_braces() {
372 let raw = r#"Output: {"msg": "hello { world }"} done"#;
373 let out = extract_json(raw).unwrap();
374 assert_eq!(out.method, ExtractionMethod::EmbeddedJson);
375 assert_eq!(out.json_body, r#"{"msg": "hello { world }"}"#);
376 }
377
378 #[test]
379 fn test_empty_input() {
380 let result = extract_json("");
381 assert!(result.is_err());
382 }
383
384 #[test]
385 fn test_no_json_at_all() {
386 let result = extract_json("This is just a plain text response with no JSON.");
387 assert!(result.is_err());
388 }
389
390 #[test]
391 fn test_fenced_json_takes_priority_over_embedded() {
392 let raw = "Preamble {\"stray\": 1}\n```json\n{\"real\": 2}\n```";
393 let out = extract_json(raw).unwrap();
394 assert_eq!(out.method, ExtractionMethod::FencedJson);
395 assert_eq!(out.json_body, "{\"real\": 2}");
396 }
397
398 #[test]
401 fn test_extract_and_deserialize_ok() {
402 #[derive(serde::Deserialize)]
403 struct Simple {
404 value: i32,
405 }
406 let raw = "```json\n{\"value\": 42}\n```";
407 let (obj, method): (Simple, _) = extract_and_deserialize(raw).unwrap();
408 assert_eq!(obj.value, 42);
409 assert_eq!(method, ExtractionMethod::FencedJson);
410 }
411
412 #[test]
413 fn test_extract_and_deserialize_bad_schema() {
414 #[derive(Debug, serde::Deserialize)]
415 struct Strict {
416 #[allow(dead_code)]
417 required_field: String,
418 }
419 let raw = "{\"other\": 1}";
420 let result: Result<(Strict, _), _> = extract_and_deserialize(raw);
421 assert!(result.is_err());
422 let err = result.unwrap_err();
423 assert!(err.reason.contains("deserialization failed"));
424 }
425
426 #[test]
429 fn test_provider_family_classification() {
430 assert_eq!(
431 ProviderFamily::from_model_name("gpt-4o"),
432 ProviderFamily::OpenAI
433 );
434 assert_eq!(
435 ProviderFamily::from_model_name("claude-opus-4-20250514"),
436 ProviderFamily::Anthropic
437 );
438 assert_eq!(
439 ProviderFamily::from_model_name("gemini-2.5-pro"),
440 ProviderFamily::Gemini
441 );
442 assert_eq!(
443 ProviderFamily::from_model_name("deepseek-r1"),
444 ProviderFamily::DeepSeek
445 );
446 assert_eq!(
447 ProviderFamily::from_model_name("my-custom-model"),
448 ProviderFamily::Unknown
449 );
450 }
451
452 #[test]
453 fn test_extract_json_with_nested_code_fence() {
454 let raw = r#"
456Here is the plan I've created for you:
457
458```json
459{
460 "steps": [
461 {"id": "s1", "action": "create_file", "path": "src/lib.rs"},
462 {"id": "s2", "action": "run_tests", "path": "."}
463 ],
464 "description": "Create and verify a new library"
465}
466```
467
468Let me know if you'd like any changes.
469"#;
470 let output = extract_json(raw).unwrap();
471 assert_eq!(output.method, ExtractionMethod::FencedJson);
472 assert!(output.json_body.contains("create_file"));
473 assert!(output.json_body.contains("run_tests"));
474 }
475
476 #[test]
477 fn test_extract_and_deserialize_realistic_plan() {
478 #[derive(Debug, serde::Deserialize, PartialEq)]
479 struct Step {
480 id: String,
481 action: String,
482 }
483 #[derive(Debug, serde::Deserialize)]
484 struct Plan {
485 steps: Vec<Step>,
486 }
487
488 let raw = r#"Sure! ```json
489{"steps": [{"id": "1", "action": "lint"}, {"id": "2", "action": "test"}]}
490```"#;
491
492 let (plan, method): (Plan, _) = extract_and_deserialize(raw).unwrap();
493 assert_eq!(method, ExtractionMethod::FencedJson);
494 assert_eq!(plan.steps.len(), 2);
495 assert_eq!(plan.steps[0].action, "lint");
496 assert_eq!(plan.steps[1].action, "test");
497 }
498}