llm/
validated_llm.rs

1//! A module providing validation capabilities for LLM responses through a wrapper implementation.
2//!
3//! This module enables adding custom validation logic to any LLM provider by wrapping it in a
4//! `ValidatedLLM` struct. The wrapper will validate responses and retry failed attempts with
5//! feedback to help guide the model toward producing valid output.
6//!
7//! # Example
8//!
9//! ```no_run
10//! use llm::builder::{LLMBuilder, LLMBackend};
11//!
12//! let llm = LLMBuilder::new()
13//!     .backend(LLMBackend::OpenAI)
14//!     .validator(|response| {
15//!         if response.contains("unsafe content") {
16//!             Err("Response contains unsafe content".to_string())
17//!         } else {
18//!             Ok(())
19//!         }
20//!     })
21//!     .validator_attempts(3)
22//!     .build()
23//!     .unwrap();
24//! ```
25
26use async_trait::async_trait;
27
28use crate::chat::{ChatMessage, ChatProvider, ChatResponse, ChatRole, MessageType, Tool};
29use crate::completion::{CompletionProvider, CompletionRequest, CompletionResponse};
30use crate::embedding::EmbeddingProvider;
31use crate::error::LLMError;
32use crate::models::ModelsProvider;
33use crate::stt::SpeechToTextProvider;
34use crate::tts::TextToSpeechProvider;
35use crate::{builder::ValidatorFn, LLMProvider};
36
37/// A wrapper around an LLM provider that validates responses before returning them.
38///
39/// The wrapper implements validation by:
40/// 1. Sending the request to the underlying provider
41/// 2. Validating the response using the provided validator function
42/// 3. If validation fails, retrying with feedback up to the configured number of attempts
43///
44/// # Type Parameters
45///
46/// The wrapped provider must implement the `LLMProvider` trait.
47pub struct ValidatedLLM {
48    /// The wrapped LLM provider
49    inner: Box<dyn LLMProvider>,
50    /// Function used to validate responses, returns Ok(()) if valid or Err with message if invalid
51    validator: Box<ValidatorFn>,
52    /// Maximum number of validation attempts before giving up
53    attempts: usize,
54}
55
56impl ValidatedLLM {
57    /// Creates a new ValidatedLLM wrapper around an existing LLM provider.
58    ///
59    /// # Arguments
60    ///
61    /// * `inner` - The LLM provider to wrap with validation
62    /// * `validator` - Function that takes a response string and returns Ok(()) if valid, or Err with error message if invalid
63    /// * `attempts` - Maximum number of validation attempts before failing
64    ///
65    /// # Returns
66    ///
67    /// A new ValidatedLLM instance configured with the provided parameters.
68    pub fn new(inner: Box<dyn LLMProvider>, validator: Box<ValidatorFn>, attempts: usize) -> Self {
69        Self {
70            inner,
71            validator,
72            attempts,
73        }
74    }
75}
76
77impl LLMProvider for ValidatedLLM {}
78
79#[async_trait]
80impl ChatProvider for ValidatedLLM {
81    /// Sends a chat request and validates the response.
82    ///
83    /// If validation fails, retries with feedback to the model about the validation error.
84    /// The feedback is appended as a new user message to help guide the model.
85    ///
86    /// # Arguments
87    ///
88    /// * `messages` - The chat messages to send to the model
89    ///
90    /// # Returns
91    ///
92    /// * `Ok(String)` - The validated response from the model
93    /// * `Err(LLMError)` - If validation fails after max attempts or other errors occur
94    async fn chat_with_tools(
95        &self,
96        messages: &[ChatMessage],
97        tools: Option<&[Tool]>,
98    ) -> Result<Box<dyn ChatResponse>, LLMError> {
99        let mut local_messages = messages.to_vec();
100        let mut remaining_attempts = self.attempts;
101
102        loop {
103            let response = match self.inner.chat_with_tools(&local_messages, tools).await {
104                Ok(resp) => resp,
105                Err(e) => return Err(e),
106            };
107
108            match (self.validator)(&response.text().unwrap_or_default()) {
109                Ok(()) => {
110                    return Ok(response);
111                }
112                Err(err) => {
113                    remaining_attempts -= 1;
114                    if remaining_attempts == 0 {
115                        return Err(LLMError::InvalidRequest(format!(
116                            "Validation error after max attempts: {}",
117                            err
118                        )));
119                    }
120
121                    log::debug!(
122                        "Completion validation failed (attempts remaining: {}). Reason: {}",
123                        remaining_attempts,
124                        err
125                    );
126
127                    log::debug!(
128                        "Validation failed (attempt remaining: {}). Reason: {}",
129                        remaining_attempts,
130                        err
131                    );
132
133                    local_messages.push(ChatMessage {
134                        role: ChatRole::User,
135                        message_type: MessageType::Text,
136                        content: format!(
137                            "Your previous output was invalid because: {}\n\
138                             Please try again and produce a valid response.",
139                            err
140                        ),
141                    });
142                }
143            }
144        }
145    }
146}
147
148#[async_trait]
149impl CompletionProvider for ValidatedLLM {
150    /// Sends a completion request and validates the response.
151    ///
152    /// If validation fails, retries up to the configured number of attempts.
153    /// Unlike chat, completion requests don't support adding feedback messages.
154    ///
155    /// # Arguments
156    ///
157    /// * `req` - The completion request to send
158    ///
159    /// # Returns
160    ///
161    /// * `Ok(CompletionResponse)` - The validated completion response
162    /// * `Err(LLMError)` - If validation fails after max attempts or other errors occur
163    async fn complete(&self, req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
164        let mut remaining_attempts = self.attempts;
165
166        loop {
167            let response = match self.inner.complete(req).await {
168                Ok(resp) => resp,
169                Err(e) => return Err(e),
170            };
171
172            match (self.validator)(&response.text) {
173                Ok(()) => {
174                    return Ok(response);
175                }
176                Err(err) => {
177                    remaining_attempts -= 1;
178                    if remaining_attempts == 0 {
179                        return Err(LLMError::InvalidRequest(format!(
180                            "Validation error after max attempts: {}",
181                            err
182                        )));
183                    }
184                }
185            }
186        }
187    }
188}
189
190#[async_trait]
191impl EmbeddingProvider for ValidatedLLM {
192    /// Passes through embedding requests to the inner provider without validation.
193    ///
194    /// Embeddings are numerical vectors that represent text semantically and don't
195    /// require validation since they're not human-readable content.
196    ///
197    /// # Arguments
198    ///
199    /// * `input` - Vector of strings to generate embeddings for
200    ///
201    /// # Returns
202    ///
203    /// * `Ok(Vec<Vec<f32>>)` - Vector of embedding vectors
204    /// * `Err(LLMError)` - If the embedding generation fails
205    async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
206        // Pass through to inner provider since embeddings don't need validation
207        self.inner.embed(input).await
208    }
209}
210
211#[async_trait]
212impl SpeechToTextProvider for ValidatedLLM {
213    async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
214        Err(LLMError::ProviderError(
215            "Speech to text not supported".to_string(),
216        ))
217    }
218}
219
220#[async_trait]
221impl TextToSpeechProvider for ValidatedLLM {
222    async fn speech(&self, _text: &str) -> Result<Vec<u8>, LLMError> {
223        Err(LLMError::ProviderError(
224            "Text to speech not supported".to_string(),
225        ))
226    }
227}
228
229#[async_trait]
230impl ModelsProvider for ValidatedLLM {}