Skip to main content

agent_io/llm/
base.rs

1//! Base trait for LLM implementations
2
3use async_trait::async_trait;
4use futures::Stream;
5use std::pin::Pin;
6
7use super::{ChatCompletion, LlmError, Message, ToolChoice, ToolDefinition};
8
9/// Type alias for boxed stream
10pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatCompletion, LlmError>> + Send>>;
11
12/// Base trait for chat model implementations
13#[async_trait]
14pub trait BaseChatModel: Send + Sync {
15    /// Get the model name
16    fn model(&self) -> &str;
17
18    /// Get the provider name
19    fn provider(&self) -> &str;
20
21    /// Get the context window size (max input tokens)
22    fn context_window(&self) -> Option<u64> {
23        None
24    }
25
26    /// Invoke the model with messages
27    async fn invoke(
28        &self,
29        messages: Vec<Message>,
30        tools: Option<Vec<ToolDefinition>>,
31        tool_choice: Option<ToolChoice>,
32    ) -> Result<ChatCompletion, LlmError>;
33
34    /// Invoke the model with streaming response
35    async fn invoke_stream(
36        &self,
37        messages: Vec<Message>,
38        tools: Option<Vec<ToolDefinition>>,
39        tool_choice: Option<ToolChoice>,
40    ) -> Result<ChatStream, LlmError>;
41
42    /// Count tokens in messages (approximate)
43    async fn count_tokens(&self, messages: &[Message]) -> u64 {
44        // Default approximation: ~4 chars per token
45        let total_chars: usize = messages
46            .iter()
47            .map(|m| match m {
48                Message::User(u) => u
49                    .content
50                    .iter()
51                    .map(|c| c.as_text().map(|t| t.len()).unwrap_or(10))
52                    .sum(),
53                Message::Assistant(a) => {
54                    a.content.as_ref().map(|c| c.len()).unwrap_or(0)
55                        + a.tool_calls
56                            .iter()
57                            .map(|tc| tc.function.arguments.len())
58                            .sum::<usize>()
59                }
60                Message::System(s) => s.content.len(),
61                Message::Developer(d) => d.content.len(),
62                Message::Tool(t) => t.content.len(),
63            })
64            .sum();
65        (total_chars / 4) as u64
66    }
67
68    /// Check if the model supports tools
69    fn supports_tools(&self) -> bool {
70        true
71    }
72
73    /// Check if the model supports streaming
74    fn supports_streaming(&self) -> bool {
75        true
76    }
77
78    /// Check if the model supports vision
79    fn supports_vision(&self) -> bool {
80        false
81    }
82}
83
84/// Builder pattern helpers for common model configurations
85pub struct ModelBuilder;
86
87impl ModelBuilder {
88    /// Create an OpenAI model
89    #[cfg(feature = "openai")]
90    pub fn openai(model: impl Into<String>) -> super::openai::ChatOpenAIBuilder {
91        super::ChatOpenAI::builder().model(model)
92    }
93
94    /// Create an Anthropic model
95    #[cfg(feature = "anthropic")]
96    pub fn anthropic(model: impl Into<String>) -> super::anthropic::ChatAnthropicBuilder {
97        super::ChatAnthropic::builder().model(model)
98    }
99
100    /// Create a Google model
101    #[cfg(feature = "google")]
102    pub fn google(model: impl Into<String>) -> super::google::ChatGoogleBuilder {
103        super::ChatGoogle::builder().model(model)
104    }
105}