use crate::error::TypeError;
use crate::genai::{
AgentAssertionTask, AssertionTask, LLMJudgeTask, TaskConfig, TasksFile, TraceAssertionTask,
};
use pyo3::prelude::*;
use pyo3::types::PyList;
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct AssertionTasks {
pub assertion: Vec<AssertionTask>,
pub judge: Vec<LLMJudgeTask>,
pub trace: Vec<TraceAssertionTask>,
pub agent: Vec<AgentAssertionTask>,
}
impl AssertionTasks {
pub fn collect_non_judge_task_ids(&self) -> BTreeSet<String> {
self.assertion
.iter()
.map(|t| t.id.clone())
.chain(self.trace.iter().map(|t| t.id.clone()))
.chain(self.agent.iter().map(|t| t.id.clone()))
.collect()
}
pub fn collect_all_task_ids(&self) -> Result<BTreeSet<String>, TypeError> {
let mut task_ids = BTreeSet::new();
for task in &self.assertion {
task_ids.insert(task.id.clone());
}
for task in &self.judge {
task_ids.insert(task.id.clone());
}
for task in &self.trace {
task_ids.insert(task.id.clone());
}
for task in &self.agent {
task_ids.insert(task.id.clone());
}
let total_tasks =
self.assertion.len() + self.judge.len() + self.trace.len() + self.agent.len();
if task_ids.len() != total_tasks {
return Err(TypeError::DuplicateTaskIds);
}
Ok(task_ids)
}
pub fn from_tasks_file(tasks: TasksFile) -> Self {
let mut assertion = Vec::new();
let mut judge = Vec::new();
let mut trace = Vec::new();
let mut agent = Vec::new();
for task in tasks.tasks {
match task {
TaskConfig::Assertion(t) => assertion.push(t),
TaskConfig::LLMJudge(t) => judge.push(*t),
TaskConfig::TraceAssertion(t) => trace.push(t),
TaskConfig::AgentAssertion(t) => agent.push(t),
}
}
AssertionTasks {
assertion,
judge,
trace,
agent,
}
}
}
pub fn extract_assertion_tasks_from_pylist(
list: &Bound<'_, PyList>,
) -> Result<AssertionTasks, TypeError> {
let mut assertion_tasks = Vec::new();
let mut llm_judge_tasks = Vec::new();
let mut trace_tasks = Vec::new();
let mut agent_tasks = Vec::new();
for item in list.iter() {
if item.is_instance_of::<AssertionTask>() {
let task = item.extract::<AssertionTask>()?;
assertion_tasks.push(task);
} else if item.is_instance_of::<LLMJudgeTask>() {
let task = item.extract::<LLMJudgeTask>()?;
llm_judge_tasks.push(task);
} else if item.is_instance_of::<TraceAssertionTask>() {
let task = item.extract::<TraceAssertionTask>()?;
trace_tasks.push(task);
} else if item.is_instance_of::<AgentAssertionTask>() {
let task = item.extract::<AgentAssertionTask>()?;
agent_tasks.push(task);
} else {
return Err(TypeError::InvalidAssertionTaskType);
}
}
Ok(AssertionTasks {
assertion: assertion_tasks,
judge: llm_judge_tasks,
trace: trace_tasks,
agent: agent_tasks,
})
}