use std::marker::PhantomData;
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use crate::llm::{
BoxStream, CallOptions, ChatModel, Content, LlmError, Message, MessageChunk, ToolChoice,
ToolDefinition,
};
#[derive(Clone, Debug)]
pub struct StructuredOutputModel<
M: ChatModel,
T: DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static,
> {
inner: M,
use_tool_based: bool,
tool_name: String,
tool_definition: ToolDefinition,
_phantom: PhantomData<T>,
}
impl<M: ChatModel, T: DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static>
StructuredOutputModel<M, T>
{
#[must_use]
pub fn new(inner: M) -> Self {
let type_name = std::any::type_name::<T>();
let tool_name = format!(
"extract_{}",
type_name
.replace("::", "_")
.replace(['<', '>', ','], "_")
.replace(' ', "")
);
let schema = schemars::schema_for!(T);
let parameters =
serde_json::to_value(&schema).unwrap_or_else(|_| serde_json::json!({"type": "object"}));
let tool_definition = ToolDefinition {
name: tool_name.clone(),
description: format!(
"Extract structured data conforming to the schema for {type_name}"
),
parameters,
};
Self {
inner,
use_tool_based: true,
tool_name,
tool_definition,
_phantom: PhantomData,
}
}
#[must_use]
pub const fn with_tool_based_extraction(mut self, enabled: bool) -> Self {
self.use_tool_based = enabled;
self
}
#[must_use]
#[allow(
clippy::missing_const_for_fn,
reason = "Cannot be const in current Rust version"
)]
pub fn inner(&self) -> &M {
&self.inner
}
pub fn extract(&self, message: &Message) -> Result<T, LlmError> {
if !message.tool_calls.is_empty()
&& let Ok(result) = Self::extract_from_tool_call(message)
{
return Ok(result);
}
Self::extract_from_text(message)
}
fn extract_from_tool_call(message: &Message) -> Result<T, LlmError> {
let tool_call = message.tool_calls.first().ok_or_else(|| {
LlmError::InvalidResponse(
"No tool calls found in response for tool-based extraction".to_string(),
)
})?;
serde_json::from_value(tool_call.arguments.clone()).map_err(|e| {
LlmError::InvalidResponse(format!(
"Failed to parse tool call arguments as structured output: {e}"
))
})
}
fn extract_from_text(message: &Message) -> Result<T, LlmError> {
let content = match &message.content {
Content::Text(text) => text,
Content::MultiPart(_) => {
return Err(LlmError::InvalidResponse(
"Cannot extract structured output from multipart content".to_string(),
));
}
};
serde_json::from_str(content).map_err(|e| {
LlmError::InvalidResponse(format!(
"Failed to parse structured output: {e}\nContent: {content}"
))
})
}
}
impl<M: ChatModel + Default, T: DeserializeOwned + JsonSchema + Clone + Send + Sync> Default
for StructuredOutputModel<M, T>
{
fn default() -> Self {
Self::new(M::default())
}
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl<M: ChatModel, T: DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static> ChatModel
for StructuredOutputModel<M, T>
{
async fn invoke(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> Result<Message, LlmError> {
if self.use_tool_based {
let model_with_tool = self.inner.bind_tools(vec![self.tool_definition.clone()]);
let mut merged_opts = options.cloned().unwrap_or_default();
merged_opts.tool_choice = Some(ToolChoice::Specific {
name: self.tool_name.clone(),
});
let response = model_with_tool.invoke(messages, Some(&merged_opts)).await?;
if !response.tool_calls.is_empty() && Self::extract_from_tool_call(&response).is_ok() {
return Ok(response);
}
Self::extract_from_text(&response)?;
Ok(response)
} else {
let response = self.inner.invoke(messages, options).await?;
Self::extract_from_text(&response)?;
Ok(response)
}
}
fn stream(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> BoxStream<'_, Result<MessageChunk, LlmError>> {
self.inner.stream(messages, options)
}
fn bind_tools(&self, tools: Vec<ToolDefinition>) -> Self {
let inner_with_tools = self.inner.bind_tools(tools);
Self {
inner: inner_with_tools,
use_tool_based: self.use_tool_based,
tool_name: self.tool_name.clone(),
tool_definition: self.tool_definition.clone(),
_phantom: PhantomData,
}
}
fn model_name(&self) -> &str {
self.inner.model_name()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::{MockChatModel, ToolCall};
use futures::stream::StreamExt;
use serde::Deserialize;
use serde_json::json;
#[derive(Debug, Clone, Deserialize, JsonSchema)]
struct WeatherReport {
temperature: f64,
conditions: String,
}
#[tokio::test]
async fn test_tool_based_extraction_success() {
let tool_calls = vec![ToolCall {
id: "call_extract".to_string(),
name: "weather_tool".to_string(),
arguments: json!({"temperature": 22.5, "conditions": "sunny"}),
}];
let base = MockChatModel::new("gpt-4")
.with_response("")
.with_tool_calls(tool_calls);
let model = StructuredOutputModel::<_, WeatherReport>::new(base);
let messages = vec![Message::human("What's the weather?")];
let response = model.invoke(&messages, None).await.unwrap();
assert!(!response.tool_calls.is_empty());
let extracted: WeatherReport = model.extract(&response).unwrap();
assert!((extracted.temperature - 22.5).abs() < f64::EPSILON);
assert_eq!(extracted.conditions, "sunny");
}
#[tokio::test]
async fn test_text_based_extraction_fallback() {
let base = MockChatModel::new("gpt-4")
.with_response(r#"{"temperature": 18.0, "conditions": "cloudy"}"#);
let model = StructuredOutputModel::<_, WeatherReport>::new(base);
let messages = vec![Message::human("What's the weather?")];
let response = model.invoke(&messages, None).await.unwrap();
let extracted: WeatherReport = model.extract(&response).unwrap();
assert!((extracted.temperature - 18.0).abs() < f64::EPSILON);
assert_eq!(extracted.conditions, "cloudy");
}
#[tokio::test]
async fn test_disabled_tool_based_extraction() {
let base = MockChatModel::new("gpt-4")
.with_response(r#"{"temperature": 25.0, "conditions": "hot"}"#);
let model =
StructuredOutputModel::<_, WeatherReport>::new(base).with_tool_based_extraction(false);
let messages = vec![Message::human("What's the weather?")];
let response = model.invoke(&messages, None).await.unwrap();
let extracted: WeatherReport = model.extract(&response).unwrap();
assert!((extracted.temperature - 25.0).abs() < f64::EPSILON);
assert_eq!(extracted.conditions, "hot");
}
#[tokio::test]
async fn test_invalid_tool_call_falls_back_to_text() {
let tool_calls = vec![ToolCall {
id: "call_bad".to_string(),
name: "structured_output".to_string(),
arguments: json!({"temperature": "not_a_number", "conditions": 42}),
}];
let base = MockChatModel::new("gpt-4")
.with_response(r#"{"temperature": 30.0, "conditions": "warm"}"#)
.with_tool_calls(tool_calls);
let model = StructuredOutputModel::<_, WeatherReport>::new(base);
let messages = vec![Message::human("What's the weather?")];
let response = model.invoke(&messages, None).await.unwrap();
let extracted: WeatherReport = model.extract(&response).unwrap();
assert!((extracted.temperature - 30.0).abs() < f64::EPSILON);
assert_eq!(extracted.conditions, "warm");
}
#[tokio::test]
async fn test_stream_returns_chunks() {
let base = MockChatModel::new("gpt-4")
.with_response(r#"{"temperature": 21.0, "conditions": "rainy"}"#);
let model = StructuredOutputModel::<_, WeatherReport>::new(base);
let messages = vec![Message::human("What's the weather?")];
let mut stream = model.stream(&messages, None);
let chunk_result = stream.next().await;
assert!(chunk_result.is_some());
let chunk = chunk_result.unwrap().unwrap();
assert!(!chunk.content.is_empty());
}
#[tokio::test]
async fn test_stream_with_tool_based_extraction() {
let tool_calls = vec![ToolCall {
id: "call_stream".to_string(),
name: "weather_tool".to_string(),
arguments: json!({"temperature": 19.5, "conditions": "windy"}),
}];
let base = MockChatModel::new("gpt-4")
.with_response("")
.with_tool_calls(tool_calls);
let model = StructuredOutputModel::<_, WeatherReport>::new(base);
let messages = vec![Message::human("What's the weather?")];
let mut stream = model.stream(&messages, None);
let chunk_result = stream.next().await;
assert!(chunk_result.is_some());
let chunk = chunk_result.unwrap().unwrap();
assert!(chunk.content.is_empty() || !chunk.tool_call_chunks.is_empty());
}
}