1use crate::error::EvaluationError;
2use crate::types::{EvaluationConfig, LLMEvalRecord, LLMEvalResults, LLMEvalTaskResult};
3use itertools::iproduct;
4use num_traits::FromPrimitive;
5use potato_head::{
6 Embedder, EmbeddingInput, PyEmbedder, Score, StructuredOutput, TaskStatus, Workflow,
7 WorkflowError,
8};
9use pyo3::prelude::*;
10use rayon::prelude::*;
11use simsimd::SpatialSimilarity;
12use std::collections::BTreeMap;
13use std::sync::{Arc, RwLock};
14use tokio::task::JoinSet;
15use tracing::{error, warn};
16pub fn process_workflow_result(
18 workflow_result: Arc<RwLock<Workflow>>,
19) -> Result<BTreeMap<String, Score>, EvaluationError> {
20 let mut metrics = BTreeMap::new();
21
22 let workflow = workflow_result
23 .read()
24 .map_err(|_| WorkflowError::LockAcquireError)?;
25
26 let tasks = workflow.task_list.tasks();
27
28 for task in tasks.values() {
30 if let (TaskStatus::Completed, Some(result)) = (&task.status, &task.result) {
31 if let Some(content) = result.content() {
32 match Score::model_validate_json_str(&content) {
33 Ok(score) => {
34 metrics.insert(task.id.clone(), score);
35 }
36 Err(e) => {
37 error!("Failed to validate score for task {}: {:?}", task.id, e);
38 }
40 }
41 }
42 }
43 }
44
45 Ok(metrics)
46}
47
48pub async fn spawn_evaluation_tasks_without_embeddings(
57 workflow: Workflow,
58 records: Vec<LLMEvalRecord>,
59) -> JoinSet<(String, Option<LLMEvalTaskResult>)> {
60 let mut join_set = JoinSet::new();
61
62 for record in records {
63 let inner_workflow = workflow.clone();
64
65 join_set.spawn(async move {
66 let record_id = record.id.clone();
67
68 match inner_workflow.run(Some(record.context)).await {
69 Ok(workflow_result) => {
70 match process_workflow_result(workflow_result) {
72 Ok(metrics) => (
73 record_id.clone(),
74 Some(LLMEvalTaskResult::new(record_id, metrics, BTreeMap::new())),
75 ),
76 Err(error) => {
77 error!(
78 "Failed to process workflow result for record {}: {}",
79 record_id, error
80 );
81 (record_id, None)
82 }
83 }
84 }
85 Err(workflow_error) => {
86 error!(
87 "Workflow execution failed for record {}: {}",
88 record_id, workflow_error
89 );
90 (record_id, None)
91 }
92 }
93 });
94 }
95
96 join_set
97}
98
99pub async fn spawn_evaluation_tasks_with_embeddings(
108 workflow: Workflow,
109 records: Vec<LLMEvalRecord>,
110 embedder: Arc<Embedder>,
111 config: &Arc<EvaluationConfig>,
112) -> JoinSet<(String, Option<LLMEvalTaskResult>)> {
113 let mut join_set = JoinSet::new();
114
115 for record in records {
116 let inner_workflow = workflow.clone();
117 let cloned_embedder = embedder.clone();
118 let cloned_config = config.clone();
119
120 join_set.spawn(async move {
121 let record_id = record.id.clone();
122
123 let embeddings = generate_embeddings_for_record(
126 &record,
127 &cloned_embedder,
128 &cloned_config.embedding_targets,
129 )
130 .await;
131
132 match inner_workflow.run(Some(record.context)).await {
134 Ok(workflow_result) => {
135 match process_workflow_result(workflow_result) {
137 Ok(metrics) => (
138 record_id.clone(),
139 Some(LLMEvalTaskResult::new(record_id, metrics, embeddings)),
140 ),
141 Err(error) => {
142 error!(
143 "Failed to process workflow result for record {}: {}",
144 record_id, error
145 );
146 (record_id, None)
147 }
148 }
149 }
150 Err(workflow_error) => {
151 error!(
152 "Workflow execution failed for record {}: {}",
153 record_id, workflow_error
154 );
155 (record_id, None)
156 }
157 }
158 });
159 }
160
161 join_set
162}
163
164pub async fn generate_embeddings_for_record(
171 record: &LLMEvalRecord,
172 embedder: &Arc<Embedder>,
173 embedding_targets: &[String],
174) -> BTreeMap<String, Vec<f32>> {
175 let mut embeddings = BTreeMap::new();
176
177 for target in embedding_targets {
178 let texts = record
179 .context
180 .get(target)
181 .and_then(|v| v.as_str())
182 .map(|s| vec![s.to_string()]);
183
184 if let Some(texts) = texts {
185 match embedder.embed(EmbeddingInput::Texts(texts)).await {
186 Ok(embedding_response) => match embedding_response.values() {
187 Ok(values) => {
188 embeddings.insert(target.clone(), values.to_vec());
190 }
191 Err(e) => {
192 error!(
193 "Failed to extract embedding values for target {}: {:?}",
194 target, e
195 );
196 }
197 },
198 Err(e) => {
199 error!(
200 "Failed to generate embedding for target {}: {:?}",
201 target, e
202 );
203 }
204 }
205 } else {
206 warn!("No text found for embedding target: {}", target);
207 }
208 }
209
210 embeddings
211}
212
213pub async fn collect_evaluation_results(
215 mut join_set: JoinSet<(String, Option<LLMEvalTaskResult>)>,
216) -> Result<LLMEvalResults, EvaluationError> {
217 let mut eval_results = LLMEvalResults::new();
218
219 while let Some(join_result) = join_set.join_next().await {
220 match join_result {
221 Ok(task_result) => {
222 let (record_id, task_result_opt) = task_result;
223 if let Some(task_result) = task_result_opt {
224 eval_results.results.entry(record_id).or_insert(task_result);
225 } else {
226 eval_results.errored_tasks.push(record_id);
227 }
228 }
229 Err(join_error) => {
230 error!("Task join error: {:?}", join_error);
231 }
233 }
234 }
235
236 Ok(eval_results)
237}
238
239pub fn parse_embedder(
245 embedder: Option<&Bound<'_, PyAny>>,
246) -> Result<Option<Arc<Embedder>>, EvaluationError> {
247 let embedder_arc = if let Some(embedder_bound) = embedder {
249 if embedder_bound.is_instance_of::<PyEmbedder>() {
250 let py_embedder = embedder_bound.extract::<PyEmbedder>()?;
251 Some(py_embedder.embedder.clone())
252 } else {
253 return Err(EvaluationError::InvalidEmbedderType);
255 }
256 } else {
257 None
258 };
259 Ok(embedder_arc)
260}
261
262pub fn compute_mean(vec: &[f32]) -> Option<f64> {
265 match vec.len() {
266 0 => None,
267 _ => {
268 let sum = vec.iter().sum::<f32>();
269 let length = f32::from_usize(vec.len())?;
270
271 let mean = sum / length;
272 Some(mean as f64)
273 }
274 }
275}
276
277pub fn compute_similarity(
278 targets: &Vec<String>,
279 embeddings: &BTreeMap<String, Vec<f32>>,
280 scores: &mut BTreeMap<String, f64>,
281) {
282 for (a, b) in iproduct!(targets, targets) {
283 if a == b {
285 continue;
286 }
287 if let (Some(vec_a), Some(vec_b)) = (embeddings.get(a), embeddings.get(b)) {
288 if vec_a.len() != vec_b.len() {
289 warn!(
290 "Embedding length mismatch for targets {} and {}: {} vs {}",
291 a,
292 b,
293 vec_a.len(),
294 vec_b.len()
295 );
296 continue;
297 }
298
299 let similarity = f32::cosine(vec_a, vec_b).unwrap_or(-1.0);
300 let key = format!("{}_{}_cosine", a, b);
301 scores.insert(key, similarity);
302 } else {
303 warn!("Missing embeddings for targets {} or {}", a, b);
304 }
305 }
306}
307
308pub fn post_process(results: &mut LLMEvalResults, config: &Arc<EvaluationConfig>) {
309 results.results.par_iter_mut().for_each(|(_, task_result)| {
311 for (target, values) in task_result.embedding.iter() {
312 let mean = compute_mean(values).unwrap_or(-1.0);
313 task_result.mean_embeddings.insert(target.clone(), mean);
314 }
315 compute_similarity(
316 &config.embedding_targets,
317 &task_result.embedding,
318 &mut task_result.similarity_scores,
319 );
320 });
321}