1use crate::error::EvaluationError;
2use crate::evaluate::evaluator::GenAIEvaluator;
3use crate::evaluate::types::{EvaluationConfig, GenAIEvalResults};
4use crate::genai::GenAIEvalDataset;
5use crate::tasks::evaluator::FieldEvaluator;
6use itertools::iproduct;
7use num_traits::FromPrimitive;
8use potato_head::{Embedder, EmbeddingInput, PyEmbedder};
9use pyo3::prelude::*;
10use rayon::prelude::*;
11use scouter_types::genai::GenAIEvalSet;
12use scouter_types::GenAIEvalRecord;
13use serde_json::Value;
14use simsimd::SpatialSimilarity;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use tokio::task::JoinSet;
18use tracing::{debug, error, warn};
19
20type EvalTaskResult = (
21 usize, Result<(GenAIEvalSet, BTreeMap<String, Vec<f32>>), String>,
23);
24
25pub async fn spawn_evaluation_tasks_without_embeddings(
34 dataset: &GenAIEvalDataset,
35 _config: &Arc<EvaluationConfig>,
36) -> JoinSet<EvalTaskResult> {
37 let mut join_set = JoinSet::new();
38
39 for (idx, _) in dataset.records.iter().enumerate() {
40 let record_ref = dataset.records.clone();
42 let profile_ref = dataset.profile.clone();
43
44 join_set.spawn(async move {
45 let record = &record_ref[idx];
47
48 debug!(
49 "Starting evaluation for record {} and index {}",
50 record.uid, idx
51 );
52
53 let result =
54 match GenAIEvaluator::process_event_record(record, profile_ref, Arc::new(vec![]))
55 .await
56 {
57 Ok(eval_set) => Ok((eval_set, BTreeMap::new())),
58 Err(e) => Err(format!("Evaluation failed: {}", e)),
59 };
60
61 (idx, result)
62 });
63 }
64
65 join_set
66}
67
68pub async fn spawn_evaluation_tasks_with_embeddings(
76 dataset: &GenAIEvalDataset,
77 embedder: Arc<Embedder>,
78 config: &Arc<EvaluationConfig>,
79) -> JoinSet<EvalTaskResult> {
80 let mut join_set = JoinSet::new();
81
82 for (idx, _) in dataset.records.iter().enumerate() {
83 let record_ref = dataset.records.clone();
84 let profile_ref = dataset.profile.clone();
85 let embedder_ref = embedder.clone();
86 let config_ref = config.clone();
87
88 join_set.spawn(async move {
89 let record = &record_ref[idx];
90
91 let embeddings = generate_embeddings_for_record(
93 record,
94 &embedder_ref,
95 &config_ref.embedding_targets,
96 )
97 .await;
98
99 let result =
101 match GenAIEvaluator::process_event_record(record, profile_ref, Arc::new(vec![]))
102 .await
103 {
104 Ok(eval_set) => Ok((eval_set, embeddings)),
105 Err(e) => Err(format!("Evaluation failed: {}", e)),
106 };
107
108 (idx, result)
109 });
110 }
111
112 join_set
113}
114
115pub async fn generate_embeddings_for_record(
122 record: &GenAIEvalRecord,
123 embedder: &Arc<Embedder>,
124 embedding_targets: &[String],
125) -> BTreeMap<String, Vec<f32>> {
126 let mut embeddings = BTreeMap::new();
127
128 for target in embedding_targets {
129 match FieldEvaluator::extract_field_value(&record.context, target) {
130 Ok(value) => {
131 let text = match value {
132 Value::String(s) => Some(s.clone()),
133 Value::Array(_) | Value::Object(_) => serde_json::to_string(value).ok(),
134 _ => {
135 warn!(
136 "Field '{}' has unsupported type for embedding: {:?}",
137 target, value
138 );
139 None
140 }
141 };
142
143 if let Some(text) = text {
144 match embedder.embed(EmbeddingInput::Texts(vec![text])).await {
145 Ok(embedding_response) => match embedding_response.values() {
146 Ok(values) => {
147 embeddings.insert(target.clone(), values.to_vec());
148 }
149 Err(e) => {
150 error!(
151 "Failed to extract embedding values for target '{}': {:?}",
152 target, e
153 );
154 }
155 },
156 Err(e) => {
157 error!(
158 "Failed to generate embedding for target '{}': {:?}",
159 target, e
160 );
161 }
162 }
163 }
164 }
165 Err(e) => {
166 warn!("Failed to extract field '{}' for embedding: {}", target, e);
167 }
168 }
169 }
170
171 embeddings
172}
173pub async fn collect_and_align_results(
175 mut join_set: JoinSet<EvalTaskResult>,
176 records: &Arc<Vec<GenAIEvalRecord>>,
177) -> Result<GenAIEvalResults, EvaluationError> {
178 let mut results = GenAIEvalResults::new();
179
180 while let Some(join_result) = join_set.join_next().await {
181 match join_result {
182 Ok((idx, eval_result)) => {
183 let record = &records[idx];
184
185 match eval_result {
186 Ok((eval_set, embeddings)) => {
187 results.add_success(record, eval_set, embeddings);
188 }
189 Err(error_msg) => {
190 results.add_failure(record, error_msg);
191 }
192 }
193 }
194 Err(join_error) => {
195 error!("Task join error: {:?}", join_error);
196 }
197 }
198 }
199
200 Ok(results)
201}
202
203pub fn post_process_aligned_results(
205 results: &mut GenAIEvalResults,
206 config: &Arc<EvaluationConfig>,
207) -> Result<(), EvaluationError> {
208 results.aligned_results.par_iter_mut().for_each(|aligned| {
209 for (target, values) in aligned.embeddings.iter() {
211 if let Some(mean) = compute_mean(values) {
212 aligned.mean_embeddings.insert(target.clone(), mean);
213 }
214 }
215
216 if config.compute_similarity {
218 compute_similarity(
219 &config.embedding_targets,
220 &aligned.embeddings,
221 &mut aligned.similarity_scores,
222 );
223 }
224 });
225
226 Ok(())
227}
228
229pub fn parse_embedder(
235 embedder: Option<&Bound<'_, PyAny>>,
236) -> Result<Option<Arc<Embedder>>, EvaluationError> {
237 let embedder_arc = if let Some(embedder_bound) = embedder {
239 if embedder_bound.is_instance_of::<PyEmbedder>() {
240 let py_embedder = embedder_bound.extract::<PyEmbedder>()?;
241 Some(py_embedder.embedder.clone())
242 } else {
243 return Err(EvaluationError::InvalidEmbedderType);
245 }
246 } else {
247 None
248 };
249 Ok(embedder_arc)
250}
251
252pub fn compute_mean(vec: &[f32]) -> Option<f64> {
255 match vec.len() {
256 0 => None,
257 _ => {
258 let sum = vec.iter().sum::<f32>();
259 let length = f32::from_usize(vec.len())?;
260
261 let mean = sum / length;
262 Some(mean as f64)
263 }
264 }
265}
266
267pub fn compute_similarity(
268 targets: &Vec<String>,
269 embeddings: &BTreeMap<String, Vec<f32>>,
270 scores: &mut BTreeMap<String, f64>,
271) {
272 for (a, b) in iproduct!(targets, targets) {
273 if a == b {
275 continue;
276 }
277 if let (Some(vec_a), Some(vec_b)) = (embeddings.get(a), embeddings.get(b)) {
278 if vec_a.len() != vec_b.len() {
279 warn!(
280 "Embedding length mismatch for targets {} and {}: {} vs {}",
281 a,
282 b,
283 vec_a.len(),
284 vec_b.len()
285 );
286 continue;
287 }
288
289 let similarity = f32::cosine(vec_a, vec_b).unwrap_or(-1.0);
290 let key = format!("{}_{}_cosine", a, b);
291 scores.insert(key, similarity);
292 } else {
293 warn!("Missing embeddings for targets {} or {}", a, b);
294 }
295 }
296}