llm/evaluator/
parallel.rs

1//! Module for parallel evaluation of multiple LLM providers.
2//!
3//! This module provides functionality to run the same prompt through multiple LLMs
4//! in parallel and select the best response based on scoring functions.
5
6use std::time::Instant;
7
8use futures::future::join_all;
9
10use crate::{
11    chat::{ChatMessage, Tool},
12    completion::CompletionRequest,
13    error::LLMError,
14    LLMProvider,
15};
16
17use super::ScoringFn;
18
19/// Result of a parallel evaluation including response, score, and timing information
20#[derive(Debug)]
21pub struct ParallelEvalResult {
22    /// The text response from the LLM
23    pub text: String,
24    /// Score assigned by the scoring function
25    pub score: f32,
26    /// Time taken to generate the response in milliseconds
27    pub time_ms: u128,
28    /// Identifier of the provider that generated this response
29    pub provider_id: String,
30}
31
32/// Evaluator for running multiple LLM providers in parallel and selecting the best response
33pub struct ParallelEvaluator {
34    /// Collection of LLM providers to evaluate with their identifiers
35    providers: Vec<(String, Box<dyn LLMProvider>)>,
36    /// Scoring functions to evaluate responses
37    scoring_fns: Vec<Box<ScoringFn>>,
38    /// Whether to include timing information in results
39    include_timing: bool,
40}
41
42impl ParallelEvaluator {
43    /// Creates a new parallel evaluator
44    ///
45    /// # Arguments
46    /// * `providers` - Vector of (id, provider) tuples to evaluate
47    pub fn new(providers: Vec<(String, Box<dyn LLMProvider>)>) -> Self {
48        Self {
49            providers,
50            scoring_fns: Vec::new(),
51            include_timing: true,
52        }
53    }
54
55    /// Adds a scoring function to evaluate LLM responses
56    ///
57    /// # Arguments
58    /// * `f` - Function that takes a response string and returns a score
59    pub fn scoring<F>(mut self, f: F) -> Self
60    where
61        F: Fn(&str) -> f32 + Send + Sync + 'static,
62    {
63        self.scoring_fns.push(Box::new(f));
64        self
65    }
66
67    /// Sets whether to include timing information in results
68    pub fn include_timing(mut self, include: bool) -> Self {
69        self.include_timing = include;
70        self
71    }
72
73    /// Evaluates chat responses from all providers in parallel for the given messages
74    ///
75    /// # Arguments
76    /// * `messages` - Chat messages to send to each provider
77    ///
78    /// # Returns
79    /// Vector of evaluation results containing responses, scores, and timing information
80    pub async fn evaluate_chat_parallel(
81        &self,
82        messages: &[ChatMessage],
83    ) -> Result<Vec<ParallelEvalResult>, LLMError> {
84        let futures = self
85            .providers
86            .iter()
87            .map(|(id, provider)| {
88                let id = id.clone();
89                let messages = messages.to_vec();
90                async move {
91                    let start = Instant::now();
92                    let result = provider.chat(&messages).await;
93                    let elapsed = start.elapsed().as_millis();
94                    (id, result, elapsed)
95                }
96            })
97            .collect::<Vec<_>>();
98
99        let results = join_all(futures).await;
100
101        let mut eval_results = Vec::new();
102        for (id, result, elapsed) in results {
103            match result {
104                Ok(response) => {
105                    let text = response.text().unwrap_or_default();
106                    let score = self.compute_score(&text);
107                    eval_results.push(ParallelEvalResult {
108                        text,
109                        score,
110                        time_ms: elapsed,
111                        provider_id: id,
112                    });
113                }
114                Err(e) => {
115                    // Log the error but continue with other results
116                    eprintln!("Error from provider {}: {}", id, e);
117                }
118            }
119        }
120
121        Ok(eval_results)
122    }
123
124    /// Evaluates chat responses with tools from all providers in parallel
125    ///
126    /// # Arguments
127    /// * `messages` - Chat messages to send to each provider
128    /// * `tools` - Optional tools to use in the chat
129    ///
130    /// # Returns
131    /// Vector of evaluation results
132    pub async fn evaluate_chat_with_tools_parallel(
133        &self,
134        messages: &[ChatMessage],
135        tools: Option<&[Tool]>,
136    ) -> Result<Vec<ParallelEvalResult>, LLMError> {
137        let futures = self
138            .providers
139            .iter()
140            .map(|(id, provider)| {
141                let id = id.clone();
142                let messages = messages.to_vec();
143                let tools_clone = tools.map(|t| t.to_vec());
144                async move {
145                    let start = Instant::now();
146                    let result = provider
147                        .chat_with_tools(&messages, tools_clone.as_deref())
148                        .await;
149                    let elapsed = start.elapsed().as_millis();
150                    (id, result, elapsed)
151                }
152            })
153            .collect::<Vec<_>>();
154
155        let results = join_all(futures).await;
156
157        let mut eval_results = Vec::new();
158        for (id, result, elapsed) in results {
159            match result {
160                Ok(response) => {
161                    let text = response.text().unwrap_or_default();
162                    let score = self.compute_score(&text);
163                    eval_results.push(ParallelEvalResult {
164                        text,
165                        score,
166                        time_ms: elapsed,
167                        provider_id: id,
168                    });
169                }
170                Err(e) => {
171                    // Log the error but continue with other results
172                    eprintln!("Error from provider {}: {}", id, e);
173                }
174            }
175        }
176
177        Ok(eval_results)
178    }
179
180    /// Evaluates completion responses from all providers in parallel
181    ///
182    /// # Arguments
183    /// * `request` - Completion request to send to each provider
184    ///
185    /// # Returns
186    /// Vector of evaluation results
187    pub async fn evaluate_completion_parallel(
188        &self,
189        request: &CompletionRequest,
190    ) -> Result<Vec<ParallelEvalResult>, LLMError> {
191        let futures = self
192            .providers
193            .iter()
194            .map(|(id, provider)| {
195                let id = id.clone();
196                let request = request.clone();
197                async move {
198                    let start = Instant::now();
199                    let result = provider.complete(&request).await;
200                    let elapsed = start.elapsed().as_millis();
201                    (id, result, elapsed)
202                }
203            })
204            .collect::<Vec<_>>();
205
206        let results = join_all(futures).await;
207
208        let mut eval_results = Vec::new();
209        for (id, result, elapsed) in results {
210            match result {
211                Ok(response) => {
212                    let score = self.compute_score(&response.text);
213                    eval_results.push(ParallelEvalResult {
214                        text: response.text,
215                        score,
216                        time_ms: elapsed,
217                        provider_id: id,
218                    });
219                }
220                Err(e) => {
221                    // Log the error but continue with other results
222                    eprintln!("Error from provider {}: {}", id, e);
223                }
224            }
225        }
226
227        Ok(eval_results)
228    }
229
230    /// Returns the best response based on scoring
231    ///
232    /// # Arguments
233    /// * `results` - Vector of evaluation results
234    ///
235    /// # Returns
236    /// The best result or None if no results are available
237    pub fn best_response<'a>(
238        &self,
239        results: &'a [ParallelEvalResult],
240    ) -> Option<&'a ParallelEvalResult> {
241        if results.is_empty() {
242            return None;
243        }
244
245        results.iter().max_by(|a, b| {
246            a.score
247                .partial_cmp(&b.score)
248                .unwrap_or(std::cmp::Ordering::Equal)
249        })
250    }
251
252    /// Computes the score for a given response
253    ///
254    /// # Arguments
255    /// * `response` - The response to score
256    ///
257    /// # Returns
258    /// The computed score
259    fn compute_score(&self, response: &str) -> f32 {
260        let mut total = 0.0;
261        for sc in &self.scoring_fns {
262            total += sc(response);
263        }
264        total
265    }
266}