autoagents_llm/evaluator/parallel.rs
1//! Module for parallel evaluation of multiple LLM providers.
2//!
3//! This module provides functionality to run the same prompt through multiple LLMs
4//! in parallel and select the best response based on scoring functions.
5
6use std::time::Instant;
7
8use futures::future::join_all;
9
10use crate::{
11 LLMProvider,
12 chat::{ChatMessage, Tool},
13 completion::CompletionRequest,
14 error::LLMError,
15};
16
17use super::ScoringFn;
18
19/// Result of a parallel evaluation including response, score, and timing information
20#[derive(Debug)]
21pub struct ParallelEvalResult {
22 /// The text response from the LLM
23 pub text: String,
24 /// Score assigned by the scoring function
25 pub score: f32,
26 /// Time taken to generate the response in milliseconds
27 pub time_ms: u128,
28 /// Identifier of the provider that generated this response
29 pub provider_id: String,
30}
31
32/// Evaluator for running multiple LLM providers in parallel and selecting the best response
33pub struct ParallelEvaluator {
34 /// Collection of LLM providers to evaluate with their identifiers
35 providers: Vec<(String, Box<dyn LLMProvider>)>,
36 /// Scoring functions to evaluate responses
37 scoring_fns: Vec<Box<ScoringFn>>,
38 /// Whether to include timing information in results
39 include_timing: bool,
40}
41
42impl ParallelEvaluator {
43 /// Creates a new parallel evaluator
44 ///
45 /// # Arguments
46 /// * `providers` - Vector of (id, provider) tuples to evaluate
47 pub fn new(providers: Vec<(String, Box<dyn LLMProvider>)>) -> Self {
48 Self {
49 providers,
50 scoring_fns: Vec::new(),
51 include_timing: true,
52 }
53 }
54
55 /// Adds a scoring function to evaluate LLM responses
56 ///
57 /// # Arguments
58 /// * `f` - Function that takes a response string and returns a score
59 pub fn scoring<F>(mut self, f: F) -> Self
60 where
61 F: Fn(&str) -> f32 + Send + Sync + 'static,
62 {
63 self.scoring_fns.push(Box::new(f));
64 self
65 }
66
67 /// Sets whether to include timing information in results
68 pub fn include_timing(mut self, include: bool) -> Self {
69 self.include_timing = include;
70 self
71 }
72
73 /// Evaluates chat responses from all providers in parallel for the given messages
74 ///
75 /// # Arguments
76 /// * `messages` - Chat messages to send to each provider
77 ///
78 /// # Returns
79 /// Vector of evaluation results containing responses, scores, and timing information
80 pub async fn evaluate_chat_parallel(
81 &self,
82 messages: &[ChatMessage],
83 ) -> Result<Vec<ParallelEvalResult>, LLMError> {
84 let futures = self
85 .providers
86 .iter()
87 .map(|(id, provider)| {
88 let id = id.clone();
89 let messages = messages.to_vec();
90 async move {
91 let start = Instant::now();
92 let result = provider.chat(&messages, None).await;
93 let elapsed = start.elapsed().as_millis();
94 (id, result, elapsed)
95 }
96 })
97 .collect::<Vec<_>>();
98
99 let results = join_all(futures).await;
100
101 let mut eval_results = Vec::new();
102 for (id, result, elapsed) in results {
103 match result {
104 Ok(response) => {
105 let text = response.text().unwrap_or_default();
106 let score = self.compute_score(&text);
107 eval_results.push(ParallelEvalResult {
108 text,
109 score,
110 time_ms: elapsed,
111 provider_id: id,
112 });
113 }
114 Err(e) => {
115 // Log the error but continue with other results
116 eprintln!("Error from provider {id}: {e}");
117 }
118 }
119 }
120
121 Ok(eval_results)
122 }
123
124 /// Evaluates chat responses with tools from all providers in parallel
125 ///
126 /// # Arguments
127 /// * `messages` - Chat messages to send to each provider
128 /// * `tools` - Optional tools to use in the chat
129 ///
130 /// # Returns
131 /// Vector of evaluation results
132 pub async fn evaluate_chat_with_tools_parallel(
133 &self,
134 messages: &[ChatMessage],
135 tools: Option<&[Tool]>,
136 ) -> Result<Vec<ParallelEvalResult>, LLMError> {
137 let futures = self
138 .providers
139 .iter()
140 .map(|(id, provider)| {
141 let id = id.clone();
142 let messages = messages.to_vec();
143 async move {
144 let start = Instant::now();
145 let result = provider.chat_with_tools(&messages, tools, None).await;
146 let elapsed = start.elapsed().as_millis();
147 (id, result, elapsed)
148 }
149 })
150 .collect::<Vec<_>>();
151
152 let results = join_all(futures).await;
153
154 let mut eval_results = Vec::new();
155 for (id, result, elapsed) in results {
156 match result {
157 Ok(response) => {
158 let text = response.text().unwrap_or_default();
159 let score = self.compute_score(&text);
160 eval_results.push(ParallelEvalResult {
161 text,
162 score,
163 time_ms: elapsed,
164 provider_id: id,
165 });
166 }
167 Err(e) => {
168 // Log the error but continue with other results
169 eprintln!("Error from provider {id}: {e}");
170 }
171 }
172 }
173
174 Ok(eval_results)
175 }
176
177 /// Evaluates completion responses from all providers in parallel
178 ///
179 /// # Arguments
180 /// * `request` - Completion request to send to each provider
181 ///
182 /// # Returns
183 /// Vector of evaluation results
184 pub async fn evaluate_completion_parallel(
185 &self,
186 request: &CompletionRequest,
187 ) -> Result<Vec<ParallelEvalResult>, LLMError> {
188 let futures = self
189 .providers
190 .iter()
191 .map(|(id, provider)| {
192 let id = id.clone();
193 let request = request.clone();
194 async move {
195 let start = Instant::now();
196 let result = provider.complete(&request, None).await;
197 let elapsed = start.elapsed().as_millis();
198 (id, result, elapsed)
199 }
200 })
201 .collect::<Vec<_>>();
202
203 let results = join_all(futures).await;
204
205 let mut eval_results = Vec::new();
206 for (id, result, elapsed) in results {
207 match result {
208 Ok(response) => {
209 let score = self.compute_score(&response.text);
210 eval_results.push(ParallelEvalResult {
211 text: response.text,
212 score,
213 time_ms: elapsed,
214 provider_id: id,
215 });
216 }
217 Err(e) => {
218 // Log the error but continue with other results
219 eprintln!("Error from provider {id}: {e}");
220 }
221 }
222 }
223
224 Ok(eval_results)
225 }
226
227 /// Returns the best response based on scoring
228 ///
229 /// # Arguments
230 /// * `results` - Vector of evaluation results
231 ///
232 /// # Returns
233 /// The best result or None if no results are available
234 pub fn best_response<'a>(
235 &self,
236 results: &'a [ParallelEvalResult],
237 ) -> Option<&'a ParallelEvalResult> {
238 if results.is_empty() {
239 return None;
240 }
241
242 results.iter().max_by(|a, b| {
243 a.score
244 .partial_cmp(&b.score)
245 .unwrap_or(std::cmp::Ordering::Equal)
246 })
247 }
248
249 /// Computes the score for a given response
250 ///
251 /// # Arguments
252 /// * `response` - The response to score
253 ///
254 /// # Returns
255 /// The computed score
256 fn compute_score(&self, response: &str) -> f32 {
257 let mut total = 0.0;
258 for sc in &self.scoring_fns {
259 total += sc(response);
260 }
261 total
262 }
263}