use std::sync::Arc;
use async_trait::async_trait;
use atomr_agents_core::{AgentError, Result};
use crate::Parser;
#[async_trait]
pub trait RepairModel: Send + Sync + 'static {
async fn repair(&self, original: &str, hint: &str) -> Result<String>;
}
pub struct OutputFixingParser<P, T>
where
P: Parser<T> + 'static,
T: Send + 'static,
{
pub inner: Arc<P>,
pub model: Arc<dyn RepairModel>,
pub max_attempts: u32,
_marker: std::marker::PhantomData<fn() -> T>,
}
impl<P, T> OutputFixingParser<P, T>
where
P: Parser<T> + 'static,
T: Send + 'static,
{
pub fn new(inner: P, model: Arc<dyn RepairModel>, max_attempts: u32) -> Self {
Self {
inner: Arc::new(inner),
model,
max_attempts,
_marker: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<P, T> Parser<T> for OutputFixingParser<P, T>
where
P: Parser<T> + 'static,
T: Send + 'static,
{
async fn parse(&self, raw: &str) -> Result<T> {
let mut last_err = None;
let mut current = raw.to_string();
let instructions = self.inner.format_instructions();
for _ in 0..self.max_attempts.max(1) {
match self.inner.parse(¤t).await {
Ok(v) => return Ok(v),
Err(e) => {
last_err = Some(e);
let hint = format!(
"Output below failed format instructions. Re-emit corrected output.\n\nFormat:\n{instructions}\n\nFailed output:\n{current}"
);
current = self.model.repair(¤t, &hint).await?;
}
}
}
Err(last_err.unwrap_or_else(|| AgentError::Internal("repair exhausted".into())))
}
fn format_instructions(&self) -> String {
self.inner.format_instructions()
}
}
pub struct RetryWithErrorParser<P, T>
where
P: Parser<T> + 'static,
T: Send + 'static,
{
pub inner: Arc<P>,
pub model: Arc<dyn RepairModel>,
pub max_attempts: u32,
pub original_prompt: String,
_marker: std::marker::PhantomData<fn() -> T>,
}
impl<P, T> RetryWithErrorParser<P, T>
where
P: Parser<T> + 'static,
T: Send + 'static,
{
pub fn new(
inner: P,
model: Arc<dyn RepairModel>,
max_attempts: u32,
original_prompt: impl Into<String>,
) -> Self {
Self {
inner: Arc::new(inner),
model,
max_attempts,
original_prompt: original_prompt.into(),
_marker: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<P, T> Parser<T> for RetryWithErrorParser<P, T>
where
P: Parser<T> + 'static,
T: Send + 'static,
{
async fn parse(&self, raw: &str) -> Result<T> {
let mut current = raw.to_string();
let mut last_err = None;
for _ in 0..self.max_attempts.max(1) {
match self.inner.parse(¤t).await {
Ok(v) => return Ok(v),
Err(e) => {
let hint = format!(
"Original prompt:\n{}\n\nError on previous output:\n{e}\n\nReply again, conforming to the prompt.",
self.original_prompt
);
last_err = Some(e);
current = self.model.repair(¤t, &hint).await?;
}
}
}
Err(last_err.unwrap_or_else(|| AgentError::Internal("retry exhausted".into())))
}
fn format_instructions(&self) -> String {
self.inner.format_instructions()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::basic::JsonParser;
use atomr_agents_core::Value;
use parking_lot::Mutex;
struct ScriptedRepair {
replies: Mutex<Vec<String>>,
}
#[async_trait]
impl RepairModel for ScriptedRepair {
async fn repair(&self, _original: &str, _hint: &str) -> Result<String> {
let mut g = self.replies.lock();
if g.is_empty() {
return Err(AgentError::Inference("no scripted reply".into()));
}
Ok(g.remove(0))
}
}
#[tokio::test]
async fn output_fixing_recovers_after_one_repair() {
let model = Arc::new(ScriptedRepair {
replies: Mutex::new(vec![r#"{"ok": true}"#.to_string()]),
});
let p: OutputFixingParser<JsonParser, Value> = OutputFixingParser::new(JsonParser, model, 3);
let v = p.parse("not json at all").await.unwrap();
assert_eq!(v, serde_json::json!({"ok": true}));
}
#[tokio::test]
async fn retry_with_error_re_prompts_with_failure() {
let model = Arc::new(ScriptedRepair {
replies: Mutex::new(vec!["still bad".into(), r#"{"ok": true}"#.to_string()]),
});
let p: RetryWithErrorParser<JsonParser, Value> =
RetryWithErrorParser::new(JsonParser, model, 5, "Reply with JSON.");
let v = p.parse("nope").await.unwrap();
assert_eq!(v, serde_json::json!({"ok": true}));
}
}