1use async_trait::async_trait;
4use futures::Stream;
5use std::pin::Pin;
6
7use super::{ChatCompletion, LlmError, Message, ToolChoice, ToolDefinition};
8
9pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatCompletion, LlmError>> + Send>>;
11
12#[async_trait]
14pub trait BaseChatModel: Send + Sync {
15 fn model(&self) -> &str;
17
18 fn provider(&self) -> &str;
20
21 fn context_window(&self) -> Option<u64> {
23 None
24 }
25
26 async fn invoke(
28 &self,
29 messages: Vec<Message>,
30 tools: Option<Vec<ToolDefinition>>,
31 tool_choice: Option<ToolChoice>,
32 ) -> Result<ChatCompletion, LlmError>;
33
34 async fn invoke_stream(
36 &self,
37 messages: Vec<Message>,
38 tools: Option<Vec<ToolDefinition>>,
39 tool_choice: Option<ToolChoice>,
40 ) -> Result<ChatStream, LlmError>;
41
42 async fn count_tokens(&self, messages: &[Message]) -> u64 {
44 let total_chars: usize = messages
46 .iter()
47 .map(|m| match m {
48 Message::User(u) => u
49 .content
50 .iter()
51 .map(|c| c.as_text().map(|t| t.len()).unwrap_or(10))
52 .sum(),
53 Message::Assistant(a) => {
54 a.content.as_ref().map(|c| c.len()).unwrap_or(0)
55 + a.tool_calls
56 .iter()
57 .map(|tc| tc.function.arguments.len())
58 .sum::<usize>()
59 }
60 Message::System(s) => s.content.len(),
61 Message::Developer(d) => d.content.len(),
62 Message::Tool(t) => t.content.len(),
63 })
64 .sum();
65 (total_chars / 4) as u64
66 }
67
68 fn supports_tools(&self) -> bool {
70 true
71 }
72
73 fn supports_streaming(&self) -> bool {
75 true
76 }
77
78 fn supports_vision(&self) -> bool {
80 false
81 }
82}
83
84pub struct ModelBuilder;
86
87impl ModelBuilder {
88 #[cfg(feature = "openai")]
90 pub fn openai(model: impl Into<String>) -> super::openai::ChatOpenAIBuilder {
91 super::ChatOpenAI::builder().model(model)
92 }
93
94 #[cfg(feature = "anthropic")]
96 pub fn anthropic(model: impl Into<String>) -> super::anthropic::ChatAnthropicBuilder {
97 super::ChatAnthropic::builder().model(model)
98 }
99
100 #[cfg(feature = "google")]
102 pub fn google(model: impl Into<String>) -> super::google::ChatGoogleBuilder {
103 super::ChatGoogle::builder().model(model)
104 }
105}