Skip to main content

atomr_agents_parser/
auto_repair.rs

1//! Auto-repair wrappers.
2//!
3//! `OutputFixingParser` calls a "repair" model with the malformed
4//! output + the parser's format instructions. `RetryWithErrorParser`
5//! re-prompts with the original prompt + the failure message.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use atomr_agents_core::{AgentError, Result};
11
12use crate::Parser;
13
14#[async_trait]
15pub trait RepairModel: Send + Sync + 'static {
16    /// Given the original output and a hint (instructions or error),
17    /// produce a corrected raw string.
18    async fn repair(&self, original: &str, hint: &str) -> Result<String>;
19}
20
21pub struct OutputFixingParser<P, T>
22where
23    P: Parser<T> + 'static,
24    T: Send + 'static,
25{
26    pub inner: Arc<P>,
27    pub model: Arc<dyn RepairModel>,
28    pub max_attempts: u32,
29    _marker: std::marker::PhantomData<fn() -> T>,
30}
31
32impl<P, T> OutputFixingParser<P, T>
33where
34    P: Parser<T> + 'static,
35    T: Send + 'static,
36{
37    pub fn new(inner: P, model: Arc<dyn RepairModel>, max_attempts: u32) -> Self {
38        Self {
39            inner: Arc::new(inner),
40            model,
41            max_attempts,
42            _marker: std::marker::PhantomData,
43        }
44    }
45}
46
47#[async_trait]
48impl<P, T> Parser<T> for OutputFixingParser<P, T>
49where
50    P: Parser<T> + 'static,
51    T: Send + 'static,
52{
53    async fn parse(&self, raw: &str) -> Result<T> {
54        let mut last_err = None;
55        let mut current = raw.to_string();
56        let instructions = self.inner.format_instructions();
57        for _ in 0..self.max_attempts.max(1) {
58            match self.inner.parse(&current).await {
59                Ok(v) => return Ok(v),
60                Err(e) => {
61                    last_err = Some(e);
62                    let hint = format!(
63                        "Output below failed format instructions. Re-emit corrected output.\n\nFormat:\n{instructions}\n\nFailed output:\n{current}"
64                    );
65                    current = self.model.repair(&current, &hint).await?;
66                }
67            }
68        }
69        Err(last_err.unwrap_or_else(|| AgentError::Internal("repair exhausted".into())))
70    }
71    fn format_instructions(&self) -> String {
72        self.inner.format_instructions()
73    }
74}
75
76pub struct RetryWithErrorParser<P, T>
77where
78    P: Parser<T> + 'static,
79    T: Send + 'static,
80{
81    pub inner: Arc<P>,
82    pub model: Arc<dyn RepairModel>,
83    pub max_attempts: u32,
84    /// The original prompt; passed to the repair model on each retry.
85    pub original_prompt: String,
86    _marker: std::marker::PhantomData<fn() -> T>,
87}
88
89impl<P, T> RetryWithErrorParser<P, T>
90where
91    P: Parser<T> + 'static,
92    T: Send + 'static,
93{
94    pub fn new(
95        inner: P,
96        model: Arc<dyn RepairModel>,
97        max_attempts: u32,
98        original_prompt: impl Into<String>,
99    ) -> Self {
100        Self {
101            inner: Arc::new(inner),
102            model,
103            max_attempts,
104            original_prompt: original_prompt.into(),
105            _marker: std::marker::PhantomData,
106        }
107    }
108}
109
110#[async_trait]
111impl<P, T> Parser<T> for RetryWithErrorParser<P, T>
112where
113    P: Parser<T> + 'static,
114    T: Send + 'static,
115{
116    async fn parse(&self, raw: &str) -> Result<T> {
117        let mut current = raw.to_string();
118        let mut last_err = None;
119        for _ in 0..self.max_attempts.max(1) {
120            match self.inner.parse(&current).await {
121                Ok(v) => return Ok(v),
122                Err(e) => {
123                    let hint = format!(
124                        "Original prompt:\n{}\n\nError on previous output:\n{e}\n\nReply again, conforming to the prompt.",
125                        self.original_prompt
126                    );
127                    last_err = Some(e);
128                    current = self.model.repair(&current, &hint).await?;
129                }
130            }
131        }
132        Err(last_err.unwrap_or_else(|| AgentError::Internal("retry exhausted".into())))
133    }
134    fn format_instructions(&self) -> String {
135        self.inner.format_instructions()
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::basic::JsonParser;
143    use atomr_agents_core::Value;
144    use parking_lot::Mutex;
145
146    struct ScriptedRepair {
147        replies: Mutex<Vec<String>>,
148    }
149    #[async_trait]
150    impl RepairModel for ScriptedRepair {
151        async fn repair(&self, _original: &str, _hint: &str) -> Result<String> {
152            let mut g = self.replies.lock();
153            if g.is_empty() {
154                return Err(AgentError::Inference("no scripted reply".into()));
155            }
156            Ok(g.remove(0))
157        }
158    }
159
160    #[tokio::test]
161    async fn output_fixing_recovers_after_one_repair() {
162        let model = Arc::new(ScriptedRepair {
163            replies: Mutex::new(vec![r#"{"ok": true}"#.to_string()]),
164        });
165        let p: OutputFixingParser<JsonParser, Value> = OutputFixingParser::new(JsonParser, model, 3);
166        let v = p.parse("not json at all").await.unwrap();
167        assert_eq!(v, serde_json::json!({"ok": true}));
168    }
169
170    #[tokio::test]
171    async fn retry_with_error_re_prompts_with_failure() {
172        let model = Arc::new(ScriptedRepair {
173            replies: Mutex::new(vec!["still bad".into(), r#"{"ok": true}"#.to_string()]),
174        });
175        let p: RetryWithErrorParser<JsonParser, Value> =
176            RetryWithErrorParser::new(JsonParser, model, 5, "Reply with JSON.");
177        let v = p.parse("nope").await.unwrap();
178        assert_eq!(v, serde_json::json!({"ok": true}));
179    }
180}