brainwires_core/
output_parser.rs1use anyhow::{Context, Result};
25use serde::de::DeserializeOwned;
26use std::marker::PhantomData;
27
28pub trait OutputParser: Send + Sync {
30 type Output;
32
33 fn parse(&self, text: &str) -> Result<Self::Output>;
35
36 fn format_instructions(&self) -> String;
41}
42
43pub struct JsonOutputParser<T> {
50 _phantom: PhantomData<T>,
51}
52
53impl<T> JsonOutputParser<T> {
54 pub fn new() -> Self {
56 Self {
57 _phantom: PhantomData,
58 }
59 }
60}
61
62impl<T> Default for JsonOutputParser<T> {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl<T: DeserializeOwned + Send + Sync> OutputParser for JsonOutputParser<T> {
69 type Output = T;
70
71 fn parse(&self, text: &str) -> Result<T> {
72 let json_str = extract_json(text).context("No JSON found in LLM response")?;
73 serde_json::from_str(&json_str).context("Failed to parse JSON from LLM response")
74 }
75
76 fn format_instructions(&self) -> String {
77 "Respond with valid JSON only. Do not include any other text before or after the JSON."
78 .to_string()
79 }
80}
81
82pub struct JsonListParser<T> {
84 _phantom: PhantomData<T>,
85}
86
87impl<T> JsonListParser<T> {
88 pub fn new() -> Self {
90 Self {
91 _phantom: PhantomData,
92 }
93 }
94}
95
96impl<T> Default for JsonListParser<T> {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102impl<T: DeserializeOwned + Send + Sync> OutputParser for JsonListParser<T> {
103 type Output = Vec<T>;
104
105 fn parse(&self, text: &str) -> Result<Vec<T>> {
106 let json_str = extract_json(text).context("No JSON array found in LLM response")?;
107 serde_json::from_str(&json_str).context("Failed to parse JSON array from LLM response")
108 }
109
110 fn format_instructions(&self) -> String {
111 "Respond with a valid JSON array only. Do not include any other text.".to_string()
112 }
113}
114
115pub struct RegexOutputParser {
117 pattern: regex::Regex,
118}
119
120impl RegexOutputParser {
121 pub fn new(pattern: &str) -> Result<Self> {
125 let regex = regex::Regex::new(pattern).context("Invalid regex pattern")?;
126 Ok(Self { pattern: regex })
127 }
128}
129
130impl OutputParser for RegexOutputParser {
131 type Output = std::collections::HashMap<String, String>;
132
133 fn parse(&self, text: &str) -> Result<Self::Output> {
134 let caps = self
135 .pattern
136 .captures(text)
137 .context("Regex pattern did not match LLM output")?;
138
139 let mut result = std::collections::HashMap::new();
140 for name in self.pattern.capture_names().flatten() {
141 if let Some(m) = caps.name(name) {
142 result.insert(name.to_string(), m.as_str().to_string());
143 }
144 }
145 Ok(result)
146 }
147
148 fn format_instructions(&self) -> String {
149 format!(
150 "Format your response to match this pattern: {}",
151 self.pattern.as_str()
152 )
153 }
154}
155
156fn extract_json(text: &str) -> Option<String> {
158 let trimmed = text.trim();
159
160 if (trimmed.starts_with('{') && trimmed.ends_with('}'))
162 || (trimmed.starts_with('[') && trimmed.ends_with(']'))
163 {
164 return Some(trimmed.to_string());
165 }
166
167 if let Some(start) = trimmed.find("```") {
169 let after_fence = &trimmed[start + 3..];
170 let content_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
172 let content = &after_fence[content_start..];
173 if let Some(end) = content.find("```") {
174 let json_str = content[..end].trim();
175 if !json_str.is_empty() {
176 return Some(json_str.to_string());
177 }
178 }
179 }
180
181 let obj_start = trimmed.find('{');
183 let arr_start = trimmed.find('[');
184
185 let start_idx = match (obj_start, arr_start) {
186 (Some(o), Some(a)) => Some(o.min(a)),
187 (Some(o), None) => Some(o),
188 (None, Some(a)) => Some(a),
189 (None, None) => None,
190 }?;
191
192 let close_char = if trimmed.as_bytes()[start_idx] == b'{' {
193 '}'
194 } else {
195 ']'
196 };
197
198 let end_idx = trimmed.rfind(close_char)?;
199 if end_idx > start_idx {
200 Some(trimmed[start_idx..=end_idx].to_string())
201 } else {
202 None
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use serde::Deserialize;
210
211 #[derive(Debug, Deserialize, PartialEq)]
212 struct TestStruct {
213 name: String,
214 value: i32,
215 }
216
217 #[test]
218 fn test_json_parser_clean() {
219 let parser = JsonOutputParser::<TestStruct>::new();
220 let result = parser.parse(r#"{"name": "test", "value": 42}"#).unwrap();
221 assert_eq!(result.name, "test");
222 assert_eq!(result.value, 42);
223 }
224
225 #[test]
226 fn test_json_parser_with_prose() {
227 let parser = JsonOutputParser::<TestStruct>::new();
228 let input = r#"Here is the result: {"name": "test", "value": 42} Hope that helps!"#;
229 let result = parser.parse(input).unwrap();
230 assert_eq!(result.name, "test");
231 assert_eq!(result.value, 42);
232 }
233
234 #[test]
235 fn test_json_parser_with_code_fence() {
236 let parser = JsonOutputParser::<TestStruct>::new();
237 let input = "Here's the JSON:\n```json\n{\"name\": \"test\", \"value\": 42}\n```";
238 let result = parser.parse(input).unwrap();
239 assert_eq!(result.name, "test");
240 }
241
242 #[test]
243 fn test_json_list_parser() {
244 let parser = JsonListParser::<TestStruct>::new();
245 let input = r#"[{"name": "a", "value": 1}, {"name": "b", "value": 2}]"#;
246 let result = parser.parse(input).unwrap();
247 assert_eq!(result.len(), 2);
248 assert_eq!(result[0].name, "a");
249 assert_eq!(result[1].name, "b");
250 }
251
252 #[test]
253 fn test_regex_parser() {
254 let parser =
255 RegexOutputParser::new(r"sentiment: (?P<sentiment>\w+), score: (?P<score>[\d.]+)")
256 .unwrap();
257 let result = parser
258 .parse("The sentiment: positive, score: 0.95 overall")
259 .unwrap();
260 assert_eq!(result["sentiment"], "positive");
261 assert_eq!(result["score"], "0.95");
262 }
263
264 #[test]
265 fn test_json_parser_no_json() {
266 let parser = JsonOutputParser::<TestStruct>::new();
267 assert!(parser.parse("no json here at all").is_err());
268 }
269
270 #[test]
271 fn test_format_instructions() {
272 let parser = JsonOutputParser::<TestStruct>::new();
273 let instructions = parser.format_instructions();
274 assert!(instructions.contains("JSON"));
275 }
276
277 #[test]
278 fn test_extract_json_array_in_prose() {
279 let input = r#"Here are the items: [1, 2, 3] done."#;
280 let result = extract_json(input).unwrap();
281 assert_eq!(result, "[1, 2, 3]");
282 }
283}