Skip to main content

camel_component_llm/provider/
mod.rs

1//! LlmProvider trait and request/response types.
2//!
3//! These types are Camel-shaped (not siumai-shaped).
4//! The siumai adapter translates between these and siumai's types.
5
6pub mod mock;
7#[cfg(any(feature = "openai", feature = "ollama", feature = "all-providers"))]
8pub mod siumai_adapter;
9
10use async_trait::async_trait;
11use futures::stream::BoxStream;
12
13use crate::error::LlmError;
14
15/// Provider abstraction for LLM chat and embedding operations.
16///
17/// Camel-shaped trait (not siumai-shaped). The siumai adapter
18/// translates between these types and siumai's types.
19#[async_trait]
20pub trait LlmProvider: Send + Sync {
21    /// Unique provider identifier (e.g., "openai", "ollama").
22    fn id(&self) -> &str;
23
24    /// Default model to use when none is specified in a request.
25    fn default_model(&self) -> &str;
26
27    /// Stream chat completions for the given request.
28    fn chat_stream(&self, req: ChatRequest) -> BoxStream<'static, Result<ChatEvent, LlmError>>;
29
30    /// Generate embeddings for the given inputs.
31    async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, LlmError>;
32
33    /// Whether this provider supports embedding operations.
34    fn supports_embed(&self) -> bool {
35        true
36    }
37}
38
39/// An event in a chat completion stream.
40#[derive(Debug, Clone)]
41#[non_exhaustive]
42pub enum ChatEvent {
43    /// A partial text delta from the stream.
44    Delta {
45        /// The text chunk received.
46        text: String,
47    },
48    /// The stream has finished.
49    Finished {
50        /// Usage statistics, if available.
51        usage: Option<LlmUsage>,
52        /// The model that generated the response.
53        model: Option<String>,
54        /// Why the stream finished.
55        finish_reason: Option<FinishReason>,
56        /// Provider-specific metadata.
57        metadata: serde_json::Map<String, serde_json::Value>,
58    },
59    /// An intermediate tool call emitted during streaming.
60    ToolCall {
61        /// Tool call identifier.
62        id: String,
63        /// Name of the tool to invoke.
64        name: String,
65        /// JSON arguments for the tool.
66        arguments: String,
67    },
68}
69
70/// Token usage statistics for an LLM operation.
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
72pub struct LlmUsage {
73    /// Tokens in the prompt.
74    pub prompt_tokens: u32,
75    /// Tokens in the completion.
76    pub completion_tokens: u32,
77    /// Total tokens used.
78    pub total_tokens: u32,
79}
80
81/// Reason why a chat stream finished.
82#[derive(Debug, Clone)]
83pub enum FinishReason {
84    /// Model stopped naturally.
85    Stop,
86    /// Response was truncated due to max tokens.
87    Length,
88    /// Model requested a tool call.
89    ToolCall,
90    /// Response was filtered by content policy.
91    ContentFilter,
92    /// Other reason.
93    Other(String),
94}
95
96/// Definition of a tool the model may call.
97#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
98pub struct ToolDefinition {
99    /// Name of the tool (e.g., "get_weather").
100    pub name: String,
101    /// Description of what the tool does.
102    pub description: String,
103    /// JSON Schema parameters for the tool.
104    pub parameters: serde_json::Map<String, serde_json::Value>,
105}
106
107/// Controls how the model chooses which tool to call.
108#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
109pub enum ToolChoice {
110    /// Model may decide to call zero or more tools.
111    Auto,
112    /// Model must not call any tool.
113    None,
114    /// Model must call the named tool.
115    Specific(String),
116}
117
118/// A tool call emitted by the model as part of an assistant message.
119#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
120pub struct EmittedToolCall {
121    /// Unique identifier for this tool call.
122    pub id: String,
123    /// Name of the tool to invoke.
124    pub name: String,
125    /// JSON string of arguments for the tool.
126    pub arguments: String,
127}
128
129/// A request to a chat completion model.
130#[derive(Debug, Clone)]
131pub struct ChatRequest {
132    /// Model identifier (e.g., "gpt-4o", "llama-3-70b").
133    pub model: String,
134    /// Conversation messages.
135    pub messages: Vec<ChatMessage>,
136    /// Sampling temperature (0.0–2.0).
137    pub temperature: Option<f64>,
138    /// Maximum tokens to generate.
139    pub max_tokens: Option<u32>,
140    /// Sequences that stop generation.
141    pub stop: Option<Vec<String>>,
142    /// System prompt override.
143    pub system_prompt: Option<String>,
144    /// Tools available for the model to call.
145    pub tools: Vec<ToolDefinition>,
146    /// Controls tool selection behaviour.
147    pub tool_choice: Option<ToolChoice>,
148    /// Additional provider-specific parameters.
149    pub extra: serde_json::Map<String, serde_json::Value>,
150}
151
152impl ChatRequest {
153    /// Create a new chat request with the given model and messages.
154    pub fn new(model: impl Into<String>, messages: Vec<ChatMessage>) -> Self {
155        Self {
156            model: model.into(),
157            messages,
158            temperature: None,
159            max_tokens: None,
160            stop: None,
161            system_prompt: None,
162            tools: Vec::new(),
163            tool_choice: None,
164            extra: serde_json::Map::new(),
165        }
166    }
167}
168
169/// A single message in a chat conversation.
170#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
171pub struct ChatMessage {
172    /// Role of the message author.
173    pub role: ChatRole,
174    /// Content of the message.
175    pub content: String,
176    /// Tool calls made by the assistant, if any.
177    pub tool_calls: Option<Vec<EmittedToolCall>>,
178}
179
180impl ChatMessage {
181    /// Create a new user message.
182    pub fn user(content: impl Into<String>) -> Self {
183        Self {
184            role: ChatRole::User,
185            content: content.into(),
186            tool_calls: None,
187        }
188    }
189}
190
191/// Role of a chat message author.
192#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
193#[non_exhaustive]
194pub enum ChatRole {
195    /// System instruction.
196    System,
197    /// User message.
198    User,
199    /// Assistant response.
200    Assistant,
201    /// Tool result message carrying the output of a tool call.
202    Tool {
203        /// The identifier of the tool call this result belongs to.
204        tool_call_id: String,
205    },
206}
207
208/// A request to generate embeddings.
209#[derive(Debug, Clone)]
210pub struct EmbedRequest {
211    /// Model identifier.
212    pub model: String,
213    /// Input texts to embed.
214    pub inputs: Vec<String>,
215    /// Additional provider-specific parameters.
216    pub extra: serde_json::Map<String, serde_json::Value>,
217}
218
219impl EmbedRequest {
220    /// Create a new embedding request.
221    pub fn new(model: impl Into<String>, inputs: Vec<String>) -> Self {
222        Self {
223            model: model.into(),
224            inputs,
225            extra: serde_json::Map::new(),
226        }
227    }
228}
229
230/// Response from an embedding operation.
231#[derive(Debug, Clone)]
232pub struct EmbedResponse {
233    /// Generated embedding vectors.
234    pub embeddings: Vec<Vec<f32>>,
235    /// Usage statistics, if available.
236    pub usage: Option<LlmUsage>,
237    /// Model used.
238    pub model: String,
239    /// Provider-specific metadata.
240    pub metadata: serde_json::Map<String, serde_json::Value>,
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn chat_request_builder() {
249        let req = ChatRequest::new("gpt-4o", vec![ChatMessage::user("hello")]);
250        assert_eq!(req.model, "gpt-4o");
251        assert_eq!(req.messages.len(), 1);
252        assert_eq!(req.messages[0].role, ChatRole::User);
253    }
254
255    #[test]
256    fn chat_request_accepts_tools() {
257        let tool = ToolDefinition {
258            name: "get_weather".into(),
259            description: "Get weather for a city".into(),
260            parameters: serde_json::Map::new(),
261        };
262        let req = ChatRequest {
263            model: "gpt-4o".into(),
264            messages: vec![ChatMessage::user("what's the weather?")],
265            temperature: None,
266            max_tokens: None,
267            stop: None,
268            system_prompt: None,
269            extra: serde_json::Map::new(),
270            tools: vec![tool],
271            tool_choice: Some(ToolChoice::Auto),
272        };
273        assert_eq!(req.tools.len(), 1);
274        assert_eq!(req.tools[0].name, "get_weather");
275        assert_eq!(req.tool_choice, Some(ToolChoice::Auto));
276    }
277
278    #[test]
279    fn tool_message_carries_tool_call_id() {
280        let msg = ChatMessage {
281            role: ChatRole::Tool {
282                tool_call_id: "call_123".into(),
283            },
284            content: "weather result".into(),
285            tool_calls: None,
286        };
287        match &msg.role {
288            ChatRole::Tool { tool_call_id } => assert_eq!(tool_call_id, "call_123"),
289            _ => panic!("expected Tool role"),
290        }
291    }
292
293    #[test]
294    fn assistant_message_carries_prior_tool_calls() {
295        let tool_call = EmittedToolCall {
296            id: "call_123".into(),
297            name: "get_weather".into(),
298            arguments: r#"{"city":"London"}"#.into(),
299        };
300        let msg = ChatMessage {
301            role: ChatRole::Assistant,
302            content: "I'll check the weather".into(),
303            tool_calls: Some(vec![tool_call]),
304        };
305        assert_eq!(msg.tool_calls.as_ref().unwrap().len(), 1);
306        assert_eq!(msg.tool_calls.as_ref().unwrap()[0].id, "call_123");
307    }
308}