use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentAction {
pub tool: String,
pub tool_input: ToolInput,
pub log: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolInput {
String(String),
Object(serde_json::Value),
}
impl Default for ToolInput {
fn default() -> Self {
ToolInput::String(String::new())
}
}
impl std::fmt::Display for ToolInput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ToolInput::String(s) => write!(f, "{}", s),
ToolInput::Object(v) => write!(f, "{}", serde_json::to_string(v).unwrap_or_default()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentFinish {
pub return_values: HashMap<String, serde_json::Value>,
pub log: String,
}
impl AgentFinish {
pub fn new(output: String, log: String) -> Self {
let mut return_values = HashMap::new();
return_values.insert("output".to_string(), serde_json::Value::String(output));
Self { return_values, log }
}
pub fn output(&self) -> Option<&str> {
self.return_values.get("output").and_then(|v| v.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentStep {
pub action: AgentAction,
pub observation: String,
}
impl AgentStep {
pub fn new(action: AgentAction, observation: String) -> Self {
Self {
action,
observation,
}
}
}
#[derive(Debug, Clone)]
pub enum AgentOutput {
Action(AgentAction),
Actions(Vec<AgentAction>),
Finish(AgentFinish),
}
impl AgentOutput {
pub fn is_finish(&self) -> bool {
matches!(self, AgentOutput::Finish(_))
}
pub fn is_action(&self) -> bool {
matches!(self, AgentOutput::Action(_) | AgentOutput::Actions(_))
}
pub fn action(&self) -> Option<&AgentAction> {
match self {
AgentOutput::Action(action) => Some(action),
_ => None,
}
}
pub fn actions(&self) -> Vec<&AgentAction> {
match self {
AgentOutput::Action(action) => vec![action],
AgentOutput::Actions(actions) => actions.iter().collect(),
_ => vec![],
}
}
pub fn finish(&self) -> Option<&AgentFinish> {
match self {
AgentOutput::Finish(finish) => Some(finish),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_action(tool: &str, input: &str) -> AgentAction {
AgentAction {
tool: tool.to_string(),
tool_input: ToolInput::String(input.to_string()),
log: "test".to_string(),
}
}
#[test]
fn test_agent_output_single_action() {
let action = create_action("calculator", "1+2");
let output = AgentOutput::Action(action);
assert!(output.is_action());
assert!(!output.is_finish());
assert_eq!(output.actions().len(), 1);
}
#[test]
fn test_agent_output_multiple_actions() {
let actions = vec![
create_action("calculator", "1+2"),
create_action("datetime", "now"),
];
let output = AgentOutput::Actions(actions);
assert!(output.is_action());
assert!(!output.is_finish());
assert_eq!(output.actions().len(), 2);
assert!(output.action().is_none());
}
#[test]
fn test_agent_output_finish() {
let finish = AgentFinish::new("answer".to_string(), "log".to_string());
let output = AgentOutput::Finish(finish);
assert!(!output.is_action());
assert!(output.is_finish());
assert_eq!(output.actions().len(), 0);
assert!(output.finish().is_some());
}
#[test]
fn test_agent_finish_output() {
let finish = AgentFinish::new("the answer is 42".to_string(), String::new());
assert_eq!(finish.output(), Some("the answer is 42"));
}
}