llm/validated_llm.rs
1//! A module providing validation capabilities for LLM responses through a wrapper implementation.
2//!
3//! This module enables adding custom validation logic to any LLM provider by wrapping it in a
4//! `ValidatedLLM` struct. The wrapper will validate responses and retry failed attempts with
5//! feedback to help guide the model toward producing valid output.
6//!
7//! # Example
8//!
9//! ```no_run
10//! use llm::builder::{LLMBuilder, LLMBackend};
11//!
12//! let llm = LLMBuilder::new()
13//! .backend(LLMBackend::OpenAI)
14//! .validator(|response| {
15//! if response.contains("unsafe content") {
16//! Err("Response contains unsafe content".to_string())
17//! } else {
18//! Ok(())
19//! }
20//! })
21//! .validator_attempts(3)
22//! .build()
23//! .unwrap();
24//! ```
25
26use async_trait::async_trait;
27
28use crate::chat::{ChatMessage, ChatProvider, ChatResponse, ChatRole, MessageType, Tool};
29use crate::completion::{CompletionProvider, CompletionRequest, CompletionResponse};
30use crate::embedding::EmbeddingProvider;
31use crate::error::LLMError;
32use crate::models::ModelsProvider;
33use crate::stt::SpeechToTextProvider;
34use crate::tts::TextToSpeechProvider;
35use crate::{builder::ValidatorFn, LLMProvider};
36
37/// A wrapper around an LLM provider that validates responses before returning them.
38///
39/// The wrapper implements validation by:
40/// 1. Sending the request to the underlying provider
41/// 2. Validating the response using the provided validator function
42/// 3. If validation fails, retrying with feedback up to the configured number of attempts
43///
44/// # Type Parameters
45///
46/// The wrapped provider must implement the `LLMProvider` trait.
47pub struct ValidatedLLM {
48 /// The wrapped LLM provider
49 inner: Box<dyn LLMProvider>,
50 /// Function used to validate responses, returns Ok(()) if valid or Err with message if invalid
51 validator: Box<ValidatorFn>,
52 /// Maximum number of validation attempts before giving up
53 attempts: usize,
54}
55
56impl ValidatedLLM {
57 /// Creates a new ValidatedLLM wrapper around an existing LLM provider.
58 ///
59 /// # Arguments
60 ///
61 /// * `inner` - The LLM provider to wrap with validation
62 /// * `validator` - Function that takes a response string and returns Ok(()) if valid, or Err with error message if invalid
63 /// * `attempts` - Maximum number of validation attempts before failing
64 ///
65 /// # Returns
66 ///
67 /// A new ValidatedLLM instance configured with the provided parameters.
68 pub fn new(inner: Box<dyn LLMProvider>, validator: Box<ValidatorFn>, attempts: usize) -> Self {
69 Self {
70 inner,
71 validator,
72 attempts,
73 }
74 }
75}
76
77impl LLMProvider for ValidatedLLM {}
78
79#[async_trait]
80impl ChatProvider for ValidatedLLM {
81 /// Sends a chat request and validates the response.
82 ///
83 /// If validation fails, retries with feedback to the model about the validation error.
84 /// The feedback is appended as a new user message to help guide the model.
85 ///
86 /// # Arguments
87 ///
88 /// * `messages` - The chat messages to send to the model
89 ///
90 /// # Returns
91 ///
92 /// * `Ok(String)` - The validated response from the model
93 /// * `Err(LLMError)` - If validation fails after max attempts or other errors occur
94 async fn chat_with_tools(
95 &self,
96 messages: &[ChatMessage],
97 tools: Option<&[Tool]>,
98 ) -> Result<Box<dyn ChatResponse>, LLMError> {
99 let mut local_messages = messages.to_vec();
100 let mut remaining_attempts = self.attempts;
101
102 loop {
103 let response = match self.inner.chat_with_tools(&local_messages, tools).await {
104 Ok(resp) => resp,
105 Err(e) => return Err(e),
106 };
107
108 match (self.validator)(&response.text().unwrap_or_default()) {
109 Ok(()) => {
110 return Ok(response);
111 }
112 Err(err) => {
113 remaining_attempts -= 1;
114 if remaining_attempts == 0 {
115 return Err(LLMError::InvalidRequest(format!(
116 "Validation error after max attempts: {}",
117 err
118 )));
119 }
120
121 log::debug!(
122 "Completion validation failed (attempts remaining: {}). Reason: {}",
123 remaining_attempts,
124 err
125 );
126
127 log::debug!(
128 "Validation failed (attempt remaining: {}). Reason: {}",
129 remaining_attempts,
130 err
131 );
132
133 local_messages.push(ChatMessage {
134 role: ChatRole::User,
135 message_type: MessageType::Text,
136 content: format!(
137 "Your previous output was invalid because: {}\n\
138 Please try again and produce a valid response.",
139 err
140 ),
141 });
142 }
143 }
144 }
145 }
146}
147
148#[async_trait]
149impl CompletionProvider for ValidatedLLM {
150 /// Sends a completion request and validates the response.
151 ///
152 /// If validation fails, retries up to the configured number of attempts.
153 /// Unlike chat, completion requests don't support adding feedback messages.
154 ///
155 /// # Arguments
156 ///
157 /// * `req` - The completion request to send
158 ///
159 /// # Returns
160 ///
161 /// * `Ok(CompletionResponse)` - The validated completion response
162 /// * `Err(LLMError)` - If validation fails after max attempts or other errors occur
163 async fn complete(&self, req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
164 let mut remaining_attempts = self.attempts;
165
166 loop {
167 let response = match self.inner.complete(req).await {
168 Ok(resp) => resp,
169 Err(e) => return Err(e),
170 };
171
172 match (self.validator)(&response.text) {
173 Ok(()) => {
174 return Ok(response);
175 }
176 Err(err) => {
177 remaining_attempts -= 1;
178 if remaining_attempts == 0 {
179 return Err(LLMError::InvalidRequest(format!(
180 "Validation error after max attempts: {}",
181 err
182 )));
183 }
184 }
185 }
186 }
187 }
188}
189
190#[async_trait]
191impl EmbeddingProvider for ValidatedLLM {
192 /// Passes through embedding requests to the inner provider without validation.
193 ///
194 /// Embeddings are numerical vectors that represent text semantically and don't
195 /// require validation since they're not human-readable content.
196 ///
197 /// # Arguments
198 ///
199 /// * `input` - Vector of strings to generate embeddings for
200 ///
201 /// # Returns
202 ///
203 /// * `Ok(Vec<Vec<f32>>)` - Vector of embedding vectors
204 /// * `Err(LLMError)` - If the embedding generation fails
205 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
206 // Pass through to inner provider since embeddings don't need validation
207 self.inner.embed(input).await
208 }
209}
210
211#[async_trait]
212impl SpeechToTextProvider for ValidatedLLM {
213 async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
214 Err(LLMError::ProviderError(
215 "Speech to text not supported".to_string(),
216 ))
217 }
218}
219
220#[async_trait]
221impl TextToSpeechProvider for ValidatedLLM {
222 async fn speech(&self, _text: &str) -> Result<Vec<u8>, LLMError> {
223 Err(LLMError::ProviderError(
224 "Text to speech not supported".to_string(),
225 ))
226 }
227}
228
229#[async_trait]
230impl ModelsProvider for ValidatedLLM {}