use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::state::{Message, Role};
pub use futures::stream::BoxStream;
#[derive(Debug, thiserror::Error)]
pub enum LlmError {
#[error("authentication failed: {0}")]
AuthError(String),
#[error("rate limited, retry after {retry_after:?}")]
RateLimited {
retry_after: Option<std::time::Duration>,
},
#[error("context length exceeded: {used} tokens used, {limit} limit")]
ContextLengthExceeded {
used: u64,
limit: u64,
},
#[error("network error: {0}")]
NetworkError(String),
#[error("invalid response: {0}")]
InvalidResponse(String),
#[error("model not found: {0}")]
ModelNotFound(String),
#[error("content filtered")]
ContentFiltered,
#[error("timeout after {0:?}")]
Timeout(std::time::Duration),
#[error("llm error: {0}")]
Other(#[source] Box<dyn std::error::Error + Send + Sync>),
}
#[derive(Clone, Debug, Default)]
pub struct CallOptions {
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub stop_sequences: Option<Vec<String>>,
pub top_p: Option<f32>,
pub model_override: Option<String>,
pub tool_choice: Option<ToolChoice>,
pub response_format: Option<ResponseFormat>,
pub tags: Vec<String>,
}
#[derive(Clone, Debug)]
pub enum ToolChoice {
Auto,
None,
Required,
Specific {
name: String,
},
}
#[derive(Clone, Debug)]
pub enum ResponseFormat {
JsonObject,
JsonSchema {
name: String,
schema: serde_json::Value,
strict: bool,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Clone, Debug)]
pub struct MessageChunk {
pub role: Option<Role>,
pub content: String,
pub tool_call_chunks: Vec<ToolCallChunk>,
pub usage: Option<crate::state::TokenUsage>,
}
pub use crate::stream::ToolCallChunk;
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
pub trait ChatModel: Send + Sync + Clone + 'static {
async fn invoke(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> Result<Message, LlmError>;
async fn stream(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> Result<BoxStream<'_, Result<MessageChunk, LlmError>>, LlmError>;
#[must_use]
fn bind_tools(&self, tools: Vec<ToolDefinition>) -> Self;
#[must_use]
fn with_structured_output<T: JsonSchema + DeserializeOwned + Serialize>(
self,
) -> StructuredOutputModel<Self, T>
where
Self: Sized;
fn model_name(&self) -> &str;
}
pub trait JsonSchema: schemars::JsonSchema {}
impl<T: schemars::JsonSchema> JsonSchema for T {}
pub trait DeserializeOwned: for<'de> Deserialize<'de> {}
impl<T: for<'de> Deserialize<'de>> DeserializeOwned for T {}
pub struct StructuredOutputModel<M, T>
where
M: Clone,
{
pub(crate) inner: M,
pub(crate) _phantom: std::marker::PhantomData<T>,
}
impl<M: Clone, T> Clone for StructuredOutputModel<M, T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_phantom: std::marker::PhantomData,
}
}
}
impl<M, T> std::fmt::Debug for StructuredOutputModel<M, T>
where
M: Clone,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StructuredOutputModel")
.field("inner", &"<model>")
.field("_phantom", &self._phantom)
.finish()
}
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl<M, T> ChatModel for StructuredOutputModel<M, T>
where
M: ChatModel,
T: JsonSchema + DeserializeOwned + Serialize + Send + Sync + 'static,
{
async fn invoke(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> Result<Message, LlmError> {
let schema = schemars::schema_for!(T);
let tool_def = ToolDefinition {
name: "structured_output".to_string(),
description: "Output structured data".to_string(),
parameters: serde_json::to_value(schema)
.map_err(|e| LlmError::InvalidResponse(e.to_string()))?,
};
#[allow(
clippy::manual_unwrap_or_default,
clippy::option_if_let_else,
reason = "project rules prohibit unwrap_or_default; match is explicit and readable"
)]
let mut opts = match options.cloned() {
Some(opts) => opts,
None => CallOptions::default(),
};
opts.tool_choice = Some(ToolChoice::Required);
let model_with_tool = self.inner.bind_tools(vec![tool_def]);
let response = model_with_tool.invoke(messages, Some(&opts)).await?;
if let Some(tool_call) = response.tool_calls.first() {
let _value: T = serde_json::from_value(tool_call.arguments.clone()).map_err(|e| {
LlmError::InvalidResponse(format!("Failed to parse structured output: {e}"))
})?;
Ok(Message {
id: response.id,
role: Role::Ai,
content: crate::state::Content::Text(serde_json::to_string(&_value).map_err(
|e| {
LlmError::InvalidResponse(format!(
"Failed to serialize structured output: {e}"
))
},
)?),
tool_calls: vec![],
tool_call_id: None,
name: None,
usage: response.usage,
})
} else {
Err(LlmError::InvalidResponse(
"No tool call in response".to_string(),
))
}
}
async fn stream(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> Result<BoxStream<'_, Result<MessageChunk, LlmError>>, LlmError> {
self.inner.stream(messages, options).await
}
fn bind_tools(&self, tools: Vec<ToolDefinition>) -> Self {
Self {
inner: self.inner.bind_tools(tools),
_phantom: std::marker::PhantomData,
}
}
fn with_structured_output<U: JsonSchema + DeserializeOwned + Serialize>(
self,
) -> StructuredOutputModel<Self, U>
where
Self: Sized,
{
StructuredOutputModel {
inner: self,
_phantom: std::marker::PhantomData,
}
}
fn model_name(&self) -> &str {
self.inner.model_name()
}
}