use crate::{content::Content, duration::DurationMs, effect::Effect, error::OperatorError, id::*};
use async_trait::async_trait;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum TriggerType {
User,
Task,
Signal,
Schedule,
SystemEvent,
Custom(String),
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OperatorInput {
pub message: Content,
pub trigger: TriggerType,
pub session: Option<SessionId>,
pub config: Option<OperatorConfig>,
#[serde(default)]
pub metadata: serde_json::Value,
}
#[non_exhaustive]
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct OperatorConfig {
pub max_turns: Option<u32>,
pub max_cost: Option<Decimal>,
pub max_duration: Option<DurationMs>,
pub model: Option<String>,
#[serde(alias = "allowed_tools")]
pub allowed_operators: Option<Vec<String>>,
pub system_addendum: Option<String>,
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ExitReason {
Complete,
MaxTurns,
BudgetExhausted,
CircuitBreaker,
Timeout,
InterceptorHalt {
reason: String,
},
Error,
SafetyStop {
reason: String,
},
AwaitingApproval,
Custom(String),
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OperatorOutput {
pub message: Content,
pub exit_reason: ExitReason,
pub metadata: OperatorMetadata,
#[serde(default)]
pub effects: Vec<Effect>,
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OperatorMetadata {
pub tokens_in: u64,
pub tokens_out: u64,
pub cost: Decimal,
pub turns_used: u32,
#[serde(alias = "tools_called")]
pub sub_dispatches: Vec<SubDispatchRecord>,
pub duration: DurationMs,
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubDispatchRecord {
pub name: String,
pub duration: DurationMs,
pub success: bool,
}
impl Default for OperatorMetadata {
fn default() -> Self {
Self {
tokens_in: 0,
tokens_out: 0,
cost: Decimal::ZERO,
turns_used: 0,
sub_dispatches: vec![],
duration: DurationMs::ZERO,
}
}
}
impl OperatorInput {
pub fn new(message: Content, trigger: TriggerType) -> Self {
Self {
message,
trigger,
session: None,
config: None,
metadata: serde_json::Value::Null,
}
}
}
impl OperatorOutput {
pub fn new(message: Content, exit_reason: ExitReason) -> Self {
Self {
message,
exit_reason,
metadata: OperatorMetadata::default(),
effects: vec![],
}
}
}
impl SubDispatchRecord {
pub fn new(name: impl Into<String>, duration: DurationMs, success: bool) -> Self {
Self {
name: name.into(),
duration,
success,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolMetadata {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
pub parallel_safe: bool,
}
impl ToolMetadata {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
input_schema: serde_json::Value,
parallel_safe: bool,
) -> Self {
Self {
name: name.into(),
description: description.into(),
input_schema,
parallel_safe,
}
}
}
#[async_trait]
pub trait Operator: Send + Sync {
async fn execute(&self, input: OperatorInput) -> Result<OperatorOutput, OperatorError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tool_metadata_construction() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string" }
},
"required": ["query"]
});
let meta = ToolMetadata::new("search", "Search the web", schema.clone(), true);
assert_eq!(meta.name, "search");
assert_eq!(meta.description, "Search the web");
assert_eq!(meta.input_schema, schema);
assert!(meta.parallel_safe);
}
#[test]
fn tool_metadata_serde_roundtrip() {
let meta = ToolMetadata::new(
"code_exec",
"Execute code in a sandbox",
serde_json::json!({"type": "object"}),
false,
);
let json = serde_json::to_string(&meta).expect("serialize");
let back: ToolMetadata = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back.name, "code_exec");
assert_eq!(back.description, "Execute code in a sandbox");
assert!(!back.parallel_safe);
}
}