1use crate::error::EvaluationError;
2use crate::evaluate::evaluator::GenAIEvaluator;
3use crate::evaluate::types::{EvalResults, EvaluationConfig};
4use crate::genai::EvalDataset;
5use crate::tasks::evaluator::FieldEvaluator;
6use itertools::iproduct;
7use num_traits::FromPrimitive;
8use potato_head::{Embedder, EmbeddingInput, PyEmbedder};
9use pyo3::prelude::*;
10use rayon::prelude::*;
11use scouter_types::genai::EvalSet;
12use scouter_types::EvalRecord;
13use serde_json::Value;
14use simsimd::SpatialSimilarity;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use tokio::task::JoinSet;
18use tracing::{debug, error, warn};
19
20type EvalTaskResult = (
21 usize, Result<(EvalSet, BTreeMap<String, Vec<f32>>), String>,
23);
24
25pub async fn spawn_evaluation_tasks_without_embeddings(
34 dataset: &EvalDataset,
35 _config: &Arc<EvaluationConfig>,
36) -> JoinSet<EvalTaskResult> {
37 let mut join_set = JoinSet::new();
38
39 for (idx, _) in dataset.records.iter().enumerate() {
40 let record_ref = dataset.records.clone();
42 let profile_ref = dataset.profile.clone();
43 let spans_ref = dataset.spans.clone();
44
45 join_set.spawn(async move {
46 let record = &record_ref[idx];
48
49 debug!(
50 "Starting evaluation for record {} and index {}",
51 record.uid, idx
52 );
53
54 let result =
55 match GenAIEvaluator::process_event_record(record, profile_ref, spans_ref).await {
56 Ok(eval_set) => Ok((eval_set, BTreeMap::new())),
57 Err(e) => Err(format!("Evaluation failed: {}", e)),
58 };
59
60 (idx, result)
61 });
62 }
63
64 join_set
65}
66
67pub async fn spawn_evaluation_tasks_with_embeddings(
75 dataset: &EvalDataset,
76 embedder: Arc<Embedder>,
77 config: &Arc<EvaluationConfig>,
78) -> JoinSet<EvalTaskResult> {
79 let mut join_set = JoinSet::new();
80
81 for (idx, _) in dataset.records.iter().enumerate() {
82 let record_ref = dataset.records.clone();
83 let profile_ref = dataset.profile.clone();
84 let spans_ref = dataset.spans.clone();
85 let embedder_ref = embedder.clone();
86 let config_ref = config.clone();
87
88 join_set.spawn(async move {
89 let record = &record_ref[idx];
90
91 let embeddings = generate_embeddings_for_record(
93 record,
94 &embedder_ref,
95 &config_ref.embedding_targets,
96 )
97 .await;
98
99 let result =
101 match GenAIEvaluator::process_event_record(record, profile_ref, spans_ref).await {
102 Ok(eval_set) => Ok((eval_set, embeddings)),
103 Err(e) => Err(format!("Evaluation failed: {}", e)),
104 };
105
106 (idx, result)
107 });
108 }
109
110 join_set
111}
112
113pub async fn generate_embeddings_for_record(
120 record: &EvalRecord,
121 embedder: &Arc<Embedder>,
122 embedding_targets: &[String],
123) -> BTreeMap<String, Vec<f32>> {
124 let mut embeddings = BTreeMap::new();
125
126 for target in embedding_targets {
127 match FieldEvaluator::extract_field_value(&record.context, target) {
128 Ok(value) => {
129 let text = match value {
130 Value::String(s) => Some(s.clone()),
131 Value::Array(_) | Value::Object(_) => serde_json::to_string(value).ok(),
132 _ => {
133 warn!(
134 "Field '{}' has unsupported type for embedding: {:?}",
135 target, value
136 );
137 None
138 }
139 };
140
141 if let Some(text) = text {
142 match embedder.embed(EmbeddingInput::Texts(vec![text])).await {
143 Ok(embedding_response) => match embedding_response.values() {
144 Ok(values) => {
145 embeddings.insert(target.clone(), values.to_vec());
146 }
147 Err(e) => {
148 error!(
149 "Failed to extract embedding values for target '{}': {:?}",
150 target, e
151 );
152 }
153 },
154 Err(e) => {
155 error!(
156 "Failed to generate embedding for target '{}': {:?}",
157 target, e
158 );
159 }
160 }
161 }
162 }
163 Err(e) => {
164 warn!("Failed to extract field '{}' for embedding: {}", target, e);
165 }
166 }
167 }
168
169 embeddings
170}
171pub async fn collect_and_align_results(
173 mut join_set: JoinSet<EvalTaskResult>,
174 records: &Arc<Vec<EvalRecord>>,
175) -> Result<EvalResults, EvaluationError> {
176 let mut results = EvalResults::new();
177
178 while let Some(join_result) = join_set.join_next().await {
179 match join_result {
180 Ok((idx, eval_result)) => {
181 let record = &records[idx];
182
183 match eval_result {
184 Ok((eval_set, embeddings)) => {
185 results.add_success(record, eval_set, embeddings);
186 }
187 Err(error_msg) => {
188 results.add_failure(record, error_msg);
189 }
190 }
191 }
192 Err(join_error) => {
193 error!("Task join error: {:?}", join_error);
194 }
195 }
196 }
197
198 Ok(results)
199}
200
201pub fn post_process_aligned_results(
203 results: &mut EvalResults,
204 config: &Arc<EvaluationConfig>,
205) -> Result<(), EvaluationError> {
206 results.aligned_results.par_iter_mut().for_each(|aligned| {
207 for (target, values) in aligned.embeddings.iter() {
209 if let Some(mean) = compute_mean(values) {
210 aligned.mean_embeddings.insert(target.clone(), mean);
211 }
212 }
213
214 if config.compute_similarity {
216 compute_similarity(
217 &config.embedding_targets,
218 &aligned.embeddings,
219 &mut aligned.similarity_scores,
220 );
221 }
222 });
223
224 Ok(())
225}
226
227pub fn parse_embedder(
233 embedder: Option<&Bound<'_, PyAny>>,
234) -> Result<Option<Arc<Embedder>>, EvaluationError> {
235 let embedder_arc = if let Some(embedder_bound) = embedder {
237 if embedder_bound.is_instance_of::<PyEmbedder>() {
238 let py_embedder = embedder_bound.extract::<PyEmbedder>()?;
239 Some(py_embedder.embedder.clone())
240 } else {
241 return Err(EvaluationError::InvalidEmbedderType);
243 }
244 } else {
245 None
246 };
247 Ok(embedder_arc)
248}
249
250pub fn compute_mean(vec: &[f32]) -> Option<f64> {
253 match vec.len() {
254 0 => None,
255 _ => {
256 let sum = vec.iter().sum::<f32>();
257 let length = f32::from_usize(vec.len())?;
258
259 let mean = sum / length;
260 Some(mean as f64)
261 }
262 }
263}
264
265pub fn compute_similarity(
266 targets: &Vec<String>,
267 embeddings: &BTreeMap<String, Vec<f32>>,
268 scores: &mut BTreeMap<String, f64>,
269) {
270 for (a, b) in iproduct!(targets, targets) {
271 if a == b {
273 continue;
274 }
275 if let (Some(vec_a), Some(vec_b)) = (embeddings.get(a), embeddings.get(b)) {
276 if vec_a.len() != vec_b.len() {
277 warn!(
278 "Embedding length mismatch for targets {} and {}: {} vs {}",
279 a,
280 b,
281 vec_a.len(),
282 vec_b.len()
283 );
284 continue;
285 }
286
287 let similarity = f32::cosine(vec_a, vec_b).unwrap_or(-1.0);
288 let key = format!("{}_{}_cosine", a, b);
289 scores.insert(key, similarity);
290 } else {
291 warn!("Missing embeddings for targets {} or {}", a, b);
292 }
293 }
294}