Skip to main content

agent_base/llm/
anthropic.rs

1use async_trait::async_trait;
2use eventsource_stream::Eventsource;
3use futures_core::Stream;
4use futures_util::StreamExt;
5use reqwest::Client;
6use serde_json::{json, Value};
7use std::pin::Pin;
8
9use crate::types::{AgentResult, AgentError, ChatMessage, ImageAttachment, ResponseFormat};
10use super::{LlmCapabilities, LlmClient, StreamChunk, UsageInfo};
11
12pub struct AnthropicClient {
13    api_key: String,
14    model: String,
15    base_url: String,
16    client: Client,
17}
18
19impl AnthropicClient {
20    pub fn new(api_key: String, model: String, base_url: Option<String>) -> Self {
21        Self {
22            api_key,
23            model,
24            base_url: base_url
25                .unwrap_or_else(|| "https://api.anthropic.com".to_string()),
26            client: Client::new(),
27        }
28    }
29
30    fn convert_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<Value>) {
31        let mut system_prompt: Option<String> = None;
32        let mut result: Vec<Value> = Vec::new();
33
34        for msg in messages {
35            match msg {
36                ChatMessage::System { content } => {
37                    system_prompt = Some(content.clone());
38                }
39                ChatMessage::User { content, images } => {
40                    let mut content_parts: Vec<Value> = Vec::new();
41                    content_parts.push(json!({"type": "text", "text": content}));
42                    for img in images {
43                        match img {
44                            ImageAttachment::Url { url, detail: _ } => {
45                                content_parts.push(json!({
46                                    "type": "image",
47                                    "source": {
48                                        "type": "url",
49                                        "url": url,
50                                    }
51                                }));
52                            }
53                            ImageAttachment::Base64 { data, media_type, detail: _ } => {
54                                let mime = media_type.as_deref().unwrap_or("image/jpeg");
55                                content_parts.push(json!({
56                                    "type": "image",
57                                    "source": {
58                                        "type": "base64",
59                                        "media_type": mime,
60                                        "data": data,
61                                    }
62                                }));
63                            }
64                        }
65                    }
66                    result.push(json!({
67                        "role": "user",
68                        "content": content_parts,
69                    }));
70                }
71                ChatMessage::Assistant { content, reasoning_content: _, tool_calls } => {
72                    let mut parts: Vec<Value> = Vec::new();
73                    if let Some(text) = content {
74                        if !text.is_empty() {
75                            parts.push(json!({"type": "text", "text": text}));
76                        }
77                    }
78                    if let Some(tc) = tool_calls {
79                        for t in tc {
80                            let input: Value = serde_json::from_str(&t.arguments)
81                                .unwrap_or(Value::Null);
82                            parts.push(json!({
83                                "type": "tool_use",
84                                "id": t.id,
85                                "name": t.name,
86                                "input": input,
87                            }));
88                        }
89                    }
90                    if !parts.is_empty() {
91                        result.push(json!({"role": "assistant", "content": parts}));
92                    }
93                }
94                ChatMessage::Tool { tool_call_id, content } => {
95                    result.push(json!({
96                        "role": "user",
97                        "content": [{
98                            "type": "tool_result",
99                            "tool_use_id": tool_call_id,
100                            "content": content,
101                        }]
102                    }));
103                }
104            }
105        }
106
107        (system_prompt, result)
108    }
109
110    fn convert_tools(tools: &[Value]) -> Vec<Value> {
111        tools
112            .iter()
113            .filter_map(|tool| {
114                let func = tool.get("function")?;
115                let name = func.get("name")?.as_str()?;
116                let description = func.get("description")
117                    .and_then(Value::as_str)
118                    .unwrap_or("");
119                let input_schema = func.get("parameters")
120                    .cloned()
121                    .unwrap_or_else(|| json!({"type": "object"}));
122                Some(json!({
123                    "name": name,
124                    "description": description,
125                    "input_schema": input_schema,
126                }))
127            })
128            .collect()
129    }
130
131    fn build_body(
132        messages: &[ChatMessage],
133        tools: &[Value],
134        model: &str,
135    ) -> Value {
136        let (system_prompt, anthropic_messages) = Self::convert_messages(messages);
137        let anthropic_tools = Self::convert_tools(tools);
138
139        let mut body = json!({
140            "model": model,
141            "max_tokens": 8192,
142            "messages": anthropic_messages,
143        });
144
145        if !anthropic_tools.is_empty() {
146            if let Some(obj) = body.as_object_mut() {
147                obj.insert("tools".to_string(), json!(anthropic_tools));
148            }
149        }
150
151        if let Some(system) = system_prompt {
152            if let Some(obj) = body.as_object_mut() {
153                obj.insert("system".to_string(), json!(system));
154            }
155        }
156
157        body
158    }
159
160    fn parse_sse(data_str: &str, event_type: &str) -> AgentResult<StreamChunk> {
161        if data_str.is_empty() {
162            return Ok(StreamChunk::Text(String::new()));
163        }
164
165        let data: Value = serde_json::from_str(data_str)
166            .map_err(|e| AgentError::json(format!("Anthropic SSE JSON: {e}")))?;
167
168        match event_type {
169            "message_start" => {
170                let input_tokens = data
171                    .get("message")
172                    .and_then(|m| m.get("usage"))
173                    .and_then(|u| u.get("input_tokens"))
174                    .and_then(Value::as_u64)
175                    .map(|v| v as u32);
176                let output_tokens = data
177                    .get("message")
178                    .and_then(|m| m.get("usage"))
179                    .and_then(|u| u.get("output_tokens"))
180                    .and_then(Value::as_u64)
181                    .map(|v| v as u32);
182                Ok(StreamChunk::Usage(UsageInfo {
183                    prompt_tokens: input_tokens,
184                    completion_tokens: output_tokens,
185                    total_tokens: None,
186                }))
187            }
188            "content_block_start" => {
189                let cb = data.get("content_block");
190                let idx = data.get("index").and_then(Value::as_u64).unwrap_or(0);
191                if let Some(cb) = cb {
192                    if cb.get("type").and_then(Value::as_str) == Some("tool_use") {
193                        let id = cb.get("id").and_then(Value::as_str).unwrap_or("").to_string();
194                        let name = cb.get("name").and_then(Value::as_str).unwrap_or("").to_string();
195                        return Ok(StreamChunk::ToolCall(json!({
196                            "delta": {
197                                "tool_calls": [{
198                                    "index": idx,
199                                    "id": if id.is_empty() { Value::Null } else { json!(id) },
200                                    "function": {
201                                        "name": name,
202                                        "arguments": "",
203                                    }
204                                }]
205                            }
206                        })));
207                    }
208                }
209                Ok(StreamChunk::Text(String::new()))
210            }
211            "content_block_delta" => {
212                let delta = data.get("delta");
213                let idx = data.get("index").and_then(Value::as_u64).unwrap_or(0);
214                if let Some(d) = delta {
215                    match d.get("type").and_then(Value::as_str) {
216                        Some("text_delta") => {
217                            let text = d.get("text").and_then(Value::as_str).unwrap_or("").to_string();
218                            Ok(StreamChunk::Text(text))
219                        }
220                        Some("input_json_delta") => {
221                            let partial = d.get("partial_json").and_then(Value::as_str).unwrap_or("").to_string();
222                            Ok(StreamChunk::ToolCall(json!({
223                                "delta": {
224                                    "tool_calls": [{
225                                        "index": idx,
226                                        "function": {
227                                            "arguments": partial,
228                                        }
229                                    }]
230                                }
231                            })))
232                        }
233                        Some("thinking_delta") => {
234                            let thinking = d.get("thinking").and_then(Value::as_str).unwrap_or("").to_string();
235                            Ok(StreamChunk::Thought(thinking))
236                        }
237                        _ => Ok(StreamChunk::Text(String::new())),
238                    }
239                } else {
240                    Ok(StreamChunk::Text(String::new()))
241                }
242            }
243            "content_block_stop" => Ok(StreamChunk::Text(String::new())),
244            "message_delta" => {
245                let output_tokens = data
246                    .get("usage")
247                    .and_then(|u| u.get("output_tokens"))
248                    .and_then(Value::as_u64)
249                    .map(|v| v as u32);
250                Ok(StreamChunk::Usage(UsageInfo {
251                    prompt_tokens: None,
252                    completion_tokens: output_tokens,
253                    total_tokens: None,
254                }))
255            }
256            "message_stop" => Ok(StreamChunk::Stop),
257            "ping" => Ok(StreamChunk::Text(String::new())),
258            _ => Ok(StreamChunk::Text(String::new())),
259        }
260    }
261}
262
263#[async_trait]
264impl LlmClient for AnthropicClient {
265    async fn chat(
266        &self,
267        messages: &[ChatMessage],
268        tools: &[Value],
269        _enable_thinking: Option<bool>,
270        _response_format: Option<&ResponseFormat>,
271    ) -> AgentResult<Value> {
272        let url = format!("{}/v1/messages", self.base_url);
273        let body = Self::build_body(messages, tools, &self.model);
274
275        let response = self
276            .client
277            .post(&url)
278            .header("x-api-key", &self.api_key)
279            .header("anthropic-version", "2023-06-01")
280            .header("Content-Type", "application/json")
281            .json(&body)
282            .send()
283            .await
284            .map_err(|e| AgentError::llm(format!("HTTP request failed: {e}")))?;
285
286        let status = response.status();
287        let res_json: Value = response.json().await
288            .map_err(|e| AgentError::json(format!("Response JSON parse failed: {e}")))?;
289
290        if !status.is_success() {
291            let err_msg = res_json
292                .get("error")
293                .and_then(|e| e.get("message"))
294                .and_then(Value::as_str)
295                .unwrap_or("unknown error");
296            return Err(AgentError::LlmApi {
297                message: err_msg.to_string(),
298            });
299        }
300
301        Ok(res_json)
302    }
303
304    async fn chat_stream(
305        &self,
306        messages: &[ChatMessage],
307        tools: &[Value],
308        _enable_thinking: Option<bool>,
309        _response_format: Option<&ResponseFormat>,
310    ) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamChunk>> + Send>>> {
311        let url = format!("{}/v1/messages", self.base_url);
312        let mut body = Self::build_body(messages, tools, &self.model);
313
314        if let Some(obj) = body.as_object_mut() {
315            obj.insert("stream".to_string(), json!(true));
316        }
317
318        let response = self
319            .client
320            .post(&url)
321            .header("x-api-key", &self.api_key)
322            .header("anthropic-version", "2023-06-01")
323            .header("Content-Type", "application/json")
324            .json(&body)
325            .send()
326            .await
327            .map_err(|e| AgentError::llm(format!("HTTP request failed: {e}")))?;
328
329        if !response.status().is_success() {
330            let err_text = response.text().await
331                .map_err(|e| AgentError::llm(format!("Failed to read error response: {e}")))?;
332            return Err(AgentError::LlmApi { message: err_text });
333        }
334
335        let stream = response
336            .bytes_stream()
337            .eventsource()
338            .filter_map(|event| async move {
339                match event {
340                    Ok(ref ev) if ev.event == "error" => {
341                        let err_msg = ev.data.clone();
342                        Some(Err(AgentError::LlmApi { message: err_msg }))
343                    }
344                    Ok(ev) => {
345                        let event_type = if ev.event.is_empty() { "message_stop" } else { ev.event.as_str() };
346                        match Self::parse_sse(&ev.data, event_type) {
347                            Ok(chunk) => Some(Ok(chunk)),
348                            Err(e) => Some(Err(e)),
349                        }
350                    }
351                    Err(e) => Some(Err(AgentError::LlmStream(format!("SSE Stream error: {e}")))),
352                }
353            });
354
355        Ok(Box::pin(stream))
356    }
357
358    fn capabilities(&self) -> LlmCapabilities {
359        LlmCapabilities {
360            supports_streaming: true,
361            supports_tools: true,
362            supports_vision: true,
363            supports_thinking: true,
364            max_context_tokens: Some(200_000),
365            max_output_tokens: Some(8_192),
366        }
367    }
368}