scouter_evaluate/
util.rs

1use crate::error::EvaluationError;
2use crate::types::{EvaluationConfig, LLMEvalRecord, LLMEvalResults, LLMEvalTaskResult};
3use itertools::iproduct;
4use num_traits::FromPrimitive;
5use potato_head::{
6    Embedder, EmbeddingInput, PyEmbedder, Score, StructuredOutput, TaskStatus, Workflow,
7    WorkflowError,
8};
9use pyo3::prelude::*;
10use rayon::prelude::*;
11use simsimd::SpatialSimilarity;
12use std::collections::BTreeMap;
13use std::sync::{Arc, RwLock};
14use tokio::task::JoinSet;
15use tracing::{error, warn};
16/// Process a workflow result and extract scores from completed tasks
17pub fn process_workflow_result(
18    workflow_result: Arc<RwLock<Workflow>>,
19) -> Result<BTreeMap<String, Score>, EvaluationError> {
20    let mut metrics = BTreeMap::new();
21
22    let workflow = workflow_result
23        .read()
24        .map_err(|_| WorkflowError::LockAcquireError)?;
25
26    let tasks = workflow.task_list.tasks();
27
28    // iterate of each task and extract score
29    for task in tasks.values() {
30        if let (TaskStatus::Completed, Some(result)) = (&task.status, &task.result) {
31            if let Some(content) = result.content() {
32                match Score::model_validate_json_str(&content) {
33                    Ok(score) => {
34                        metrics.insert(task.id.clone(), score);
35                    }
36                    Err(e) => {
37                        error!("Failed to validate score for task {}: {:?}", task.id, e);
38                        // Continue processing other tasks instead of failing completely
39                    }
40                }
41            }
42        }
43    }
44
45    Ok(metrics)
46}
47
48/// Spawn tasks without embedding support
49/// This function will spawn a task that runs the workflows and extracts the results
50/// If there is an error during workflow execution, it will log the error and return None for that record
51/// # Arguments
52/// * `workflow` - The workflow to execute for each record.
53/// * `records` - The list of LLMEvalRecords to process.
54/// # Returns
55/// A JoinSet containing tuples of record ID and optional LLMEvalTaskResult.
56pub async fn spawn_evaluation_tasks_without_embeddings(
57    workflow: Workflow,
58    records: Vec<LLMEvalRecord>,
59) -> JoinSet<(String, Option<LLMEvalTaskResult>)> {
60    let mut join_set = JoinSet::new();
61
62    for record in records {
63        let inner_workflow = workflow.clone();
64
65        join_set.spawn(async move {
66            let record_id = record.id.clone();
67
68            match inner_workflow.run(Some(record.context)).await {
69                Ok(workflow_result) => {
70                    // parse the workflow result
71                    match process_workflow_result(workflow_result) {
72                        Ok(metrics) => (
73                            record_id.clone(),
74                            Some(LLMEvalTaskResult::new(record_id, metrics, BTreeMap::new())),
75                        ),
76                        Err(error) => {
77                            error!(
78                                "Failed to process workflow result for record {}: {}",
79                                record_id, error
80                            );
81                            (record_id, None)
82                        }
83                    }
84                }
85                Err(workflow_error) => {
86                    error!(
87                        "Workflow execution failed for record {}: {}",
88                        record_id, workflow_error
89                    );
90                    (record_id, None)
91                }
92            }
93        });
94    }
95
96    join_set
97}
98
99/// Spawn tasks to run evaluation workflows with embedding calculations
100/// # Arguments
101/// * `workflow` - The workflow to execute for each record.
102/// * `records` - The list of LLMEvalRecords to process.
103/// * `embedder` - The Embedder instance to use for generating embeddings.
104/// * `embedding_targets` - The list of keys in the record's context to generate embeddings.
105/// # Returns
106/// A JoinSet containing LLMEvalTaskResults for each record.
107pub async fn spawn_evaluation_tasks_with_embeddings(
108    workflow: Workflow,
109    records: Vec<LLMEvalRecord>,
110    embedder: Arc<Embedder>,
111    config: &Arc<EvaluationConfig>,
112) -> JoinSet<(String, Option<LLMEvalTaskResult>)> {
113    let mut join_set = JoinSet::new();
114
115    for record in records {
116        let inner_workflow = workflow.clone();
117        let cloned_embedder = embedder.clone();
118        let cloned_config = config.clone();
119
120        join_set.spawn(async move {
121            let record_id = record.id.clone();
122
123            // Generate embeddings
124            // We do this first because the workflow will consume the context
125            let embeddings = generate_embeddings_for_record(
126                &record,
127                &cloned_embedder,
128                &cloned_config.embedding_targets,
129            )
130            .await;
131
132            // Run workflow
133            match inner_workflow.run(Some(record.context)).await {
134                Ok(workflow_result) => {
135                    // parse the workflow result
136                    match process_workflow_result(workflow_result) {
137                        Ok(metrics) => (
138                            record_id.clone(),
139                            Some(LLMEvalTaskResult::new(record_id, metrics, embeddings)),
140                        ),
141                        Err(error) => {
142                            error!(
143                                "Failed to process workflow result for record {}: {}",
144                                record_id, error
145                            );
146                            (record_id, None)
147                        }
148                    }
149                }
150                Err(workflow_error) => {
151                    error!(
152                        "Workflow execution failed for record {}: {}",
153                        record_id, workflow_error
154                    );
155                    (record_id, None)
156                }
157            }
158        });
159    }
160
161    join_set
162}
163
164/// Helper for extracting embeddings for a single record. Used in the llm evaulation workflow.
165/// # Arguments
166/// * `record` - The LLMEvalRecord to extract embeddings from.
167/// * `embedder` - The Embedder instance to use for generating embeddings.
168/// * `embedding_targets` - The list of keys in the record's context to generate embeddings for.
169/// # Returns
170pub async fn generate_embeddings_for_record(
171    record: &LLMEvalRecord,
172    embedder: &Arc<Embedder>,
173    embedding_targets: &[String],
174) -> BTreeMap<String, Vec<f32>> {
175    let mut embeddings = BTreeMap::new();
176
177    for target in embedding_targets {
178        let texts = record
179            .context
180            .get(target)
181            .and_then(|v| v.as_str())
182            .map(|s| vec![s.to_string()]);
183
184        if let Some(texts) = texts {
185            match embedder.embed(EmbeddingInput::Texts(texts)).await {
186                Ok(embedding_response) => match embedding_response.values() {
187                    Ok(values) => {
188                        // move ownership of values into Embedding struct
189                        embeddings.insert(target.clone(), values.to_vec());
190                    }
191                    Err(e) => {
192                        error!(
193                            "Failed to extract embedding values for target {}: {:?}",
194                            target, e
195                        );
196                    }
197                },
198                Err(e) => {
199                    error!(
200                        "Failed to generate embedding for target {}: {:?}",
201                        target, e
202                    );
203                }
204            }
205        } else {
206            warn!("No text found for embedding target: {}", target);
207        }
208    }
209
210    embeddings
211}
212
213/// Enhanced result collection with proper error handling
214pub async fn collect_evaluation_results(
215    mut join_set: JoinSet<(String, Option<LLMEvalTaskResult>)>,
216) -> Result<LLMEvalResults, EvaluationError> {
217    let mut eval_results = LLMEvalResults::new();
218
219    while let Some(join_result) = join_set.join_next().await {
220        match join_result {
221            Ok(task_result) => {
222                let (record_id, task_result_opt) = task_result;
223                if let Some(task_result) = task_result_opt {
224                    eval_results.results.entry(record_id).or_insert(task_result);
225                } else {
226                    eval_results.errored_tasks.push(record_id);
227                }
228            }
229            Err(join_error) => {
230                error!("Task join error: {:?}", join_error);
231                // We can't associate this error with a specific record ID
232            }
233        }
234    }
235
236    Ok(eval_results)
237}
238
239/// Helper function for extracting embedder and runtime from optional PyEmbedder
240/// # Arguments
241/// * `embedder` - Optional reference to a PyEmbedder instance.
242/// # Returns
243/// An optional Arc-wrapped Embedder instance if provided, otherwise None.
244pub fn parse_embedder(
245    embedder: Option<&Bound<'_, PyAny>>,
246) -> Result<Option<Arc<Embedder>>, EvaluationError> {
247    // Extract embedder and runtime if PyEmbedder is provided
248    let embedder_arc = if let Some(embedder_bound) = embedder {
249        if embedder_bound.is_instance_of::<PyEmbedder>() {
250            let py_embedder = embedder_bound.extract::<PyEmbedder>()?;
251            Some(py_embedder.embedder.clone())
252        } else {
253            // embedder provided but not a PyEmbedder instance
254            return Err(EvaluationError::InvalidEmbedderType);
255        }
256    } else {
257        None
258    };
259    Ok(embedder_arc)
260}
261
262/// Calculate the mean of for a slice of f32 values
263/// There's no need for a generic implementation here, as we only need f32 for embeddings
264pub fn compute_mean(vec: &[f32]) -> Option<f64> {
265    match vec.len() {
266        0 => None,
267        _ => {
268            let sum = vec.iter().sum::<f32>();
269            let length = f32::from_usize(vec.len())?;
270
271            let mean = sum / length;
272            Some(mean as f64)
273        }
274    }
275}
276
277pub fn compute_similarity(
278    targets: &Vec<String>,
279    embeddings: &BTreeMap<String, Vec<f32>>,
280    scores: &mut BTreeMap<String, f64>,
281) {
282    for (a, b) in iproduct!(targets, targets) {
283        // only want unique pairs
284        if a == b {
285            continue;
286        }
287        if let (Some(vec_a), Some(vec_b)) = (embeddings.get(a), embeddings.get(b)) {
288            if vec_a.len() != vec_b.len() {
289                warn!(
290                    "Embedding length mismatch for targets {} and {}: {} vs {}",
291                    a,
292                    b,
293                    vec_a.len(),
294                    vec_b.len()
295                );
296                continue;
297            }
298
299            let similarity = f32::cosine(vec_a, vec_b).unwrap_or(-1.0);
300            let key = format!("{}_{}_cosine", a, b);
301            scores.insert(key, similarity);
302        } else {
303            warn!("Missing embeddings for targets {} or {}", a, b);
304        }
305    }
306}
307
308pub fn post_process(results: &mut LLMEvalResults, config: &Arc<EvaluationConfig>) {
309    // compute means for each embedding target
310    results.results.par_iter_mut().for_each(|(_, task_result)| {
311        for (target, values) in task_result.embedding.iter() {
312            let mean = compute_mean(values).unwrap_or(-1.0);
313            task_result.mean_embeddings.insert(target.clone(), mean);
314        }
315        compute_similarity(
316            &config.embedding_targets,
317            &task_result.embedding,
318            &mut task_result.similarity_scores,
319        );
320    });
321}