autoagents_llm/evaluator/parallel.rs
1//! Module for parallel evaluation of multiple LLM providers.
2//!
3//! This module provides functionality to run the same prompt through multiple LLMs
4//! in parallel and select the best response based on scoring functions.
5
6use std::time::Instant;
7
8use futures::future::join_all;
9
10use crate::{
11 chat::{ChatMessage, Tool},
12 completion::CompletionRequest,
13 error::LLMError,
14 LLMProvider,
15};
16
17use super::ScoringFn;
18
19/// Result of a parallel evaluation including response, score, and timing information
20#[derive(Debug)]
21pub struct ParallelEvalResult {
22 /// The text response from the LLM
23 pub text: String,
24 /// Score assigned by the scoring function
25 pub score: f32,
26 /// Time taken to generate the response in milliseconds
27 pub time_ms: u128,
28 /// Identifier of the provider that generated this response
29 pub provider_id: String,
30}
31
32/// Evaluator for running multiple LLM providers in parallel and selecting the best response
33pub struct ParallelEvaluator {
34 /// Collection of LLM providers to evaluate with their identifiers
35 providers: Vec<(String, Box<dyn LLMProvider>)>,
36 /// Scoring functions to evaluate responses
37 scoring_fns: Vec<Box<ScoringFn>>,
38 /// Whether to include timing information in results
39 include_timing: bool,
40}
41
42impl ParallelEvaluator {
43 /// Creates a new parallel evaluator
44 ///
45 /// # Arguments
46 /// * `providers` - Vector of (id, provider) tuples to evaluate
47 pub fn new(providers: Vec<(String, Box<dyn LLMProvider>)>) -> Self {
48 Self {
49 providers,
50 scoring_fns: Vec::new(),
51 include_timing: true,
52 }
53 }
54
55 /// Adds a scoring function to evaluate LLM responses
56 ///
57 /// # Arguments
58 /// * `f` - Function that takes a response string and returns a score
59 pub fn scoring<F>(mut self, f: F) -> Self
60 where
61 F: Fn(&str) -> f32 + Send + Sync + 'static,
62 {
63 self.scoring_fns.push(Box::new(f));
64 self
65 }
66
67 /// Sets whether to include timing information in results
68 pub fn include_timing(mut self, include: bool) -> Self {
69 self.include_timing = include;
70 self
71 }
72
73 /// Evaluates chat responses from all providers in parallel for the given messages
74 ///
75 /// # Arguments
76 /// * `messages` - Chat messages to send to each provider
77 ///
78 /// # Returns
79 /// Vector of evaluation results containing responses, scores, and timing information
80 pub async fn evaluate_chat_parallel(
81 &self,
82 messages: &[ChatMessage],
83 ) -> Result<Vec<ParallelEvalResult>, LLMError> {
84 let futures = self
85 .providers
86 .iter()
87 .map(|(id, provider)| {
88 let id = id.clone();
89 let messages = messages.to_vec();
90 async move {
91 let start = Instant::now();
92 let result = provider.chat(&messages, None, None).await;
93 let elapsed = start.elapsed().as_millis();
94 (id, result, elapsed)
95 }
96 })
97 .collect::<Vec<_>>();
98
99 let results = join_all(futures).await;
100
101 let mut eval_results = Vec::new();
102 for (id, result, elapsed) in results {
103 match result {
104 Ok(response) => {
105 let text = response.text().unwrap_or_default();
106 let score = self.compute_score(&text);
107 eval_results.push(ParallelEvalResult {
108 text,
109 score,
110 time_ms: elapsed,
111 provider_id: id,
112 });
113 }
114 Err(e) => {
115 // Log the error but continue with other results
116 eprintln!("Error from provider {id}: {e}");
117 }
118 }
119 }
120
121 Ok(eval_results)
122 }
123
124 /// Evaluates chat responses with tools from all providers in parallel
125 ///
126 /// # Arguments
127 /// * `messages` - Chat messages to send to each provider
128 /// * `tools` - Optional tools to use in the chat
129 ///
130 /// # Returns
131 /// Vector of evaluation results
132 pub async fn evaluate_chat_with_tools_parallel(
133 &self,
134 messages: &[ChatMessage],
135 tools: Option<&[Tool]>,
136 ) -> Result<Vec<ParallelEvalResult>, LLMError> {
137 let futures = self
138 .providers
139 .iter()
140 .map(|(id, provider)| {
141 let id = id.clone();
142 let messages = messages.to_vec();
143 let tools_clone = tools.map(|t| t.to_vec());
144 async move {
145 let start = Instant::now();
146 let result = provider.chat(&messages, tools_clone.as_deref(), None).await;
147 let elapsed = start.elapsed().as_millis();
148 (id, result, elapsed)
149 }
150 })
151 .collect::<Vec<_>>();
152
153 let results = join_all(futures).await;
154
155 let mut eval_results = Vec::new();
156 for (id, result, elapsed) in results {
157 match result {
158 Ok(response) => {
159 let text = response.text().unwrap_or_default();
160 let score = self.compute_score(&text);
161 eval_results.push(ParallelEvalResult {
162 text,
163 score,
164 time_ms: elapsed,
165 provider_id: id,
166 });
167 }
168 Err(e) => {
169 // Log the error but continue with other results
170 eprintln!("Error from provider {id}: {e}");
171 }
172 }
173 }
174
175 Ok(eval_results)
176 }
177
178 /// Evaluates completion responses from all providers in parallel
179 ///
180 /// # Arguments
181 /// * `request` - Completion request to send to each provider
182 ///
183 /// # Returns
184 /// Vector of evaluation results
185 pub async fn evaluate_completion_parallel(
186 &self,
187 request: &CompletionRequest,
188 ) -> Result<Vec<ParallelEvalResult>, LLMError> {
189 let futures = self
190 .providers
191 .iter()
192 .map(|(id, provider)| {
193 let id = id.clone();
194 let request = request.clone();
195 async move {
196 let start = Instant::now();
197 let result = provider.complete(&request, None).await;
198 let elapsed = start.elapsed().as_millis();
199 (id, result, elapsed)
200 }
201 })
202 .collect::<Vec<_>>();
203
204 let results = join_all(futures).await;
205
206 let mut eval_results = Vec::new();
207 for (id, result, elapsed) in results {
208 match result {
209 Ok(response) => {
210 let score = self.compute_score(&response.text);
211 eval_results.push(ParallelEvalResult {
212 text: response.text,
213 score,
214 time_ms: elapsed,
215 provider_id: id,
216 });
217 }
218 Err(e) => {
219 // Log the error but continue with other results
220 eprintln!("Error from provider {id}: {e}");
221 }
222 }
223 }
224
225 Ok(eval_results)
226 }
227
228 /// Returns the best response based on scoring
229 ///
230 /// # Arguments
231 /// * `results` - Vector of evaluation results
232 ///
233 /// # Returns
234 /// The best result or None if no results are available
235 pub fn best_response<'a>(
236 &self,
237 results: &'a [ParallelEvalResult],
238 ) -> Option<&'a ParallelEvalResult> {
239 if results.is_empty() {
240 return None;
241 }
242
243 results.iter().max_by(|a, b| {
244 a.score
245 .partial_cmp(&b.score)
246 .unwrap_or(std::cmp::Ordering::Equal)
247 })
248 }
249
250 /// Computes the score for a given response
251 ///
252 /// # Arguments
253 /// * `response` - The response to score
254 ///
255 /// # Returns
256 /// The computed score
257 fn compute_score(&self, response: &str) -> f32 {
258 let mut total = 0.0;
259 for sc in &self.scoring_fns {
260 total += sc(response);
261 }
262 total
263 }
264}