use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
use std::time::Duration;
use thiserror::Error;
use crate::llm::{Message, MessageChunk};
pub type BoxStream<'a, T> = Pin<Box<dyn Stream<Item = T> + Send + 'a>>;
#[derive(Debug, Error)]
pub enum LlmError {
#[error("authentication failed: {0}")]
AuthError(String),
#[error("rate limited, retry after {retry_after:?}")]
RateLimited {
retry_after: Option<Duration>,
},
#[error("context length exceeded: {used} tokens used, {limit} limit")]
ContextLengthExceeded {
used: u64,
limit: u64,
},
#[cfg(any(feature = "anthropic", feature = "openai", feature = "ollama"))]
#[error("network error: {0}")]
NetworkError(#[from] reqwest::Error),
#[cfg(not(any(feature = "anthropic", feature = "openai", feature = "ollama")))]
#[error("network error: {0}")]
NetworkError(String),
#[error("invalid response: {0}")]
InvalidResponse(String),
#[error("model not found: {0}")]
ModelNotFound(String),
#[error("content filtered")]
ContentFiltered,
#[error("timeout after {0:?}")]
Timeout(Duration),
#[error("LLM error: {0}")]
Other(#[source] Box<dyn std::error::Error + Send + Sync>),
}
#[derive(Clone, Debug, Default)]
pub enum ToolChoice {
#[default]
Auto,
None,
Required,
Specific {
name: String,
},
}
#[derive(Clone, Debug)]
pub enum ResponseFormat {
JsonObject,
JsonSchema {
name: String,
schema: serde_json::Value,
strict: bool,
},
}
#[derive(Clone, Debug, Default)]
pub struct CallOptions {
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub stop_sequences: Option<Vec<String>>,
pub top_p: Option<f32>,
pub model_override: Option<String>,
pub tool_choice: Option<ToolChoice>,
pub response_format: Option<ResponseFormat>,
pub tags: Vec<String>,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
pub trait ChatModel: Send + Sync + Clone + 'static {
async fn invoke(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> Result<Message, LlmError>;
fn stream(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> BoxStream<'_, Result<MessageChunk, LlmError>>;
#[must_use]
fn bind_tools(&self, tools: Vec<ToolDefinition>) -> Self;
fn model_name(&self) -> &str;
#[must_use]
#[cfg(feature = "structured-output")]
fn with_structured_output<T>(self) -> crate::llm::StructuredOutputModel<Self, T>
where
Self: Sized,
T: serde::de::DeserializeOwned + schemars::JsonSchema + Clone + Send + Sync + 'static,
{
crate::llm::StructuredOutputModel::new(self)
}
}