use std::sync::Arc;
use crate::types::{Invocation, RecordedToolCall, TurnRecord};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EvaluationLevel {
Tool,
Trace,
Session,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum ExtractedInput {
Tool {
turn_index: usize,
call: RecordedToolCall,
},
Trace(Box<Invocation>),
Session {
turns: Vec<TurnRecord>,
},
}
impl ExtractedInput {
#[must_use]
pub fn level(&self) -> EvaluationLevel {
match self {
Self::Tool { .. } => EvaluationLevel::Tool,
Self::Trace(_) => EvaluationLevel::Trace,
Self::Session { .. } => EvaluationLevel::Session,
}
}
}
pub trait TraceExtractor: Send + Sync {
fn extract(&self, inv: &Invocation, level: EvaluationLevel) -> Vec<ExtractedInput>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct ToolLevelExtractor;
impl TraceExtractor for ToolLevelExtractor {
fn extract(&self, inv: &Invocation, level: EvaluationLevel) -> Vec<ExtractedInput> {
match level {
EvaluationLevel::Tool => inv
.turns
.iter()
.flat_map(|turn| {
turn.tool_calls
.iter()
.cloned()
.map(move |call| ExtractedInput::Tool {
turn_index: turn.turn_index,
call,
})
})
.collect(),
EvaluationLevel::Trace => vec![ExtractedInput::Trace(Box::new(inv.clone()))],
EvaluationLevel::Session => vec![ExtractedInput::Session {
turns: inv.turns.clone(),
}],
}
}
}
const TRANSFER_TO_AGENT_TOOL: &str = "transfer_to_agent";
#[derive(Debug, Clone)]
pub struct SwarmExtractor {
handoff_tool: String,
}
impl Default for SwarmExtractor {
fn default() -> Self {
Self {
handoff_tool: TRANSFER_TO_AGENT_TOOL.to_string(),
}
}
}
impl SwarmExtractor {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_handoff_tool(mut self, tool_name: impl Into<String>) -> Self {
self.handoff_tool = tool_name.into();
self
}
#[must_use]
pub fn handoff_tool(&self) -> &str {
&self.handoff_tool
}
}
impl TraceExtractor for SwarmExtractor {
fn extract(&self, inv: &Invocation, level: EvaluationLevel) -> Vec<ExtractedInput> {
match level {
EvaluationLevel::Tool => ToolLevelExtractor.extract(inv, level),
EvaluationLevel::Trace => vec![ExtractedInput::Trace(Box::new(inv.clone()))],
EvaluationLevel::Session => {
let mut out: Vec<ExtractedInput> = Vec::new();
let mut cohort: Vec<crate::types::TurnRecord> = Vec::new();
for turn in &inv.turns {
cohort.push(turn.clone());
let fires_handoff = turn
.tool_calls
.iter()
.any(|call| call.name == self.handoff_tool);
if fires_handoff && !cohort.is_empty() {
out.push(ExtractedInput::Session {
turns: std::mem::take(&mut cohort),
});
}
}
if !cohort.is_empty() {
out.push(ExtractedInput::Session { turns: cohort });
}
if out.is_empty() && !inv.turns.is_empty() {
out.push(ExtractedInput::Session {
turns: inv.turns.clone(),
});
}
out
}
}
}
}
#[derive(Clone)]
pub struct GraphExtractor {
node_key: Arc<dyn Fn(&crate::types::TurnRecord) -> String + Send + Sync>,
}
impl std::fmt::Debug for GraphExtractor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GraphExtractor").finish_non_exhaustive()
}
}
impl GraphExtractor {
#[must_use]
pub fn new() -> Self {
Self {
node_key: Arc::new(|turn| turn.assistant_message.model_id.clone()),
}
}
#[must_use]
pub fn with_node_key<F>(key: F) -> Self
where
F: Fn(&crate::types::TurnRecord) -> String + Send + Sync + 'static,
{
Self {
node_key: Arc::new(key),
}
}
}
impl Default for GraphExtractor {
fn default() -> Self {
Self::new()
}
}
impl TraceExtractor for GraphExtractor {
fn extract(&self, inv: &Invocation, level: EvaluationLevel) -> Vec<ExtractedInput> {
use crate::types::TurnRecord;
match level {
EvaluationLevel::Tool => ToolLevelExtractor.extract(inv, level),
EvaluationLevel::Trace => vec![ExtractedInput::Trace(Box::new(inv.clone()))],
EvaluationLevel::Session => {
let mut out: Vec<ExtractedInput> = Vec::new();
let mut cohort: Vec<TurnRecord> = Vec::new();
let mut current_key: Option<String> = None;
for turn in &inv.turns {
let key = (self.node_key)(turn);
match ¤t_key {
Some(k) if k == &key => {
cohort.push(turn.clone());
}
_ => {
if !cohort.is_empty() {
out.push(ExtractedInput::Session {
turns: std::mem::take(&mut cohort),
});
}
current_key = Some(key);
cohort.push(turn.clone());
}
}
}
if !cohort.is_empty() {
out.push(ExtractedInput::Session { turns: cohort });
}
out
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn evaluation_level_serde_round_trip() {
let yaml_like = serde_json::to_string(&EvaluationLevel::Trace).unwrap();
assert_eq!(yaml_like, "\"trace\"");
let back: EvaluationLevel = serde_json::from_str(&yaml_like).unwrap();
assert_eq!(back, EvaluationLevel::Trace);
}
#[test]
fn extracted_input_level_matches_variant() {
let call = RecordedToolCall {
id: "id".into(),
name: "n".into(),
arguments: serde_json::Value::Null,
};
assert_eq!(
ExtractedInput::Tool {
turn_index: 0,
call
}
.level(),
EvaluationLevel::Tool
);
assert_eq!(
ExtractedInput::Session { turns: vec![] }.level(),
EvaluationLevel::Session
);
}
}