use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
use super::{ChatCompletion, LlmError, Message, ToolChoice, ToolDefinition};
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatCompletion, LlmError>> + Send>>;
#[async_trait]
pub trait BaseChatModel: Send + Sync {
fn model(&self) -> &str;
fn provider(&self) -> &str;
fn context_window(&self) -> Option<u64> {
None
}
async fn invoke(
&self,
messages: Vec<Message>,
tools: Option<Vec<ToolDefinition>>,
tool_choice: Option<ToolChoice>,
) -> Result<ChatCompletion, LlmError>;
async fn invoke_stream(
&self,
messages: Vec<Message>,
tools: Option<Vec<ToolDefinition>>,
tool_choice: Option<ToolChoice>,
) -> Result<ChatStream, LlmError>;
async fn count_tokens(&self, messages: &[Message]) -> u64 {
let total_chars: usize = messages
.iter()
.map(|m| match m {
Message::User(u) => u
.content
.iter()
.map(|c| c.as_text().map(|t| t.len()).unwrap_or(10))
.sum(),
Message::Assistant(a) => {
a.content.as_ref().map(|c| c.len()).unwrap_or(0)
+ a.tool_calls
.iter()
.map(|tc| tc.function.arguments.len())
.sum::<usize>()
}
Message::System(s) => s.content.len(),
Message::Developer(d) => d.content.len(),
Message::Tool(t) => t.content.len(),
})
.sum();
(total_chars / 4) as u64
}
fn supports_tools(&self) -> bool {
true
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_vision(&self) -> bool {
false
}
}
pub struct ModelBuilder;
impl ModelBuilder {
#[cfg(feature = "openai")]
pub fn openai(model: impl Into<String>) -> super::openai::ChatOpenAIBuilder {
super::ChatOpenAI::builder().model(model)
}
#[cfg(feature = "anthropic")]
pub fn anthropic(model: impl Into<String>) -> super::anthropic::ChatAnthropicBuilder {
super::ChatAnthropic::builder().model(model)
}
#[cfg(feature = "google")]
pub fn google(model: impl Into<String>) -> super::google::ChatGoogleBuilder {
super::ChatGoogle::builder().model(model)
}
}