cognis_core/output_parsers/
structured.rs1use std::marker::PhantomData;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use schemars::JsonSchema;
21use serde::de::DeserializeOwned;
22
23use crate::output_parsers::OutputParser;
24use crate::runnable::{Runnable, RunnableConfig};
25use crate::{CognisError, Result};
26
27#[derive(Clone)]
34pub enum JsonExtraction {
35 FirstBalanced,
38 Strict,
41 Custom(Arc<dyn JsonExtractor>),
44}
45
46impl std::fmt::Debug for JsonExtraction {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 Self::FirstBalanced => f.write_str("FirstBalanced"),
50 Self::Strict => f.write_str("Strict"),
51 Self::Custom(_) => f.write_str("Custom(<extractor>)"),
52 }
53 }
54}
55
56pub trait JsonExtractor: Send + Sync {
58 fn extract<'a>(&self, text: &'a str) -> Option<&'a str>;
62}
63
64fn extract_first_balanced(text: &str) -> Option<&str> {
66 let bytes = text.as_bytes();
67 let start = bytes.iter().position(|&b| b == b'{' || b == b'[')?;
69 let open = bytes[start];
70 let close = if open == b'{' { b'}' } else { b']' };
71 let mut depth = 0i32;
72 let mut in_string = false;
73 let mut escaped = false;
74 for (i, &b) in bytes.iter().enumerate().skip(start) {
75 if in_string {
76 if escaped {
77 escaped = false;
78 } else if b == b'\\' {
79 escaped = true;
80 } else if b == b'"' {
81 in_string = false;
82 }
83 continue;
84 }
85 match b {
86 b'"' => in_string = true,
87 x if x == open => depth += 1,
88 x if x == close => {
89 depth -= 1;
90 if depth == 0 {
91 return Some(&text[start..=i]);
92 }
93 }
94 _ => {}
95 }
96 }
97 None
98}
99
100pub type FormatTemplate = Arc<dyn Fn(&str) -> String + Send + Sync>;
104
105#[derive(Clone)]
108pub struct StructuredOutputConfig {
109 extraction: JsonExtraction,
110 format_template: Option<FormatTemplate>,
111}
112
113impl Default for StructuredOutputConfig {
114 fn default() -> Self {
115 Self {
116 extraction: JsonExtraction::FirstBalanced,
117 format_template: None,
118 }
119 }
120}
121
122impl std::fmt::Debug for StructuredOutputConfig {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 f.debug_struct("StructuredOutputConfig")
125 .field("extraction", &self.extraction)
126 .field("format_template", &self.format_template.is_some())
127 .finish()
128 }
129}
130
131impl StructuredOutputConfig {
132 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn with_extraction(mut self, e: JsonExtraction) -> Self {
139 self.extraction = e;
140 self
141 }
142
143 pub fn with_format_template<F>(mut self, f: F) -> Self
147 where
148 F: Fn(&str) -> String + Send + Sync + 'static,
149 {
150 self.format_template = Some(Arc::new(f));
151 self
152 }
153}
154
155pub struct StructuredOutputParser<T> {
172 config: StructuredOutputConfig,
173 _t: PhantomData<fn() -> T>,
174}
175
176impl<T> Clone for StructuredOutputParser<T> {
177 fn clone(&self) -> Self {
178 Self {
179 config: self.config.clone(),
180 _t: PhantomData,
181 }
182 }
183}
184
185impl<T> Default for StructuredOutputParser<T> {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191impl<T> StructuredOutputParser<T> {
192 pub fn new() -> Self {
194 Self {
195 config: StructuredOutputConfig::default(),
196 _t: PhantomData,
197 }
198 }
199
200 pub fn with_config(config: StructuredOutputConfig) -> Self {
202 Self {
203 config,
204 _t: PhantomData,
205 }
206 }
207
208 pub fn config(&self) -> &StructuredOutputConfig {
210 &self.config
211 }
212}
213
214impl<T> StructuredOutputParser<T>
215where
216 T: JsonSchema,
217{
218 pub fn schema_string(&self) -> String {
222 let schema = schemars::schema_for!(T);
223 serde_json::to_string_pretty(&schema).unwrap_or_else(|_| "{}".to_string())
224 }
225}
226
227impl<T> OutputParser<T> for StructuredOutputParser<T>
228where
229 T: DeserializeOwned + JsonSchema + Send + 'static,
230{
231 fn parse(&self, text: &str) -> Result<T> {
232 let trimmed = text.trim();
233 let candidate: &str = match &self.config.extraction {
234 JsonExtraction::Strict => trimmed,
235 JsonExtraction::FirstBalanced => extract_first_balanced(trimmed).ok_or_else(|| {
236 CognisError::Serialization(
237 "structured parser: no balanced JSON object/array found in output".into(),
238 )
239 })?,
240 JsonExtraction::Custom(extractor) => extractor.extract(trimmed).ok_or_else(|| {
241 CognisError::Serialization(
242 "structured parser: custom extractor returned None".into(),
243 )
244 })?,
245 };
246 serde_json::from_str(candidate)
247 .map_err(|e| CognisError::Serialization(format!("structured parser: deserialize: {e}")))
248 }
249
250 fn format_instructions(&self) -> Option<String> {
251 let schema = self.schema_string();
252 if let Some(tmpl) = &self.config.format_template {
253 return Some(tmpl(&schema));
254 }
255 Some(format!(
256 "Reply with a single JSON value that conforms to this schema. \
257 Do not include any prose, markdown fences, or commentary outside \
258 the JSON.\n\nSchema:\n{schema}"
259 ))
260 }
261}
262
263#[async_trait]
264impl<T> Runnable<String, T> for StructuredOutputParser<T>
265where
266 T: DeserializeOwned + JsonSchema + Send + 'static,
267{
268 async fn invoke(&self, input: String, _: RunnableConfig) -> Result<T> {
269 OutputParser::parse(self, &input)
270 }
271 fn name(&self) -> &str {
272 "StructuredOutputParser"
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use serde::Deserialize;
280
281 #[derive(Debug, Deserialize, JsonSchema, PartialEq)]
282 struct Answer {
283 topic: String,
284 steps: Vec<String>,
285 }
286
287 #[test]
288 fn parses_clean_json() {
289 let p: StructuredOutputParser<Answer> = StructuredOutputParser::new();
290 let out = p.parse(r#"{"topic":"rust","steps":["a","b"]}"#).unwrap();
291 assert_eq!(out.topic, "rust");
292 assert_eq!(out.steps, vec!["a".to_string(), "b".into()]);
293 }
294
295 #[test]
296 fn extracts_balanced_json_from_prose() {
297 let p: StructuredOutputParser<Answer> = StructuredOutputParser::new();
298 let text = r#"Sure! Here is the answer:
299{"topic":"rust","steps":["x"]}
300Hope that helps!"#;
301 let out = p.parse(text).unwrap();
302 assert_eq!(out.topic, "rust");
303 }
304
305 #[test]
306 fn handles_nested_braces() {
307 #[derive(Deserialize, JsonSchema)]
308 struct Wrap {
309 outer: serde_json::Value,
310 }
311 let p: StructuredOutputParser<Wrap> = StructuredOutputParser::new();
312 let text = r#"prelude {"outer":{"a":{"b":1}},"extra":"ignored"} suffix"#;
313 let out = p.parse(text).unwrap();
314 assert_eq!(out.outer["a"]["b"], 1);
315 }
316
317 #[test]
318 fn strict_mode_rejects_prose() {
319 let p: StructuredOutputParser<Answer> = StructuredOutputParser::with_config(
320 StructuredOutputConfig::new().with_extraction(JsonExtraction::Strict),
321 );
322 let err = p
323 .parse(r#"prelude {"topic":"x","steps":[]} suffix"#)
324 .unwrap_err();
325 assert!(matches!(err, CognisError::Serialization(_)));
326 }
327
328 #[test]
329 fn custom_extractor_used() {
330 struct TagExtractor;
331 impl JsonExtractor for TagExtractor {
332 fn extract<'a>(&self, text: &'a str) -> Option<&'a str> {
333 let start = text.find("<json>")? + "<json>".len();
334 let end = text.find("</json>")?;
335 Some(&text[start..end])
336 }
337 }
338 let cfg = StructuredOutputConfig::new()
339 .with_extraction(JsonExtraction::Custom(Arc::new(TagExtractor)));
340 let p: StructuredOutputParser<Answer> = StructuredOutputParser::with_config(cfg);
341 let out = p
342 .parse(r#"see <json>{"topic":"x","steps":[]}</json> done"#)
343 .unwrap();
344 assert_eq!(out.topic, "x");
345 }
346
347 #[test]
348 fn format_instructions_includes_schema() {
349 let p: StructuredOutputParser<Answer> = StructuredOutputParser::new();
350 let s = OutputParser::format_instructions(&p).unwrap();
351 assert!(s.contains("\"topic\""));
352 assert!(s.contains("\"steps\""));
353 }
354
355 #[test]
356 fn custom_format_template_is_used() {
357 let cfg = StructuredOutputConfig::new()
358 .with_format_template(|schema| format!("<custom>{schema}</custom>"));
359 let p: StructuredOutputParser<Answer> = StructuredOutputParser::with_config(cfg);
360 let s = OutputParser::format_instructions(&p).unwrap();
361 assert!(s.starts_with("<custom>"));
362 assert!(s.ends_with("</custom>"));
363 }
364
365 #[test]
366 fn invalid_json_returns_serialization_error() {
367 let p: StructuredOutputParser<Answer> = StructuredOutputParser::new();
368 let err = p.parse("plain text, no JSON here").unwrap_err();
369 assert!(matches!(err, CognisError::Serialization(_)));
370 }
371
372 #[test]
373 fn ignores_braces_inside_strings() {
374 let p: StructuredOutputParser<Answer> = StructuredOutputParser::new();
375 let out = p.parse(r#"{"topic":"a {nested} b","steps":[]}"#).unwrap();
377 assert_eq!(out.topic, "a {nested} b");
378 }
379}