use async_trait::async_trait;
use serde_json::Value;
use cognis_core::error::{CognisError, Result};
use cognis_core::language_models::chat_model::{
BaseChatModel, ChatStream, ModelProfile, ToolChoice,
};
use cognis_core::messages::{AIMessage, Message};
use cognis_core::outputs::{ChatGeneration, ChatResult};
use cognis_core::tools::ToolSchema;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StructuredOutputMethod {
ToolCalling,
JsonMode,
}
impl StructuredOutputMethod {
pub fn from_str_or_default(s: Option<&str>) -> Self {
match s {
Some("json_mode") => Self::JsonMode,
_ => Self::ToolCalling,
}
}
}
pub struct StructuredOutputChatModel {
inner: Box<dyn BaseChatModel>,
schema: Value,
method: StructuredOutputMethod,
tool_name: String,
include_raw: bool,
}
impl StructuredOutputChatModel {
fn extract_tool_call_output(&self, result: &ChatResult) -> Result<Value> {
let gen = result
.generations
.first()
.ok_or_else(|| CognisError::Other("No generations returned".into()))?;
let ai_msg = match &gen.message {
Message::Ai(ai) => ai,
_ => {
return Err(CognisError::Other(
"Expected AIMessage in generation".into(),
))
}
};
for tc in &ai_msg.tool_calls {
if tc.name == self.tool_name {
return serde_json::to_value(&tc.args).map_err(|e| {
CognisError::Other(format!("Failed to serialize tool call args: {}", e))
});
}
}
Err(CognisError::Other(format!(
"No tool call found with name '{}'. Tool calls present: {:?}",
self.tool_name,
ai_msg
.tool_calls
.iter()
.map(|tc| &tc.name)
.collect::<Vec<_>>()
)))
}
}
#[async_trait]
impl BaseChatModel for StructuredOutputChatModel {
async fn _generate(&self, messages: &[Message], stop: Option<&[String]>) -> Result<ChatResult> {
match self.method {
StructuredOutputMethod::ToolCalling => {
let result = self.inner._generate(messages, stop).await?;
let structured_output = self.extract_tool_call_output(&result)?;
let json_string = serde_json::to_string(&structured_output)
.map_err(|e| CognisError::Other(format!("JSON serialization error: {}", e)))?;
let mut ai_message = AIMessage::new(&json_string);
if let Some(gen) = result.generations.first() {
if let Message::Ai(ref original) = gen.message {
ai_message.usage_metadata = original.usage_metadata.clone();
ai_message.base.id = original.base.id.clone();
if self.include_raw {
ai_message.tool_calls = original.tool_calls.clone();
}
}
}
let generation = ChatGeneration::new(ai_message);
Ok(ChatResult {
generations: vec![generation],
llm_output: result.llm_output,
})
}
StructuredOutputMethod::JsonMode => {
let schema_str = serde_json::to_string_pretty(&self.schema)
.unwrap_or_else(|_| self.schema.to_string());
let system_instruction = format!(
"Respond with valid JSON matching this schema:\n{}",
schema_str
);
let mut augmented_messages = vec![Message::System(
cognis_core::messages::SystemMessage::new(&system_instruction),
)];
augmented_messages.extend_from_slice(messages);
let result = self.inner._generate(&augmented_messages, stop).await?;
if let Some(gen) = result.generations.first() {
if let Message::Ai(ref ai) = gen.message {
let content = ai.base.content.text();
if !content.is_empty() {
let _: Value = serde_json::from_str(&content).map_err(|e| {
CognisError::Other(format!(
"Model response is not valid JSON: {}. Content: {}",
e, content
))
})?;
}
}
}
Ok(result)
}
}
}
fn llm_type(&self) -> &str {
"structured_output"
}
async fn _stream(&self, messages: &[Message], stop: Option<&[String]>) -> Result<ChatStream> {
self.inner._stream(messages, stop).await
}
fn bind_tools(
&self,
tools: &[ToolSchema],
tool_choice: Option<ToolChoice>,
) -> Result<Box<dyn BaseChatModel>> {
self.inner.bind_tools(tools, tool_choice)
}
fn profile(&self) -> ModelProfile {
self.inner.profile()
}
fn get_num_tokens_from_messages(&self, messages: &[Message]) -> usize {
self.inner.get_num_tokens_from_messages(messages)
}
}
pub fn with_structured_output(
model: Box<dyn BaseChatModel>,
schema: Value,
method: Option<&str>,
include_raw: bool,
) -> Result<Box<dyn BaseChatModel>> {
let method = StructuredOutputMethod::from_str_or_default(method);
let tool_name = schema
.get("title")
.and_then(|t| t.as_str())
.unwrap_or("structured_output")
.to_string();
let description = schema
.get("description")
.and_then(|d| d.as_str())
.unwrap_or("Structured output tool")
.to_string();
let inner = match method {
StructuredOutputMethod::ToolCalling => {
let tool_schema = ToolSchema {
name: tool_name.clone(),
description,
parameters: Some(schema.clone()),
extras: None,
};
model.bind_tools(&[tool_schema], Some(ToolChoice::Tool(tool_name.clone())))?
}
StructuredOutputMethod::JsonMode => {
model
}
};
Ok(Box::new(StructuredOutputChatModel {
inner,
schema,
method,
tool_name,
include_raw,
}))
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::messages::{HumanMessage, ToolCall};
use serde_json::json;
use std::collections::HashMap;
struct MockToolCallModel {
tool_calls: Vec<ToolCall>,
bound_tools: Vec<ToolSchema>,
tool_choice: Option<ToolChoice>,
}
impl MockToolCallModel {
fn new(tool_calls: Vec<ToolCall>) -> Self {
Self {
tool_calls,
bound_tools: Vec::new(),
tool_choice: None,
}
}
}
#[async_trait]
impl BaseChatModel for MockToolCallModel {
async fn _generate(
&self,
_messages: &[Message],
_stop: Option<&[String]>,
) -> Result<ChatResult> {
let mut ai_message = AIMessage::new("");
ai_message.tool_calls = self.tool_calls.clone();
let generation = ChatGeneration::new(ai_message);
Ok(ChatResult {
generations: vec![generation],
llm_output: None,
})
}
fn llm_type(&self) -> &str {
"mock"
}
fn bind_tools(
&self,
tools: &[ToolSchema],
tool_choice: Option<ToolChoice>,
) -> Result<Box<dyn BaseChatModel>> {
Ok(Box::new(MockToolCallModel {
tool_calls: self.tool_calls.clone(),
bound_tools: tools.to_vec(),
tool_choice,
}))
}
}
struct MockTextModel {
content: String,
}
impl MockTextModel {
fn new(content: &str) -> Self {
Self {
content: content.to_string(),
}
}
}
#[async_trait]
impl BaseChatModel for MockTextModel {
async fn _generate(
&self,
_messages: &[Message],
_stop: Option<&[String]>,
) -> Result<ChatResult> {
let ai_message = AIMessage::new(&self.content);
let generation = ChatGeneration::new(ai_message);
Ok(ChatResult {
generations: vec![generation],
llm_output: None,
})
}
fn llm_type(&self) -> &str {
"mock_text"
}
}
#[tokio::test]
async fn test_structured_output_tool_calling_extracts_args() {
let mut args = HashMap::new();
args.insert("name".to_string(), json!("Alice"));
args.insert("age".to_string(), json!(30));
let mock = MockToolCallModel::new(vec![ToolCall {
name: "Person".to_string(),
args,
id: Some("call_1".to_string()),
}]);
let schema = json!({
"title": "Person",
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
});
let structured =
with_structured_output(Box::new(mock), schema, Some("tool_calling"), false).unwrap();
let messages = vec![Message::Human(HumanMessage::new("Who is Alice?"))];
let result = structured._generate(&messages, None).await.unwrap();
assert_eq!(result.generations.len(), 1);
if let Message::Ai(ref ai) = result.generations[0].message {
let parsed: Value = serde_json::from_str(&ai.base.content.text()).unwrap();
assert_eq!(parsed["name"], "Alice");
assert_eq!(parsed["age"], 30);
assert!(ai.tool_calls.is_empty());
} else {
panic!("Expected AIMessage");
}
}
#[tokio::test]
async fn test_structured_output_tool_calling_with_include_raw() {
let mut args = HashMap::new();
args.insert("city".to_string(), json!("Paris"));
let mock = MockToolCallModel::new(vec![ToolCall {
name: "Location".to_string(),
args,
id: Some("call_2".to_string()),
}]);
let schema = json!({
"title": "Location",
"type": "object",
"properties": {
"city": {"type": "string"}
}
});
let structured = with_structured_output(
Box::new(mock),
schema,
Some("tool_calling"),
true, )
.unwrap();
let messages = vec![Message::Human(HumanMessage::new("Where?"))];
let result = structured._generate(&messages, None).await.unwrap();
if let Message::Ai(ref ai) = result.generations[0].message {
let parsed: Value = serde_json::from_str(&ai.base.content.text()).unwrap();
assert_eq!(parsed["city"], "Paris");
assert_eq!(ai.tool_calls.len(), 1);
assert_eq!(ai.tool_calls[0].name, "Location");
} else {
panic!("Expected AIMessage");
}
}
#[tokio::test]
async fn test_structured_output_json_mode() {
let mock = MockTextModel::new(r#"{"name": "Bob", "age": 25}"#);
let schema = json!({
"title": "Person",
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
}
});
let structured =
with_structured_output(Box::new(mock), schema, Some("json_mode"), false).unwrap();
let messages = vec![Message::Human(HumanMessage::new("Tell me about Bob"))];
let result = structured._generate(&messages, None).await.unwrap();
if let Message::Ai(ref ai) = result.generations[0].message {
let parsed: Value = serde_json::from_str(&ai.base.content.text()).unwrap();
assert_eq!(parsed["name"], "Bob");
assert_eq!(parsed["age"], 25);
} else {
panic!("Expected AIMessage");
}
}
#[tokio::test]
async fn test_structured_output_no_matching_tool_call_errors() {
let mut args = HashMap::new();
args.insert("x".to_string(), json!(1));
let mock = MockToolCallModel::new(vec![ToolCall {
name: "WrongTool".to_string(),
args,
id: None,
}]);
let schema = json!({
"title": "ExpectedTool",
"type": "object",
"properties": {"x": {"type": "integer"}}
});
let structured =
with_structured_output(Box::new(mock), schema, Some("tool_calling"), false).unwrap();
let messages = vec![Message::Human(HumanMessage::new("test"))];
let result = structured._generate(&messages, None).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("No tool call found with name 'ExpectedTool'"));
}
#[tokio::test]
async fn test_structured_output_json_mode_invalid_json_errors() {
let mock = MockTextModel::new("this is not json");
let schema = json!({"title": "Test", "type": "object"});
let structured =
with_structured_output(Box::new(mock), schema, Some("json_mode"), false).unwrap();
let messages = vec![Message::Human(HumanMessage::new("test"))];
let result = structured._generate(&messages, None).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not valid JSON"));
}
#[test]
fn test_structured_output_method_parsing() {
assert_eq!(
StructuredOutputMethod::from_str_or_default(None),
StructuredOutputMethod::ToolCalling
);
assert_eq!(
StructuredOutputMethod::from_str_or_default(Some("tool_calling")),
StructuredOutputMethod::ToolCalling
);
assert_eq!(
StructuredOutputMethod::from_str_or_default(Some("json_mode")),
StructuredOutputMethod::JsonMode
);
}
#[test]
fn test_with_structured_output_default_tool_name() {
let mock = MockToolCallModel::new(vec![]);
let schema = json!({"type": "object"});
let result = with_structured_output(Box::new(mock), schema, Some("tool_calling"), false);
assert!(result.is_ok());
}
#[test]
fn test_llm_type_returns_structured_output() {
let mock = MockTextModel::new("");
let schema = json!({"title": "Test", "type": "object"});
let structured =
with_structured_output(Box::new(mock), schema, Some("json_mode"), false).unwrap();
assert_eq!(structured.llm_type(), "structured_output");
}
}