use futures::future::BoxFuture;
use serde::de::DeserializeOwned;
use crate::error::LlmError;
use crate::provider::{
ChatExtras, ChatResponse, ChatStream, LlmProvider, Message, Role, ToolDefinition,
cached_schema, short_type_name,
};
mod private {
pub trait Sealed {}
impl<T: super::LlmProvider> Sealed for T {}
}
pub trait LlmProviderDyn: private::Sealed + std::fmt::Debug + Send + Sync {
fn context_window(&self) -> Option<usize>;
fn chat<'a>(&'a self, messages: &'a [Message]) -> BoxFuture<'a, Result<String, LlmError>>;
fn chat_stream<'a>(
&'a self,
messages: &'a [Message],
) -> BoxFuture<'a, Result<ChatStream, LlmError>>;
fn supports_streaming(&self) -> bool;
fn embed<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Vec<f32>, LlmError>>;
fn embed_batch<'a>(
&'a self,
texts: &'a [&'a str],
) -> BoxFuture<'a, Result<Vec<Vec<f32>>, LlmError>>;
fn supports_embeddings(&self) -> bool;
fn name(&self) -> &str;
fn model_identifier(&self) -> &str;
fn supports_vision(&self) -> bool;
fn supports_tool_use(&self) -> bool;
fn chat_with_tools<'a>(
&'a self,
messages: &'a [Message],
tools: &'a [ToolDefinition],
) -> BoxFuture<'a, Result<ChatResponse, LlmError>>;
fn last_cache_usage(&self) -> Option<(u64, u64)>;
fn last_usage(&self) -> Option<(u64, u64)>;
fn take_compaction_summary(&self) -> Option<String>;
fn chat_with_extras<'a>(
&'a self,
messages: &'a [Message],
) -> BoxFuture<'a, Result<(String, ChatExtras), LlmError>>;
#[must_use]
fn debug_request_json(
&self,
messages: &[Message],
tools: &[ToolDefinition],
stream: bool,
) -> serde_json::Value;
fn list_models(&self) -> Vec<String>;
fn supports_structured_output(&self) -> bool;
}
impl<T: LlmProvider + std::fmt::Debug + Send + Sync + 'static> LlmProviderDyn for T {
fn context_window(&self) -> Option<usize> {
LlmProvider::context_window(self)
}
fn chat<'a>(&'a self, messages: &'a [Message]) -> BoxFuture<'a, Result<String, LlmError>> {
Box::pin(LlmProvider::chat(self, messages))
}
fn chat_stream<'a>(
&'a self,
messages: &'a [Message],
) -> BoxFuture<'a, Result<ChatStream, LlmError>> {
Box::pin(LlmProvider::chat_stream(self, messages))
}
fn supports_streaming(&self) -> bool {
LlmProvider::supports_streaming(self)
}
fn embed<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Vec<f32>, LlmError>> {
Box::pin(LlmProvider::embed(self, text))
}
fn embed_batch<'a>(
&'a self,
texts: &'a [&'a str],
) -> BoxFuture<'a, Result<Vec<Vec<f32>>, LlmError>> {
Box::pin(LlmProvider::embed_batch(self, texts))
}
fn supports_embeddings(&self) -> bool {
LlmProvider::supports_embeddings(self)
}
fn name(&self) -> &str {
LlmProvider::name(self)
}
fn model_identifier(&self) -> &str {
LlmProvider::model_identifier(self)
}
fn supports_vision(&self) -> bool {
LlmProvider::supports_vision(self)
}
fn supports_tool_use(&self) -> bool {
LlmProvider::supports_tool_use(self)
}
fn chat_with_tools<'a>(
&'a self,
messages: &'a [Message],
tools: &'a [ToolDefinition],
) -> BoxFuture<'a, Result<ChatResponse, LlmError>> {
Box::pin(LlmProvider::chat_with_tools(self, messages, tools))
}
fn last_cache_usage(&self) -> Option<(u64, u64)> {
LlmProvider::last_cache_usage(self)
}
fn last_usage(&self) -> Option<(u64, u64)> {
LlmProvider::last_usage(self)
}
fn take_compaction_summary(&self) -> Option<String> {
LlmProvider::take_compaction_summary(self)
}
fn chat_with_extras<'a>(
&'a self,
messages: &'a [Message],
) -> BoxFuture<'a, Result<(String, ChatExtras), LlmError>> {
Box::pin(LlmProvider::chat_with_extras(self, messages))
}
fn debug_request_json(
&self,
messages: &[Message],
tools: &[ToolDefinition],
stream: bool,
) -> serde_json::Value {
LlmProvider::debug_request_json(self, messages, tools, stream)
}
fn list_models(&self) -> Vec<String> {
LlmProvider::list_models(self)
}
fn supports_structured_output(&self) -> bool {
LlmProvider::supports_structured_output(self)
}
}
pub async fn chat_typed_dyn<T, P>(provider: &P, messages: &[Message]) -> Result<T, LlmError>
where
T: DeserializeOwned + schemars::JsonSchema + 'static,
P: ?Sized + LlmProviderDyn,
{
let (_, schema_json) = cached_schema::<T>()?;
let type_name = short_type_name::<T>();
let instruction = format!(
"Respond with a valid JSON object matching this schema. \
Output ONLY the JSON, no markdown fences or extra text.\n\n\
Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
);
let mut augmented = messages.to_vec();
augmented.insert(0, Message::from_legacy(Role::System, instruction));
let raw = provider.chat(&augmented).await?;
let cleaned = strip_json_fences(&raw);
match serde_json::from_str::<T>(cleaned) {
Ok(val) => Ok(val),
Err(first_err) => {
augmented.push(Message::from_legacy(Role::Assistant, &raw));
augmented.push(Message::from_legacy(
Role::User,
format!(
"Your response was not valid JSON. Error: {first_err}. \
Please output ONLY valid JSON matching the schema."
),
));
let retry_raw = provider.chat(&augmented).await?;
let retry_cleaned = strip_json_fences(&retry_raw);
serde_json::from_str::<T>(retry_cleaned)
.map_err(|e| LlmError::StructuredParse(format!("parse failed after retry: {e}")))
}
}
}
fn strip_json_fences(s: &str) -> &str {
s.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim()
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::provider::{ChatStream, StreamChunk};
#[derive(Debug)]
struct StubProvider {
response: String,
}
impl LlmProvider for StubProvider {
async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
Ok(self.response.clone())
}
async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
let response = LlmProvider::chat(self, messages).await?;
Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
response,
)))))
}
fn supports_streaming(&self) -> bool {
false
}
async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
Ok(vec![0.1, 0.2, 0.3])
}
fn supports_embeddings(&self) -> bool {
false
}
fn name(&self) -> &'static str {
"stub"
}
}
#[tokio::test]
async fn dyn_chat_works() {
let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
response: "hello".into(),
});
let msgs = vec![Message::from_legacy(Role::User, "test")];
let result = provider.chat(&msgs).await.unwrap();
assert_eq!(result, "hello");
}
#[tokio::test]
async fn dyn_embed_works() {
let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
response: String::new(),
});
let result = provider.embed("hello").await.unwrap();
assert_eq!(result, vec![0.1_f32, 0.2, 0.3]);
}
#[test]
fn dyn_sync_methods_forward_correctly() {
let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
response: String::new(),
});
assert_eq!(provider.name(), "stub");
assert!(!provider.supports_streaming());
assert!(!provider.supports_embeddings());
assert!(provider.context_window().is_none());
assert!(provider.last_cache_usage().is_none());
assert!(provider.last_usage().is_none());
}
#[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
struct TestOutput {
value: String,
}
#[tokio::test]
async fn chat_typed_dyn_happy_path() {
let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
response: r#"{"value": "hello"}"#.into(),
});
let msgs = vec![Message::from_legacy(Role::User, "test")];
let result: TestOutput = chat_typed_dyn(&*provider, &msgs).await.unwrap();
assert_eq!(
result,
TestOutput {
value: "hello".into()
}
);
}
#[tokio::test]
async fn chat_typed_dyn_strips_fences() {
let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
response: "```json\n{\"value\": \"fenced\"}\n```".into(),
});
let msgs = vec![Message::from_legacy(Role::User, "test")];
let result: TestOutput = chat_typed_dyn(&*provider, &msgs).await.unwrap();
assert_eq!(
result,
TestOutput {
value: "fenced".into()
}
);
}
}