use async_trait::async_trait;
use serde_json::{json, Value};
use crate::error::{CognisError, Result};
use crate::outputs::ChatGeneration;
use crate::runnables::base::Runnable;
use crate::runnables::config::RunnableConfig;
use super::base::OutputParser;
pub struct ToolCallOutputParser {
pub first_tool_only: bool,
pub return_id: bool,
}
impl ToolCallOutputParser {
pub fn new() -> Self {
Self {
first_tool_only: false,
return_id: true,
}
}
pub fn first_only(mut self) -> Self {
self.first_tool_only = true;
self
}
pub fn without_id(mut self) -> Self {
self.return_id = false;
self
}
pub fn parse_chat_generation(&self, generation: &ChatGeneration) -> Result<Value> {
let tool_calls = match &generation.message {
crate::messages::Message::Ai(ai_msg) => &ai_msg.tool_calls,
_ => {
return if self.first_tool_only {
Err(CognisError::OutputParserError {
message: "Expected AIMessage in ChatGeneration for tool call parsing"
.into(),
observation: None,
llm_output: Some(generation.text.clone()),
})
} else {
Ok(json!([]))
}
}
};
if tool_calls.is_empty() {
return if self.first_tool_only {
Err(CognisError::OutputParserError {
message: "No tool calls found in AIMessage".into(),
observation: None,
llm_output: Some(generation.text.clone()),
})
} else {
Ok(json!([]))
};
}
let calls: Vec<Value> = tool_calls
.iter()
.map(|tc| {
let mut obj = json!({
"type": tc.name,
"args": tc.args,
});
if self.return_id {
if let Some(id) = &tc.id {
obj.as_object_mut().unwrap().insert("id".into(), json!(id));
}
}
obj
})
.collect();
if self.first_tool_only {
Ok(calls.into_iter().next().unwrap())
} else {
Ok(Value::Array(calls))
}
}
}
impl Default for ToolCallOutputParser {
fn default() -> Self {
Self::new()
}
}
impl OutputParser for ToolCallOutputParser {
fn parse(&self, text: &str) -> Result<Value> {
Err(CognisError::OutputParserError {
message: "ToolCallOutputParser requires ChatGeneration, not raw text. \
Use parse_chat_generation() instead."
.into(),
observation: None,
llm_output: Some(text.to_string()),
})
}
fn get_format_instructions(&self) -> Option<String> {
None
}
fn parser_type(&self) -> &str {
"tool_call_output_parser"
}
}
#[async_trait]
impl Runnable for ToolCallOutputParser {
fn name(&self) -> &str {
"ToolCallOutputParser"
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
if let Ok(gen) = serde_json::from_value::<ChatGeneration>(input.clone()) {
return self.parse_chat_generation(&gen);
}
if let Ok(ai_msg) = serde_json::from_value::<crate::messages::AIMessage>(input) {
let gen = ChatGeneration::new(ai_msg);
return self.parse_chat_generation(&gen);
}
Err(CognisError::TypeMismatch {
expected: "ChatGeneration or AIMessage".into(),
got: "unrecognized input".into(),
})
}
}