Skip to main content

cognis_core/output_parsers/
structured.rs

1//! Schema-driven structured output parser — accepts any
2//! `T: DeserializeOwned + JsonSchema`, builds an LLM prompt fragment from
3//! the schema, and parses the model's text response into `T`.
4//!
5//! Differs from [`super::JsonParser`] in two ways:
6//! 1. `format_instructions()` returns a *schema-aware* prompt fragment
7//!    (the JSON Schema rendered as a hint), not a generic "reply with
8//!    JSON" string.
9//! 2. Tolerates JSON embedded in surrounding prose by extracting the
10//!    largest balanced object/array; useful for chatty models.
11//!
12//! Customization: the parser is built from a [`StructuredOutputConfig`]
13//! that lets callers swap the prompt template and the JSON-extraction
14//! strategy without subclassing the parser type.
15
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use schemars::JsonSchema;
21use serde::de::DeserializeOwned;
22
23use crate::output_parsers::OutputParser;
24use crate::runnable::{Runnable, RunnableConfig};
25use crate::{CognisError, Result};
26
27/// Strategy for locating a JSON value inside a string that may also
28/// contain prose. The default is [`JsonExtraction::FirstBalanced`] which
29/// picks the first `{...}` or `[...]` substring whose braces balance.
30///
31/// Implement [`JsonExtractor`] for fully custom logic (e.g. parsing a
32/// specific markdown section, or stripping `<json>...</json>` tags).
33#[derive(Clone)]
34pub enum JsonExtraction {
35    /// First balanced `{...}` or `[...]` block. Tolerates leading prose,
36    /// markdown fences, and trailing apologies.
37    FirstBalanced,
38    /// Treat the entire trimmed input as JSON. Strict; useful when you've
39    /// instructed the model in the prompt to emit only JSON.
40    Strict,
41    /// User-supplied extractor. The closure receives the raw model text
42    /// and returns a `&str` slice that should be JSON.
43    Custom(Arc<dyn JsonExtractor>),
44}
45
46impl std::fmt::Debug for JsonExtraction {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Self::FirstBalanced => f.write_str("FirstBalanced"),
50            Self::Strict => f.write_str("Strict"),
51            Self::Custom(_) => f.write_str("Custom(<extractor>)"),
52        }
53    }
54}
55
56/// Object-safe extractor trait — implement to plug a custom JSON-locator.
57pub trait JsonExtractor: Send + Sync {
58    /// Locate a JSON substring inside `text`.
59    /// Returns the slice that should be passed to `serde_json::from_str`.
60    /// Return `None` to fall through to a `Serialization` error.
61    fn extract<'a>(&self, text: &'a str) -> Option<&'a str>;
62}
63
64/// Default extractor: first balanced `{...}` / `[...]` block.
65fn extract_first_balanced(text: &str) -> Option<&str> {
66    let bytes = text.as_bytes();
67    // Find the first `{` or `[`.
68    let start = bytes.iter().position(|&b| b == b'{' || b == b'[')?;
69    let open = bytes[start];
70    let close = if open == b'{' { b'}' } else { b']' };
71    let mut depth = 0i32;
72    let mut in_string = false;
73    let mut escaped = false;
74    for (i, &b) in bytes.iter().enumerate().skip(start) {
75        if in_string {
76            if escaped {
77                escaped = false;
78            } else if b == b'\\' {
79                escaped = true;
80            } else if b == b'"' {
81                in_string = false;
82            }
83            continue;
84        }
85        match b {
86            b'"' => in_string = true,
87            x if x == open => depth += 1,
88            x if x == close => {
89                depth -= 1;
90                if depth == 0 {
91                    return Some(&text[start..=i]);
92                }
93            }
94            _ => {}
95        }
96    }
97    None
98}
99
100/// Function passed to [`StructuredOutputConfig::with_format_template`].
101/// Receives the rendered JSON Schema as a string and returns the prompt
102/// fragment to embed.
103pub type FormatTemplate = Arc<dyn Fn(&str) -> String + Send + Sync>;
104
105/// Configuration knobs exposed to callers without forcing parser-subclass
106/// gymnastics. Defaults are tuned for general-purpose chat models.
107#[derive(Clone)]
108pub struct StructuredOutputConfig {
109    extraction: JsonExtraction,
110    format_template: Option<FormatTemplate>,
111}
112
113impl Default for StructuredOutputConfig {
114    fn default() -> Self {
115        Self {
116            extraction: JsonExtraction::FirstBalanced,
117            format_template: None,
118        }
119    }
120}
121
122impl std::fmt::Debug for StructuredOutputConfig {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        f.debug_struct("StructuredOutputConfig")
125            .field("extraction", &self.extraction)
126            .field("format_template", &self.format_template.is_some())
127            .finish()
128    }
129}
130
131impl StructuredOutputConfig {
132    /// Construct with default settings.
133    pub fn new() -> Self {
134        Self::default()
135    }
136
137    /// Override the JSON extraction strategy.
138    pub fn with_extraction(mut self, e: JsonExtraction) -> Self {
139        self.extraction = e;
140        self
141    }
142
143    /// Provide a custom prompt-fragment template. The closure receives
144    /// the JSON Schema rendered as a pretty-printed string and returns
145    /// the prompt text to embed.
146    pub fn with_format_template<F>(mut self, f: F) -> Self
147    where
148        F: Fn(&str) -> String + Send + Sync + 'static,
149    {
150        self.format_template = Some(Arc::new(f));
151        self
152    }
153}
154
155/// Schema-driven output parser.
156///
157/// Build a parser via:
158/// ```ignore
159/// use serde::Deserialize;
160/// use schemars::JsonSchema;
161/// use cognis_core::output_parsers::StructuredOutputParser;
162///
163/// #[derive(Deserialize, JsonSchema)]
164/// struct Answer { topic: String, summary: String }
165///
166/// let p: StructuredOutputParser<Answer> = StructuredOutputParser::new();
167/// let instructions = p.format_instructions().unwrap();
168/// // ... feed instructions into your prompt, then parse model output:
169/// let parsed: Answer = p.parse(model_output).unwrap();
170/// ```
171pub struct StructuredOutputParser<T> {
172    config: StructuredOutputConfig,
173    _t: PhantomData<fn() -> T>,
174}
175
176impl<T> Clone for StructuredOutputParser<T> {
177    fn clone(&self) -> Self {
178        Self {
179            config: self.config.clone(),
180            _t: PhantomData,
181        }
182    }
183}
184
185impl<T> Default for StructuredOutputParser<T> {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191impl<T> StructuredOutputParser<T> {
192    /// New parser with default config.
193    pub fn new() -> Self {
194        Self {
195            config: StructuredOutputConfig::default(),
196            _t: PhantomData,
197        }
198    }
199
200    /// New parser with caller-supplied config.
201    pub fn with_config(config: StructuredOutputConfig) -> Self {
202        Self {
203            config,
204            _t: PhantomData,
205        }
206    }
207
208    /// Borrow the active config (read-only).
209    pub fn config(&self) -> &StructuredOutputConfig {
210        &self.config
211    }
212}
213
214impl<T> StructuredOutputParser<T>
215where
216    T: JsonSchema,
217{
218    /// Render the JSON Schema for `T` as a pretty string. Used by
219    /// `format_instructions` and exposed for advanced callers that want
220    /// to embed the schema themselves.
221    pub fn schema_string(&self) -> String {
222        let schema = schemars::schema_for!(T);
223        serde_json::to_string_pretty(&schema).unwrap_or_else(|_| "{}".to_string())
224    }
225}
226
227impl<T> OutputParser<T> for StructuredOutputParser<T>
228where
229    T: DeserializeOwned + JsonSchema + Send + 'static,
230{
231    fn parse(&self, text: &str) -> Result<T> {
232        let trimmed = text.trim();
233        let candidate: &str = match &self.config.extraction {
234            JsonExtraction::Strict => trimmed,
235            JsonExtraction::FirstBalanced => extract_first_balanced(trimmed).ok_or_else(|| {
236                CognisError::Serialization(
237                    "structured parser: no balanced JSON object/array found in output".into(),
238                )
239            })?,
240            JsonExtraction::Custom(extractor) => extractor.extract(trimmed).ok_or_else(|| {
241                CognisError::Serialization(
242                    "structured parser: custom extractor returned None".into(),
243                )
244            })?,
245        };
246        serde_json::from_str(candidate)
247            .map_err(|e| CognisError::Serialization(format!("structured parser: deserialize: {e}")))
248    }
249
250    fn format_instructions(&self) -> Option<String> {
251        let schema = self.schema_string();
252        if let Some(tmpl) = &self.config.format_template {
253            return Some(tmpl(&schema));
254        }
255        Some(format!(
256            "Reply with a single JSON value that conforms to this schema. \
257             Do not include any prose, markdown fences, or commentary outside \
258             the JSON.\n\nSchema:\n{schema}"
259        ))
260    }
261}
262
263#[async_trait]
264impl<T> Runnable<String, T> for StructuredOutputParser<T>
265where
266    T: DeserializeOwned + JsonSchema + Send + 'static,
267{
268    async fn invoke(&self, input: String, _: RunnableConfig) -> Result<T> {
269        OutputParser::parse(self, &input)
270    }
271    fn name(&self) -> &str {
272        "StructuredOutputParser"
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use serde::Deserialize;
280
281    #[derive(Debug, Deserialize, JsonSchema, PartialEq)]
282    struct Answer {
283        topic: String,
284        steps: Vec<String>,
285    }
286
287    #[test]
288    fn parses_clean_json() {
289        let p: StructuredOutputParser<Answer> = StructuredOutputParser::new();
290        let out = p.parse(r#"{"topic":"rust","steps":["a","b"]}"#).unwrap();
291        assert_eq!(out.topic, "rust");
292        assert_eq!(out.steps, vec!["a".to_string(), "b".into()]);
293    }
294
295    #[test]
296    fn extracts_balanced_json_from_prose() {
297        let p: StructuredOutputParser<Answer> = StructuredOutputParser::new();
298        let text = r#"Sure! Here is the answer:
299{"topic":"rust","steps":["x"]}
300Hope that helps!"#;
301        let out = p.parse(text).unwrap();
302        assert_eq!(out.topic, "rust");
303    }
304
305    #[test]
306    fn handles_nested_braces() {
307        #[derive(Deserialize, JsonSchema)]
308        struct Wrap {
309            outer: serde_json::Value,
310        }
311        let p: StructuredOutputParser<Wrap> = StructuredOutputParser::new();
312        let text = r#"prelude {"outer":{"a":{"b":1}},"extra":"ignored"} suffix"#;
313        let out = p.parse(text).unwrap();
314        assert_eq!(out.outer["a"]["b"], 1);
315    }
316
317    #[test]
318    fn strict_mode_rejects_prose() {
319        let p: StructuredOutputParser<Answer> = StructuredOutputParser::with_config(
320            StructuredOutputConfig::new().with_extraction(JsonExtraction::Strict),
321        );
322        let err = p
323            .parse(r#"prelude {"topic":"x","steps":[]} suffix"#)
324            .unwrap_err();
325        assert!(matches!(err, CognisError::Serialization(_)));
326    }
327
328    #[test]
329    fn custom_extractor_used() {
330        struct TagExtractor;
331        impl JsonExtractor for TagExtractor {
332            fn extract<'a>(&self, text: &'a str) -> Option<&'a str> {
333                let start = text.find("<json>")? + "<json>".len();
334                let end = text.find("</json>")?;
335                Some(&text[start..end])
336            }
337        }
338        let cfg = StructuredOutputConfig::new()
339            .with_extraction(JsonExtraction::Custom(Arc::new(TagExtractor)));
340        let p: StructuredOutputParser<Answer> = StructuredOutputParser::with_config(cfg);
341        let out = p
342            .parse(r#"see <json>{"topic":"x","steps":[]}</json> done"#)
343            .unwrap();
344        assert_eq!(out.topic, "x");
345    }
346
347    #[test]
348    fn format_instructions_includes_schema() {
349        let p: StructuredOutputParser<Answer> = StructuredOutputParser::new();
350        let s = OutputParser::format_instructions(&p).unwrap();
351        assert!(s.contains("\"topic\""));
352        assert!(s.contains("\"steps\""));
353    }
354
355    #[test]
356    fn custom_format_template_is_used() {
357        let cfg = StructuredOutputConfig::new()
358            .with_format_template(|schema| format!("<custom>{schema}</custom>"));
359        let p: StructuredOutputParser<Answer> = StructuredOutputParser::with_config(cfg);
360        let s = OutputParser::format_instructions(&p).unwrap();
361        assert!(s.starts_with("<custom>"));
362        assert!(s.ends_with("</custom>"));
363    }
364
365    #[test]
366    fn invalid_json_returns_serialization_error() {
367        let p: StructuredOutputParser<Answer> = StructuredOutputParser::new();
368        let err = p.parse("plain text, no JSON here").unwrap_err();
369        assert!(matches!(err, CognisError::Serialization(_)));
370    }
371
372    #[test]
373    fn ignores_braces_inside_strings() {
374        let p: StructuredOutputParser<Answer> = StructuredOutputParser::new();
375        // Embedded `{` inside a string must not break the brace counter.
376        let out = p.parse(r#"{"topic":"a {nested} b","steps":[]}"#).unwrap();
377        assert_eq!(out.topic, "a {nested} b");
378    }
379}