llm/evaluator/parallel.rs
1//! Module for parallel evaluation of multiple LLM providers.
2//!
3//! This module provides functionality to run the same prompt through multiple LLMs
4//! in parallel and select the best response based on scoring functions.
5
6use std::time::Instant;
7
8use futures::future::join_all;
9
10use crate::{
11 chat::{ChatMessage, Tool},
12 completion::CompletionRequest,
13 error::LLMError,
14 LLMProvider,
15};
16
17use super::ScoringFn;
18
19/// Result of a parallel evaluation including response, score, and timing information
20#[derive(Debug)]
21pub struct ParallelEvalResult {
22 /// The text response from the LLM
23 pub text: String,
24 /// Score assigned by the scoring function
25 pub score: f32,
26 /// Time taken to generate the response in milliseconds
27 pub time_ms: u128,
28 /// Identifier of the provider that generated this response
29 pub provider_id: String,
30}
31
32/// Evaluator for running multiple LLM providers in parallel and selecting the best response
33pub struct ParallelEvaluator {
34 /// Collection of LLM providers to evaluate with their identifiers
35 providers: Vec<(String, Box<dyn LLMProvider>)>,
36 /// Scoring functions to evaluate responses
37 scoring_fns: Vec<Box<ScoringFn>>,
38 /// Whether to include timing information in results
39 include_timing: bool,
40}
41
42impl ParallelEvaluator {
43 /// Creates a new parallel evaluator
44 ///
45 /// # Arguments
46 /// * `providers` - Vector of (id, provider) tuples to evaluate
47 pub fn new(providers: Vec<(String, Box<dyn LLMProvider>)>) -> Self {
48 Self {
49 providers,
50 scoring_fns: Vec::new(),
51 include_timing: true,
52 }
53 }
54
55 /// Adds a scoring function to evaluate LLM responses
56 ///
57 /// # Arguments
58 /// * `f` - Function that takes a response string and returns a score
59 pub fn scoring<F>(mut self, f: F) -> Self
60 where
61 F: Fn(&str) -> f32 + Send + Sync + 'static,
62 {
63 self.scoring_fns.push(Box::new(f));
64 self
65 }
66
67 /// Sets whether to include timing information in results
68 pub fn include_timing(mut self, include: bool) -> Self {
69 self.include_timing = include;
70 self
71 }
72
73 /// Evaluates chat responses from all providers in parallel for the given messages
74 ///
75 /// # Arguments
76 /// * `messages` - Chat messages to send to each provider
77 ///
78 /// # Returns
79 /// Vector of evaluation results containing responses, scores, and timing information
80 pub async fn evaluate_chat_parallel(
81 &self,
82 messages: &[ChatMessage],
83 ) -> Result<Vec<ParallelEvalResult>, LLMError> {
84 let futures = self
85 .providers
86 .iter()
87 .map(|(id, provider)| {
88 let id = id.clone();
89 let messages = messages.to_vec();
90 async move {
91 let start = Instant::now();
92 let result = provider.chat(&messages).await;
93 let elapsed = start.elapsed().as_millis();
94 (id, result, elapsed)
95 }
96 })
97 .collect::<Vec<_>>();
98
99 let results = join_all(futures).await;
100
101 let mut eval_results = Vec::new();
102 for (id, result, elapsed) in results {
103 match result {
104 Ok(response) => {
105 let text = response.text().unwrap_or_default();
106 let score = self.compute_score(&text);
107 eval_results.push(ParallelEvalResult {
108 text,
109 score,
110 time_ms: elapsed,
111 provider_id: id,
112 });
113 }
114 Err(e) => {
115 // Log the error but continue with other results
116 eprintln!("Error from provider {}: {}", id, e);
117 }
118 }
119 }
120
121 Ok(eval_results)
122 }
123
124 /// Evaluates chat responses with tools from all providers in parallel
125 ///
126 /// # Arguments
127 /// * `messages` - Chat messages to send to each provider
128 /// * `tools` - Optional tools to use in the chat
129 ///
130 /// # Returns
131 /// Vector of evaluation results
132 pub async fn evaluate_chat_with_tools_parallel(
133 &self,
134 messages: &[ChatMessage],
135 tools: Option<&[Tool]>,
136 ) -> Result<Vec<ParallelEvalResult>, LLMError> {
137 let futures = self
138 .providers
139 .iter()
140 .map(|(id, provider)| {
141 let id = id.clone();
142 let messages = messages.to_vec();
143 let tools_clone = tools.map(|t| t.to_vec());
144 async move {
145 let start = Instant::now();
146 let result = provider
147 .chat_with_tools(&messages, tools_clone.as_deref())
148 .await;
149 let elapsed = start.elapsed().as_millis();
150 (id, result, elapsed)
151 }
152 })
153 .collect::<Vec<_>>();
154
155 let results = join_all(futures).await;
156
157 let mut eval_results = Vec::new();
158 for (id, result, elapsed) in results {
159 match result {
160 Ok(response) => {
161 let text = response.text().unwrap_or_default();
162 let score = self.compute_score(&text);
163 eval_results.push(ParallelEvalResult {
164 text,
165 score,
166 time_ms: elapsed,
167 provider_id: id,
168 });
169 }
170 Err(e) => {
171 // Log the error but continue with other results
172 eprintln!("Error from provider {}: {}", id, e);
173 }
174 }
175 }
176
177 Ok(eval_results)
178 }
179
180 /// Evaluates completion responses from all providers in parallel
181 ///
182 /// # Arguments
183 /// * `request` - Completion request to send to each provider
184 ///
185 /// # Returns
186 /// Vector of evaluation results
187 pub async fn evaluate_completion_parallel(
188 &self,
189 request: &CompletionRequest,
190 ) -> Result<Vec<ParallelEvalResult>, LLMError> {
191 let futures = self
192 .providers
193 .iter()
194 .map(|(id, provider)| {
195 let id = id.clone();
196 let request = request.clone();
197 async move {
198 let start = Instant::now();
199 let result = provider.complete(&request).await;
200 let elapsed = start.elapsed().as_millis();
201 (id, result, elapsed)
202 }
203 })
204 .collect::<Vec<_>>();
205
206 let results = join_all(futures).await;
207
208 let mut eval_results = Vec::new();
209 for (id, result, elapsed) in results {
210 match result {
211 Ok(response) => {
212 let score = self.compute_score(&response.text);
213 eval_results.push(ParallelEvalResult {
214 text: response.text,
215 score,
216 time_ms: elapsed,
217 provider_id: id,
218 });
219 }
220 Err(e) => {
221 // Log the error but continue with other results
222 eprintln!("Error from provider {}: {}", id, e);
223 }
224 }
225 }
226
227 Ok(eval_results)
228 }
229
230 /// Returns the best response based on scoring
231 ///
232 /// # Arguments
233 /// * `results` - Vector of evaluation results
234 ///
235 /// # Returns
236 /// The best result or None if no results are available
237 pub fn best_response<'a>(
238 &self,
239 results: &'a [ParallelEvalResult],
240 ) -> Option<&'a ParallelEvalResult> {
241 if results.is_empty() {
242 return None;
243 }
244
245 results.iter().max_by(|a, b| {
246 a.score
247 .partial_cmp(&b.score)
248 .unwrap_or(std::cmp::Ordering::Equal)
249 })
250 }
251
252 /// Computes the score for a given response
253 ///
254 /// # Arguments
255 /// * `response` - The response to score
256 ///
257 /// # Returns
258 /// The computed score
259 fn compute_score(&self, response: &str) -> f32 {
260 let mut total = 0.0;
261 for sc in &self.scoring_fns {
262 total += sc(response);
263 }
264 total
265 }
266}