atomr_agents_parser/
auto_repair.rs1use std::sync::Arc;
8
9use async_trait::async_trait;
10use atomr_agents_core::{AgentError, Result};
11
12use crate::Parser;
13
14#[async_trait]
15pub trait RepairModel: Send + Sync + 'static {
16 async fn repair(&self, original: &str, hint: &str) -> Result<String>;
19}
20
21pub struct OutputFixingParser<P, T>
22where
23 P: Parser<T> + 'static,
24 T: Send + 'static,
25{
26 pub inner: Arc<P>,
27 pub model: Arc<dyn RepairModel>,
28 pub max_attempts: u32,
29 _marker: std::marker::PhantomData<fn() -> T>,
30}
31
32impl<P, T> OutputFixingParser<P, T>
33where
34 P: Parser<T> + 'static,
35 T: Send + 'static,
36{
37 pub fn new(inner: P, model: Arc<dyn RepairModel>, max_attempts: u32) -> Self {
38 Self {
39 inner: Arc::new(inner),
40 model,
41 max_attempts,
42 _marker: std::marker::PhantomData,
43 }
44 }
45}
46
47#[async_trait]
48impl<P, T> Parser<T> for OutputFixingParser<P, T>
49where
50 P: Parser<T> + 'static,
51 T: Send + 'static,
52{
53 async fn parse(&self, raw: &str) -> Result<T> {
54 let mut last_err = None;
55 let mut current = raw.to_string();
56 let instructions = self.inner.format_instructions();
57 for _ in 0..self.max_attempts.max(1) {
58 match self.inner.parse(¤t).await {
59 Ok(v) => return Ok(v),
60 Err(e) => {
61 last_err = Some(e);
62 let hint = format!(
63 "Output below failed format instructions. Re-emit corrected output.\n\nFormat:\n{instructions}\n\nFailed output:\n{current}"
64 );
65 current = self.model.repair(¤t, &hint).await?;
66 }
67 }
68 }
69 Err(last_err.unwrap_or_else(|| AgentError::Internal("repair exhausted".into())))
70 }
71 fn format_instructions(&self) -> String {
72 self.inner.format_instructions()
73 }
74}
75
76pub struct RetryWithErrorParser<P, T>
77where
78 P: Parser<T> + 'static,
79 T: Send + 'static,
80{
81 pub inner: Arc<P>,
82 pub model: Arc<dyn RepairModel>,
83 pub max_attempts: u32,
84 pub original_prompt: String,
86 _marker: std::marker::PhantomData<fn() -> T>,
87}
88
89impl<P, T> RetryWithErrorParser<P, T>
90where
91 P: Parser<T> + 'static,
92 T: Send + 'static,
93{
94 pub fn new(
95 inner: P,
96 model: Arc<dyn RepairModel>,
97 max_attempts: u32,
98 original_prompt: impl Into<String>,
99 ) -> Self {
100 Self {
101 inner: Arc::new(inner),
102 model,
103 max_attempts,
104 original_prompt: original_prompt.into(),
105 _marker: std::marker::PhantomData,
106 }
107 }
108}
109
110#[async_trait]
111impl<P, T> Parser<T> for RetryWithErrorParser<P, T>
112where
113 P: Parser<T> + 'static,
114 T: Send + 'static,
115{
116 async fn parse(&self, raw: &str) -> Result<T> {
117 let mut current = raw.to_string();
118 let mut last_err = None;
119 for _ in 0..self.max_attempts.max(1) {
120 match self.inner.parse(¤t).await {
121 Ok(v) => return Ok(v),
122 Err(e) => {
123 let hint = format!(
124 "Original prompt:\n{}\n\nError on previous output:\n{e}\n\nReply again, conforming to the prompt.",
125 self.original_prompt
126 );
127 last_err = Some(e);
128 current = self.model.repair(¤t, &hint).await?;
129 }
130 }
131 }
132 Err(last_err.unwrap_or_else(|| AgentError::Internal("retry exhausted".into())))
133 }
134 fn format_instructions(&self) -> String {
135 self.inner.format_instructions()
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::basic::JsonParser;
143 use atomr_agents_core::Value;
144 use parking_lot::Mutex;
145
146 struct ScriptedRepair {
147 replies: Mutex<Vec<String>>,
148 }
149 #[async_trait]
150 impl RepairModel for ScriptedRepair {
151 async fn repair(&self, _original: &str, _hint: &str) -> Result<String> {
152 let mut g = self.replies.lock();
153 if g.is_empty() {
154 return Err(AgentError::Inference("no scripted reply".into()));
155 }
156 Ok(g.remove(0))
157 }
158 }
159
160 #[tokio::test]
161 async fn output_fixing_recovers_after_one_repair() {
162 let model = Arc::new(ScriptedRepair {
163 replies: Mutex::new(vec![r#"{"ok": true}"#.to_string()]),
164 });
165 let p: OutputFixingParser<JsonParser, Value> = OutputFixingParser::new(JsonParser, model, 3);
166 let v = p.parse("not json at all").await.unwrap();
167 assert_eq!(v, serde_json::json!({"ok": true}));
168 }
169
170 #[tokio::test]
171 async fn retry_with_error_re_prompts_with_failure() {
172 let model = Arc::new(ScriptedRepair {
173 replies: Mutex::new(vec!["still bad".into(), r#"{"ok": true}"#.to_string()]),
174 });
175 let p: RetryWithErrorParser<JsonParser, Value> =
176 RetryWithErrorParser::new(JsonParser, model, 5, "Reply with JSON.");
177 let v = p.parse("nope").await.unwrap();
178 assert_eq!(v, serde_json::json!({"ok": true}));
179 }
180}