Skip to main content

perspt_core/
normalize.rs

1//! Provider-Neutral Output Normalization
2//!
3//! PSP-5 Phase 4: Extracts structured content (JSON objects, JSON arrays) from
4//! raw LLM responses regardless of provider-specific formatting quirks.
5//!
6//! Supported extraction strategies (tried in order):
7//! 1. Fenced JSON code block: ```json ... ```
8//! 2. Generic fenced code block: ``` ... ``` containing JSON
9//! 3. Direct JSON: response body starts with `{` or `[`
10//! 4. Embedded JSON: first `{` to last matching `}` in wrapper text
11//!
12//! The module is provider-agnostic by design. Provider family classification
13//! is available for diagnostics and telemetry but does not change extraction
14//! behavior.
15
16use serde::de::DeserializeOwned;
17
18/// Provider family for diagnostics and telemetry.
19///
20/// Does not affect extraction semantics — all providers go through the same
21/// normalization pipeline.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ProviderFamily {
24    OpenAI,
25    Anthropic,
26    Gemini,
27    Groq,
28    Cohere,
29    XAI,
30    DeepSeek,
31    Ollama,
32    Unknown,
33}
34
35impl ProviderFamily {
36    /// Classify a provider family from a model name string.
37    ///
38    /// Uses prefix heuristics; returns `Unknown` when the model name does not
39    /// match any known pattern.
40    pub fn from_model_name(model: &str) -> Self {
41        let lower = model.to_lowercase();
42        if lower.starts_with("gpt-")
43            || lower.starts_with("o1-")
44            || lower.starts_with("o3-")
45            || lower.starts_with("o4-")
46            || lower.contains("openai")
47        {
48            ProviderFamily::OpenAI
49        } else if lower.starts_with("claude") || lower.contains("anthropic") {
50            ProviderFamily::Anthropic
51        } else if lower.starts_with("gemini") || lower.contains("google") {
52            ProviderFamily::Gemini
53        } else if lower.contains("groq")
54            || lower.starts_with("llama")
55            || lower.starts_with("mixtral")
56        {
57            ProviderFamily::Groq
58        } else if lower.starts_with("command") || lower.contains("cohere") {
59            ProviderFamily::Cohere
60        } else if lower.starts_with("grok") || lower.contains("xai") {
61            ProviderFamily::XAI
62        } else if lower.starts_with("deepseek") {
63            ProviderFamily::DeepSeek
64        } else if lower.contains("ollama") {
65            ProviderFamily::Ollama
66        } else {
67            ProviderFamily::Unknown
68        }
69    }
70}
71
72impl std::fmt::Display for ProviderFamily {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        match self {
75            ProviderFamily::OpenAI => write!(f, "openai"),
76            ProviderFamily::Anthropic => write!(f, "anthropic"),
77            ProviderFamily::Gemini => write!(f, "gemini"),
78            ProviderFamily::Groq => write!(f, "groq"),
79            ProviderFamily::Cohere => write!(f, "cohere"),
80            ProviderFamily::XAI => write!(f, "xai"),
81            ProviderFamily::DeepSeek => write!(f, "deepseek"),
82            ProviderFamily::Ollama => write!(f, "ollama"),
83            ProviderFamily::Unknown => write!(f, "unknown"),
84        }
85    }
86}
87
88/// Which extraction strategy succeeded.
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum ExtractionMethod {
91    /// Found inside a ```json ... ``` fence.
92    FencedJson,
93    /// Found inside a generic ``` ... ``` fence containing JSON.
94    GenericFence,
95    /// Response body started directly with `{` or `[`.
96    DirectJson,
97    /// Extracted from first `{` to last balanced `}` in wrapper text.
98    EmbeddedJson,
99}
100
101impl std::fmt::Display for ExtractionMethod {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        match self {
104            ExtractionMethod::FencedJson => write!(f, "fenced_json"),
105            ExtractionMethod::GenericFence => write!(f, "generic_fence"),
106            ExtractionMethod::DirectJson => write!(f, "direct_json"),
107            ExtractionMethod::EmbeddedJson => write!(f, "embedded_json"),
108        }
109    }
110}
111
112/// Result of a successful normalization.
113#[derive(Debug, Clone)]
114pub struct NormalizedOutput {
115    /// The extracted JSON body (trimmed, ready for `serde_json::from_str`).
116    pub json_body: String,
117    /// How the JSON was extracted.
118    pub method: ExtractionMethod,
119}
120
121/// Error returned when normalization cannot extract a JSON body.
122#[derive(Debug, Clone)]
123pub struct NormalizationError {
124    /// Human-readable reason.
125    pub reason: String,
126    /// Byte length of the raw input that was inspected.
127    pub input_len: usize,
128}
129
130impl std::fmt::Display for NormalizationError {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        write!(
133            f,
134            "normalization failed (input {} bytes): {}",
135            self.input_len, self.reason
136        )
137    }
138}
139
140impl std::error::Error for NormalizationError {}
141
142/// Extract a JSON body from a raw LLM response.
143///
144/// Tries extraction strategies in order of specificity:
145/// 1. Fenced JSON (`\`\`\`json`)
146/// 2. Generic fence (`\`\`\``) whose content parses as JSON
147/// 3. Direct JSON (trimmed input starts with `{` or `[`)
148/// 4. Embedded JSON (first `{` to last balanced `}`)
149///
150/// Returns the extracted body and the method used, or an error if no JSON
151/// could be found.
152pub fn extract_json(raw: &str) -> Result<NormalizedOutput, NormalizationError> {
153    let trimmed = raw.trim();
154
155    if trimmed.is_empty() {
156        return Err(NormalizationError {
157            reason: "empty input".to_string(),
158            input_len: 0,
159        });
160    }
161
162    // Strategy 1: fenced JSON code block
163    if let Some(body) = extract_fenced_json(trimmed) {
164        return Ok(NormalizedOutput {
165            json_body: body,
166            method: ExtractionMethod::FencedJson,
167        });
168    }
169
170    // Strategy 2: generic fenced code block containing JSON
171    if let Some(body) = extract_generic_fence_json(trimmed) {
172        return Ok(NormalizedOutput {
173            json_body: body,
174            method: ExtractionMethod::GenericFence,
175        });
176    }
177
178    // Strategy 3: direct JSON
179    if trimmed.starts_with('{') || trimmed.starts_with('[') {
180        return Ok(NormalizedOutput {
181            json_body: trimmed.to_string(),
182            method: ExtractionMethod::DirectJson,
183        });
184    }
185
186    // Strategy 4: embedded JSON via balanced brace matching
187    if let Some(body) = extract_embedded_json(trimmed) {
188        return Ok(NormalizedOutput {
189            json_body: body,
190            method: ExtractionMethod::EmbeddedJson,
191        });
192    }
193
194    Err(NormalizationError {
195        reason: "no JSON object or array found in response".to_string(),
196        input_len: raw.len(),
197    })
198}
199
200/// Convenience: extract JSON and deserialize into `T` in one step.
201pub fn extract_and_deserialize<T: DeserializeOwned>(
202    raw: &str,
203) -> Result<(T, ExtractionMethod), NormalizationError> {
204    let output = extract_json(raw)?;
205    match serde_json::from_str::<T>(&output.json_body) {
206        Ok(value) => Ok((value, output.method)),
207        Err(e) => Err(NormalizationError {
208            reason: format!(
209                "JSON extracted via {} but deserialization failed: {}",
210                output.method, e
211            ),
212            input_len: raw.len(),
213        }),
214    }
215}
216
217// ---------------------------------------------------------------------------
218// Internal extraction helpers
219// ---------------------------------------------------------------------------
220
221/// Extract content from a ```json ... ``` fence.
222fn extract_fenced_json(input: &str) -> Option<String> {
223    let marker = "```json";
224    let start_idx = input.find(marker)?;
225    let body_start = start_idx + marker.len();
226
227    // Skip optional whitespace/newline after ```json
228    let remaining = &input[body_start..];
229    let remaining = remaining.strip_prefix('\n').unwrap_or(remaining);
230
231    let end_offset = remaining.find("```")?;
232    let body = remaining[..end_offset].trim();
233    if body.is_empty() {
234        return None;
235    }
236    Some(body.to_string())
237}
238
239/// Extract content from a generic ``` ... ``` fence that looks like JSON.
240fn extract_generic_fence_json(input: &str) -> Option<String> {
241    let marker = "```";
242    let start_idx = input.find(marker)?;
243    let after_marker = start_idx + marker.len();
244
245    // Skip language identifier if present (anything until the next newline)
246    let remaining = &input[after_marker..];
247    let body_start = remaining.find('\n').map(|n| n + 1).unwrap_or(0);
248    let remaining = &remaining[body_start..];
249
250    let end_offset = remaining.find("```")?;
251    let body = remaining[..end_offset].trim();
252
253    // Only return if it plausibly starts with JSON
254    if body.starts_with('{') || body.starts_with('[') {
255        Some(body.to_string())
256    } else {
257        None
258    }
259}
260
261/// Extract the outermost balanced `{ ... }` from text that may have wrapper
262/// prose before and/or after the JSON object.
263fn extract_embedded_json(input: &str) -> Option<String> {
264    let open = input.find('{')?;
265    // Walk forward with a brace‐depth counter to find the matching close
266    let mut depth = 0i32;
267    let mut in_string = false;
268    let mut escape_next = false;
269    let mut close = None;
270
271    for (i, ch) in input[open..].char_indices() {
272        if escape_next {
273            escape_next = false;
274            continue;
275        }
276        match ch {
277            '\\' if in_string => {
278                escape_next = true;
279            }
280            '"' => {
281                in_string = !in_string;
282            }
283            '{' if !in_string => {
284                depth += 1;
285            }
286            '}' if !in_string => {
287                depth -= 1;
288                if depth == 0 {
289                    close = Some(open + i);
290                    break;
291                }
292            }
293            _ => {}
294        }
295    }
296
297    let close = close?;
298    let body = &input[open..=close];
299    Some(body.to_string())
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    // -- extract_json --------------------------------------------------------
307
308    #[test]
309    fn test_direct_json_object() {
310        let raw = r#"{"tasks": [{"id": "1"}]}"#;
311        let out = extract_json(raw).unwrap();
312        assert_eq!(out.method, ExtractionMethod::DirectJson);
313        assert_eq!(out.json_body, raw);
314    }
315
316    #[test]
317    fn test_direct_json_array() {
318        let raw = r#"[{"id": 1}]"#;
319        let out = extract_json(raw).unwrap();
320        assert_eq!(out.method, ExtractionMethod::DirectJson);
321    }
322
323    #[test]
324    fn test_fenced_json() {
325        let raw = "Here is the plan:\n```json\n{\"tasks\": []}\n```\nDone.";
326        let out = extract_json(raw).unwrap();
327        assert_eq!(out.method, ExtractionMethod::FencedJson);
328        assert_eq!(out.json_body, "{\"tasks\": []}");
329    }
330
331    #[test]
332    fn test_generic_fence_with_json() {
333        let raw = "Result:\n```\n{\"artifacts\": []}\n```";
334        let out = extract_json(raw).unwrap();
335        assert_eq!(out.method, ExtractionMethod::GenericFence);
336        assert_eq!(out.json_body, "{\"artifacts\": []}");
337    }
338
339    #[test]
340    fn test_generic_fence_with_language_hint() {
341        let raw = "```rust\nfn main() {}\n```";
342        // Not JSON — should fall through to embedded, which also won't match a valid JSON object
343        // because the braces are inside a Rust function, not a JSON root.
344        // Expect failure.
345        let result = extract_json(raw);
346        // It may extract the embedded braces; the important thing is that
347        // generic_fence_json correctly rejected non-JSON content.
348        if let Ok(out) = &result {
349            assert_ne!(out.method, ExtractionMethod::GenericFence);
350        }
351    }
352
353    #[test]
354    fn test_embedded_json_with_wrapper_text() {
355        let raw = "Sure! Here is the bundle:\n{\"artifacts\": [{\"path\": \"main.rs\", \"operation\": \"write\", \"content\": \"fn main() {}\"}]}\nLet me know if you need changes.";
356        let out = extract_json(raw).unwrap();
357        assert_eq!(out.method, ExtractionMethod::EmbeddedJson);
358        assert!(out.json_body.starts_with('{'));
359        assert!(out.json_body.ends_with('}'));
360    }
361
362    #[test]
363    fn test_embedded_json_with_nested_braces() {
364        let raw = "Plan: {\"a\": {\"b\": {\"c\": 1}}} end";
365        let out = extract_json(raw).unwrap();
366        assert_eq!(out.method, ExtractionMethod::EmbeddedJson);
367        assert_eq!(out.json_body, "{\"a\": {\"b\": {\"c\": 1}}}");
368    }
369
370    #[test]
371    fn test_embedded_json_with_strings_containing_braces() {
372        let raw = r#"Output: {"msg": "hello { world }"} done"#;
373        let out = extract_json(raw).unwrap();
374        assert_eq!(out.method, ExtractionMethod::EmbeddedJson);
375        assert_eq!(out.json_body, r#"{"msg": "hello { world }"}"#);
376    }
377
378    #[test]
379    fn test_empty_input() {
380        let result = extract_json("");
381        assert!(result.is_err());
382    }
383
384    #[test]
385    fn test_no_json_at_all() {
386        let result = extract_json("This is just a plain text response with no JSON.");
387        assert!(result.is_err());
388    }
389
390    #[test]
391    fn test_fenced_json_takes_priority_over_embedded() {
392        let raw = "Preamble {\"stray\": 1}\n```json\n{\"real\": 2}\n```";
393        let out = extract_json(raw).unwrap();
394        assert_eq!(out.method, ExtractionMethod::FencedJson);
395        assert_eq!(out.json_body, "{\"real\": 2}");
396    }
397
398    // -- extract_and_deserialize ---------------------------------------------
399
400    #[test]
401    fn test_extract_and_deserialize_ok() {
402        #[derive(serde::Deserialize)]
403        struct Simple {
404            value: i32,
405        }
406        let raw = "```json\n{\"value\": 42}\n```";
407        let (obj, method): (Simple, _) = extract_and_deserialize(raw).unwrap();
408        assert_eq!(obj.value, 42);
409        assert_eq!(method, ExtractionMethod::FencedJson);
410    }
411
412    #[test]
413    fn test_extract_and_deserialize_bad_schema() {
414        #[derive(Debug, serde::Deserialize)]
415        struct Strict {
416            #[allow(dead_code)]
417            required_field: String,
418        }
419        let raw = "{\"other\": 1}";
420        let result: Result<(Strict, _), _> = extract_and_deserialize(raw);
421        assert!(result.is_err());
422        let err = result.unwrap_err();
423        assert!(err.reason.contains("deserialization failed"));
424    }
425
426    // -- ProviderFamily ------------------------------------------------------
427
428    #[test]
429    fn test_provider_family_classification() {
430        assert_eq!(
431            ProviderFamily::from_model_name("gpt-4o"),
432            ProviderFamily::OpenAI
433        );
434        assert_eq!(
435            ProviderFamily::from_model_name("claude-opus-4-20250514"),
436            ProviderFamily::Anthropic
437        );
438        assert_eq!(
439            ProviderFamily::from_model_name("gemini-2.5-pro"),
440            ProviderFamily::Gemini
441        );
442        assert_eq!(
443            ProviderFamily::from_model_name("deepseek-r1"),
444            ProviderFamily::DeepSeek
445        );
446        assert_eq!(
447            ProviderFamily::from_model_name("my-custom-model"),
448            ProviderFamily::Unknown
449        );
450    }
451
452    #[test]
453    fn test_extract_json_with_nested_code_fence() {
454        // LLMs often wrap JSON in markdown code fences with extra prose
455        let raw = r#"
456Here is the plan I've created for you:
457
458```json
459{
460  "steps": [
461    {"id": "s1", "action": "create_file", "path": "src/lib.rs"},
462    {"id": "s2", "action": "run_tests", "path": "."}
463  ],
464  "description": "Create and verify a new library"
465}
466```
467
468Let me know if you'd like any changes.
469"#;
470        let output = extract_json(raw).unwrap();
471        assert_eq!(output.method, ExtractionMethod::FencedJson);
472        assert!(output.json_body.contains("create_file"));
473        assert!(output.json_body.contains("run_tests"));
474    }
475
476    #[test]
477    fn test_extract_and_deserialize_realistic_plan() {
478        #[derive(Debug, serde::Deserialize, PartialEq)]
479        struct Step {
480            id: String,
481            action: String,
482        }
483        #[derive(Debug, serde::Deserialize)]
484        struct Plan {
485            steps: Vec<Step>,
486        }
487
488        let raw = r#"Sure! ```json
489{"steps": [{"id": "1", "action": "lint"}, {"id": "2", "action": "test"}]}
490```"#;
491
492        let (plan, method): (Plan, _) = extract_and_deserialize(raw).unwrap();
493        assert_eq!(method, ExtractionMethod::FencedJson);
494        assert_eq!(plan.steps.len(), 2);
495        assert_eq!(plan.steps[0].action, "lint");
496        assert_eq!(plan.steps[1].action, "test");
497    }
498}