cognis_core/output_parsers/
fixing.rs1use std::marker::PhantomData;
12use std::sync::Arc;
13
14use async_trait::async_trait;
15
16use crate::output_parsers::OutputParser;
17use crate::runnable::{Runnable, RunnableConfig};
18use crate::{CognisError, Result};
19
20pub struct OutputFixingParser<T, P> {
24 inner: P,
25 fixer: Arc<dyn Runnable<String, String>>,
26 _marker: PhantomData<fn() -> T>,
27}
28
29impl<T, P> OutputFixingParser<T, P>
30where
31 T: Send + 'static,
32 P: OutputParser<T>,
33{
34 pub fn new(inner: P, fixer: Arc<dyn Runnable<String, String>>) -> Self {
36 Self {
37 inner,
38 fixer,
39 _marker: PhantomData,
40 }
41 }
42
43 pub async fn parse_with_fix(&self, text: &str) -> Result<T> {
45 match self.inner.parse(text) {
46 Ok(v) => Ok(v),
47 Err(parse_err) => {
48 let format_hint = self
49 .inner
50 .format_instructions()
51 .unwrap_or_else(|| "Return only the requested format.".to_string());
52 let prompt = format!(
53 "The previous output failed to parse with error:\n{parse_err}\n\n\
54 Previous output:\n{text}\n\n\
55 Format requirements:\n{format_hint}\n\n\
56 Return a corrected version. Output ONLY the corrected content — no \
57 explanations, no markdown fences."
58 );
59 let fixed = self.fixer.invoke(prompt, RunnableConfig::default()).await?;
60 self.inner.parse(&fixed)
61 }
62 }
63 }
64}
65
66impl<T, P> OutputParser<T> for OutputFixingParser<T, P>
67where
68 T: Send + 'static,
69 P: OutputParser<T>,
70{
71 fn parse(&self, text: &str) -> Result<T> {
72 self.inner.parse(text)
74 }
75 fn format_instructions(&self) -> Option<String> {
76 self.inner.format_instructions()
77 }
78}
79
80#[async_trait]
81impl<T, P> Runnable<String, T> for OutputFixingParser<T, P>
82where
83 T: Send + 'static,
84 P: OutputParser<T> + Send + Sync,
85{
86 async fn invoke(&self, input: String, _config: RunnableConfig) -> Result<T> {
87 self.parse_with_fix(&input).await
88 }
89 fn name(&self) -> &str {
90 "OutputFixingParser"
91 }
92}
93
94pub struct RetryParser<T, P> {
98 inner: P,
99 fixer: Arc<dyn Runnable<String, String>>,
100 max_retries: usize,
101 _marker: PhantomData<fn() -> T>,
102}
103
104impl<T, P> RetryParser<T, P>
105where
106 T: Send + 'static,
107 P: OutputParser<T>,
108{
109 pub fn new(inner: P, fixer: Arc<dyn Runnable<String, String>>) -> Self {
111 Self::with_retries(inner, fixer, 3)
112 }
113
114 pub fn with_retries(
117 inner: P,
118 fixer: Arc<dyn Runnable<String, String>>,
119 max_retries: usize,
120 ) -> Self {
121 Self {
122 inner,
123 fixer,
124 max_retries,
125 _marker: PhantomData,
126 }
127 }
128
129 pub async fn parse_with_retries(&self, text: &str) -> Result<T> {
131 let mut current = text.to_string();
132 let mut last_err: Option<CognisError> = None;
133 for _ in 0..=self.max_retries {
134 match self.inner.parse(¤t) {
135 Ok(v) => return Ok(v),
136 Err(e) => {
137 last_err = Some(e);
138 if self.max_retries == 0 {
139 break;
140 }
141 let format_hint = self
142 .inner
143 .format_instructions()
144 .unwrap_or_else(|| "Return only the requested format.".to_string());
145 let prompt = format!(
146 "Previous output failed to parse: {}\n\n\
147 Previous output:\n{current}\n\n\
148 Format requirements:\n{format_hint}\n\n\
149 Return a corrected version. Output ONLY the corrected content.",
150 last_err.as_ref().unwrap()
151 );
152 current = self.fixer.invoke(prompt, RunnableConfig::default()).await?;
153 }
154 }
155 }
156 Err(last_err
157 .unwrap_or_else(|| CognisError::Internal("RetryParser exhausted retries".into())))
158 }
159}
160
161impl<T, P> OutputParser<T> for RetryParser<T, P>
162where
163 T: Send + 'static,
164 P: OutputParser<T>,
165{
166 fn parse(&self, text: &str) -> Result<T> {
167 self.inner.parse(text)
168 }
169 fn format_instructions(&self) -> Option<String> {
170 self.inner.format_instructions()
171 }
172}
173
174#[async_trait]
175impl<T, P> Runnable<String, T> for RetryParser<T, P>
176where
177 T: Send + 'static,
178 P: OutputParser<T> + Send + Sync,
179{
180 async fn invoke(&self, input: String, _config: RunnableConfig) -> Result<T> {
181 self.parse_with_retries(&input).await
182 }
183 fn name(&self) -> &str {
184 "RetryParser"
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use std::sync::atomic::{AtomicUsize, Ordering};
191 use std::sync::Arc;
192
193 use super::*;
194 use crate::compose::lambda;
195 use crate::output_parsers::JsonParser;
196 use serde::Deserialize;
197
198 #[derive(Debug, Deserialize, PartialEq)]
199 struct Person {
200 name: String,
201 age: u32,
202 }
203
204 fn fixer_returns(value: &'static str) -> Arc<dyn Runnable<String, String>> {
205 let v = value.to_string();
206 Arc::new(lambda(move |_: String| {
207 let v = v.clone();
208 async move { Ok::<_, CognisError>(v) }
209 }))
210 }
211
212 #[tokio::test]
213 async fn fixing_parser_repairs_invalid_json() {
214 let parser = OutputFixingParser::new(
215 JsonParser::<Person>::new(),
216 fixer_returns(r#"{"name":"Ada","age":36}"#),
217 );
218 let bad = r#"{name: Ada, age: 36"#; let p = parser.parse_with_fix(bad).await.unwrap();
220 assert_eq!(
221 p,
222 Person {
223 name: "Ada".into(),
224 age: 36
225 }
226 );
227 }
228
229 #[tokio::test]
230 async fn fixing_parser_passes_through_valid() {
231 let calls = Arc::new(AtomicUsize::new(0));
232 let calls2 = calls.clone();
233 let fixer: Arc<dyn Runnable<String, String>> = Arc::new(lambda(move |_: String| {
234 let c = calls2.clone();
235 async move {
236 c.fetch_add(1, Ordering::Relaxed);
237 Ok::<_, CognisError>(String::from(r#"{"name":"X","age":0}"#))
238 }
239 }));
240 let parser = OutputFixingParser::new(JsonParser::<Person>::new(), fixer);
241 let good = r#"{"name":"Bob","age":42}"#;
242 let p = parser.parse_with_fix(good).await.unwrap();
243 assert_eq!(
244 p,
245 Person {
246 name: "Bob".into(),
247 age: 42
248 }
249 );
250 assert_eq!(
251 calls.load(Ordering::Relaxed),
252 0,
253 "fixer must not be called for valid input"
254 );
255 }
256
257 #[tokio::test]
258 async fn retry_parser_loops_until_valid() {
259 let attempts = Arc::new(AtomicUsize::new(0));
260 let a = attempts.clone();
261 let fixer: Arc<dyn Runnable<String, String>> = Arc::new(lambda(move |_: String| {
262 let a = a.clone();
263 async move {
264 let n = a.fetch_add(1, Ordering::Relaxed);
265 Ok::<_, CognisError>(if n < 2 {
266 "still invalid".into()
267 } else {
268 r#"{"name":"Eve","age":29}"#.into()
269 })
270 }
271 }));
272 let parser = RetryParser::with_retries(JsonParser::<Person>::new(), fixer, 5);
273 let p = parser.parse_with_retries("garbage").await.unwrap();
274 assert_eq!(
275 p,
276 Person {
277 name: "Eve".into(),
278 age: 29
279 }
280 );
281 assert_eq!(attempts.load(Ordering::Relaxed), 3);
282 }
283
284 #[tokio::test]
285 async fn retry_parser_returns_last_error_after_exhaustion() {
286 let fixer = fixer_returns("still bad");
287 let parser = RetryParser::with_retries(JsonParser::<Person>::new(), fixer, 2);
288 let err = parser.parse_with_retries("garbage").await.unwrap_err();
289 assert!(
291 !err.to_string().contains("exhausted"),
292 "expected a real parse error, got: {err}"
293 );
294 }
295}