Skip to main content

synaptic_parsers/
retry_parser.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatModel, ChatRequest, Message, RunnableConfig, SynapticError};
5use synaptic_runnables::Runnable;
6
7/// A parser that uses an LLM to fix outputs that fail to parse,
8/// including the original prompt context for better correction.
9///
10/// Wraps an inner `Runnable<String, O>`. If the inner parser fails,
11/// sends the original prompt, the completion, and the error to the LLM
12/// and retries parsing.
13pub struct RetryOutputParser<O: Send + Sync + 'static> {
14    inner: Box<dyn Runnable<String, O>>,
15    llm: Arc<dyn ChatModel>,
16    prompt: String,
17    max_retries: usize,
18}
19
20impl<O: Send + Sync + 'static> RetryOutputParser<O> {
21    /// Create a new `RetryOutputParser` wrapping the given inner parser, LLM,
22    /// and original prompt that generated the output.
23    /// Defaults to 1 retry attempt.
24    pub fn new(
25        inner: Box<dyn Runnable<String, O>>,
26        llm: Arc<dyn ChatModel>,
27        prompt: impl Into<String>,
28    ) -> Self {
29        Self {
30            inner,
31            llm,
32            prompt: prompt.into(),
33            max_retries: 1,
34        }
35    }
36
37    /// Set the maximum number of retry attempts.
38    pub fn with_max_retries(mut self, n: usize) -> Self {
39        self.max_retries = n;
40        self
41    }
42}
43
44#[async_trait]
45impl<O: Send + Sync + 'static> Runnable<String, O> for RetryOutputParser<O> {
46    async fn invoke(&self, input: String, config: &RunnableConfig) -> Result<O, SynapticError> {
47        // First attempt with the original input.
48        match self.inner.invoke(input.clone(), config).await {
49            Ok(value) => return Ok(value),
50            Err(first_err) => {
51                let mut last_err = first_err;
52                let mut current_input = input;
53
54                for _ in 0..self.max_retries {
55                    let retry_prompt = format!(
56                        "Prompt:\n{}\n\nCompletion:\n{}\n\nError:\n{}\n\nPlease provide a corrected completion that will parse successfully.",
57                        self.prompt, current_input, last_err
58                    );
59
60                    let request = ChatRequest::new(vec![
61                        Message::system("You are a helpful assistant that fixes parsing errors."),
62                        Message::human(retry_prompt),
63                    ]);
64
65                    let response = self.llm.chat(request).await?;
66                    let fixed = response.message.content().to_string();
67
68                    match self.inner.invoke(fixed.clone(), config).await {
69                        Ok(value) => return Ok(value),
70                        Err(e) => {
71                            last_err = e;
72                            current_input = fixed;
73                        }
74                    }
75                }
76
77                Err(last_err)
78            }
79        }
80    }
81}