Skip to main content

scouter_evaluate/
utils.rs

1use crate::error::EvaluationError;
2use crate::evaluate::evaluator::GenAIEvaluator;
3use crate::evaluate::types::{EvaluationConfig, GenAIEvalResults};
4use crate::genai::GenAIEvalDataset;
5use crate::tasks::evaluator::FieldEvaluator;
6use itertools::iproduct;
7use num_traits::FromPrimitive;
8use potato_head::{Embedder, EmbeddingInput, PyEmbedder};
9use pyo3::prelude::*;
10use rayon::prelude::*;
11use scouter_types::genai::GenAIEvalSet;
12use scouter_types::GenAIEvalRecord;
13use serde_json::Value;
14use simsimd::SpatialSimilarity;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use tokio::task::JoinSet;
18use tracing::{debug, error, warn};
19
20type EvalTaskResult = (
21    usize, // Index into records array
22    Result<(GenAIEvalSet, BTreeMap<String, Vec<f32>>), String>,
23);
24
25/// Spawn tasks without embedding support
26/// This function will spawn a task that runs the workflows and extracts the results
27/// If there is an error during workflow execution, it will log the error and return None for that record
28/// # Arguments
29/// * `workflow` - The workflow to execute for each record.
30/// * `records` - The list of GenAIEvalRecords to process.
31/// # Returns
32/// A JoinSet containing tuples of record ID and optional GenAIEvalTaskResult.
33pub async fn spawn_evaluation_tasks_without_embeddings(
34    dataset: &GenAIEvalDataset,
35    _config: &Arc<EvaluationConfig>,
36) -> JoinSet<EvalTaskResult> {
37    let mut join_set = JoinSet::new();
38
39    for (idx, _) in dataset.records.iter().enumerate() {
40        // cloning here so we can reference inside async move
41        let record_ref = dataset.records.clone();
42        let profile_ref = dataset.profile.clone();
43
44        join_set.spawn(async move {
45            // Access record by index - no cloning
46            let record = &record_ref[idx];
47
48            debug!(
49                "Starting evaluation for record {} and index {}",
50                record.uid, idx
51            );
52
53            let result =
54                match GenAIEvaluator::process_event_record(record, profile_ref, Arc::new(vec![]))
55                    .await
56                {
57                    Ok(eval_set) => Ok((eval_set, BTreeMap::new())),
58                    Err(e) => Err(format!("Evaluation failed: {}", e)),
59                };
60
61            (idx, result)
62        });
63    }
64
65    join_set
66}
67
68/// Spawn tasks to run evaluation workflows with embedding calculations
69/// # Arguments
70/// * `dataset` - The GenAIEvalDataset containing records to evaluate.
71/// * `embedder` - The Embedder instance to use for generating embeddings.
72/// * `config` - The EvaluationConfig containing evaluation settings.
73/// # Returns
74/// A JoinSet containing GenAIEvalTaskResults for each record.
75pub async fn spawn_evaluation_tasks_with_embeddings(
76    dataset: &GenAIEvalDataset,
77    embedder: Arc<Embedder>,
78    config: &Arc<EvaluationConfig>,
79) -> JoinSet<EvalTaskResult> {
80    let mut join_set = JoinSet::new();
81
82    for (idx, _) in dataset.records.iter().enumerate() {
83        let record_ref = dataset.records.clone();
84        let profile_ref = dataset.profile.clone();
85        let embedder_ref = embedder.clone();
86        let config_ref = config.clone();
87
88        join_set.spawn(async move {
89            let record = &record_ref[idx];
90
91            // Generate embeddings
92            let embeddings = generate_embeddings_for_record(
93                record,
94                &embedder_ref,
95                &config_ref.embedding_targets,
96            )
97            .await;
98
99            // Execute evaluation
100            let result =
101                match GenAIEvaluator::process_event_record(record, profile_ref, Arc::new(vec![]))
102                    .await
103                {
104                    Ok(eval_set) => Ok((eval_set, embeddings)),
105                    Err(e) => Err(format!("Evaluation failed: {}", e)),
106                };
107
108            (idx, result)
109        });
110    }
111
112    join_set
113}
114
115/// Helper for extracting embeddings for a single record. Used in the genai evaulation workflow.
116/// # Arguments
117/// * `record` - The GenAIEvalRecord to extract embeddings from.
118/// * `embedder` - The Embedder instance to use for generating embeddings.
119/// * `embedding_targets` - The list of keys in the record's context to generate embeddings for.
120/// # Returns
121pub async fn generate_embeddings_for_record(
122    record: &GenAIEvalRecord,
123    embedder: &Arc<Embedder>,
124    embedding_targets: &[String],
125) -> BTreeMap<String, Vec<f32>> {
126    let mut embeddings = BTreeMap::new();
127
128    for target in embedding_targets {
129        match FieldEvaluator::extract_field_value(&record.context, target) {
130            Ok(value) => {
131                let text = match value {
132                    Value::String(s) => Some(s.clone()),
133                    Value::Array(_) | Value::Object(_) => serde_json::to_string(value).ok(),
134                    _ => {
135                        warn!(
136                            "Field '{}' has unsupported type for embedding: {:?}",
137                            target, value
138                        );
139                        None
140                    }
141                };
142
143                if let Some(text) = text {
144                    match embedder.embed(EmbeddingInput::Texts(vec![text])).await {
145                        Ok(embedding_response) => match embedding_response.values() {
146                            Ok(values) => {
147                                embeddings.insert(target.clone(), values.to_vec());
148                            }
149                            Err(e) => {
150                                error!(
151                                    "Failed to extract embedding values for target '{}': {:?}",
152                                    target, e
153                                );
154                            }
155                        },
156                        Err(e) => {
157                            error!(
158                                "Failed to generate embedding for target '{}': {:?}",
159                                target, e
160                            );
161                        }
162                    }
163                }
164            }
165            Err(e) => {
166                warn!("Failed to extract field '{}' for embedding: {}", target, e);
167            }
168        }
169    }
170
171    embeddings
172}
173/// Collect and align results with original records
174pub async fn collect_and_align_results(
175    mut join_set: JoinSet<EvalTaskResult>,
176    records: &Arc<Vec<GenAIEvalRecord>>,
177) -> Result<GenAIEvalResults, EvaluationError> {
178    let mut results = GenAIEvalResults::new();
179
180    while let Some(join_result) = join_set.join_next().await {
181        match join_result {
182            Ok((idx, eval_result)) => {
183                let record = &records[idx];
184
185                match eval_result {
186                    Ok((eval_set, embeddings)) => {
187                        results.add_success(record, eval_set, embeddings);
188                    }
189                    Err(error_msg) => {
190                        results.add_failure(record, error_msg);
191                    }
192                }
193            }
194            Err(join_error) => {
195                error!("Task join error: {:?}", join_error);
196            }
197        }
198    }
199
200    Ok(results)
201}
202
203/// Post-process aligned results
204pub fn post_process_aligned_results(
205    results: &mut GenAIEvalResults,
206    config: &Arc<EvaluationConfig>,
207) -> Result<(), EvaluationError> {
208    results.aligned_results.par_iter_mut().for_each(|aligned| {
209        // Compute embedding means
210        for (target, values) in aligned.embeddings.iter() {
211            if let Some(mean) = compute_mean(values) {
212                aligned.mean_embeddings.insert(target.clone(), mean);
213            }
214        }
215
216        // Compute similarities
217        if config.compute_similarity {
218            compute_similarity(
219                &config.embedding_targets,
220                &aligned.embeddings,
221                &mut aligned.similarity_scores,
222            );
223        }
224    });
225
226    Ok(())
227}
228
229/// Helper function for extracting embedder and runtime from optional PyEmbedder
230/// # Arguments
231/// * `embedder` - Optional reference to a PyEmbedder instance.
232/// # Returns
233/// An optional Arc-wrapped Embedder instance if provided, otherwise None.
234pub fn parse_embedder(
235    embedder: Option<&Bound<'_, PyAny>>,
236) -> Result<Option<Arc<Embedder>>, EvaluationError> {
237    // Extract embedder and runtime if PyEmbedder is provided
238    let embedder_arc = if let Some(embedder_bound) = embedder {
239        if embedder_bound.is_instance_of::<PyEmbedder>() {
240            let py_embedder = embedder_bound.extract::<PyEmbedder>()?;
241            Some(py_embedder.embedder.clone())
242        } else {
243            // embedder provided but not a PyEmbedder instance
244            return Err(EvaluationError::InvalidEmbedderType);
245        }
246    } else {
247        None
248    };
249    Ok(embedder_arc)
250}
251
252/// Calculate the mean of for a slice of f32 values
253/// There's no need for a generic implementation here, as we only need f32 for embeddings
254pub fn compute_mean(vec: &[f32]) -> Option<f64> {
255    match vec.len() {
256        0 => None,
257        _ => {
258            let sum = vec.iter().sum::<f32>();
259            let length = f32::from_usize(vec.len())?;
260
261            let mean = sum / length;
262            Some(mean as f64)
263        }
264    }
265}
266
267pub fn compute_similarity(
268    targets: &Vec<String>,
269    embeddings: &BTreeMap<String, Vec<f32>>,
270    scores: &mut BTreeMap<String, f64>,
271) {
272    for (a, b) in iproduct!(targets, targets) {
273        // only want unique pairs
274        if a == b {
275            continue;
276        }
277        if let (Some(vec_a), Some(vec_b)) = (embeddings.get(a), embeddings.get(b)) {
278            if vec_a.len() != vec_b.len() {
279                warn!(
280                    "Embedding length mismatch for targets {} and {}: {} vs {}",
281                    a,
282                    b,
283                    vec_a.len(),
284                    vec_b.len()
285                );
286                continue;
287            }
288
289            let similarity = f32::cosine(vec_a, vec_b).unwrap_or(-1.0);
290            let key = format!("{}_{}_cosine", a, b);
291            scores.insert(key, similarity);
292        } else {
293            warn!("Missing embeddings for targets {} or {}", a, b);
294        }
295    }
296}