use crate::types::{Message, ToolCall};
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum StepType {
Think,
ToolCall {
tool_name: String,
call_id: String,
},
ToolResult {
tool_name: String,
call_id: String,
},
Response,
Error,
WaitForHuman,
}
impl std::fmt::Display for StepType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StepType::Think => write!(f, "think"),
StepType::ToolCall { tool_name, .. } => write!(f, "tool_call:{}", tool_name),
StepType::ToolResult { tool_name, .. } => write!(f, "tool_result:{}", tool_name),
StepType::Response => write!(f, "response"),
StepType::Error => write!(f, "error"),
StepType::WaitForHuman => write!(f, "wait_for_human"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum StepResult {
Continue,
Done {
response: String,
},
ToolCalls {
calls: Vec<ToolCall>,
},
WaitForHuman {
prompt: String,
},
Error {
message: String,
},
}
impl StepResult {
pub fn is_done(&self) -> bool {
matches!(self, StepResult::Done { .. })
}
pub fn has_tool_calls(&self) -> bool {
matches!(self, StepResult::ToolCalls { .. })
}
pub fn is_error(&self) -> bool {
matches!(self, StepResult::Error { .. })
}
pub fn response(&self) -> Option<&str> {
match self {
StepResult::Done { response } => Some(response),
_ => None,
}
}
pub fn tool_calls(&self) -> Option<&[ToolCall]> {
match self {
StepResult::ToolCalls { calls } => Some(calls),
_ => None,
}
}
pub fn error_message(&self) -> Option<&str> {
match self {
StepResult::Error { message } => Some(message),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentStep {
pub step_number: usize,
pub step_type: StepType,
pub input: Option<serde_json::Value>,
pub output: Option<serde_json::Value>,
pub result: StepResult,
#[serde(with = "duration_serde")]
pub duration: Duration,
pub tokens: Option<TokenUsage>,
pub messages: Vec<Message>,
#[serde(default)]
pub metadata: std::collections::HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl AgentStep {
pub fn new(step_number: usize, step_type: StepType) -> Self {
Self {
step_number,
step_type,
input: None,
output: None,
result: StepResult::Continue,
duration: Duration::ZERO,
tokens: None,
messages: Vec::new(),
metadata: std::collections::HashMap::new(),
}
}
pub fn with_input(mut self, input: serde_json::Value) -> Self {
self.input = Some(input);
self
}
pub fn with_output(mut self, output: serde_json::Value) -> Self {
self.output = Some(output);
self
}
pub fn with_result(mut self, result: StepResult) -> Self {
self.result = result;
self
}
pub fn with_duration(mut self, duration: Duration) -> Self {
self.duration = duration;
self
}
pub fn with_tokens(mut self, tokens: TokenUsage) -> Self {
self.tokens = Some(tokens);
self
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
}
mod duration_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
duration.as_millis().serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let millis = u64::deserialize(deserializer)?;
Ok(Duration::from_millis(millis))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_step_type_display() {
assert_eq!(StepType::Think.to_string(), "think");
assert_eq!(
StepType::ToolCall {
tool_name: "calc".to_string(),
call_id: "1".to_string()
}
.to_string(),
"tool_call:calc"
);
assert_eq!(StepType::Response.to_string(), "response");
}
#[test]
fn test_step_result_methods() {
let done = StepResult::Done {
response: "Hello!".to_string(),
};
assert!(done.is_done());
assert_eq!(done.response(), Some("Hello!"));
let tool_calls = StepResult::ToolCalls { calls: vec![] };
assert!(tool_calls.has_tool_calls());
assert!(!tool_calls.is_done());
let error = StepResult::Error {
message: "Oops".to_string(),
};
assert!(error.is_error());
assert_eq!(error.error_message(), Some("Oops"));
}
#[test]
fn test_agent_step_builder() {
let step = AgentStep::new(0, StepType::Think)
.with_input(serde_json::json!({"prompt": "Hello"}))
.with_output(serde_json::json!({"response": "Hi!"}))
.with_result(StepResult::Done {
response: "Hi!".to_string(),
})
.with_duration(Duration::from_millis(100));
assert_eq!(step.step_number, 0);
assert!(step.result.is_done());
assert_eq!(step.duration.as_millis(), 100);
}
#[test]
fn test_token_usage() {
let tokens = TokenUsage {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
};
let step = AgentStep::new(0, StepType::Think).with_tokens(tokens);
assert_eq!(step.tokens.unwrap().total_tokens, 150);
}
}