use std::collections::HashMap;
use serde_json::{json, Value};
use cognis_core::agents::{AgentAction, AgentFinish};
use cognis_core::error::Result;
use cognis_core::messages::Message;
#[derive(Debug, Clone)]
pub enum AgentOutput {
Actions(Vec<AgentAction>),
Finish(AgentFinish),
}
pub fn parse_ai_message_to_agent_output(ai_message: &Value) -> Result<AgentOutput> {
let tool_calls = ai_message
.get("tool_calls")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_default();
if tool_calls.is_empty() {
let content = ai_message
.get("content")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let mut return_values = HashMap::new();
return_values.insert("output".to_string(), Value::String(content));
return Ok(AgentOutput::Finish(AgentFinish::new(return_values, "")));
}
let actions = tool_calls
.into_iter()
.map(|tc| {
let tool_name = tc
.get("name")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let tool_input = tc.get("args").cloned().unwrap_or(json!({}));
let _tool_call_id = tc
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let log = format!(
"Calling tool `{}` with args: {}",
tool_name,
serde_json::to_string(&tool_input).unwrap_or_default()
);
AgentAction::new(tool_name, tool_input, log)
})
.collect();
Ok(AgentOutput::Actions(actions))
}
pub fn format_to_tool_messages(intermediate_steps: &[(AgentAction, String)]) -> Vec<Message> {
let mut messages = Vec::new();
for (action, observation) in intermediate_steps {
messages.push(Message::tool(observation.as_str(), &action.tool));
}
messages
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_parse_ai_message_no_tool_calls() {
let msg = json!({
"content": "The answer is 42",
"tool_calls": []
});
match parse_ai_message_to_agent_output(&msg).unwrap() {
AgentOutput::Finish(finish) => {
assert_eq!(
finish.return_values.get("output"),
Some(&Value::String("The answer is 42".into()))
);
}
_ => panic!("Expected Finish"),
}
}
#[test]
fn test_parse_ai_message_with_tool_calls() {
let msg = json!({
"content": "",
"tool_calls": [
{
"name": "search",
"args": {"query": "rust lang"},
"id": "call_123"
}
]
});
match parse_ai_message_to_agent_output(&msg).unwrap() {
AgentOutput::Actions(actions) => {
assert_eq!(actions.len(), 1);
assert_eq!(actions[0].tool, "search");
assert_eq!(actions[0].tool_input, json!({"query": "rust lang"}));
}
_ => panic!("Expected Actions"),
}
}
#[test]
fn test_parse_ai_message_multiple_tool_calls() {
let msg = json!({
"content": "",
"tool_calls": [
{"name": "search", "args": {"q": "a"}, "id": "1"},
{"name": "calc", "args": {"expr": "2+2"}, "id": "2"}
]
});
match parse_ai_message_to_agent_output(&msg).unwrap() {
AgentOutput::Actions(actions) => {
assert_eq!(actions.len(), 2);
assert_eq!(actions[0].tool, "search");
assert_eq!(actions[1].tool, "calc");
}
_ => panic!("Expected Actions"),
}
}
#[test]
fn test_format_to_tool_messages() {
let steps = vec![(
AgentAction::new("search", json!({"q": "test"}), "Calling search"),
"Search result: found it".to_string(),
)];
let messages = format_to_tool_messages(&steps);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].content().text(), "Search result: found it");
}
}