Skip to main content

zeph_llm/
provider_dyn.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Object-safe adapter for [`LlmProvider`].
5//!
6//! [`LlmProviderDyn`] mirrors every method of [`LlmProvider`] but returns
7//! [`BoxFuture`] instead of `impl Future + Send`. A blanket implementation
8//! over any `T: LlmProvider + Send + Sync + 'static` wires the two traits
9//! together automatically.
10//!
11//! ## Usage
12//!
13//! Use [`LlmProvider`] as the *implementation* surface (concrete types, monomorphic
14//! call sites). Use `Arc<dyn LlmProviderDyn>` as the *storage* type wherever runtime
15//! polymorphism is required (router, cascade, dependency injection).
16//!
17//! Implementors never need to implement [`LlmProviderDyn`] directly — the blanket impl
18//! handles it. Implement [`LlmProvider`] instead.
19//!
20//! ## Generic methods
21//!
22//! `LlmProvider::chat_typed<T: DeserializeOwned>` cannot be part of a dyn-safe trait
23//! because it carries a generic type parameter. Use the free function
24//! [`chat_typed_dyn`] instead when working with `dyn LlmProviderDyn`.
25//!
26//! ## Examples
27//!
28//! ```rust,no_run
29//! use std::sync::Arc;
30//! use zeph_llm::provider::{LlmProvider, Message, Role};
31//! use zeph_llm::provider_dyn::LlmProviderDyn;
32//! use zeph_llm::ollama::OllamaProvider;
33//!
34//! # async fn example() -> Result<(), zeph_llm::LlmError> {
35//! let provider = OllamaProvider::new(
36//!     "http://localhost:11434",
37//!     "llama3.2".into(),
38//!     "nomic-embed-text".into(),
39//! );
40//!
41//! // Erase the concrete type for storage in a router or DI container.
42//! let dyn_provider: Arc<dyn LlmProviderDyn> = Arc::new(provider);
43//!
44//! let messages = vec![Message::from_legacy(Role::User, "Hello!")];
45//! let response = dyn_provider.chat(&messages).await?;
46//! println!("{response}");
47//! # Ok(())
48//! # }
49//! ```
50
51use futures::future::BoxFuture;
52use serde::de::DeserializeOwned;
53
54use crate::error::LlmError;
55use crate::provider::{
56    ChatExtras, ChatResponse, ChatStream, LlmProvider, Message, Role, ToolDefinition,
57    cached_schema, short_type_name,
58};
59
60mod private {
61    pub trait Sealed {}
62    impl<T: super::LlmProvider> Sealed for T {}
63}
64
65/// Object-safe shadow of [`LlmProvider`].
66///
67/// Sealed — only the blanket `impl<T: LlmProvider + Send + Sync + 'static>` exists.
68/// External crates cannot implement this trait directly; implement [`LlmProvider`] instead
69/// and the blanket impl wires everything up automatically.
70///
71/// All async methods return [`BoxFuture`] rather than `impl Future + Send`, making this
72/// trait dyn-compatible and usable behind `Arc<dyn LlmProviderDyn>`.
73pub trait LlmProviderDyn: private::Sealed + std::fmt::Debug + Send + Sync {
74    /// Report the model's context window size in tokens. `None` if unknown.
75    fn context_window(&self) -> Option<usize>;
76
77    /// Send messages to the LLM and return the assistant response.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if the provider fails to communicate or the response is invalid.
82    fn chat<'a>(&'a self, messages: &'a [Message]) -> BoxFuture<'a, Result<String, LlmError>>;
83
84    /// Send messages and return a stream of response chunks.
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if the provider fails to communicate or the response is invalid.
89    fn chat_stream<'a>(
90        &'a self,
91        messages: &'a [Message],
92    ) -> BoxFuture<'a, Result<ChatStream, LlmError>>;
93
94    /// Whether this provider supports native streaming.
95    fn supports_streaming(&self) -> bool;
96
97    /// Generate an embedding vector from text.
98    ///
99    /// # Errors
100    ///
101    /// Returns an error if the provider does not support embeddings or the request fails.
102    fn embed<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Vec<f32>, LlmError>>;
103
104    /// Embed multiple texts in a single API call.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if any embedding fails.
109    fn embed_batch<'a>(
110        &'a self,
111        texts: &'a [&'a str],
112    ) -> BoxFuture<'a, Result<Vec<Vec<f32>>, LlmError>>;
113
114    /// Whether this provider supports embedding generation.
115    fn supports_embeddings(&self) -> bool;
116
117    /// Provider name for logging and identification.
118    fn name(&self) -> &str;
119
120    /// Model identifier string (e.g. `gpt-4o-mini`, `claude-sonnet-4-6`).
121    fn model_identifier(&self) -> &str;
122
123    /// Whether this provider supports image input (vision).
124    fn supports_vision(&self) -> bool;
125
126    /// Whether this provider supports native `tool_use` / function calling.
127    fn supports_tool_use(&self) -> bool;
128
129    /// Send messages with tool definitions, returning a structured response.
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if the provider fails to communicate or the response is invalid.
134    fn chat_with_tools<'a>(
135        &'a self,
136        messages: &'a [Message],
137        tools: &'a [ToolDefinition],
138    ) -> BoxFuture<'a, Result<ChatResponse, LlmError>>;
139
140    /// Return the cache usage from the last API call, if available.
141    /// Returns `(cache_creation_tokens, cache_read_tokens)`.
142    fn last_cache_usage(&self) -> Option<(u64, u64)>;
143
144    /// Return token counts from the last API call, if available.
145    /// Returns `(input_tokens, output_tokens)`.
146    fn last_usage(&self) -> Option<(u64, u64)>;
147
148    /// Return reasoning tokens from the last API call, if the provider reports them.
149    ///
150    /// Reasoning tokens are a **subset** of completion tokens (`OpenAI` o-series only).
151    /// Returns `None` for providers that do not expose reasoning token counts.
152    fn last_reasoning_tokens(&self) -> Option<u64> {
153        None
154    }
155
156    /// Return the compaction summary from the most recent API call, if available.
157    fn take_compaction_summary(&self) -> Option<String>;
158
159    /// Send messages and return the assistant response together with per-call extras.
160    ///
161    /// # Errors
162    ///
163    /// Same as [`chat`](Self::chat).
164    fn chat_with_extras<'a>(
165        &'a self,
166        messages: &'a [Message],
167    ) -> BoxFuture<'a, Result<(String, ChatExtras), LlmError>>;
168
169    /// Return the request payload that will be sent to the provider, for debug dumps.
170    #[must_use]
171    fn debug_request_json(
172        &self,
173        messages: &[Message],
174        tools: &[ToolDefinition],
175        stream: bool,
176    ) -> serde_json::Value;
177
178    /// Return the list of model identifiers this provider can serve.
179    fn list_models(&self) -> Vec<String>;
180
181    /// Whether this provider supports native structured output.
182    fn supports_structured_output(&self) -> bool;
183}
184
185impl<T: LlmProvider + std::fmt::Debug + Send + Sync + 'static> LlmProviderDyn for T {
186    fn context_window(&self) -> Option<usize> {
187        LlmProvider::context_window(self)
188    }
189
190    fn chat<'a>(&'a self, messages: &'a [Message]) -> BoxFuture<'a, Result<String, LlmError>> {
191        Box::pin(LlmProvider::chat(self, messages))
192    }
193
194    fn chat_stream<'a>(
195        &'a self,
196        messages: &'a [Message],
197    ) -> BoxFuture<'a, Result<ChatStream, LlmError>> {
198        Box::pin(LlmProvider::chat_stream(self, messages))
199    }
200
201    fn supports_streaming(&self) -> bool {
202        LlmProvider::supports_streaming(self)
203    }
204
205    fn embed<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Vec<f32>, LlmError>> {
206        Box::pin(LlmProvider::embed(self, text))
207    }
208
209    fn embed_batch<'a>(
210        &'a self,
211        texts: &'a [&'a str],
212    ) -> BoxFuture<'a, Result<Vec<Vec<f32>>, LlmError>> {
213        Box::pin(LlmProvider::embed_batch(self, texts))
214    }
215
216    fn supports_embeddings(&self) -> bool {
217        LlmProvider::supports_embeddings(self)
218    }
219
220    fn name(&self) -> &str {
221        LlmProvider::name(self)
222    }
223
224    fn model_identifier(&self) -> &str {
225        LlmProvider::model_identifier(self)
226    }
227
228    fn supports_vision(&self) -> bool {
229        LlmProvider::supports_vision(self)
230    }
231
232    fn supports_tool_use(&self) -> bool {
233        LlmProvider::supports_tool_use(self)
234    }
235
236    fn chat_with_tools<'a>(
237        &'a self,
238        messages: &'a [Message],
239        tools: &'a [ToolDefinition],
240    ) -> BoxFuture<'a, Result<ChatResponse, LlmError>> {
241        Box::pin(LlmProvider::chat_with_tools(self, messages, tools))
242    }
243
244    fn last_cache_usage(&self) -> Option<(u64, u64)> {
245        LlmProvider::last_cache_usage(self)
246    }
247
248    fn last_usage(&self) -> Option<(u64, u64)> {
249        LlmProvider::last_usage(self)
250    }
251
252    fn take_compaction_summary(&self) -> Option<String> {
253        LlmProvider::take_compaction_summary(self)
254    }
255
256    fn chat_with_extras<'a>(
257        &'a self,
258        messages: &'a [Message],
259    ) -> BoxFuture<'a, Result<(String, ChatExtras), LlmError>> {
260        Box::pin(LlmProvider::chat_with_extras(self, messages))
261    }
262
263    fn debug_request_json(
264        &self,
265        messages: &[Message],
266        tools: &[ToolDefinition],
267        stream: bool,
268    ) -> serde_json::Value {
269        LlmProvider::debug_request_json(self, messages, tools, stream)
270    }
271
272    fn list_models(&self) -> Vec<String> {
273        LlmProvider::list_models(self)
274    }
275
276    fn supports_structured_output(&self) -> bool {
277        LlmProvider::supports_structured_output(self)
278    }
279}
280
281/// Send messages and parse the response into a typed value `T`.
282///
283/// This is the dyn-compatible equivalent of [`LlmProvider::chat_typed`]. Because
284/// `chat_typed` carries a generic type parameter, it cannot be part of a dyn-safe
285/// trait. Use this free function when working with `&dyn LlmProviderDyn` or
286/// `Arc<dyn LlmProviderDyn>`.
287///
288/// The default implementation injects the JSON schema into the system prompt and
289/// retries once on parse failure, matching the behaviour of the trait method.
290///
291/// # Errors
292///
293/// Returns [`LlmError::StructuredParse`] when the response cannot be parsed as `T`
294/// after one retry. Propagates any underlying [`LlmError`] from the provider.
295///
296/// # Examples
297///
298/// ```rust,no_run
299/// use std::sync::Arc;
300/// use schemars::JsonSchema;
301/// use serde::Deserialize;
302/// use zeph_llm::provider::{Message, Role};
303/// use zeph_llm::provider_dyn::{LlmProviderDyn, chat_typed_dyn};
304/// use zeph_llm::ollama::OllamaProvider;
305///
306/// #[derive(Debug, Deserialize, JsonSchema)]
307/// struct Answer {
308///     value: String,
309/// }
310///
311/// # async fn example() -> Result<(), zeph_llm::LlmError> {
312/// let provider = OllamaProvider::new(
313///     "http://localhost:11434",
314///     "llama3.2".into(),
315///     "nomic-embed-text".into(),
316/// );
317/// let dyn_provider: Arc<dyn LlmProviderDyn> = Arc::new(provider);
318/// let messages = vec![Message::from_legacy(Role::User, "What is 2+2?")];
319/// let answer: Answer = chat_typed_dyn(&*dyn_provider, &messages).await?;
320/// println!("{}", answer.value);
321/// # Ok(())
322/// # }
323/// ```
324pub async fn chat_typed_dyn<T, P>(provider: &P, messages: &[Message]) -> Result<T, LlmError>
325where
326    T: DeserializeOwned + schemars::JsonSchema + 'static,
327    P: ?Sized + LlmProviderDyn,
328{
329    let (_, schema_json) = cached_schema::<T>()?;
330    let type_name = short_type_name::<T>();
331
332    let instruction = format!(
333        "Respond with a valid JSON object matching this schema. \
334         Output ONLY the JSON, no markdown fences or extra text.\n\n\
335         Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
336    );
337
338    let mut augmented = messages.to_vec();
339    augmented.insert(0, Message::from_legacy(Role::System, instruction));
340
341    let raw = provider.chat(&augmented).await?;
342    let cleaned = strip_json_fences(&raw);
343    match serde_json::from_str::<T>(cleaned) {
344        Ok(val) => Ok(val),
345        Err(first_err) => {
346            augmented.push(Message::from_legacy(Role::Assistant, &raw));
347            augmented.push(Message::from_legacy(
348                Role::User,
349                format!(
350                    "Your response was not valid JSON. Error: {first_err}. \
351                     Please output ONLY valid JSON matching the schema."
352                ),
353            ));
354            let retry_raw = provider.chat(&augmented).await?;
355            let retry_cleaned = strip_json_fences(&retry_raw);
356            serde_json::from_str::<T>(retry_cleaned)
357                .map_err(|e| LlmError::StructuredParse(format!("parse failed after retry: {e}")))
358        }
359    }
360}
361
362/// Strip markdown code fences from LLM output.
363fn strip_json_fences(s: &str) -> &str {
364    s.trim()
365        .trim_start_matches("```json")
366        .trim_start_matches("```")
367        .trim_end_matches("```")
368        .trim()
369}
370
371#[cfg(test)]
372mod tests {
373    use std::sync::Arc;
374
375    use super::*;
376    use crate::provider::{ChatStream, StreamChunk};
377
378    #[derive(Debug)]
379    struct StubProvider {
380        response: String,
381    }
382
383    impl LlmProvider for StubProvider {
384        async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
385            Ok(self.response.clone())
386        }
387
388        async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
389            let response = LlmProvider::chat(self, messages).await?;
390            Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
391                response,
392            )))))
393        }
394
395        fn supports_streaming(&self) -> bool {
396            false
397        }
398
399        async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
400            Ok(vec![0.1, 0.2, 0.3])
401        }
402
403        fn supports_embeddings(&self) -> bool {
404            false
405        }
406
407        fn name(&self) -> &'static str {
408            "stub"
409        }
410    }
411
412    #[tokio::test]
413    async fn dyn_chat_works() {
414        let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
415            response: "hello".into(),
416        });
417        let msgs = vec![Message::from_legacy(Role::User, "test")];
418        let result = provider.chat(&msgs).await.unwrap();
419        assert_eq!(result, "hello");
420    }
421
422    #[tokio::test]
423    async fn dyn_embed_works() {
424        let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
425            response: String::new(),
426        });
427        let result = provider.embed("hello").await.unwrap();
428        assert_eq!(result, vec![0.1_f32, 0.2, 0.3]);
429    }
430
431    #[test]
432    fn dyn_sync_methods_forward_correctly() {
433        let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
434            response: String::new(),
435        });
436        assert_eq!(provider.name(), "stub");
437        assert!(!provider.supports_streaming());
438        assert!(!provider.supports_embeddings());
439        assert!(provider.context_window().is_none());
440        assert!(provider.last_cache_usage().is_none());
441        assert!(provider.last_usage().is_none());
442    }
443
444    #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
445    struct TestOutput {
446        value: String,
447    }
448
449    #[tokio::test]
450    async fn chat_typed_dyn_happy_path() {
451        let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
452            response: r#"{"value": "hello"}"#.into(),
453        });
454        let msgs = vec![Message::from_legacy(Role::User, "test")];
455        let result: TestOutput = chat_typed_dyn(&*provider, &msgs).await.unwrap();
456        assert_eq!(
457            result,
458            TestOutput {
459                value: "hello".into()
460            }
461        );
462    }
463
464    #[tokio::test]
465    async fn chat_typed_dyn_strips_fences() {
466        let provider: Arc<dyn LlmProviderDyn> = Arc::new(StubProvider {
467            response: "```json\n{\"value\": \"fenced\"}\n```".into(),
468        });
469        let msgs = vec![Message::from_legacy(Role::User, "test")];
470        let result: TestOutput = chat_typed_dyn(&*provider, &msgs).await.unwrap();
471        assert_eq!(
472            result,
473            TestOutput {
474                value: "fenced".into()
475            }
476        );
477    }
478}