scouter-types 0.25.0

Client and server contract for scouter
Documentation
use crate::error::TypeError;
use crate::genai::{
    AgentAssertionTask, AssertionTask, LLMJudgeTask, TaskConfig, TasksFile, TraceAssertionTask,
};
use pyo3::prelude::*;
use pyo3::types::PyList;
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct AssertionTasks {
    pub assertion: Vec<AssertionTask>,
    pub judge: Vec<LLMJudgeTask>,
    pub trace: Vec<TraceAssertionTask>,
    pub agent: Vec<AgentAssertionTask>,
}

impl AssertionTasks {
    pub fn collect_non_judge_task_ids(&self) -> BTreeSet<String> {
        self.assertion
            .iter()
            .map(|t| t.id.clone())
            .chain(self.trace.iter().map(|t| t.id.clone()))
            .chain(self.agent.iter().map(|t| t.id.clone()))
            .collect()
    }

    pub fn collect_all_task_ids(&self) -> Result<BTreeSet<String>, TypeError> {
        let mut task_ids = BTreeSet::new();

        for task in &self.assertion {
            task_ids.insert(task.id.clone());
        }
        for task in &self.judge {
            task_ids.insert(task.id.clone());
        }
        for task in &self.trace {
            task_ids.insert(task.id.clone());
        }
        for task in &self.agent {
            task_ids.insert(task.id.clone());
        }

        let total_tasks =
            self.assertion.len() + self.judge.len() + self.trace.len() + self.agent.len();
        if task_ids.len() != total_tasks {
            return Err(TypeError::DuplicateTaskIds);
        }

        Ok(task_ids)
    }

    pub fn from_tasks_file(tasks: TasksFile) -> Self {
        let mut assertion = Vec::new();
        let mut judge = Vec::new();
        let mut trace = Vec::new();
        let mut agent = Vec::new();

        for task in tasks.tasks {
            match task {
                TaskConfig::Assertion(t) => assertion.push(t),
                TaskConfig::LLMJudge(t) => judge.push(*t),
                TaskConfig::TraceAssertion(t) => trace.push(t),
                TaskConfig::AgentAssertion(t) => agent.push(t),
            }
        }

        AssertionTasks {
            assertion,
            judge,
            trace,
            agent,
        }
    }
}

/// Helper function to extract AssertionTask and LLMJudgeTask from a PyList
pub fn extract_assertion_tasks_from_pylist(
    list: &Bound<'_, PyList>,
) -> Result<AssertionTasks, TypeError> {
    let mut assertion_tasks = Vec::new();
    let mut llm_judge_tasks = Vec::new();
    let mut trace_tasks = Vec::new();
    let mut agent_tasks = Vec::new();

    for item in list.iter() {
        if item.is_instance_of::<AssertionTask>() {
            let task = item.extract::<AssertionTask>()?;
            assertion_tasks.push(task);
        } else if item.is_instance_of::<LLMJudgeTask>() {
            let task = item.extract::<LLMJudgeTask>()?;
            llm_judge_tasks.push(task);
        } else if item.is_instance_of::<TraceAssertionTask>() {
            let task = item.extract::<TraceAssertionTask>()?;
            trace_tasks.push(task);
        } else if item.is_instance_of::<AgentAssertionTask>() {
            let task = item.extract::<AgentAssertionTask>()?;
            agent_tasks.push(task);
        } else {
            return Err(TypeError::InvalidAssertionTaskType);
        }
    }
    Ok(AssertionTasks {
        assertion: assertion_tasks,
        judge: llm_judge_tasks,
        trace: trace_tasks,
        agent: agent_tasks,
    })
}