Skip to main content

abu_agent/model/
chat.rs

1use std::sync::RwLock;
2use abu_base::chat::{ChatMessage, ChatRequest, ChatRequestBuilder, ChatResponse, UserMessage};
3use abu_provider::{anthropic::Anthropic, deepseek::DeepSeek, openai::OpenAi, ChatProvide, ProvideError};
4use abu_tool::{Tool, ToolDefinition};
5
6pub struct ChatModel<P> {
7    request: RwLock<ChatRequest>,
8    config: ChatConfig,
9    provider: P,
10}
11
12#[derive(Default)]
13pub struct ChatConfig {
14    pub temperature: Option<f64>,
15}
16
17impl ChatModel<OpenAi> {
18    /// load `OPENAI_BASE_URL` and `OPENAI_API_KEY` in env
19    pub fn openai(model: impl Into<String>) -> Result<Self, ChatModelError> {
20        let openai = OpenAi::from_env()
21            .map_err(|e| ChatModelError::BuildOpenAi(e))?;
22        Ok(Self::new(openai, model))
23    }
24}
25
26impl ChatModel<DeepSeek> {
27    /// load `DEEPSEEK_BASE_URL` and `DEEPSEEK_API_KEY` in env
28    pub fn deepseek(model: impl Into<String>) -> Result<Self, ChatModelError> {
29        let deepseek = DeepSeek::from_env()
30            .map_err(|e| ChatModelError::BuildDeepSeek(e))?;
31        Ok(Self::new(deepseek, model))
32
33    }
34}
35
36impl ChatModel<Anthropic> {
37    /// load `ANTHROPIC_BASE_URL` and `ANTHROPIC_API_KEY` in env
38    pub fn anthropic(model: impl Into<String>) -> Result<Self, ChatModelError> {
39        let anthropic = Anthropic::from_env()
40            .map_err(|e| ChatModelError::BuildAnthropic(e))?;
41        Ok(Self::new(anthropic, model))
42    }
43}
44
45impl<P: ChatProvide> ChatModel<P> {
46    pub fn new(provider: P, model: impl Into<String>) -> Self {
47        let request = ChatRequestBuilder::default()
48            .model(model)
49            .build()
50            .expect("request just need model to build!");
51        Self {
52            request: RwLock::new(request),
53            config: ChatConfig::default(),
54            provider
55        }
56    }
57
58    pub fn set_config(&mut self, config: ChatConfig) {
59        self.config = config;
60    }
61
62    pub fn bind_tools<'a>(&'a mut self, tools: impl IntoIterator<Item = &'a Box<dyn Tool>>) {
63        let tool_defines: Vec<_> = tools.into_iter()
64            .map(|t| t.to_function_define())
65            .collect(); 
66        self.request.write().unwrap().tools = tool_defines;  
67    }
68
69    pub fn bind_tool_defines(&mut self, tools: impl Into<Vec<ToolDefinition>>) {
70        self.request.write().unwrap().tools = tools.into(); 
71    }
72
73    #[inline]
74    pub async fn chat(&self, messages: impl IntoChatMessages) -> Result<ChatResponse, ChatModelError> {
75        self.send(messages, &self.config, true).await
76    }
77
78    #[inline]
79    pub async fn chat_no_tools(&self, messages: impl IntoChatMessages) -> Result<ChatResponse, ChatModelError> {
80        self.send(messages, &self.config, false).await
81    }
82
83    async fn send(&self, messages: impl IntoChatMessages, config: &ChatConfig, with_tools: bool) -> Result<ChatResponse, ChatModelError> {
84        // set messages
85        let messages = messages.into_messages();
86        let mut request = self.request.write().unwrap();
87        request.messages = messages;
88        // set config
89        request.temperature = config.temperature;
90        // swap tools
91        let mut tools = vec![];
92        if !with_tools {
93            std::mem::swap(&mut tools, &mut request.tools); 
94        }
95
96        // send with provider
97        let response = self.provider
98            .chat(&request).await
99            .map_err(|e| ChatModelError::Provide(Box::new(e)))?;
100
101        // clear messages
102        request.messages.clear();
103        // reset config
104        request.temperature = self.config.temperature;
105        // recover tools
106        if !with_tools {
107            std::mem::swap(&mut tools, &mut request.tools); 
108        }
109
110        Ok(response)
111    
112    }
113}
114
115#[derive(Debug, thiserror::Error)]
116pub enum ChatModelError {
117    #[error("provide error: {0}")]
118    Provide(Box<dyn std::error::Error + 'static + Send + Sync>),
119
120    #[error("build openai provider: {0}")]
121    BuildOpenAi(ProvideError),
122
123    #[error("build deepseek provider: {0}")]
124    BuildDeepSeek(ProvideError),
125
126    #[error("build anthropic provider: {0}")]
127    BuildAnthropic(ProvideError),
128}
129
130pub trait IntoChatMessages {
131    fn into_messages(self) -> Vec<ChatMessage>;
132}
133
134impl IntoChatMessages for String {
135    #[inline]
136    fn into_messages(self) -> Vec<ChatMessage> {
137        vec![ChatMessage::user(self)]
138    }
139}
140
141impl IntoChatMessages for &String {
142    #[inline]
143    fn into_messages(self) -> Vec<ChatMessage> {
144        vec![ChatMessage::user(self)]
145    }
146}
147
148impl IntoChatMessages for &str {
149    #[inline]
150    fn into_messages(self) -> Vec<ChatMessage> {
151        vec![ChatMessage::user(self)]
152    }
153}
154
155impl IntoChatMessages for UserMessage {
156    #[inline]
157    fn into_messages(self) -> Vec<ChatMessage> {
158        vec![ChatMessage::User(self)]
159    }
160}
161
162impl IntoChatMessages for Vec<ChatMessage> {
163    #[inline]
164    fn into_messages(self) -> Vec<ChatMessage> {
165        self
166    }
167}
168
169impl IntoChatMessages for &[ChatMessage] {
170    #[inline]
171    fn into_messages(self) -> Vec<ChatMessage> {
172        self.to_vec()
173    }
174}
175
176impl IntoChatMessages for &Vec<ChatMessage> {
177    #[inline]
178    fn into_messages(self) -> Vec<ChatMessage> {
179        self.clone()
180    }
181}