Skip to main content

scouter_evaluate/
utils.rs

1use crate::error::EvaluationError;
2use crate::evaluate::evaluator::GenAIEvaluator;
3use crate::evaluate::types::{EvalResults, EvaluationConfig};
4use crate::genai::EvalDataset;
5use crate::tasks::evaluator::FieldEvaluator;
6use itertools::iproduct;
7use num_traits::FromPrimitive;
8use potato_head::{Embedder, EmbeddingInput, PyEmbedder};
9use pyo3::prelude::*;
10use rayon::prelude::*;
11use scouter_types::genai::EvalSet;
12use scouter_types::EvalRecord;
13use serde_json::Value;
14use simsimd::SpatialSimilarity;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use tokio::task::JoinSet;
18use tracing::{debug, error, warn};
19
20type EvalTaskResult = (
21    usize, // Index into records array
22    Result<(EvalSet, BTreeMap<String, Vec<f32>>), String>,
23);
24
25/// Spawn tasks without embedding support
26/// This function will spawn a task that runs the workflows and extracts the results
27/// If there is an error during workflow execution, it will log the error and return None for that record
28/// # Arguments
29/// * `workflow` - The workflow to execute for each record.
30/// * `records` - The list of EvalRecords to process.
31/// # Returns
32/// A JoinSet containing tuples of record ID and optional EvalTaskResult.
33pub async fn spawn_evaluation_tasks_without_embeddings(
34    dataset: &EvalDataset,
35    _config: &Arc<EvaluationConfig>,
36) -> JoinSet<EvalTaskResult> {
37    let mut join_set = JoinSet::new();
38
39    for (idx, _) in dataset.records.iter().enumerate() {
40        // cloning here so we can reference inside async move
41        let record_ref = dataset.records.clone();
42        let profile_ref = dataset.profile.clone();
43        let spans_ref = dataset.spans.clone();
44
45        join_set.spawn(async move {
46            // Access record by index - no cloning
47            let record = &record_ref[idx];
48
49            debug!(
50                "Starting evaluation for record {} and index {}",
51                record.uid, idx
52            );
53
54            let result =
55                match GenAIEvaluator::process_event_record(record, profile_ref, spans_ref).await {
56                    Ok(eval_set) => Ok((eval_set, BTreeMap::new())),
57                    Err(e) => Err(format!("Evaluation failed: {}", e)),
58                };
59
60            (idx, result)
61        });
62    }
63
64    join_set
65}
66
67/// Spawn tasks to run evaluation workflows with embedding calculations
68/// # Arguments
69/// * `dataset` - The EvalDataset containing records to evaluate.
70/// * `embedder` - The Embedder instance to use for generating embeddings.
71/// * `config` - The EvaluationConfig containing evaluation settings.
72/// # Returns
73/// A JoinSet containing GenAIEvalTaskResults for each record.
74pub async fn spawn_evaluation_tasks_with_embeddings(
75    dataset: &EvalDataset,
76    embedder: Arc<Embedder>,
77    config: &Arc<EvaluationConfig>,
78) -> JoinSet<EvalTaskResult> {
79    let mut join_set = JoinSet::new();
80
81    for (idx, _) in dataset.records.iter().enumerate() {
82        let record_ref = dataset.records.clone();
83        let profile_ref = dataset.profile.clone();
84        let spans_ref = dataset.spans.clone();
85        let embedder_ref = embedder.clone();
86        let config_ref = config.clone();
87
88        join_set.spawn(async move {
89            let record = &record_ref[idx];
90
91            // Generate embeddings
92            let embeddings = generate_embeddings_for_record(
93                record,
94                &embedder_ref,
95                &config_ref.embedding_targets,
96            )
97            .await;
98
99            // Execute evaluation
100            let result =
101                match GenAIEvaluator::process_event_record(record, profile_ref, spans_ref).await {
102                    Ok(eval_set) => Ok((eval_set, embeddings)),
103                    Err(e) => Err(format!("Evaluation failed: {}", e)),
104                };
105
106            (idx, result)
107        });
108    }
109
110    join_set
111}
112
113/// Helper for extracting embeddings for a single record. Used in the genai evaulation workflow.
114/// # Arguments
115/// * `record` - The EvalRecord to extract embeddings from.
116/// * `embedder` - The Embedder instance to use for generating embeddings.
117/// * `embedding_targets` - The list of keys in the record's context to generate embeddings for.
118/// # Returns
119pub async fn generate_embeddings_for_record(
120    record: &EvalRecord,
121    embedder: &Arc<Embedder>,
122    embedding_targets: &[String],
123) -> BTreeMap<String, Vec<f32>> {
124    let mut embeddings = BTreeMap::new();
125
126    for target in embedding_targets {
127        match FieldEvaluator::extract_field_value(&record.context, target) {
128            Ok(value) => {
129                let text = match value {
130                    Value::String(s) => Some(s.clone()),
131                    Value::Array(_) | Value::Object(_) => serde_json::to_string(value).ok(),
132                    _ => {
133                        warn!(
134                            "Field '{}' has unsupported type for embedding: {:?}",
135                            target, value
136                        );
137                        None
138                    }
139                };
140
141                if let Some(text) = text {
142                    match embedder.embed(EmbeddingInput::Texts(vec![text])).await {
143                        Ok(embedding_response) => match embedding_response.values() {
144                            Ok(values) => {
145                                embeddings.insert(target.clone(), values.to_vec());
146                            }
147                            Err(e) => {
148                                error!(
149                                    "Failed to extract embedding values for target '{}': {:?}",
150                                    target, e
151                                );
152                            }
153                        },
154                        Err(e) => {
155                            error!(
156                                "Failed to generate embedding for target '{}': {:?}",
157                                target, e
158                            );
159                        }
160                    }
161                }
162            }
163            Err(e) => {
164                warn!("Failed to extract field '{}' for embedding: {}", target, e);
165            }
166        }
167    }
168
169    embeddings
170}
171/// Collect and align results with original records
172pub async fn collect_and_align_results(
173    mut join_set: JoinSet<EvalTaskResult>,
174    records: &Arc<Vec<EvalRecord>>,
175) -> Result<EvalResults, EvaluationError> {
176    let mut results = EvalResults::new();
177
178    while let Some(join_result) = join_set.join_next().await {
179        match join_result {
180            Ok((idx, eval_result)) => {
181                let record = &records[idx];
182
183                match eval_result {
184                    Ok((eval_set, embeddings)) => {
185                        results.add_success(record, eval_set, embeddings);
186                    }
187                    Err(error_msg) => {
188                        results.add_failure(record, error_msg);
189                    }
190                }
191            }
192            Err(join_error) => {
193                error!("Task join error: {:?}", join_error);
194            }
195        }
196    }
197
198    Ok(results)
199}
200
201/// Post-process aligned results
202pub fn post_process_aligned_results(
203    results: &mut EvalResults,
204    config: &Arc<EvaluationConfig>,
205) -> Result<(), EvaluationError> {
206    results.aligned_results.par_iter_mut().for_each(|aligned| {
207        // Compute embedding means
208        for (target, values) in aligned.embeddings.iter() {
209            if let Some(mean) = compute_mean(values) {
210                aligned.mean_embeddings.insert(target.clone(), mean);
211            }
212        }
213
214        // Compute similarities
215        if config.compute_similarity {
216            compute_similarity(
217                &config.embedding_targets,
218                &aligned.embeddings,
219                &mut aligned.similarity_scores,
220            );
221        }
222    });
223
224    Ok(())
225}
226
227/// Helper function for extracting embedder and runtime from optional PyEmbedder
228/// # Arguments
229/// * `embedder` - Optional reference to a PyEmbedder instance.
230/// # Returns
231/// An optional Arc-wrapped Embedder instance if provided, otherwise None.
232pub fn parse_embedder(
233    embedder: Option<&Bound<'_, PyAny>>,
234) -> Result<Option<Arc<Embedder>>, EvaluationError> {
235    // Extract embedder and runtime if PyEmbedder is provided
236    let embedder_arc = if let Some(embedder_bound) = embedder {
237        if embedder_bound.is_instance_of::<PyEmbedder>() {
238            let py_embedder = embedder_bound.extract::<PyEmbedder>()?;
239            Some(py_embedder.embedder.clone())
240        } else {
241            // embedder provided but not a PyEmbedder instance
242            return Err(EvaluationError::InvalidEmbedderType);
243        }
244    } else {
245        None
246    };
247    Ok(embedder_arc)
248}
249
250/// Calculate the mean of for a slice of f32 values
251/// There's no need for a generic implementation here, as we only need f32 for embeddings
252pub fn compute_mean(vec: &[f32]) -> Option<f64> {
253    match vec.len() {
254        0 => None,
255        _ => {
256            let sum = vec.iter().sum::<f32>();
257            let length = f32::from_usize(vec.len())?;
258
259            let mean = sum / length;
260            Some(mean as f64)
261        }
262    }
263}
264
265pub fn compute_similarity(
266    targets: &Vec<String>,
267    embeddings: &BTreeMap<String, Vec<f32>>,
268    scores: &mut BTreeMap<String, f64>,
269) {
270    for (a, b) in iproduct!(targets, targets) {
271        // only want unique pairs
272        if a == b {
273            continue;
274        }
275        if let (Some(vec_a), Some(vec_b)) = (embeddings.get(a), embeddings.get(b)) {
276            if vec_a.len() != vec_b.len() {
277                warn!(
278                    "Embedding length mismatch for targets {} and {}: {} vs {}",
279                    a,
280                    b,
281                    vec_a.len(),
282                    vec_b.len()
283                );
284                continue;
285            }
286
287            let similarity = f32::cosine(vec_a, vec_b).unwrap_or(-1.0);
288            let key = format!("{}_{}_cosine", a, b);
289            scores.insert(key, similarity);
290        } else {
291            warn!("Missing embeddings for targets {} or {}", a, b);
292        }
293    }
294}