Skip to main content

juncture_core/
llm.rs

1// LLM integration types and traits
2//
3// This module provides the foundational abstractions for LLM integration,
4// including the `ChatModel` trait, message types, and related configurations.
5//
6// # Design Principles
7//
8// - Unified abstraction: Single trait covering all LLM providers
9// - Streaming-first: Both invoke and stream are first-class operations
10// - Type-safe: Leverages Rust's type system for message and tool handling
11// - Zero-cost: Abstractions don't add runtime overhead
12
13use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15
16use crate::state::{Message, Role};
17
18/// Re-export `BoxStream` for use in `ChatModel` trait
19pub use futures::stream::BoxStream;
20
21/// LLM invocation error types
22#[derive(Debug, thiserror::Error)]
23pub enum LlmError {
24    /// Authentication failed
25    #[error("authentication failed: {0}")]
26    AuthError(String),
27
28    /// Rate limited with optional retry-after duration
29    #[error("rate limited, retry after {retry_after:?}")]
30    RateLimited {
31        /// Optional duration to wait before retrying
32        retry_after: Option<std::time::Duration>,
33    },
34
35    /// Context length exceeded
36    #[error("context length exceeded: {used} tokens used, {limit} limit")]
37    ContextLengthExceeded {
38        /// Tokens used in request
39        used: u64,
40        /// Model's context window limit
41        limit: u64,
42    },
43
44    /// Network error during HTTP request
45    #[error("network error: {0}")]
46    NetworkError(String),
47
48    /// Invalid response from LLM provider
49    #[error("invalid response: {0}")]
50    InvalidResponse(String),
51
52    /// Requested model not found
53    #[error("model not found: {0}")]
54    ModelNotFound(String),
55
56    /// Content was filtered by provider
57    #[error("content filtered")]
58    ContentFiltered,
59
60    /// Request timeout
61    #[error("timeout after {0:?}")]
62    Timeout(std::time::Duration),
63
64    /// Other errors
65    #[error("llm error: {0}")]
66    Other(#[source] Box<dyn std::error::Error + Send + Sync>),
67}
68
69/// Options for LLM invocations
70///
71/// These options override default settings on the `ChatModel` instance
72/// for a single invocation.
73#[derive(Clone, Debug, Default)]
74pub struct CallOptions {
75    /// Sampling temperature (0.0 to 1.0)
76    pub temperature: Option<f32>,
77
78    /// Maximum tokens to generate
79    pub max_tokens: Option<u32>,
80
81    /// Sequences that will stop generation
82    pub stop_sequences: Option<Vec<String>>,
83
84    /// Nucleus sampling threshold (0.0 to 1.0)
85    pub top_p: Option<f32>,
86
87    /// Override the model name for this call
88    pub model_override: Option<String>,
89
90    /// Tool selection strategy
91    pub tool_choice: Option<ToolChoice>,
92
93    /// Response format for structured output
94    pub response_format: Option<ResponseFormat>,
95
96    /// Tags for streaming metadata and filtering.
97    ///
98    /// Tags are propagated into stream events as [`MessageStreamMetadata::tags`].
99    /// The `"nostream"` tag causes [`EventEmitter::should_emit`] to suppress
100    /// streaming events for this call.
101    pub tags: Vec<String>,
102}
103
104/// Tool selection strategy
105#[derive(Clone, Debug)]
106pub enum ToolChoice {
107    /// Automatically decide whether to call tools
108    Auto,
109    /// Do not call any tools
110    None,
111    /// Must call at least one tool
112    Required,
113    /// Must call the specified tool
114    Specific {
115        /// Name of the tool to call
116        name: String,
117    },
118}
119
120/// Response format for structured output
121#[derive(Clone, Debug)]
122pub enum ResponseFormat {
123    /// JSON object (model outputs valid JSON)
124    JsonObject,
125    /// JSON Schema with strict validation
126    JsonSchema {
127        /// Name of the schema
128        name: String,
129        /// JSON Schema definition
130        schema: serde_json::Value,
131        /// Whether to use strict mode
132        strict: bool,
133    },
134}
135
136/// Tool definition for function calling
137#[derive(Clone, Debug, Serialize, Deserialize)]
138pub struct ToolDefinition {
139    /// Tool name
140    pub name: String,
141    /// Tool description
142    pub description: String,
143    /// JSON Schema for parameters
144    pub parameters: serde_json::Value,
145}
146
147/// Streaming message chunk
148///
149/// Represents incremental data from streaming LLM responses.
150/// Chunks must be accumulated to reconstruct the complete message.
151#[derive(Clone, Debug)]
152pub struct MessageChunk {
153    /// Message role (may be empty in early chunks)
154    pub role: Option<Role>,
155    /// Text content delta
156    pub content: String,
157    /// Tool call chunks (using `args_delta` field name from stream module)
158    pub tool_call_chunks: Vec<ToolCallChunk>,
159    /// Token usage (only in final chunk)
160    pub usage: Option<crate::state::TokenUsage>,
161}
162
163/// Streaming tool call chunk
164///
165/// Re-exported from `crate::stream` for LLM integration.
166/// Note: This struct uses `args_delta` as the field name (not `arguments`).
167/// Use the stream module's version for consistency.
168pub use crate::stream::ToolCallChunk;
169
170/// Unified `ChatModel` trait for all LLM providers
171///
172/// This trait provides a common interface for interacting with different
173/// LLM providers (`Anthropic`, `OpenAI`, `Ollama`, etc.).
174///
175/// # Type Parameters
176///
177/// * `'a` - Lifetime for borrowed data in streaming
178#[cfg_attr(target_family = "wasm", async_trait(?Send))]
179#[cfg_attr(not(target_family = "wasm"), async_trait)]
180pub trait ChatModel: Send + Sync + Clone + 'static {
181    /// Invoke the model with messages
182    ///
183    /// # Arguments
184    ///
185    /// * `messages` - Conversation history
186    /// * `options` - Optional call settings to override defaults
187    ///
188    /// # Returns
189    ///
190    /// The model's response as a complete message
191    async fn invoke(
192        &self,
193        messages: &[Message],
194        options: Option<&CallOptions>,
195    ) -> Result<Message, LlmError>;
196
197    /// Stream the model's response
198    ///
199    /// # Arguments
200    ///
201    /// * `messages` - Conversation history
202    /// * `options` - Optional call settings to override defaults
203    ///
204    /// # Returns
205    ///
206    /// A stream of message chunks that must be accumulated
207    async fn stream(
208        &self,
209        messages: &[Message],
210        options: Option<&CallOptions>,
211    ) -> Result<BoxStream<'_, Result<MessageChunk, LlmError>>, LlmError>;
212
213    /// Bind tools to this model instance
214    ///
215    /// Returns a new instance with the tools registered for function calling.
216    ///
217    /// # Arguments
218    ///
219    /// * `tools` - List of tool definitions
220    #[must_use]
221    fn bind_tools(&self, tools: Vec<ToolDefinition>) -> Self;
222
223    /// Convert to structured output model
224    ///
225    /// Returns a wrapper that forces the model to output structured JSON
226    /// matching type T's schema.
227    ///
228    /// # Type Parameters
229    ///
230    /// * `T` - Target type with JSON Schema support
231    #[must_use]
232    fn with_structured_output<T: JsonSchema + DeserializeOwned + Serialize>(
233        self,
234    ) -> StructuredOutputModel<Self, T>
235    where
236        Self: Sized;
237
238    /// Get the model name
239    fn model_name(&self) -> &str;
240}
241
242/// Trait for types with JSON Schema support
243pub trait JsonSchema: schemars::JsonSchema {}
244
245/// Blanket implementation for all `schemars::JsonSchema` types
246impl<T: schemars::JsonSchema> JsonSchema for T {}
247
248/// Marker for deserializable types
249pub trait DeserializeOwned: for<'de> Deserialize<'de> {}
250
251/// Blanket implementation for all deserializable types
252impl<T: for<'de> Deserialize<'de>> DeserializeOwned for T {}
253
254/// Wrapper for structured output from LLMs
255///
256/// Uses function calling to force the model to output JSON matching
257/// the schema of type T.
258pub struct StructuredOutputModel<M, T>
259where
260    M: Clone,
261{
262    /// Inner model
263    pub(crate) inner: M,
264    /// Phantom data for target type
265    pub(crate) _phantom: std::marker::PhantomData<T>,
266}
267
268impl<M: Clone, T> Clone for StructuredOutputModel<M, T> {
269    fn clone(&self) -> Self {
270        Self {
271            inner: self.inner.clone(),
272            _phantom: std::marker::PhantomData,
273        }
274    }
275}
276
277impl<M, T> std::fmt::Debug for StructuredOutputModel<M, T>
278where
279    M: Clone,
280{
281    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282        f.debug_struct("StructuredOutputModel")
283            .field("inner", &"<model>")
284            .field("_phantom", &self._phantom)
285            .finish()
286    }
287}
288
289#[cfg_attr(target_family = "wasm", async_trait(?Send))]
290#[cfg_attr(not(target_family = "wasm"), async_trait)]
291impl<M, T> ChatModel for StructuredOutputModel<M, T>
292where
293    M: ChatModel,
294    T: JsonSchema + DeserializeOwned + Serialize + Send + Sync + 'static,
295{
296    async fn invoke(
297        &self,
298        messages: &[Message],
299        options: Option<&CallOptions>,
300    ) -> Result<Message, LlmError> {
301        // Create a virtual tool with T's schema
302        let schema = schemars::schema_for!(T);
303        let tool_def = ToolDefinition {
304            name: "structured_output".to_string(),
305            description: "Output structured data".to_string(),
306            parameters: serde_json::to_value(schema)
307                .map_err(|e| LlmError::InvalidResponse(e.to_string()))?,
308        };
309
310        // Force tool usage
311        #[allow(
312            clippy::manual_unwrap_or_default,
313            clippy::option_if_let_else,
314            reason = "project rules prohibit unwrap_or_default; match is explicit and readable"
315        )]
316        let mut opts = match options.cloned() {
317            Some(opts) => opts,
318            None => CallOptions::default(),
319        };
320        opts.tool_choice = Some(ToolChoice::Required);
321
322        // Call inner model with tool bound
323        let model_with_tool = self.inner.bind_tools(vec![tool_def]);
324        let response = model_with_tool.invoke(messages, Some(&opts)).await?;
325
326        // Extract tool call arguments and parse as T
327        if let Some(tool_call) = response.tool_calls.first() {
328            let _value: T = serde_json::from_value(tool_call.arguments.clone()).map_err(|e| {
329                LlmError::InvalidResponse(format!("Failed to parse structured output: {e}"))
330            })?;
331
332            // Return as JSON string in content
333            Ok(Message {
334                id: response.id,
335                role: Role::Ai,
336                content: crate::state::Content::Text(serde_json::to_string(&_value).map_err(
337                    |e| {
338                        LlmError::InvalidResponse(format!(
339                            "Failed to serialize structured output: {e}"
340                        ))
341                    },
342                )?),
343                tool_calls: vec![],
344                tool_call_id: None,
345                name: None,
346                usage: response.usage,
347            })
348        } else {
349            Err(LlmError::InvalidResponse(
350                "No tool call in response".to_string(),
351            ))
352        }
353    }
354
355    async fn stream(
356        &self,
357        messages: &[Message],
358        options: Option<&CallOptions>,
359    ) -> Result<BoxStream<'_, Result<MessageChunk, LlmError>>, LlmError> {
360        self.inner.stream(messages, options).await
361    }
362
363    fn bind_tools(&self, tools: Vec<ToolDefinition>) -> Self {
364        Self {
365            inner: self.inner.bind_tools(tools),
366            _phantom: std::marker::PhantomData,
367        }
368    }
369
370    fn with_structured_output<U: JsonSchema + DeserializeOwned + Serialize>(
371        self,
372    ) -> StructuredOutputModel<Self, U>
373    where
374        Self: Sized,
375    {
376        StructuredOutputModel {
377            inner: self,
378            _phantom: std::marker::PhantomData,
379        }
380    }
381
382    fn model_name(&self) -> &str {
383        self.inner.model_name()
384    }
385}
386
387// Rust guideline compliant 2026-05-20