use crate::error::TypeError;
use crate::genai::traits::TaskAccessor;
use crate::PyHelperFuncs;
use core::fmt::Debug;
use potato_head::prompt_types::Prompt;
use potato_head::Provider;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PySlice, PyString};
use pyo3::IntoPyObjectExt;
use pythonize::{depythonize, pythonize};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt::Display;
use std::path::PathBuf;
use std::str::FromStr;
use tracing::error;
pub fn deserialize_from_path<T: DeserializeOwned>(path: PathBuf) -> Result<T, TypeError> {
let content = std::fs::read_to_string(&path)?;
let extension = path
.extension()
.and_then(|ext| ext.to_str())
.ok_or_else(|| TypeError::Error(format!("Invalid file path: {:?}", path)))?;
let item = match extension.to_lowercase().as_str() {
"json" => serde_json::from_str(&content)?,
"yaml" | "yml" => serde_yaml::from_str(&content)?,
_ => {
return Err(TypeError::Error(format!(
"Unsupported file extension '{}'. Expected .json, .yaml, or .yml",
extension
)))
}
};
Ok(item)
}
fn default_assertion_task_type() -> EvaluationTaskType {
EvaluationTaskType::Assertion
}
fn default_trace_assertion_task_type() -> EvaluationTaskType {
EvaluationTaskType::TraceAssertion
}
fn default_agent_assertion_task_type() -> EvaluationTaskType {
EvaluationTaskType::AgentAssertion
}
#[pyclass]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AssertionResult {
#[pyo3(get)]
pub passed: bool,
pub actual: Value,
#[pyo3(get)]
pub message: String,
pub expected: Value,
}
impl AssertionResult {
pub fn new(passed: bool, actual: Value, message: String, expected: Value) -> Self {
Self {
passed,
actual,
message,
expected,
}
}
pub fn to_metric_value(&self) -> f64 {
if self.passed {
1.0
} else {
0.0
}
}
}
#[pymethods]
impl AssertionResult {
pub fn __str__(&self) -> String {
PyHelperFuncs::__str__(self)
}
#[getter]
pub fn get_actual<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
let py_value = pythonize(py, &self.actual)?;
Ok(py_value)
}
#[getter]
pub fn get_expected<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
let py_value = pythonize(py, &self.expected)?;
Ok(py_value)
}
}
#[pyclass]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AssertionResults {
#[pyo3(get)]
pub results: HashMap<String, AssertionResult>,
}
#[pymethods]
impl AssertionResults {
pub fn __str__(&self) -> String {
PyHelperFuncs::__str__(self)
}
pub fn __getitem__(&self, key: &str) -> Result<AssertionResult, TypeError> {
if let Some(result) = self.results.get(key) {
Ok(result.clone())
} else {
Err(TypeError::KeyNotFound {
key: key.to_string(),
})
}
}
}
#[pyclass]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AssertionTask {
#[pyo3(get, set)]
pub id: String,
#[pyo3(get, set)]
#[serde(default)]
pub context_path: Option<String>,
#[pyo3(get, set)]
#[serde(default)]
pub item_context_path: Option<String>,
#[pyo3(get, set)]
pub operator: ComparisonOperator,
pub expected_value: Value,
#[pyo3(get, set)]
#[serde(default)]
pub description: Option<String>,
#[pyo3(get, set)]
#[serde(default)]
pub depends_on: Vec<String>,
#[serde(default = "default_assertion_task_type")]
#[pyo3(get)]
pub task_type: EvaluationTaskType,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<AssertionResult>,
#[serde(default)]
pub condition: bool,
}
#[pymethods]
impl AssertionTask {
#[new]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (id, context_path, expected_value, operator, item_context_path=None, description=None, depends_on=None, condition=None))]
pub fn new(
id: String,
context_path: Option<String>,
expected_value: &Bound<'_, PyAny>,
operator: ComparisonOperator,
item_context_path: Option<String>,
description: Option<String>,
depends_on: Option<Vec<String>>,
condition: Option<bool>,
) -> Result<Self, TypeError> {
let expected_value = depythonize(expected_value)?;
let condition = condition.unwrap_or(false);
Ok(Self {
id: id.to_lowercase(),
context_path,
item_context_path,
operator,
expected_value,
description,
task_type: EvaluationTaskType::Assertion,
depends_on: depends_on.unwrap_or_default(),
result: None,
condition,
})
}
#[getter]
pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
let py_value = pythonize(py, &self.expected_value)?;
Ok(py_value)
}
#[staticmethod]
pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
deserialize_from_path(path)
}
}
impl AssertionTask {}
impl TaskAccessor for AssertionTask {
fn context_path(&self) -> Option<&str> {
self.context_path.as_deref()
}
fn item_context_path(&self) -> Option<&str> {
self.item_context_path.as_deref()
}
fn id(&self) -> &str {
&self.id
}
fn operator(&self) -> &ComparisonOperator {
&self.operator
}
fn task_type(&self) -> &EvaluationTaskType {
&self.task_type
}
fn expected_value(&self) -> &Value {
&self.expected_value
}
fn depends_on(&self) -> &[String] {
&self.depends_on
}
fn add_result(&mut self, result: AssertionResult) {
self.result = Some(result);
}
}
pub trait ValueExt {
fn to_length(&self) -> Option<i64>;
fn as_numeric(&self) -> Option<f64>;
fn is_truthy(&self) -> bool;
}
impl ValueExt for Value {
fn to_length(&self) -> Option<i64> {
match self {
Value::Array(arr) => Some(arr.len() as i64),
Value::String(s) => Some(s.chars().count() as i64),
Value::Object(obj) => Some(obj.len() as i64),
_ => None,
}
}
fn as_numeric(&self) -> Option<f64> {
match self {
Value::Number(n) => n.as_f64(),
_ => None,
}
}
fn is_truthy(&self) -> bool {
match self {
Value::Null => false,
Value::Bool(b) => *b,
Value::Number(n) => n.as_f64() != Some(0.0),
Value::String(s) => !s.is_empty(),
Value::Array(arr) => !arr.is_empty(),
Value::Object(obj) => !obj.is_empty(),
}
}
}
#[pyclass]
#[derive(Debug, Serialize, Clone, PartialEq)]
pub struct LLMJudgeTask {
#[pyo3(get, set)]
pub id: String,
#[pyo3(get)]
pub prompt: Prompt,
#[pyo3(get)]
#[serde(default)]
pub context_path: Option<String>,
pub expected_value: Value,
#[pyo3(get)]
pub operator: ComparisonOperator,
#[pyo3(get)]
pub task_type: EvaluationTaskType,
#[pyo3(get, set)]
#[serde(default)]
pub depends_on: Vec<String>,
#[pyo3(get, set)]
#[serde(default)]
pub max_retries: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<AssertionResult>,
#[serde(default)]
pub description: Option<String>,
#[pyo3(get, set)]
#[serde(default)]
pub condition: bool,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum PromptConfig {
Path { path: String },
Inline(Box<Prompt>),
}
#[derive(Debug, Deserialize)]
struct LLMJudgeTaskConfig {
pub id: String,
pub prompt: PromptConfig,
pub expected_value: Value,
pub operator: ComparisonOperator,
pub context_path: Option<String>,
pub description: Option<String>,
pub depends_on: Vec<String>,
pub max_retries: Option<u32>,
pub condition: bool,
}
impl LLMJudgeTaskConfig {
pub fn into_task(self) -> Result<LLMJudgeTask, TypeError> {
let prompt = match self.prompt {
PromptConfig::Path { path } => {
Prompt::from_path(PathBuf::from(path)).inspect_err(|e| {
error!("Failed to deserialize Prompt from path: {}", e);
})?
}
PromptConfig::Inline(prompt) => *prompt,
};
Ok(LLMJudgeTask {
id: self.id.to_lowercase(),
prompt,
expected_value: self.expected_value,
operator: self.operator,
context_path: self.context_path,
description: self.description,
depends_on: self.depends_on,
max_retries: self.max_retries.or(Some(3)),
task_type: EvaluationTaskType::LLMJudge,
result: None,
condition: self.condition,
})
}
}
#[derive(Debug, Deserialize)]
struct LLMJudgeTaskInternal {
pub id: String,
pub prompt: Prompt,
pub context_path: Option<String>,
pub expected_value: Value,
pub operator: ComparisonOperator,
pub task_type: EvaluationTaskType,
pub depends_on: Vec<String>,
pub max_retries: Option<u32>,
pub result: Option<AssertionResult>,
pub description: Option<String>,
pub condition: bool,
}
impl LLMJudgeTaskInternal {
pub fn into_task(self) -> LLMJudgeTask {
LLMJudgeTask {
id: self.id.to_lowercase(),
prompt: self.prompt,
context_path: self.context_path,
expected_value: self.expected_value,
operator: self.operator,
task_type: self.task_type,
depends_on: self.depends_on,
max_retries: self.max_retries.or(Some(3)),
result: self.result,
description: self.description,
condition: self.condition,
}
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum LLMJudgeFormat {
Full(Box<LLMJudgeTaskInternal>),
Generic(LLMJudgeTaskConfig),
}
impl<'de> Deserialize<'de> for LLMJudgeTask {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let format = LLMJudgeFormat::deserialize(deserializer)?;
match format {
LLMJudgeFormat::Generic(config) => config.into_task().map_err(serde::de::Error::custom),
LLMJudgeFormat::Full(internal) => Ok(internal.into_task()),
}
}
}
#[pymethods]
impl LLMJudgeTask {
#[new]
#[pyo3(signature = (id, prompt, expected_value, context_path,operator, description=None, depends_on=None, max_retries=None, condition=None))]
#[allow(clippy::too_many_arguments)]
pub fn new(
id: &str,
prompt: Prompt,
expected_value: &Bound<'_, PyAny>,
context_path: Option<String>,
operator: ComparisonOperator,
description: Option<String>,
depends_on: Option<Vec<String>>,
max_retries: Option<u32>,
condition: Option<bool>,
) -> Result<Self, TypeError> {
let expected_value = depythonize(expected_value)?;
Ok(Self {
id: id.to_lowercase(),
prompt,
expected_value,
operator,
task_type: EvaluationTaskType::LLMJudge,
depends_on: depends_on.unwrap_or_default(),
max_retries: max_retries.or(Some(3)),
context_path,
result: None,
description,
condition: condition.unwrap_or(false),
})
}
pub fn __str__(&self) -> String {
PyHelperFuncs::__str__(self)
}
#[getter]
pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
let py_value = pythonize(py, &self.expected_value)?;
Ok(py_value)
}
#[staticmethod]
pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
deserialize_from_path(path)
}
}
impl LLMJudgeTask {
#[allow(clippy::too_many_arguments)]
pub fn new_rs(
id: &str,
prompt: Prompt,
expected_value: Value,
context_path: Option<String>,
operator: ComparisonOperator,
depends_on: Option<Vec<String>>,
max_retries: Option<u32>,
description: Option<String>,
condition: Option<bool>,
) -> Self {
Self {
id: id.to_lowercase(),
prompt,
expected_value,
operator,
task_type: EvaluationTaskType::LLMJudge,
depends_on: depends_on.unwrap_or_default(),
max_retries: max_retries.or(Some(3)),
context_path,
result: None,
description,
condition: condition.unwrap_or(false),
}
}
}
impl TaskAccessor for LLMJudgeTask {
fn context_path(&self) -> Option<&str> {
self.context_path.as_deref()
}
fn item_context_path(&self) -> Option<&str> {
None
}
fn id(&self) -> &str {
&self.id
}
fn task_type(&self) -> &EvaluationTaskType {
&self.task_type
}
fn operator(&self) -> &ComparisonOperator {
&self.operator
}
fn expected_value(&self) -> &Value {
&self.expected_value
}
fn depends_on(&self) -> &[String] {
&self.depends_on
}
fn add_result(&mut self, result: AssertionResult) {
self.result = Some(result);
}
}
#[pyclass(eq, eq_int)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum SpanStatus {
Ok,
Error,
Unset,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PyValueWrapper(pub Value);
impl<'py> IntoPyObject<'py> for PyValueWrapper {
type Target = PyAny;
type Output = Bound<'py, Self::Target>;
type Error = TypeError;
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
pythonize(py, &self.0).map_err(TypeError::from)
}
}
impl<'a, 'py> FromPyObject<'a, 'py> for PyValueWrapper {
type Error = TypeError;
fn extract(ob: pyo3::Borrowed<'a, 'py, pyo3::PyAny>) -> Result<Self, Self::Error> {
let value: Value = depythonize(&ob)?;
Ok(PyValueWrapper(value))
}
}
#[pyclass(eq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum SpanFilter {
ByName { name: String },
ByNamePattern { pattern: String },
WithAttribute { key: String },
WithAttributeValue { key: String, value: PyValueWrapper },
WithStatus { status: SpanStatus },
WithDuration {
min_ms: Option<f64>,
max_ms: Option<f64>,
},
Sequence { names: Vec<String> },
And { filters: Vec<SpanFilter> },
Or { filters: Vec<SpanFilter> },
}
#[pymethods]
impl SpanFilter {
#[staticmethod]
pub fn by_name(name: String) -> Self {
SpanFilter::ByName { name }
}
#[staticmethod]
pub fn by_name_pattern(pattern: String) -> Self {
SpanFilter::ByNamePattern { pattern }
}
#[staticmethod]
pub fn with_attribute(key: String) -> Self {
SpanFilter::WithAttribute { key }
}
#[staticmethod]
pub fn with_attribute_value(key: String, value: &Bound<'_, PyAny>) -> Result<Self, TypeError> {
let value = PyValueWrapper(depythonize(value)?);
Ok(SpanFilter::WithAttributeValue { key, value })
}
#[staticmethod]
pub fn with_status(status: SpanStatus) -> Self {
SpanFilter::WithStatus { status }
}
#[staticmethod]
#[pyo3(signature = (min_ms=None, max_ms=None))]
pub fn with_duration(min_ms: Option<f64>, max_ms: Option<f64>) -> Self {
SpanFilter::WithDuration { min_ms, max_ms }
}
#[staticmethod]
pub fn sequence(names: Vec<String>) -> Self {
SpanFilter::Sequence { names }
}
pub fn and_(&self, other: SpanFilter) -> Self {
match self {
SpanFilter::And { filters } => {
let mut new_filters = filters.clone();
new_filters.push(other);
SpanFilter::And {
filters: new_filters,
}
}
_ => SpanFilter::And {
filters: vec![self.clone(), other],
},
}
}
pub fn or_(&self, other: SpanFilter) -> Self {
match self {
SpanFilter::Or { filters } => {
let mut new_filters = filters.clone();
new_filters.push(other);
SpanFilter::Or {
filters: new_filters,
}
}
_ => SpanFilter::Or {
filters: vec![self.clone(), other],
},
}
}
}
#[pyclass(eq, eq_int)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum AggregationType {
Count,
Sum,
Average,
Min,
Max,
First,
Last,
}
#[pyclass(eq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum MultiResponseMode {
Any,
All,
}
#[pyclass]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum AttributeFilterTask {
Assertion(AssertionTask),
AgentAssertion(AgentAssertionTask),
}
#[pymethods]
impl AttributeFilterTask {
#[staticmethod]
pub fn assertion(task: AssertionTask) -> Self {
AttributeFilterTask::Assertion(task)
}
#[staticmethod]
pub fn agent_assertion(task: AgentAssertionTask) -> Self {
AttributeFilterTask::AgentAssertion(task)
}
}
#[pyclass(eq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(clippy::large_enum_variant)]
pub enum TraceAssertion {
SpanSequence { span_names: Vec<String> },
SpanSet { span_names: Vec<String> },
SpanCount { filter: SpanFilter },
SpanExists { filter: SpanFilter },
SpanAttribute {
filter: SpanFilter,
attribute_key: String,
},
SpanDuration { filter: SpanFilter },
SpanAggregation {
filter: SpanFilter,
attribute_key: String,
aggregation: AggregationType,
},
TraceDuration {},
TraceSpanCount {},
TraceErrorCount {},
TraceServiceCount {},
TraceMaxDepth {},
TraceAttribute { attribute_key: String },
AttributeFilter {
key: String,
task: AttributeFilterTask,
mode: MultiResponseMode,
},
}
impl Display for TraceAssertion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = serde_json::to_string(self).unwrap_or_default();
write!(f, "{}", s)
}
}
#[pymethods]
impl TraceAssertion {
#[staticmethod]
pub fn span_sequence(span_names: Vec<String>) -> Self {
TraceAssertion::SpanSequence { span_names }
}
#[staticmethod]
pub fn span_set(span_names: Vec<String>) -> Self {
TraceAssertion::SpanSet { span_names }
}
#[staticmethod]
pub fn span_count(filter: SpanFilter) -> Self {
TraceAssertion::SpanCount { filter }
}
#[staticmethod]
pub fn span_exists(filter: SpanFilter) -> Self {
TraceAssertion::SpanExists { filter }
}
#[staticmethod]
pub fn span_attribute(filter: SpanFilter, attribute_key: String) -> Self {
TraceAssertion::SpanAttribute {
filter,
attribute_key,
}
}
#[staticmethod]
pub fn span_duration(filter: SpanFilter) -> Self {
TraceAssertion::SpanDuration { filter }
}
#[staticmethod]
pub fn span_aggregation(
filter: SpanFilter,
attribute_key: String,
aggregation: AggregationType,
) -> Self {
TraceAssertion::SpanAggregation {
filter,
attribute_key,
aggregation,
}
}
#[staticmethod]
pub fn trace_duration() -> Self {
TraceAssertion::TraceDuration {}
}
#[staticmethod]
pub fn trace_span_count() -> Self {
TraceAssertion::TraceSpanCount {}
}
#[staticmethod]
pub fn trace_error_count() -> Self {
TraceAssertion::TraceErrorCount {}
}
#[staticmethod]
pub fn trace_service_count() -> Self {
TraceAssertion::TraceServiceCount {}
}
#[staticmethod]
pub fn trace_max_depth() -> Self {
TraceAssertion::TraceMaxDepth {}
}
#[staticmethod]
pub fn trace_attribute(attribute_key: String) -> Self {
TraceAssertion::TraceAttribute { attribute_key }
}
#[staticmethod]
pub fn attribute_filter(
key: String,
task: AttributeFilterTask,
mode: MultiResponseMode,
) -> Self {
TraceAssertion::AttributeFilter { key, task, mode }
}
pub fn model_dump_json(&self) -> String {
serde_json::to_string(self).unwrap_or_default()
}
}
#[pyclass]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TraceAssertionTask {
#[pyo3(get, set)]
pub id: String,
#[pyo3(get, set)]
pub assertion: TraceAssertion,
#[pyo3(get, set)]
pub operator: ComparisonOperator,
pub expected_value: Value,
#[pyo3(get, set)]
#[serde(default)]
pub description: Option<String>,
#[pyo3(get, set)]
#[serde(default)]
pub depends_on: Vec<String>,
#[serde(default = "default_trace_assertion_task_type")]
#[pyo3(get)]
pub task_type: EvaluationTaskType,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<AssertionResult>,
#[pyo3(get, set)]
#[serde(default)]
pub condition: bool,
}
#[pymethods]
impl TraceAssertionTask {
#[new]
#[pyo3(signature = (id, assertion, expected_value, operator, description=None, depends_on=None, condition=None))]
pub fn new(
id: String,
assertion: TraceAssertion,
expected_value: &Bound<'_, PyAny>,
operator: ComparisonOperator,
description: Option<String>,
depends_on: Option<Vec<String>>,
condition: Option<bool>,
) -> Result<Self, TypeError> {
let expected_value = depythonize(expected_value)?;
Ok(Self {
id: id.to_lowercase(),
assertion,
operator,
expected_value,
description,
task_type: EvaluationTaskType::TraceAssertion,
depends_on: depends_on.unwrap_or_default(),
result: None,
condition: condition.unwrap_or(false),
})
}
pub fn __str__(&self) -> String {
PyHelperFuncs::__str__(self)
}
#[getter]
pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
let py_value = pythonize(py, &self.expected_value)?;
Ok(py_value)
}
#[staticmethod]
pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
deserialize_from_path(path)
}
}
impl TaskAccessor for TraceAssertionTask {
fn context_path(&self) -> Option<&str> {
None
}
fn item_context_path(&self) -> Option<&str> {
None
}
fn id(&self) -> &str {
&self.id
}
fn operator(&self) -> &ComparisonOperator {
&self.operator
}
fn task_type(&self) -> &EvaluationTaskType {
&self.task_type
}
fn expected_value(&self) -> &Value {
&self.expected_value
}
fn depends_on(&self) -> &[String] {
&self.depends_on
}
fn add_result(&mut self, result: AssertionResult) {
self.result = Some(result);
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
pub name: String,
pub arguments: Value,
pub result: Option<Value>,
pub call_id: Option<String>,
}
#[pyclass]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TokenUsage {
#[pyo3(get, set)]
pub input_tokens: Option<i64>,
#[pyo3(get, set)]
pub output_tokens: Option<i64>,
#[pyo3(get, set)]
pub total_tokens: Option<i64>,
}
#[pymethods]
impl TokenUsage {
#[new]
#[pyo3(signature = (input_tokens=None, output_tokens=None, total_tokens=None))]
pub fn new(
input_tokens: Option<i64>,
output_tokens: Option<i64>,
total_tokens: Option<i64>,
) -> Self {
Self {
input_tokens,
output_tokens,
total_tokens,
}
}
pub fn __str__(&self) -> String {
serde_json::to_string_pretty(self).unwrap_or_default()
}
}
#[pyclass(eq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum AgentAssertion {
ToolCalled { name: String },
ToolNotCalled { name: String },
ToolCalledWithArgs {
name: String,
arguments: PyValueWrapper,
},
ToolCallSequence { names: Vec<String> },
ToolCallCount { name: Option<String> },
ToolArgument { name: String, argument_key: String },
ToolResult { name: String },
ResponseContent {},
ResponseModel {},
ResponseFinishReason {},
ResponseInputTokens {},
ResponseOutputTokens {},
ResponseTotalTokens {},
ResponseField { path: String },
}
impl Display for AgentAssertion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = serde_json::to_string(self).unwrap_or_default();
write!(f, "{}", s)
}
}
#[pymethods]
impl AgentAssertion {
#[staticmethod]
pub fn tool_called(name: &str) -> Self {
AgentAssertion::ToolCalled {
name: name.to_string(),
}
}
#[staticmethod]
pub fn tool_not_called(name: &str) -> Self {
AgentAssertion::ToolNotCalled {
name: name.to_string(),
}
}
#[staticmethod]
pub fn tool_called_with_args(
name: &str,
arguments: &Bound<'_, PyAny>,
) -> Result<Self, TypeError> {
let arguments: Value = depythonize(arguments)?;
Ok(AgentAssertion::ToolCalledWithArgs {
name: name.to_string(),
arguments: PyValueWrapper(arguments),
})
}
#[staticmethod]
pub fn tool_call_sequence(names: Vec<String>) -> Self {
AgentAssertion::ToolCallSequence { names }
}
#[staticmethod]
#[pyo3(signature = (name=None))]
pub fn tool_call_count(name: Option<String>) -> Self {
AgentAssertion::ToolCallCount { name }
}
#[staticmethod]
pub fn tool_argument(name: &str, argument_key: &str) -> Self {
AgentAssertion::ToolArgument {
name: name.to_string(),
argument_key: argument_key.to_string(),
}
}
#[staticmethod]
pub fn tool_result(name: &str) -> Self {
AgentAssertion::ToolResult {
name: name.to_string(),
}
}
#[staticmethod]
pub fn response_content() -> Self {
AgentAssertion::ResponseContent {}
}
#[staticmethod]
pub fn response_model() -> Self {
AgentAssertion::ResponseModel {}
}
#[staticmethod]
pub fn response_finish_reason() -> Self {
AgentAssertion::ResponseFinishReason {}
}
#[staticmethod]
pub fn response_input_tokens() -> Self {
AgentAssertion::ResponseInputTokens {}
}
#[staticmethod]
pub fn response_output_tokens() -> Self {
AgentAssertion::ResponseOutputTokens {}
}
#[staticmethod]
pub fn response_total_tokens() -> Self {
AgentAssertion::ResponseTotalTokens {}
}
#[staticmethod]
pub fn response_field(path: &str) -> Self {
AgentAssertion::ResponseField {
path: path.to_string(),
}
}
pub fn __str__(&self) -> String {
serde_json::to_string_pretty(self).unwrap_or_default()
}
}
#[pyclass]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AgentAssertionTask {
#[pyo3(get, set)]
pub id: String,
#[pyo3(get, set)]
pub assertion: AgentAssertion,
#[pyo3(get, set)]
pub operator: ComparisonOperator,
pub expected_value: Value,
#[pyo3(get, set)]
#[serde(default)]
pub description: Option<String>,
#[pyo3(get, set)]
#[serde(default)]
pub depends_on: Vec<String>,
#[serde(default = "default_agent_assertion_task_type")]
#[pyo3(get)]
pub task_type: EvaluationTaskType,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<AssertionResult>,
#[pyo3(get, set)]
#[serde(default)]
pub condition: bool,
#[pyo3(get, set)]
#[serde(default)]
pub provider: Option<Provider>,
}
#[pymethods]
impl AgentAssertionTask {
#[new]
#[pyo3(signature = (id, assertion, expected_value, operator, description=None, depends_on=None, condition=None, provider=None))]
#[allow(clippy::too_many_arguments)]
pub fn new(
id: String,
assertion: AgentAssertion,
expected_value: &Bound<'_, PyAny>,
operator: ComparisonOperator,
description: Option<String>,
depends_on: Option<Vec<String>>,
condition: Option<bool>,
provider: Option<Provider>,
) -> Result<Self, TypeError> {
let expected_value = depythonize(expected_value)?;
Ok(Self {
id: id.to_lowercase(),
assertion,
operator,
expected_value,
description,
task_type: EvaluationTaskType::AgentAssertion,
depends_on: depends_on.unwrap_or_default(),
result: None,
condition: condition.unwrap_or(false),
provider,
})
}
pub fn __str__(&self) -> String {
PyHelperFuncs::__str__(self)
}
pub fn model_dump_json(&self) -> String {
serde_json::to_string(self).unwrap_or_default()
}
#[getter]
pub fn get_expected_value<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
let py_value = pythonize(py, &self.expected_value)?;
Ok(py_value)
}
#[getter]
pub fn get_result<'py>(&self, py: Python<'py>) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
match &self.result {
Some(result) => {
let py_value = pythonize(py, result)?;
Ok(Some(py_value))
}
None => Ok(None),
}
}
#[staticmethod]
pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
deserialize_from_path(path)
}
}
impl TaskAccessor for AgentAssertionTask {
fn context_path(&self) -> Option<&str> {
None
}
fn item_context_path(&self) -> Option<&str> {
None
}
fn id(&self) -> &str {
&self.id
}
fn operator(&self) -> &ComparisonOperator {
&self.operator
}
fn task_type(&self) -> &EvaluationTaskType {
&self.task_type
}
fn expected_value(&self) -> &Value {
&self.expected_value
}
fn depends_on(&self) -> &[String] {
&self.depends_on
}
fn add_result(&mut self, result: AssertionResult) {
self.result = Some(result);
}
}
#[derive(Debug, Clone)]
pub enum EvaluationTask {
Assertion(Box<AssertionTask>),
LLMJudge(Box<LLMJudgeTask>),
TraceAssertion(Box<TraceAssertionTask>),
AgentAssertion(Box<AgentAssertionTask>),
}
impl TaskAccessor for EvaluationTask {
fn context_path(&self) -> Option<&str> {
match self {
EvaluationTask::Assertion(t) => t.context_path(),
EvaluationTask::LLMJudge(t) => t.context_path(),
EvaluationTask::TraceAssertion(t) => t.context_path(),
EvaluationTask::AgentAssertion(t) => t.context_path(),
}
}
fn item_context_path(&self) -> Option<&str> {
match self {
EvaluationTask::Assertion(t) => t.item_context_path(),
EvaluationTask::LLMJudge(t) => t.item_context_path(),
EvaluationTask::TraceAssertion(t) => t.item_context_path(),
EvaluationTask::AgentAssertion(t) => t.item_context_path(),
}
}
fn id(&self) -> &str {
match self {
EvaluationTask::Assertion(t) => t.id(),
EvaluationTask::LLMJudge(t) => t.id(),
EvaluationTask::TraceAssertion(t) => t.id(),
EvaluationTask::AgentAssertion(t) => t.id(),
}
}
fn task_type(&self) -> &EvaluationTaskType {
match self {
EvaluationTask::Assertion(t) => t.task_type(),
EvaluationTask::LLMJudge(t) => t.task_type(),
EvaluationTask::TraceAssertion(t) => t.task_type(),
EvaluationTask::AgentAssertion(t) => t.task_type(),
}
}
fn operator(&self) -> &ComparisonOperator {
match self {
EvaluationTask::Assertion(t) => t.operator(),
EvaluationTask::LLMJudge(t) => t.operator(),
EvaluationTask::TraceAssertion(t) => t.operator(),
EvaluationTask::AgentAssertion(t) => t.operator(),
}
}
fn expected_value(&self) -> &Value {
match self {
EvaluationTask::Assertion(t) => t.expected_value(),
EvaluationTask::LLMJudge(t) => t.expected_value(),
EvaluationTask::TraceAssertion(t) => t.expected_value(),
EvaluationTask::AgentAssertion(t) => t.expected_value(),
}
}
fn depends_on(&self) -> &[String] {
match self {
EvaluationTask::Assertion(t) => t.depends_on(),
EvaluationTask::LLMJudge(t) => t.depends_on(),
EvaluationTask::TraceAssertion(t) => t.depends_on(),
EvaluationTask::AgentAssertion(t) => t.depends_on(),
}
}
fn add_result(&mut self, result: AssertionResult) {
match self {
EvaluationTask::Assertion(t) => t.add_result(result),
EvaluationTask::LLMJudge(t) => t.add_result(result),
EvaluationTask::TraceAssertion(t) => t.add_result(result),
EvaluationTask::AgentAssertion(t) => t.add_result(result),
}
}
}
pub struct EvaluationTasks(Vec<EvaluationTask>);
impl EvaluationTasks {
pub fn new() -> Self {
Self(Vec::new())
}
pub fn add_task(mut self, task: impl Into<EvaluationTask>) -> Self {
self.0.push(task.into());
self
}
pub fn build(self) -> Vec<EvaluationTask> {
self.0
}
}
impl From<AssertionTask> for EvaluationTask {
fn from(task: AssertionTask) -> Self {
EvaluationTask::Assertion(Box::new(task))
}
}
impl From<LLMJudgeTask> for EvaluationTask {
fn from(task: LLMJudgeTask) -> Self {
EvaluationTask::LLMJudge(Box::new(task))
}
}
impl From<TraceAssertionTask> for EvaluationTask {
fn from(task: TraceAssertionTask) -> Self {
EvaluationTask::TraceAssertion(Box::new(task))
}
}
impl From<AgentAssertionTask> for EvaluationTask {
fn from(task: AgentAssertionTask) -> Self {
EvaluationTask::AgentAssertion(Box::new(task))
}
}
impl Default for EvaluationTasks {
fn default() -> Self {
Self::new()
}
}
#[pyclass(eq, eq_int)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ComparisonOperator {
Equals,
NotEqual,
GreaterThan,
GreaterThanOrEqual,
LessThan,
LessThanOrEqual,
Contains,
NotContains,
StartsWith,
EndsWith,
Matches,
HasLengthGreaterThan,
HasLengthLessThan,
HasLengthEqual,
HasLengthGreaterThanOrEqual,
HasLengthLessThanOrEqual,
IsNumeric,
IsString,
IsBoolean,
IsNull,
IsArray,
IsObject,
IsEmail,
IsUrl,
IsUuid,
IsIso8601,
IsJson,
MatchesRegex,
InRange,
NotInRange,
IsPositive,
IsNegative,
IsZero,
SequenceMatches,
ContainsAll,
ContainsAny,
ContainsNone,
IsEmpty,
IsNotEmpty,
HasUniqueItems,
IsAlphabetic,
IsAlphanumeric,
IsLowerCase,
IsUpperCase,
ContainsWord,
ApproximatelyEquals,
}
impl Display for ComparisonOperator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl FromStr for ComparisonOperator {
type Err = TypeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Equals" => Ok(ComparisonOperator::Equals),
"NotEqual" => Ok(ComparisonOperator::NotEqual),
"GreaterThan" => Ok(ComparisonOperator::GreaterThan),
"GreaterThanOrEqual" => Ok(ComparisonOperator::GreaterThanOrEqual),
"LessThan" => Ok(ComparisonOperator::LessThan),
"LessThanOrEqual" => Ok(ComparisonOperator::LessThanOrEqual),
"Contains" => Ok(ComparisonOperator::Contains),
"NotContains" => Ok(ComparisonOperator::NotContains),
"StartsWith" => Ok(ComparisonOperator::StartsWith),
"EndsWith" => Ok(ComparisonOperator::EndsWith),
"Matches" => Ok(ComparisonOperator::Matches),
"HasLengthEqual" => Ok(ComparisonOperator::HasLengthEqual),
"HasLengthGreaterThan" => Ok(ComparisonOperator::HasLengthGreaterThan),
"HasLengthLessThan" => Ok(ComparisonOperator::HasLengthLessThan),
"HasLengthGreaterThanOrEqual" => Ok(ComparisonOperator::HasLengthGreaterThanOrEqual),
"HasLengthLessThanOrEqual" => Ok(ComparisonOperator::HasLengthLessThanOrEqual),
"IsNumeric" => Ok(ComparisonOperator::IsNumeric),
"IsString" => Ok(ComparisonOperator::IsString),
"IsBoolean" => Ok(ComparisonOperator::IsBoolean),
"IsNull" => Ok(ComparisonOperator::IsNull),
"IsArray" => Ok(ComparisonOperator::IsArray),
"IsObject" => Ok(ComparisonOperator::IsObject),
"IsEmail" => Ok(ComparisonOperator::IsEmail),
"IsUrl" => Ok(ComparisonOperator::IsUrl),
"IsUuid" => Ok(ComparisonOperator::IsUuid),
"IsIso8601" => Ok(ComparisonOperator::IsIso8601),
"IsJson" => Ok(ComparisonOperator::IsJson),
"MatchesRegex" => Ok(ComparisonOperator::MatchesRegex),
"InRange" => Ok(ComparisonOperator::InRange),
"NotInRange" => Ok(ComparisonOperator::NotInRange),
"IsPositive" => Ok(ComparisonOperator::IsPositive),
"IsNegative" => Ok(ComparisonOperator::IsNegative),
"IsZero" => Ok(ComparisonOperator::IsZero),
"ContainsAll" => Ok(ComparisonOperator::ContainsAll),
"ContainsAny" => Ok(ComparisonOperator::ContainsAny),
"ContainsNone" => Ok(ComparisonOperator::ContainsNone),
"IsEmpty" => Ok(ComparisonOperator::IsEmpty),
"IsNotEmpty" => Ok(ComparisonOperator::IsNotEmpty),
"HasUniqueItems" => Ok(ComparisonOperator::HasUniqueItems),
"SequenceMatches" => Ok(ComparisonOperator::SequenceMatches),
"IsAlphabetic" => Ok(ComparisonOperator::IsAlphabetic),
"IsAlphanumeric" => Ok(ComparisonOperator::IsAlphanumeric),
"IsLowerCase" => Ok(ComparisonOperator::IsLowerCase),
"IsUpperCase" => Ok(ComparisonOperator::IsUpperCase),
"ContainsWord" => Ok(ComparisonOperator::ContainsWord),
"ApproximatelyEquals" => Ok(ComparisonOperator::ApproximatelyEquals),
_ => Err(TypeError::InvalidCompressionTypeError),
}
}
}
impl ComparisonOperator {
pub fn as_str(&self) -> &str {
match self {
ComparisonOperator::Equals => "Equals",
ComparisonOperator::NotEqual => "NotEqual",
ComparisonOperator::GreaterThan => "GreaterThan",
ComparisonOperator::GreaterThanOrEqual => "GreaterThanOrEqual",
ComparisonOperator::LessThan => "LessThan",
ComparisonOperator::LessThanOrEqual => "LessThanOrEqual",
ComparisonOperator::Contains => "Contains",
ComparisonOperator::NotContains => "NotContains",
ComparisonOperator::StartsWith => "StartsWith",
ComparisonOperator::EndsWith => "EndsWith",
ComparisonOperator::Matches => "Matches",
ComparisonOperator::HasLengthEqual => "HasLengthEqual",
ComparisonOperator::HasLengthGreaterThan => "HasLengthGreaterThan",
ComparisonOperator::HasLengthLessThan => "HasLengthLessThan",
ComparisonOperator::HasLengthGreaterThanOrEqual => "HasLengthGreaterThanOrEqual",
ComparisonOperator::HasLengthLessThanOrEqual => "HasLengthLessThanOrEqual",
ComparisonOperator::IsNumeric => "IsNumeric",
ComparisonOperator::IsString => "IsString",
ComparisonOperator::IsBoolean => "IsBoolean",
ComparisonOperator::IsNull => "IsNull",
ComparisonOperator::IsArray => "IsArray",
ComparisonOperator::IsObject => "IsObject",
ComparisonOperator::IsEmail => "IsEmail",
ComparisonOperator::IsUrl => "IsUrl",
ComparisonOperator::IsUuid => "IsUuid",
ComparisonOperator::IsIso8601 => "IsIso8601",
ComparisonOperator::IsJson => "IsJson",
ComparisonOperator::MatchesRegex => "MatchesRegex",
ComparisonOperator::InRange => "InRange",
ComparisonOperator::NotInRange => "NotInRange",
ComparisonOperator::IsPositive => "IsPositive",
ComparisonOperator::IsNegative => "IsNegative",
ComparisonOperator::IsZero => "IsZero",
ComparisonOperator::ContainsAll => "ContainsAll",
ComparisonOperator::ContainsAny => "ContainsAny",
ComparisonOperator::ContainsNone => "ContainsNone",
ComparisonOperator::IsEmpty => "IsEmpty",
ComparisonOperator::IsNotEmpty => "IsNotEmpty",
ComparisonOperator::HasUniqueItems => "HasUniqueItems",
ComparisonOperator::SequenceMatches => "SequenceMatches",
ComparisonOperator::IsAlphabetic => "IsAlphabetic",
ComparisonOperator::IsAlphanumeric => "IsAlphanumeric",
ComparisonOperator::IsLowerCase => "IsLowerCase",
ComparisonOperator::IsUpperCase => "IsUpperCase",
ComparisonOperator::ContainsWord => "ContainsWord",
ComparisonOperator::ApproximatelyEquals => "ApproximatelyEquals",
}
}
}
#[pyclass]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum AssertionValue {
String(String),
Number(f64),
Integer(i64),
Boolean(bool),
List(Vec<AssertionValue>),
Null(),
}
impl AssertionValue {
pub fn to_actual(self, comparison: &ComparisonOperator) -> AssertionValue {
match comparison {
ComparisonOperator::HasLengthEqual
| ComparisonOperator::HasLengthGreaterThan
| ComparisonOperator::HasLengthLessThan
| ComparisonOperator::HasLengthGreaterThanOrEqual
| ComparisonOperator::HasLengthLessThanOrEqual => match self {
AssertionValue::List(arr) => AssertionValue::Integer(arr.len() as i64),
AssertionValue::String(s) => AssertionValue::Integer(s.chars().count() as i64),
_ => self,
},
_ => self,
}
}
pub fn to_serde_value(&self) -> Value {
match self {
AssertionValue::String(s) => Value::String(s.clone()),
AssertionValue::Number(n) => Value::Number(serde_json::Number::from_f64(*n).unwrap()),
AssertionValue::Integer(i) => Value::Number(serde_json::Number::from(*i)),
AssertionValue::Boolean(b) => Value::Bool(*b),
AssertionValue::List(arr) => {
let json_arr: Vec<Value> = arr.iter().map(|v| v.to_serde_value()).collect();
Value::Array(json_arr)
}
AssertionValue::Null() => Value::Null,
}
}
}
pub fn assertion_value_from_py(value: &Bound<'_, PyAny>) -> Result<AssertionValue, TypeError> {
if value.is_none() {
return Ok(AssertionValue::Null());
}
if value.is_instance_of::<PyBool>() {
return Ok(AssertionValue::Boolean(value.extract()?));
}
if value.is_instance_of::<PyString>() {
return Ok(AssertionValue::String(value.extract()?));
}
if value.is_instance_of::<PyInt>() {
return Ok(AssertionValue::Integer(value.extract()?));
}
if value.is_instance_of::<PyFloat>() {
return Ok(AssertionValue::Number(value.extract()?));
}
if value.is_instance_of::<PyList>() {
let list = value.cast::<PyList>()?; let assertion_list = list
.iter()
.map(|item| assertion_value_from_py(&item))
.collect::<Result<Vec<_>, _>>()?;
return Ok(AssertionValue::List(assertion_list));
}
Err(TypeError::UnsupportedType(
value.get_type().name()?.to_string(),
))
}
#[pyclass(eq, eq_int)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum EvaluationTaskType {
Assertion,
LLMJudge,
Conditional,
HumanValidation,
TraceAssertion,
AgentAssertion,
}
impl Display for EvaluationTaskType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let task_type_str = match self {
EvaluationTaskType::Assertion => "Assertion",
EvaluationTaskType::LLMJudge => "LLMJudge",
EvaluationTaskType::Conditional => "Conditional",
EvaluationTaskType::HumanValidation => "HumanValidation",
EvaluationTaskType::TraceAssertion => "TraceAssertion",
EvaluationTaskType::AgentAssertion => "AgentAssertion",
};
write!(f, "{}", task_type_str)
}
}
impl FromStr for EvaluationTaskType {
type Err = TypeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Assertion" => Ok(EvaluationTaskType::Assertion),
"LLMJudge" => Ok(EvaluationTaskType::LLMJudge),
"Conditional" => Ok(EvaluationTaskType::Conditional),
"HumanValidation" => Ok(EvaluationTaskType::HumanValidation),
"TraceAssertion" => Ok(EvaluationTaskType::TraceAssertion),
"AgentAssertion" => Ok(EvaluationTaskType::AgentAssertion),
_ => Err(TypeError::InvalidEvalType(s.to_string())),
}
}
}
impl EvaluationTaskType {
pub fn as_str(&self) -> &str {
match self {
EvaluationTaskType::Assertion => "Assertion",
EvaluationTaskType::LLMJudge => "LLMJudge",
EvaluationTaskType::Conditional => "Conditional",
EvaluationTaskType::HumanValidation => "HumanValidation",
EvaluationTaskType::TraceAssertion => "TraceAssertion",
EvaluationTaskType::AgentAssertion => "AgentAssertion",
}
}
}
#[pyclass]
#[derive(Debug, Serialize)]
pub struct TasksFile {
pub tasks: Vec<TaskConfig>,
#[serde(default)]
index: usize,
}
#[pymethods]
impl TasksFile {
#[staticmethod]
pub fn from_path(path: PathBuf) -> Result<Self, TypeError> {
let tasks_file: TasksFile = deserialize_from_path(path)?;
Ok(tasks_file)
}
pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
pub fn __next__<'py>(
mut slf: PyRefMut<'py, Self>,
) -> Result<Option<Bound<'py, PyAny>>, TypeError> {
let py = slf.py();
if slf.index < slf.tasks.len() {
let task = slf.tasks[slf.index].clone().into_bound_py_any(py)?;
slf.index += 1;
Ok(Some(task))
} else {
Ok(None)
}
}
fn __getitem__<'py>(
&self,
py: Python<'py>,
index: &Bound<'py, PyAny>,
) -> Result<Bound<'py, PyAny>, TypeError> {
if let Ok(i) = index.extract::<isize>() {
let len = self.tasks.len() as isize;
let actual_index = if i < 0 { len + i } else { i };
if actual_index < 0 || actual_index >= len {
return Err(TypeError::IndexOutOfBounds {
index: i,
length: self.tasks.len(),
});
}
Ok(self.tasks[actual_index as usize]
.clone()
.into_bound_py_any(py)?)
} else if let Ok(slice) = index.cast::<PySlice>() {
let indices = slice.indices(self.tasks.len() as isize)?;
let result = PyList::empty(py);
let mut i = indices.start;
while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
result.append(self.tasks[i as usize].clone().into_bound_py_any(py)?)?;
i += indices.step;
}
Ok(result.into_bound_py_any(py)?)
} else {
Err(TypeError::IndexOrSliceExpected)
}
}
fn __len__(&self) -> usize {
self.tasks.len()
}
fn __str__(&self) -> String {
PyHelperFuncs::__str__(self)
}
}
#[derive(Debug, Serialize, Clone)]
#[allow(clippy::large_enum_variant)]
pub enum TaskConfig {
Assertion(AssertionTask),
#[serde(rename = "LLMJudge")]
LLMJudge(Box<LLMJudgeTask>),
TraceAssertion(TraceAssertionTask),
AgentAssertion(AgentAssertionTask),
}
impl TaskConfig {
fn into_bound_py_any<'py>(self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TypeError> {
match self {
TaskConfig::Assertion(task) => Ok(task.into_bound_py_any(py)?),
TaskConfig::LLMJudge(task) => Ok(task.into_bound_py_any(py)?),
TaskConfig::TraceAssertion(task) => Ok(task.into_bound_py_any(py)?),
TaskConfig::AgentAssertion(task) => Ok(task.into_bound_py_any(py)?),
}
}
}
impl<'de> Deserialize<'de> for TasksFile {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum TasksFileRaw {
Direct(Vec<TaskConfigRaw>),
Wrapped { tasks: Vec<TaskConfigRaw> },
}
#[derive(Deserialize)]
struct TaskConfigRaw {
task_type: EvaluationTaskType,
#[serde(flatten)]
data: Value,
}
let raw = TasksFileRaw::deserialize(deserializer)?;
let raw_tasks = match raw {
TasksFileRaw::Direct(tasks) => tasks,
TasksFileRaw::Wrapped { tasks } => tasks,
};
let mut tasks = Vec::new();
for task_raw in raw_tasks {
let task_config = match task_raw.task_type {
EvaluationTaskType::Assertion => {
let mut task: AssertionTask =
serde_json::from_value(task_raw.data).map_err(|e| {
error!("Failed to deserialize AssertionTask: {}", e);
serde::de::Error::custom(e.to_string())
})?;
task.task_type = EvaluationTaskType::Assertion;
TaskConfig::Assertion(task)
}
EvaluationTaskType::LLMJudge => {
let mut task: LLMJudgeTask =
serde_json::from_value(task_raw.data).map_err(|e| {
error!("Failed to deserialize LLMJudgeTask: {}", e);
serde::de::Error::custom(e.to_string())
})?;
task.task_type = EvaluationTaskType::LLMJudge;
TaskConfig::LLMJudge(Box::new(task))
}
EvaluationTaskType::TraceAssertion => {
let mut task: TraceAssertionTask = serde_json::from_value(task_raw.data)
.map_err(|e| {
error!("Failed to deserialize TraceAssertionTask: {}", e);
serde::de::Error::custom(e.to_string())
})?;
task.task_type = EvaluationTaskType::TraceAssertion;
TaskConfig::TraceAssertion(task)
}
EvaluationTaskType::AgentAssertion => {
let mut task: AgentAssertionTask = serde_json::from_value(task_raw.data)
.map_err(|e| {
error!("Failed to deserialize AgentAssertionTask: {}", e);
serde::de::Error::custom(e.to_string())
})?;
task.task_type = EvaluationTaskType::AgentAssertion;
TaskConfig::AgentAssertion(task)
}
_ => {
return Err(serde::de::Error::custom(format!(
"Unknown task_type: {}",
task_raw.task_type
)))
}
};
tasks.push(task_config);
}
Ok(TasksFile { tasks, index: 0 })
}
}