1use futures::future::BoxFuture;
52use serde::de::DeserializeOwned;
53
54use crate::error::LlmError;
55use crate::provider::{
56 ChatExtras, ChatResponse, ChatStream, LlmProvider, Message, Role, ToolDefinition,
57 cached_schema, short_type_name,
58};
59
60mod private {
61 pub trait Sealed {}
62 impl<T: super::LlmProvider> Sealed for T {}
63}
64
65pub trait LlmProviderDyn: private::Sealed + std::fmt::Debug + Send + Sync {
74 fn context_window(&self) -> Option<usize>;
76
77 fn chat<'a>(&'a self, messages: &'a [Message]) -> BoxFuture<'a, Result<String, LlmError>>;
83
84 fn chat_stream<'a>(
90 &'a self,
91 messages: &'a [Message],
92 ) -> BoxFuture<'a, Result<ChatStream, LlmError>>;
93
94 fn supports_streaming(&self) -> bool;
96
97 fn embed<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Vec<f32>, LlmError>>;
103
104 fn embed_batch<'a>(
110 &'a self,
111 texts: &'a [&'a str],
112 ) -> BoxFuture<'a, Result<Vec<Vec<f32>>, LlmError>>;
113
114 fn supports_embeddings(&self) -> bool;
116
117 fn name(&self) -> &str;
119
120 fn model_identifier(&self) -> &str;
122
123 fn supports_vision(&self) -> bool;
125
126 fn supports_tool_use(&self) -> bool;
128
129 fn chat_with_tools<'a>(
135 &'a self,
136 messages: &'a [Message],
137 tools: &'a [ToolDefinition],
138 ) -> BoxFuture<'a, Result<ChatResponse, LlmError>>;
139
140 fn last_cache_usage(&self) -> Option<(u64, u64)>;
143
144 fn last_usage(&self) -> Option<(u64, u64)>;
147
148 fn last_reasoning_tokens(&self) -> Option<u64> {
153 None
154 }
155
156 fn take_compaction_summary(&self) -> Option<String>;
158
159 fn chat_with_extras<'a>(
165 &'a self,
166 messages: &'a [Message],
167 ) -> BoxFuture<'a, Result<(String, ChatExtras), LlmError>>;
168
169 #[must_use]
171 fn debug_request_json(
172 &self,
173 messages: &[Message],
174 tools: &[ToolDefinition],
175 stream: bool,
176 ) -> serde_json::Value;
177
178 fn list_models(&self) -> Vec<String>;
180
181 fn supports_structured_output(&self) -> bool;
183}
184
185impl<T: LlmProvider + std::fmt::Debug + Send + Sync + 'static> LlmProviderDyn for T {
186 fn context_window(&self) -> Option<usize> {
187 LlmProvider::context_window(self)
188 }
189
190 fn chat<'a>(&'a self, messages: &'a [Message]) -> BoxFuture<'a, Result<String, LlmError>> {
191 Box::pin(LlmProvider::chat(self, messages))
192 }
193
194 fn chat_stream<'a>(
195 &'a self,
196 messages: &'a [Message],
197 ) -> BoxFuture<'a, Result<ChatStream, LlmError>> {
198 Box::pin(LlmProvider::chat_stream(self, messages))
199 }
200
201 fn supports_streaming(&self) -> bool {
202 LlmProvider::supports_streaming(self)
203 }
204
205 fn embed<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Vec<f32>, LlmError>> {
206 Box::pin(LlmProvider::embed(self, text))
207 }
208
209 fn embed_batch<'a>(
210 &'a self,
211 texts: &'a [&'a str],
212 ) -> BoxFuture<'a, Result<Vec<Vec<f32>>, LlmError>> {
213 Box::pin(LlmProvider::embed_batch(self, texts))
214 }
215
216 fn supports_embeddings(&self) -> bool {
217 LlmProvider::supports_embeddings(self)
218 }
219
220 fn name(&self) -> &str {
221 LlmProvider::name(self)
222 }
223
224 fn model_identifier(&self) -> &str {
225 LlmProvider::model_identifier(self)
226 }
227
228 fn supports_vision(&self) -> bool {
229 LlmProvider::supports_vision(self)
230 }
231
232 fn supports_tool_use(&self) -> bool {
233 LlmProvider::supports_tool_use(self)
234 }
235
236 fn chat_with_tools<'a>(
237 &'a self,
238 messages: &'a [Message],
239 tools: &'a [ToolDefinition],
240 ) -> BoxFuture<'a, Result<ChatResponse, LlmError>> {
241 Box::pin(LlmProvider::chat_with_tools(self, messages, tools))
242 }
243
244 fn last_cache_usage(&self) -> Option<(u64, u64)> {
245 LlmProvider::last_cache_usage(self)
246 }
247
248 fn last_usage(&self) -> Option<(u64, u64)> {
249 LlmProvider::last_usage(self)
250 }
251
252 fn take_compaction_summary(&self) -> Option<String> {
253 LlmProvider::take_compaction_summary(self)
254 }
255
256 fn chat_with_extras<'a>(
257 &'a self,
258 messages: &'a [Message],
259 ) -> BoxFuture<'a, Result<(String, ChatExtras), LlmError>> {
260 Box::pin(LlmProvider::chat_with_extras(self, messages))
261 }
262
263 fn debug_request_json(
264 &self,
265 messages: &[Message],
266 tools: &[ToolDefinition],
267 stream: bool,
268 ) -> serde_json::Value {
269 LlmProvider::debug_request_json(self, messages, tools, stream)
270 }
271
272 fn list_models(&self) -> Vec<String> {
273 LlmProvider::list_models(self)
274 }
275
276 fn supports_structured_output(&self) -> bool {
277 LlmProvider::supports_structured_output(self)
278 }
279}
280
281pub async fn chat_typed_dyn<T, P>(provider: &P, messages: &[Message]) -> Result<T, LlmError>
325where
326 T: DeserializeOwned + schemars::JsonSchema + 'static,
327 P: ?Sized + LlmProviderDyn,
328{
329 let (_, schema_json) = cached_schema::<T>()?;
330 let type_name = short_type_name::<T>();
331
332 let instruction = format!(
333 "Respond with a valid JSON object matching this schema. \
334 Output ONLY the JSON, no markdown fences or extra text.\n\n\
335 Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
336 );
337
338 let mut augmented = messages.to_vec();
339 augmented.insert(0, Message::from_legacy(Role::System, instruction));
340
341 let raw = provider.chat(&augmented).await?;
342 let cleaned = strip_json_fences(&raw);
343 match serde_json::from_str::<T>(cleaned) {
344 Ok(val) => Ok(val),
345 Err(first_err) => {
346 augmented.push(Message::from_legacy(Role::Assistant, &raw));
347 augmented.push(Message::from_legacy(
348 Role::User,
349 format!(
350 "Your response was not valid JSON. Error: {first_err}. \
351 Please output ONLY valid JSON matching the schema."
352 ),
353 ));
354 let retry_raw = provider.chat(&augmented).await?;
355 let retry_cleaned = strip_json_fences(&retry_raw);
356 serde_json::from_str::<T>(retry_cleaned)
357 .map_err(|e| LlmError::StructuredParse(format!("parse failed after retry: {e}")))
358 }
359 }
360}
361
362fn strip_json_fences(s: &str) -> &str {
364 s.trim()
365 .trim_start_matches("```json")
366 .trim_start_matches("```")
367 .trim_end_matches("```")
368 .trim()
369}
370
371#[cfg(test)]
372mod tests {
373 use std::sync::Arc;
374
375 use super::*;
376 use crate::provider::{ChatStream, StreamChunk};
377
378 #[derive(Debug)]
379 struct StubProvider {
380 response: String,
381 }
382
383 impl LlmProvider for StubProvider {
384 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
385 Ok(self.response.clone())
386 }
387
388 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
389 let response = LlmProvider::chat(self, messages).await?;
390 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
391 response,
392 )))))
393 }
394
395 fn supports_streaming(&self) -> bool {
396 false
397 }
398
399 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
400 Ok(vec![0.1, 0.2, 0.3])
401 }
402
403 fn supports_embeddings(&self) -> bool {
404 false
405 }
406
407 fn name(&self) -> &'static str {
408 "stub"
409 }
410 }
411
412 #[tokio::test]
413 async fn dyn_chat_works() {
414 let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
415 response: "hello".into(),
416 });
417 let msgs = vec![Message::from_legacy(Role::User, "test")];
418 let result = provider.chat(&msgs).await.unwrap();
419 assert_eq!(result, "hello");
420 }
421
422 #[tokio::test]
423 async fn dyn_embed_works() {
424 let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
425 response: String::new(),
426 });
427 let result = provider.embed("hello").await.unwrap();
428 assert_eq!(result, vec![0.1_f32, 0.2, 0.3]);
429 }
430
431 #[test]
432 fn dyn_sync_methods_forward_correctly() {
433 let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
434 response: String::new(),
435 });
436 assert_eq!(provider.name(), "stub");
437 assert!(!provider.supports_streaming());
438 assert!(!provider.supports_embeddings());
439 assert!(provider.context_window().is_none());
440 assert!(provider.last_cache_usage().is_none());
441 assert!(provider.last_usage().is_none());
442 }
443
444 #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
445 struct TestOutput {
446 value: String,
447 }
448
449 #[tokio::test]
450 async fn chat_typed_dyn_happy_path() {
451 let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
452 response: r#"{"value": "hello"}"#.into(),
453 });
454 let msgs = vec![Message::from_legacy(Role::User, "test")];
455 let result: TestOutput = chat_typed_dyn(&*provider, &msgs).await.unwrap();
456 assert_eq!(
457 result,
458 TestOutput {
459 value: "hello".into()
460 }
461 );
462 }
463
464 #[tokio::test]
465 async fn chat_typed_dyn_strips_fences() {
466 let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
467 response: "```json\n{\"value\": \"fenced\"}\n```".into(),
468 });
469 let msgs = vec![Message::from_legacy(Role::User, "test")];
470 let result: TestOutput = chat_typed_dyn(&*provider, &msgs).await.unwrap();
471 assert_eq!(
472 result,
473 TestOutput {
474 value: "fenced".into()
475 }
476 );
477 }
478}