1use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15
16use crate::state::{Message, Role};
17
18pub use futures::stream::BoxStream;
20
21#[derive(Debug, thiserror::Error)]
23pub enum LlmError {
24 #[error("authentication failed: {0}")]
26 AuthError(String),
27
28 #[error("rate limited, retry after {retry_after:?}")]
30 RateLimited {
31 retry_after: Option<std::time::Duration>,
33 },
34
35 #[error("context length exceeded: {used} tokens used, {limit} limit")]
37 ContextLengthExceeded {
38 used: u64,
40 limit: u64,
42 },
43
44 #[error("network error: {0}")]
46 NetworkError(String),
47
48 #[error("invalid response: {0}")]
50 InvalidResponse(String),
51
52 #[error("model not found: {0}")]
54 ModelNotFound(String),
55
56 #[error("content filtered")]
58 ContentFiltered,
59
60 #[error("timeout after {0:?}")]
62 Timeout(std::time::Duration),
63
64 #[error("llm error: {0}")]
66 Other(#[source] Box<dyn std::error::Error + Send + Sync>),
67}
68
69#[derive(Clone, Debug, Default)]
74pub struct CallOptions {
75 pub temperature: Option<f32>,
77
78 pub max_tokens: Option<u32>,
80
81 pub stop_sequences: Option<Vec<String>>,
83
84 pub top_p: Option<f32>,
86
87 pub model_override: Option<String>,
89
90 pub tool_choice: Option<ToolChoice>,
92
93 pub response_format: Option<ResponseFormat>,
95
96 pub tags: Vec<String>,
102}
103
104#[derive(Clone, Debug)]
106pub enum ToolChoice {
107 Auto,
109 None,
111 Required,
113 Specific {
115 name: String,
117 },
118}
119
120#[derive(Clone, Debug)]
122pub enum ResponseFormat {
123 JsonObject,
125 JsonSchema {
127 name: String,
129 schema: serde_json::Value,
131 strict: bool,
133 },
134}
135
136#[derive(Clone, Debug, Serialize, Deserialize)]
138pub struct ToolDefinition {
139 pub name: String,
141 pub description: String,
143 pub parameters: serde_json::Value,
145}
146
147#[derive(Clone, Debug)]
152pub struct MessageChunk {
153 pub role: Option<Role>,
155 pub content: String,
157 pub tool_call_chunks: Vec<ToolCallChunk>,
159 pub usage: Option<crate::state::TokenUsage>,
161}
162
163pub use crate::stream::ToolCallChunk;
169
170#[cfg_attr(target_family = "wasm", async_trait(?Send))]
179#[cfg_attr(not(target_family = "wasm"), async_trait)]
180pub trait ChatModel: Send + Sync + Clone + 'static {
181 async fn invoke(
192 &self,
193 messages: &[Message],
194 options: Option<&CallOptions>,
195 ) -> Result<Message, LlmError>;
196
197 async fn stream(
208 &self,
209 messages: &[Message],
210 options: Option<&CallOptions>,
211 ) -> Result<BoxStream<'_, Result<MessageChunk, LlmError>>, LlmError>;
212
213 #[must_use]
221 fn bind_tools(&self, tools: Vec<ToolDefinition>) -> Self;
222
223 #[must_use]
232 fn with_structured_output<T: JsonSchema + DeserializeOwned + Serialize>(
233 self,
234 ) -> StructuredOutputModel<Self, T>
235 where
236 Self: Sized;
237
238 fn model_name(&self) -> &str;
240}
241
242pub trait JsonSchema: schemars::JsonSchema {}
244
245impl<T: schemars::JsonSchema> JsonSchema for T {}
247
248pub trait DeserializeOwned: for<'de> Deserialize<'de> {}
250
251impl<T: for<'de> Deserialize<'de>> DeserializeOwned for T {}
253
254pub struct StructuredOutputModel<M, T>
259where
260 M: Clone,
261{
262 pub(crate) inner: M,
264 pub(crate) _phantom: std::marker::PhantomData<T>,
266}
267
268impl<M: Clone, T> Clone for StructuredOutputModel<M, T> {
269 fn clone(&self) -> Self {
270 Self {
271 inner: self.inner.clone(),
272 _phantom: std::marker::PhantomData,
273 }
274 }
275}
276
277impl<M, T> std::fmt::Debug for StructuredOutputModel<M, T>
278where
279 M: Clone,
280{
281 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282 f.debug_struct("StructuredOutputModel")
283 .field("inner", &"<model>")
284 .field("_phantom", &self._phantom)
285 .finish()
286 }
287}
288
289#[cfg_attr(target_family = "wasm", async_trait(?Send))]
290#[cfg_attr(not(target_family = "wasm"), async_trait)]
291impl<M, T> ChatModel for StructuredOutputModel<M, T>
292where
293 M: ChatModel,
294 T: JsonSchema + DeserializeOwned + Serialize + Send + Sync + 'static,
295{
296 async fn invoke(
297 &self,
298 messages: &[Message],
299 options: Option<&CallOptions>,
300 ) -> Result<Message, LlmError> {
301 let schema = schemars::schema_for!(T);
303 let tool_def = ToolDefinition {
304 name: "structured_output".to_string(),
305 description: "Output structured data".to_string(),
306 parameters: serde_json::to_value(schema)
307 .map_err(|e| LlmError::InvalidResponse(e.to_string()))?,
308 };
309
310 #[allow(
312 clippy::manual_unwrap_or_default,
313 clippy::option_if_let_else,
314 reason = "project rules prohibit unwrap_or_default; match is explicit and readable"
315 )]
316 let mut opts = match options.cloned() {
317 Some(opts) => opts,
318 None => CallOptions::default(),
319 };
320 opts.tool_choice = Some(ToolChoice::Required);
321
322 let model_with_tool = self.inner.bind_tools(vec![tool_def]);
324 let response = model_with_tool.invoke(messages, Some(&opts)).await?;
325
326 if let Some(tool_call) = response.tool_calls.first() {
328 let _value: T = serde_json::from_value(tool_call.arguments.clone()).map_err(|e| {
329 LlmError::InvalidResponse(format!("Failed to parse structured output: {e}"))
330 })?;
331
332 Ok(Message {
334 id: response.id,
335 role: Role::Ai,
336 content: crate::state::Content::Text(serde_json::to_string(&_value).map_err(
337 |e| {
338 LlmError::InvalidResponse(format!(
339 "Failed to serialize structured output: {e}"
340 ))
341 },
342 )?),
343 tool_calls: vec![],
344 tool_call_id: None,
345 name: None,
346 usage: response.usage,
347 })
348 } else {
349 Err(LlmError::InvalidResponse(
350 "No tool call in response".to_string(),
351 ))
352 }
353 }
354
355 async fn stream(
356 &self,
357 messages: &[Message],
358 options: Option<&CallOptions>,
359 ) -> Result<BoxStream<'_, Result<MessageChunk, LlmError>>, LlmError> {
360 self.inner.stream(messages, options).await
361 }
362
363 fn bind_tools(&self, tools: Vec<ToolDefinition>) -> Self {
364 Self {
365 inner: self.inner.bind_tools(tools),
366 _phantom: std::marker::PhantomData,
367 }
368 }
369
370 fn with_structured_output<U: JsonSchema + DeserializeOwned + Serialize>(
371 self,
372 ) -> StructuredOutputModel<Self, U>
373 where
374 Self: Sized,
375 {
376 StructuredOutputModel {
377 inner: self,
378 _phantom: std::marker::PhantomData,
379 }
380 }
381
382 fn model_name(&self) -> &str {
383 self.inner.model_name()
384 }
385}
386
387