Skip to main content

cognis_core/output_parsers/
fixing.rs

1//! Self-correcting output parsers.
2//!
3//! [`OutputFixingParser`] wraps an inner [`OutputParser`] and, on parse
4//! failure, asks an LLM-like `Runnable<String, String>` to repair the
5//! malformed output, then re-parses once. [`RetryParser`] loops the
6//! repair-then-parse cycle up to `max_retries` times.
7//!
8//! Both parsers also implement `Runnable<String, T>` so they slot into
9//! a chain (`prompt | model | fixing_parser`).
10
11use std::marker::PhantomData;
12use std::sync::Arc;
13
14use async_trait::async_trait;
15
16use crate::output_parsers::OutputParser;
17use crate::runnable::{Runnable, RunnableConfig};
18use crate::{CognisError, Result};
19
20/// Single-shot fix: try the inner parser, and on failure ask `fixer`
21/// to rewrite the output, then parse the rewrite. If the second parse
22/// also fails, return the second error.
23pub struct OutputFixingParser<T, P> {
24    inner: P,
25    fixer: Arc<dyn Runnable<String, String>>,
26    _marker: PhantomData<fn() -> T>,
27}
28
29impl<T, P> OutputFixingParser<T, P>
30where
31    T: Send + 'static,
32    P: OutputParser<T>,
33{
34    /// Wrap `inner` with a single repair attempt via `fixer`.
35    pub fn new(inner: P, fixer: Arc<dyn Runnable<String, String>>) -> Self {
36        Self {
37            inner,
38            fixer,
39            _marker: PhantomData,
40        }
41    }
42
43    /// Async parse: try once, repair via `fixer` on failure, try once more.
44    pub async fn parse_with_fix(&self, text: &str) -> Result<T> {
45        match self.inner.parse(text) {
46            Ok(v) => Ok(v),
47            Err(parse_err) => {
48                let format_hint = self
49                    .inner
50                    .format_instructions()
51                    .unwrap_or_else(|| "Return only the requested format.".to_string());
52                let prompt = format!(
53                    "The previous output failed to parse with error:\n{parse_err}\n\n\
54                     Previous output:\n{text}\n\n\
55                     Format requirements:\n{format_hint}\n\n\
56                     Return a corrected version. Output ONLY the corrected content — no \
57                     explanations, no markdown fences."
58                );
59                let fixed = self.fixer.invoke(prompt, RunnableConfig::default()).await?;
60                self.inner.parse(&fixed)
61            }
62        }
63    }
64}
65
66impl<T, P> OutputParser<T> for OutputFixingParser<T, P>
67where
68    T: Send + 'static,
69    P: OutputParser<T>,
70{
71    fn parse(&self, text: &str) -> Result<T> {
72        // Sync path: cannot call the async fixer, so just defer.
73        self.inner.parse(text)
74    }
75    fn format_instructions(&self) -> Option<String> {
76        self.inner.format_instructions()
77    }
78}
79
80#[async_trait]
81impl<T, P> Runnable<String, T> for OutputFixingParser<T, P>
82where
83    T: Send + 'static,
84    P: OutputParser<T> + Send + Sync,
85{
86    async fn invoke(&self, input: String, _config: RunnableConfig) -> Result<T> {
87        self.parse_with_fix(&input).await
88    }
89    fn name(&self) -> &str {
90        "OutputFixingParser"
91    }
92}
93
94/// Looped fix: try the inner parser, ask `fixer` to rewrite on failure,
95/// repeat up to `max_retries` extra times. Returns the last parse error
96/// if every attempt fails.
97pub struct RetryParser<T, P> {
98    inner: P,
99    fixer: Arc<dyn Runnable<String, String>>,
100    max_retries: usize,
101    _marker: PhantomData<fn() -> T>,
102}
103
104impl<T, P> RetryParser<T, P>
105where
106    T: Send + 'static,
107    P: OutputParser<T>,
108{
109    /// Default to 3 repair attempts.
110    pub fn new(inner: P, fixer: Arc<dyn Runnable<String, String>>) -> Self {
111        Self::with_retries(inner, fixer, 3)
112    }
113
114    /// Customize the maximum number of repair attempts. `0` means no retries
115    /// — equivalent to calling `inner.parse(...)` directly.
116    pub fn with_retries(
117        inner: P,
118        fixer: Arc<dyn Runnable<String, String>>,
119        max_retries: usize,
120    ) -> Self {
121        Self {
122            inner,
123            fixer,
124            max_retries,
125            _marker: PhantomData,
126        }
127    }
128
129    /// Async parse with up to `max_retries` repair-then-parse cycles.
130    pub async fn parse_with_retries(&self, text: &str) -> Result<T> {
131        let mut current = text.to_string();
132        let mut last_err: Option<CognisError> = None;
133        for _ in 0..=self.max_retries {
134            match self.inner.parse(&current) {
135                Ok(v) => return Ok(v),
136                Err(e) => {
137                    last_err = Some(e);
138                    if self.max_retries == 0 {
139                        break;
140                    }
141                    let format_hint = self
142                        .inner
143                        .format_instructions()
144                        .unwrap_or_else(|| "Return only the requested format.".to_string());
145                    let prompt = format!(
146                        "Previous output failed to parse: {}\n\n\
147                         Previous output:\n{current}\n\n\
148                         Format requirements:\n{format_hint}\n\n\
149                         Return a corrected version. Output ONLY the corrected content.",
150                        last_err.as_ref().unwrap()
151                    );
152                    current = self.fixer.invoke(prompt, RunnableConfig::default()).await?;
153                }
154            }
155        }
156        Err(last_err
157            .unwrap_or_else(|| CognisError::Internal("RetryParser exhausted retries".into())))
158    }
159}
160
161impl<T, P> OutputParser<T> for RetryParser<T, P>
162where
163    T: Send + 'static,
164    P: OutputParser<T>,
165{
166    fn parse(&self, text: &str) -> Result<T> {
167        self.inner.parse(text)
168    }
169    fn format_instructions(&self) -> Option<String> {
170        self.inner.format_instructions()
171    }
172}
173
174#[async_trait]
175impl<T, P> Runnable<String, T> for RetryParser<T, P>
176where
177    T: Send + 'static,
178    P: OutputParser<T> + Send + Sync,
179{
180    async fn invoke(&self, input: String, _config: RunnableConfig) -> Result<T> {
181        self.parse_with_retries(&input).await
182    }
183    fn name(&self) -> &str {
184        "RetryParser"
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use std::sync::atomic::{AtomicUsize, Ordering};
191    use std::sync::Arc;
192
193    use super::*;
194    use crate::compose::lambda;
195    use crate::output_parsers::JsonParser;
196    use serde::Deserialize;
197
198    #[derive(Debug, Deserialize, PartialEq)]
199    struct Person {
200        name: String,
201        age: u32,
202    }
203
204    fn fixer_returns(value: &'static str) -> Arc<dyn Runnable<String, String>> {
205        let v = value.to_string();
206        Arc::new(lambda(move |_: String| {
207            let v = v.clone();
208            async move { Ok::<_, CognisError>(v) }
209        }))
210    }
211
212    #[tokio::test]
213    async fn fixing_parser_repairs_invalid_json() {
214        let parser = OutputFixingParser::new(
215            JsonParser::<Person>::new(),
216            fixer_returns(r#"{"name":"Ada","age":36}"#),
217        );
218        let bad = r#"{name: Ada, age: 36"#; // malformed
219        let p = parser.parse_with_fix(bad).await.unwrap();
220        assert_eq!(
221            p,
222            Person {
223                name: "Ada".into(),
224                age: 36
225            }
226        );
227    }
228
229    #[tokio::test]
230    async fn fixing_parser_passes_through_valid() {
231        let calls = Arc::new(AtomicUsize::new(0));
232        let calls2 = calls.clone();
233        let fixer: Arc<dyn Runnable<String, String>> = Arc::new(lambda(move |_: String| {
234            let c = calls2.clone();
235            async move {
236                c.fetch_add(1, Ordering::Relaxed);
237                Ok::<_, CognisError>(String::from(r#"{"name":"X","age":0}"#))
238            }
239        }));
240        let parser = OutputFixingParser::new(JsonParser::<Person>::new(), fixer);
241        let good = r#"{"name":"Bob","age":42}"#;
242        let p = parser.parse_with_fix(good).await.unwrap();
243        assert_eq!(
244            p,
245            Person {
246                name: "Bob".into(),
247                age: 42
248            }
249        );
250        assert_eq!(
251            calls.load(Ordering::Relaxed),
252            0,
253            "fixer must not be called for valid input"
254        );
255    }
256
257    #[tokio::test]
258    async fn retry_parser_loops_until_valid() {
259        let attempts = Arc::new(AtomicUsize::new(0));
260        let a = attempts.clone();
261        let fixer: Arc<dyn Runnable<String, String>> = Arc::new(lambda(move |_: String| {
262            let a = a.clone();
263            async move {
264                let n = a.fetch_add(1, Ordering::Relaxed);
265                Ok::<_, CognisError>(if n < 2 {
266                    "still invalid".into()
267                } else {
268                    r#"{"name":"Eve","age":29}"#.into()
269                })
270            }
271        }));
272        let parser = RetryParser::with_retries(JsonParser::<Person>::new(), fixer, 5);
273        let p = parser.parse_with_retries("garbage").await.unwrap();
274        assert_eq!(
275            p,
276            Person {
277                name: "Eve".into(),
278                age: 29
279            }
280        );
281        assert_eq!(attempts.load(Ordering::Relaxed), 3);
282    }
283
284    #[tokio::test]
285    async fn retry_parser_returns_last_error_after_exhaustion() {
286        let fixer = fixer_returns("still bad");
287        let parser = RetryParser::with_retries(JsonParser::<Person>::new(), fixer, 2);
288        let err = parser.parse_with_retries("garbage").await.unwrap_err();
289        // Should surface a parse error, not "exhausted retries".
290        assert!(
291            !err.to_string().contains("exhausted"),
292            "expected a real parse error, got: {err}"
293        );
294    }
295}