use async_trait::async_trait;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use tokio_stream::Stream;
use traitclaw_core::traits::provider::{ModelInfo, Provider};
use traitclaw_core::types::completion::{CompletionRequest, CompletionResponse};
use traitclaw_core::types::message::{Message, MessageRole};
use traitclaw_core::types::stream::StreamEvent;
use traitclaw_core::{Error, Result};
pub struct StructuredOutputProvider<T>
where
T: JsonSchema + DeserializeOwned + Send + Sync + 'static,
{
inner: Box<dyn Provider>,
schema: serde_json::Value,
_phantom: std::marker::PhantomData<T>,
}
impl<T> StructuredOutputProvider<T>
where
T: JsonSchema + DeserializeOwned + Send + Sync + 'static,
{
#[must_use]
pub fn new(inner: impl Provider + 'static, schema_name: impl Into<String>) -> Self {
let schema_name = schema_name.into();
let root_schema = schemars::schema_for!(T);
let mut schema_value =
serde_json::to_value(&root_schema).unwrap_or(serde_json::Value::Null);
if let Some(obj) = schema_value.as_object_mut() {
obj.insert("title".to_string(), serde_json::Value::String(schema_name));
}
Self {
inner: Box::new(inner),
schema: schema_value,
_phantom: std::marker::PhantomData,
}
}
pub fn parse(response: &CompletionResponse) -> Result<T> {
use traitclaw_core::types::completion::ResponseContent;
match &response.content {
ResponseContent::Text(text) => serde_json::from_str(text)
.map_err(|e| Error::provider(format!("Structured output parse error: {e}"))),
ResponseContent::ToolCalls(_) => Err(Error::provider(
"Expected structured text response, got tool calls",
)),
}
}
fn inject_schema_instruction(&self, mut request: CompletionRequest) -> CompletionRequest {
let schema_str = serde_json::to_string_pretty(&self.schema).unwrap_or_default();
let instruction = format!(
"You MUST respond with valid JSON matching this exact schema. \
Do not add any text, explanation, or markdown. \
Only output the raw JSON object.\n\nSchema:\n{schema_str}"
);
let has_system = request
.messages
.iter()
.any(|m| matches!(m.role, MessageRole::System));
if has_system {
for msg in &mut request.messages {
if matches!(msg.role, MessageRole::System) {
msg.content.push_str("\n\n");
msg.content.push_str(&instruction);
break;
}
}
} else {
request.messages.insert(
0,
Message {
role: MessageRole::System,
content: instruction,
tool_call_id: None,
},
);
}
request
}
}
#[async_trait]
impl<T> Provider for StructuredOutputProvider<T>
where
T: JsonSchema + DeserializeOwned + Send + Sync + 'static,
{
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
let enriched = self.inject_schema_instruction(request);
self.inner.complete(enriched).await
}
async fn stream(
&self,
request: CompletionRequest,
) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
let enriched = self.inject_schema_instruction(request);
self.inner.stream(enriched).await
}
fn model_info(&self) -> &ModelInfo {
self.inner.model_info()
}
}