autoagents_llm/evaluator/
parallel.rs

1//! Module for parallel evaluation of multiple LLM providers.
2//!
3//! This module provides functionality to run the same prompt through multiple LLMs
4//! in parallel and select the best response based on scoring functions.
5
6use std::time::Instant;
7
8use futures::future::join_all;
9
10use crate::{
11    chat::{ChatMessage, Tool},
12    completion::CompletionRequest,
13    error::LLMError,
14    LLMProvider,
15};
16
17use super::ScoringFn;
18
19/// Result of a parallel evaluation including response, score, and timing information
20#[derive(Debug)]
21pub struct ParallelEvalResult {
22    /// The text response from the LLM
23    pub text: String,
24    /// Score assigned by the scoring function
25    pub score: f32,
26    /// Time taken to generate the response in milliseconds
27    pub time_ms: u128,
28    /// Identifier of the provider that generated this response
29    pub provider_id: String,
30}
31
32/// Evaluator for running multiple LLM providers in parallel and selecting the best response
33pub struct ParallelEvaluator {
34    /// Collection of LLM providers to evaluate with their identifiers
35    providers: Vec<(String, Box<dyn LLMProvider>)>,
36    /// Scoring functions to evaluate responses
37    scoring_fns: Vec<Box<ScoringFn>>,
38    /// Whether to include timing information in results
39    include_timing: bool,
40}
41
42impl ParallelEvaluator {
43    /// Creates a new parallel evaluator
44    ///
45    /// # Arguments
46    /// * `providers` - Vector of (id, provider) tuples to evaluate
47    pub fn new(providers: Vec<(String, Box<dyn LLMProvider>)>) -> Self {
48        Self {
49            providers,
50            scoring_fns: Vec::new(),
51            include_timing: true,
52        }
53    }
54
55    /// Adds a scoring function to evaluate LLM responses
56    ///
57    /// # Arguments
58    /// * `f` - Function that takes a response string and returns a score
59    pub fn scoring<F>(mut self, f: F) -> Self
60    where
61        F: Fn(&str) -> f32 + Send + Sync + 'static,
62    {
63        self.scoring_fns.push(Box::new(f));
64        self
65    }
66
67    /// Sets whether to include timing information in results
68    pub fn include_timing(mut self, include: bool) -> Self {
69        self.include_timing = include;
70        self
71    }
72
73    /// Evaluates chat responses from all providers in parallel for the given messages
74    ///
75    /// # Arguments
76    /// * `messages` - Chat messages to send to each provider
77    ///
78    /// # Returns
79    /// Vector of evaluation results containing responses, scores, and timing information
80    pub async fn evaluate_chat_parallel(
81        &self,
82        messages: &[ChatMessage],
83    ) -> Result<Vec<ParallelEvalResult>, LLMError> {
84        let futures = self
85            .providers
86            .iter()
87            .map(|(id, provider)| {
88                let id = id.clone();
89                let messages = messages.to_vec();
90                async move {
91                    let start = Instant::now();
92                    let result = provider.chat(&messages, None, None).await;
93                    let elapsed = start.elapsed().as_millis();
94                    (id, result, elapsed)
95                }
96            })
97            .collect::<Vec<_>>();
98
99        let results = join_all(futures).await;
100
101        let mut eval_results = Vec::new();
102        for (id, result, elapsed) in results {
103            match result {
104                Ok(response) => {
105                    let text = response.text().unwrap_or_default();
106                    let score = self.compute_score(&text);
107                    eval_results.push(ParallelEvalResult {
108                        text,
109                        score,
110                        time_ms: elapsed,
111                        provider_id: id,
112                    });
113                }
114                Err(e) => {
115                    // Log the error but continue with other results
116                    eprintln!("Error from provider {id}: {e}");
117                }
118            }
119        }
120
121        Ok(eval_results)
122    }
123
124    /// Evaluates chat responses with tools from all providers in parallel
125    ///
126    /// # Arguments
127    /// * `messages` - Chat messages to send to each provider
128    /// * `tools` - Optional tools to use in the chat
129    ///
130    /// # Returns
131    /// Vector of evaluation results
132    pub async fn evaluate_chat_with_tools_parallel(
133        &self,
134        messages: &[ChatMessage],
135        tools: Option<&[Tool]>,
136    ) -> Result<Vec<ParallelEvalResult>, LLMError> {
137        let futures = self
138            .providers
139            .iter()
140            .map(|(id, provider)| {
141                let id = id.clone();
142                let messages = messages.to_vec();
143                let tools_clone = tools.map(|t| t.to_vec());
144                async move {
145                    let start = Instant::now();
146                    let result = provider.chat(&messages, tools_clone.as_deref(), None).await;
147                    let elapsed = start.elapsed().as_millis();
148                    (id, result, elapsed)
149                }
150            })
151            .collect::<Vec<_>>();
152
153        let results = join_all(futures).await;
154
155        let mut eval_results = Vec::new();
156        for (id, result, elapsed) in results {
157            match result {
158                Ok(response) => {
159                    let text = response.text().unwrap_or_default();
160                    let score = self.compute_score(&text);
161                    eval_results.push(ParallelEvalResult {
162                        text,
163                        score,
164                        time_ms: elapsed,
165                        provider_id: id,
166                    });
167                }
168                Err(e) => {
169                    // Log the error but continue with other results
170                    eprintln!("Error from provider {id}: {e}");
171                }
172            }
173        }
174
175        Ok(eval_results)
176    }
177
178    /// Evaluates completion responses from all providers in parallel
179    ///
180    /// # Arguments
181    /// * `request` - Completion request to send to each provider
182    ///
183    /// # Returns
184    /// Vector of evaluation results
185    pub async fn evaluate_completion_parallel(
186        &self,
187        request: &CompletionRequest,
188    ) -> Result<Vec<ParallelEvalResult>, LLMError> {
189        let futures = self
190            .providers
191            .iter()
192            .map(|(id, provider)| {
193                let id = id.clone();
194                let request = request.clone();
195                async move {
196                    let start = Instant::now();
197                    let result = provider.complete(&request, None).await;
198                    let elapsed = start.elapsed().as_millis();
199                    (id, result, elapsed)
200                }
201            })
202            .collect::<Vec<_>>();
203
204        let results = join_all(futures).await;
205
206        let mut eval_results = Vec::new();
207        for (id, result, elapsed) in results {
208            match result {
209                Ok(response) => {
210                    let score = self.compute_score(&response.text);
211                    eval_results.push(ParallelEvalResult {
212                        text: response.text,
213                        score,
214                        time_ms: elapsed,
215                        provider_id: id,
216                    });
217                }
218                Err(e) => {
219                    // Log the error but continue with other results
220                    eprintln!("Error from provider {id}: {e}");
221                }
222            }
223        }
224
225        Ok(eval_results)
226    }
227
228    /// Returns the best response based on scoring
229    ///
230    /// # Arguments
231    /// * `results` - Vector of evaluation results
232    ///
233    /// # Returns
234    /// The best result or None if no results are available
235    pub fn best_response<'a>(
236        &self,
237        results: &'a [ParallelEvalResult],
238    ) -> Option<&'a ParallelEvalResult> {
239        if results.is_empty() {
240            return None;
241        }
242
243        results.iter().max_by(|a, b| {
244            a.score
245                .partial_cmp(&b.score)
246                .unwrap_or(std::cmp::Ordering::Equal)
247        })
248    }
249
250    /// Computes the score for a given response
251    ///
252    /// # Arguments
253    /// * `response` - The response to score
254    ///
255    /// # Returns
256    /// The computed score
257    fn compute_score(&self, response: &str) -> f32 {
258        let mut total = 0.0;
259        for sc in &self.scoring_fns {
260            total += sc(response);
261        }
262        total
263    }
264}