Skip to main content

autoagents_llm/evaluator/
parallel.rs

1//! Module for parallel evaluation of multiple LLM providers.
2//!
3//! This module provides functionality to run the same prompt through multiple LLMs
4//! in parallel and select the best response based on scoring functions.
5
6use std::time::Instant;
7
8use futures::future::join_all;
9
10use crate::{
11    LLMProvider,
12    chat::{ChatMessage, Tool},
13    completion::CompletionRequest,
14    error::LLMError,
15};
16
17use super::ScoringFn;
18
19/// Result of a parallel evaluation including response, score, and timing information
20#[derive(Debug)]
21pub struct ParallelEvalResult {
22    /// The text response from the LLM
23    pub text: String,
24    /// Score assigned by the scoring function
25    pub score: f32,
26    /// Time taken to generate the response in milliseconds
27    pub time_ms: u128,
28    /// Identifier of the provider that generated this response
29    pub provider_id: String,
30}
31
32/// Evaluator for running multiple LLM providers in parallel and selecting the best response
33pub struct ParallelEvaluator {
34    /// Collection of LLM providers to evaluate with their identifiers
35    providers: Vec<(String, Box<dyn LLMProvider>)>,
36    /// Scoring functions to evaluate responses
37    scoring_fns: Vec<Box<ScoringFn>>,
38    /// Whether to include timing information in results
39    include_timing: bool,
40}
41
42impl ParallelEvaluator {
43    /// Creates a new parallel evaluator
44    ///
45    /// # Arguments
46    /// * `providers` - Vector of (id, provider) tuples to evaluate
47    pub fn new(providers: Vec<(String, Box<dyn LLMProvider>)>) -> Self {
48        Self {
49            providers,
50            scoring_fns: Vec::new(),
51            include_timing: true,
52        }
53    }
54
55    /// Adds a scoring function to evaluate LLM responses
56    ///
57    /// # Arguments
58    /// * `f` - Function that takes a response string and returns a score
59    pub fn scoring<F>(mut self, f: F) -> Self
60    where
61        F: Fn(&str) -> f32 + Send + Sync + 'static,
62    {
63        self.scoring_fns.push(Box::new(f));
64        self
65    }
66
67    /// Sets whether to include timing information in results
68    pub fn include_timing(mut self, include: bool) -> Self {
69        self.include_timing = include;
70        self
71    }
72
73    /// Evaluates chat responses from all providers in parallel for the given messages
74    ///
75    /// # Arguments
76    /// * `messages` - Chat messages to send to each provider
77    ///
78    /// # Returns
79    /// Vector of evaluation results containing responses, scores, and timing information
80    pub async fn evaluate_chat_parallel(
81        &self,
82        messages: &[ChatMessage],
83    ) -> Result<Vec<ParallelEvalResult>, LLMError> {
84        let futures = self
85            .providers
86            .iter()
87            .map(|(id, provider)| {
88                let id = id.clone();
89                let messages = messages.to_vec();
90                async move {
91                    let start = Instant::now();
92                    let result = provider.chat(&messages, None).await;
93                    let elapsed = start.elapsed().as_millis();
94                    (id, result, elapsed)
95                }
96            })
97            .collect::<Vec<_>>();
98
99        let results = join_all(futures).await;
100
101        let mut eval_results = Vec::new();
102        for (id, result, elapsed) in results {
103            match result {
104                Ok(response) => {
105                    let text = response.text().unwrap_or_default();
106                    let score = self.compute_score(&text);
107                    eval_results.push(ParallelEvalResult {
108                        text,
109                        score,
110                        time_ms: elapsed,
111                        provider_id: id,
112                    });
113                }
114                Err(e) => {
115                    // Log the error but continue with other results
116                    eprintln!("Error from provider {id}: {e}");
117                }
118            }
119        }
120
121        Ok(eval_results)
122    }
123
124    /// Evaluates chat responses with tools from all providers in parallel
125    ///
126    /// # Arguments
127    /// * `messages` - Chat messages to send to each provider
128    /// * `tools` - Optional tools to use in the chat
129    ///
130    /// # Returns
131    /// Vector of evaluation results
132    pub async fn evaluate_chat_with_tools_parallel(
133        &self,
134        messages: &[ChatMessage],
135        tools: Option<&[Tool]>,
136    ) -> Result<Vec<ParallelEvalResult>, LLMError> {
137        let futures = self
138            .providers
139            .iter()
140            .map(|(id, provider)| {
141                let id = id.clone();
142                let messages = messages.to_vec();
143                async move {
144                    let start = Instant::now();
145                    let result = provider.chat_with_tools(&messages, tools, None).await;
146                    let elapsed = start.elapsed().as_millis();
147                    (id, result, elapsed)
148                }
149            })
150            .collect::<Vec<_>>();
151
152        let results = join_all(futures).await;
153
154        let mut eval_results = Vec::new();
155        for (id, result, elapsed) in results {
156            match result {
157                Ok(response) => {
158                    let text = response.text().unwrap_or_default();
159                    let score = self.compute_score(&text);
160                    eval_results.push(ParallelEvalResult {
161                        text,
162                        score,
163                        time_ms: elapsed,
164                        provider_id: id,
165                    });
166                }
167                Err(e) => {
168                    // Log the error but continue with other results
169                    eprintln!("Error from provider {id}: {e}");
170                }
171            }
172        }
173
174        Ok(eval_results)
175    }
176
177    /// Evaluates completion responses from all providers in parallel
178    ///
179    /// # Arguments
180    /// * `request` - Completion request to send to each provider
181    ///
182    /// # Returns
183    /// Vector of evaluation results
184    pub async fn evaluate_completion_parallel(
185        &self,
186        request: &CompletionRequest,
187    ) -> Result<Vec<ParallelEvalResult>, LLMError> {
188        let futures = self
189            .providers
190            .iter()
191            .map(|(id, provider)| {
192                let id = id.clone();
193                let request = request.clone();
194                async move {
195                    let start = Instant::now();
196                    let result = provider.complete(&request, None).await;
197                    let elapsed = start.elapsed().as_millis();
198                    (id, result, elapsed)
199                }
200            })
201            .collect::<Vec<_>>();
202
203        let results = join_all(futures).await;
204
205        let mut eval_results = Vec::new();
206        for (id, result, elapsed) in results {
207            match result {
208                Ok(response) => {
209                    let score = self.compute_score(&response.text);
210                    eval_results.push(ParallelEvalResult {
211                        text: response.text,
212                        score,
213                        time_ms: elapsed,
214                        provider_id: id,
215                    });
216                }
217                Err(e) => {
218                    // Log the error but continue with other results
219                    eprintln!("Error from provider {id}: {e}");
220                }
221            }
222        }
223
224        Ok(eval_results)
225    }
226
227    /// Returns the best response based on scoring
228    ///
229    /// # Arguments
230    /// * `results` - Vector of evaluation results
231    ///
232    /// # Returns
233    /// The best result or None if no results are available
234    pub fn best_response<'a>(
235        &self,
236        results: &'a [ParallelEvalResult],
237    ) -> Option<&'a ParallelEvalResult> {
238        if results.is_empty() {
239            return None;
240        }
241
242        results.iter().max_by(|a, b| {
243            a.score
244                .partial_cmp(&b.score)
245                .unwrap_or(std::cmp::Ordering::Equal)
246        })
247    }
248
249    /// Computes the score for a given response
250    ///
251    /// # Arguments
252    /// * `response` - The response to score
253    ///
254    /// # Returns
255    /// The computed score
256    fn compute_score(&self, response: &str) -> f32 {
257        let mut total = 0.0;
258        for sc in &self.scoring_fns {
259            total += sc(response);
260        }
261        total
262    }
263}