Skip to main content

brainwires_core/
output_parser.rs

1//! Structured output parsing for LLM responses
2//!
3//! Provides parsers that extract structured data from raw LLM text output.
4//! Supports JSON extraction, regex-based parsing, and retry-on-invalid patterns.
5//!
6//! # Example
7//!
8//! ```rust
9//! use brainwires_core::output_parser::{JsonOutputParser, OutputParser};
10//! use serde::Deserialize;
11//!
12//! #[derive(Deserialize)]
13//! struct Review {
14//!     sentiment: String,
15//!     score: f32,
16//! }
17//!
18//! let parser = JsonOutputParser::<Review>::new();
19//! let raw = r#"Here's my analysis: {"sentiment": "positive", "score": 0.9}"#;
20//! let review = parser.parse(raw).unwrap();
21//! assert_eq!(review.sentiment, "positive");
22//! ```
23
24use anyhow::{Context, Result};
25use serde::de::DeserializeOwned;
26use std::marker::PhantomData;
27
28/// Trait for parsing structured output from LLM text responses.
29pub trait OutputParser: Send + Sync {
30    /// The output type produced by this parser.
31    type Output;
32
33    /// Parse the raw LLM response text into structured output.
34    fn parse(&self, text: &str) -> Result<Self::Output>;
35
36    /// Return format instructions to inject into the prompt.
37    ///
38    /// These instructions tell the LLM how to format its response so this
39    /// parser can extract structured data from it.
40    fn format_instructions(&self) -> String;
41}
42
43/// Extracts JSON from LLM responses and deserializes into `T`.
44///
45/// Handles common LLM quirks:
46/// - JSON wrapped in markdown code fences
47/// - JSON embedded in surrounding prose
48/// - Leading/trailing whitespace
49pub struct JsonOutputParser<T> {
50    _phantom: PhantomData<T>,
51}
52
53impl<T> JsonOutputParser<T> {
54    /// Create a new JSON output parser.
55    pub fn new() -> Self {
56        Self {
57            _phantom: PhantomData,
58        }
59    }
60}
61
62impl<T> Default for JsonOutputParser<T> {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl<T: DeserializeOwned + Send + Sync> OutputParser for JsonOutputParser<T> {
69    type Output = T;
70
71    fn parse(&self, text: &str) -> Result<T> {
72        let json_str = extract_json(text).context("No JSON found in LLM response")?;
73        serde_json::from_str(&json_str).context("Failed to parse JSON from LLM response")
74    }
75
76    fn format_instructions(&self) -> String {
77        "Respond with valid JSON only. Do not include any other text before or after the JSON."
78            .to_string()
79    }
80}
81
82/// Extracts a list of items from a JSON array in the LLM response.
83pub struct JsonListParser<T> {
84    _phantom: PhantomData<T>,
85}
86
87impl<T> JsonListParser<T> {
88    /// Create a new JSON list parser.
89    pub fn new() -> Self {
90        Self {
91            _phantom: PhantomData,
92        }
93    }
94}
95
96impl<T> Default for JsonListParser<T> {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102impl<T: DeserializeOwned + Send + Sync> OutputParser for JsonListParser<T> {
103    type Output = Vec<T>;
104
105    fn parse(&self, text: &str) -> Result<Vec<T>> {
106        let json_str = extract_json(text).context("No JSON array found in LLM response")?;
107        serde_json::from_str(&json_str).context("Failed to parse JSON array from LLM response")
108    }
109
110    fn format_instructions(&self) -> String {
111        "Respond with a valid JSON array only. Do not include any other text.".to_string()
112    }
113}
114
115/// Parses LLM output using a regex pattern with named capture groups.
116pub struct RegexOutputParser {
117    pattern: regex::Regex,
118}
119
120impl RegexOutputParser {
121    /// Create a new regex parser.
122    ///
123    /// The pattern should use named capture groups like `(?P<name>...)`.
124    pub fn new(pattern: &str) -> Result<Self> {
125        let regex = regex::Regex::new(pattern).context("Invalid regex pattern")?;
126        Ok(Self { pattern: regex })
127    }
128}
129
130impl OutputParser for RegexOutputParser {
131    type Output = std::collections::HashMap<String, String>;
132
133    fn parse(&self, text: &str) -> Result<Self::Output> {
134        let caps = self
135            .pattern
136            .captures(text)
137            .context("Regex pattern did not match LLM output")?;
138
139        let mut result = std::collections::HashMap::new();
140        for name in self.pattern.capture_names().flatten() {
141            if let Some(m) = caps.name(name) {
142                result.insert(name.to_string(), m.as_str().to_string());
143            }
144        }
145        Ok(result)
146    }
147
148    fn format_instructions(&self) -> String {
149        format!(
150            "Format your response to match this pattern: {}",
151            self.pattern.as_str()
152        )
153    }
154}
155
156/// Extract JSON from text that may contain markdown fences or surrounding prose.
157fn extract_json(text: &str) -> Option<String> {
158    let trimmed = text.trim();
159
160    // Try direct parse first
161    if (trimmed.starts_with('{') && trimmed.ends_with('}'))
162        || (trimmed.starts_with('[') && trimmed.ends_with(']'))
163    {
164        return Some(trimmed.to_string());
165    }
166
167    // Try markdown code fence: ```json ... ``` or ``` ... ```
168    if let Some(start) = trimmed.find("```") {
169        let after_fence = &trimmed[start + 3..];
170        // Skip optional language tag
171        let content_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
172        let content = &after_fence[content_start..];
173        if let Some(end) = content.find("```") {
174            let json_str = content[..end].trim();
175            if !json_str.is_empty() {
176                return Some(json_str.to_string());
177            }
178        }
179    }
180
181    // Try to find first { or [ and match to last } or ]
182    let obj_start = trimmed.find('{');
183    let arr_start = trimmed.find('[');
184
185    let start_idx = match (obj_start, arr_start) {
186        (Some(o), Some(a)) => Some(o.min(a)),
187        (Some(o), None) => Some(o),
188        (None, Some(a)) => Some(a),
189        (None, None) => None,
190    }?;
191
192    let close_char = if trimmed.as_bytes()[start_idx] == b'{' {
193        '}'
194    } else {
195        ']'
196    };
197
198    let end_idx = trimmed.rfind(close_char)?;
199    if end_idx > start_idx {
200        Some(trimmed[start_idx..=end_idx].to_string())
201    } else {
202        None
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use serde::Deserialize;
210
211    #[derive(Debug, Deserialize, PartialEq)]
212    struct TestStruct {
213        name: String,
214        value: i32,
215    }
216
217    #[test]
218    fn test_json_parser_clean() {
219        let parser = JsonOutputParser::<TestStruct>::new();
220        let result = parser.parse(r#"{"name": "test", "value": 42}"#).unwrap();
221        assert_eq!(result.name, "test");
222        assert_eq!(result.value, 42);
223    }
224
225    #[test]
226    fn test_json_parser_with_prose() {
227        let parser = JsonOutputParser::<TestStruct>::new();
228        let input = r#"Here is the result: {"name": "test", "value": 42} Hope that helps!"#;
229        let result = parser.parse(input).unwrap();
230        assert_eq!(result.name, "test");
231        assert_eq!(result.value, 42);
232    }
233
234    #[test]
235    fn test_json_parser_with_code_fence() {
236        let parser = JsonOutputParser::<TestStruct>::new();
237        let input = "Here's the JSON:\n```json\n{\"name\": \"test\", \"value\": 42}\n```";
238        let result = parser.parse(input).unwrap();
239        assert_eq!(result.name, "test");
240    }
241
242    #[test]
243    fn test_json_list_parser() {
244        let parser = JsonListParser::<TestStruct>::new();
245        let input = r#"[{"name": "a", "value": 1}, {"name": "b", "value": 2}]"#;
246        let result = parser.parse(input).unwrap();
247        assert_eq!(result.len(), 2);
248        assert_eq!(result[0].name, "a");
249        assert_eq!(result[1].name, "b");
250    }
251
252    #[test]
253    fn test_regex_parser() {
254        let parser =
255            RegexOutputParser::new(r"sentiment: (?P<sentiment>\w+), score: (?P<score>[\d.]+)")
256                .unwrap();
257        let result = parser
258            .parse("The sentiment: positive, score: 0.95 overall")
259            .unwrap();
260        assert_eq!(result["sentiment"], "positive");
261        assert_eq!(result["score"], "0.95");
262    }
263
264    #[test]
265    fn test_json_parser_no_json() {
266        let parser = JsonOutputParser::<TestStruct>::new();
267        assert!(parser.parse("no json here at all").is_err());
268    }
269
270    #[test]
271    fn test_format_instructions() {
272        let parser = JsonOutputParser::<TestStruct>::new();
273        let instructions = parser.format_instructions();
274        assert!(instructions.contains("JSON"));
275    }
276
277    #[test]
278    fn test_extract_json_array_in_prose() {
279        let input = r#"Here are the items: [1, 2, 3] done."#;
280        let result = extract_json(input).unwrap();
281        assert_eq!(result, "[1, 2, 3]");
282    }
283}