Skip to main content

ai_lib_core/drivers/
anthropic.rs

1//! Anthropic Messages API 驱动 — 实现 Anthropic 特有的请求/响应格式转换
2//!
3//! Anthropic Messages API driver. Handles the key differences from OpenAI:
4//! - System messages are a top-level `system` parameter, not part of `messages`.
5//! - Content uses typed blocks: `[{"type": "text", "text": "..."}]`.
6//! - Streaming uses `event: content_block_delta` with `delta.text`.
7//! - Response uses `content[0].text` instead of `choices[0].message.content`.
8//! - `max_tokens` is required, not optional.
9
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14use crate::error::Error;
15use crate::protocol::v2::capabilities::Capability;
16use crate::protocol::v2::manifest::ApiStyle;
17use crate::protocol::ProtocolError;
18use crate::types::events::StreamingEvent;
19use crate::types::message::{Message, MessageContent, MessageRole};
20
21use super::{DriverRequest, DriverResponse, ProviderDriver, UsageInfo};
22
23const DEFAULT_MAX_TOKENS: u32 = 4096;
24
25/// Anthropic Messages API driver.
26#[derive(Debug)]
27pub struct AnthropicDriver {
28    provider_id: String,
29    capabilities: Vec<Capability>,
30}
31
32impl AnthropicDriver {
33    pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
34        Self {
35            provider_id: provider_id.into(),
36            capabilities,
37        }
38    }
39
40    /// Extract system message and non-system messages separately.
41    /// Anthropic requires system as a top-level param, not in messages array.
42    fn split_system_messages(messages: &[Message]) -> (Option<String>, Vec<Value>) {
43        let mut system_parts: Vec<String> = Vec::new();
44        let mut user_messages: Vec<Value> = Vec::new();
45
46        for m in messages {
47            match m.role {
48                MessageRole::System => {
49                    if let MessageContent::Text(ref s) = m.content {
50                        system_parts.push(s.clone());
51                    }
52                }
53                MessageRole::Tool => {
54                    // Anthropic: tool results are sent as user message with tool_result block
55                    if let (Some(ref id), MessageContent::Text(ref s)) =
56                        (&m.tool_call_id, &m.content)
57                    {
58                        user_messages.push(serde_json::json!({
59                            "role": "user",
60                            "content": [{ "type": "tool_result", "tool_use_id": id, "content": s }],
61                        }));
62                    }
63                }
64                _ => {
65                    let role = match m.role {
66                        MessageRole::User => "user",
67                        MessageRole::Assistant => "assistant",
68                        MessageRole::System => unreachable!(),
69                        MessageRole::Tool => unreachable!(),
70                    };
71                    let content = match &m.content {
72                        MessageContent::Text(s) => {
73                            serde_json::json!([{ "type": "text", "text": s }])
74                        }
75                        MessageContent::Blocks(_) => {
76                            serde_json::to_value(&m.content).unwrap_or(Value::Null)
77                        }
78                    };
79                    user_messages.push(serde_json::json!({
80                        "role": role,
81                        "content": content,
82                    }));
83                }
84            }
85        }
86
87        let system = if system_parts.is_empty() {
88            None
89        } else {
90            Some(system_parts.join("\n\n"))
91        };
92
93        (system, user_messages)
94    }
95}
96
97#[async_trait]
98impl ProviderDriver for AnthropicDriver {
99    fn provider_id(&self) -> &str {
100        &self.provider_id
101    }
102
103    fn api_style(&self) -> ApiStyle {
104        ApiStyle::AnthropicMessages
105    }
106
107    fn build_request(
108        &self,
109        messages: &[Message],
110        model: &str,
111        temperature: Option<f64>,
112        max_tokens: Option<u32>,
113        stream: bool,
114        extra: Option<&Value>,
115    ) -> Result<DriverRequest, Error> {
116        let (system, msgs) = Self::split_system_messages(messages);
117
118        let mut body = serde_json::json!({
119            "model": model,
120            "messages": msgs,
121            "max_tokens": max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
122            "stream": stream,
123        });
124
125        if let Some(sys) = system {
126            body["system"] = Value::String(sys);
127        }
128        if let Some(t) = temperature {
129            body["temperature"] = serde_json::json!(t);
130        }
131        if let Some(Value::Object(map)) = extra {
132            for (k, v) in map {
133                body[k] = v.clone();
134            }
135        }
136
137        let mut headers = HashMap::new();
138        headers.insert("anthropic-version".into(), "2023-06-01".into());
139
140        Ok(DriverRequest {
141            url: String::new(),
142            method: "POST".into(),
143            headers,
144            body,
145            stream,
146        })
147    }
148
149    fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
150        // Anthropic response: { content: [{type: "text", text: "..."}], stop_reason, usage }
151        let content = body
152            .pointer("/content/0/text")
153            .and_then(|v| v.as_str())
154            .map(String::from);
155
156        // Normalize stop_reason → finish_reason
157        let finish_reason = body
158            .get("stop_reason")
159            .and_then(|v| v.as_str())
160            .map(|r| match r {
161                "end_turn" => "stop".to_string(),
162                "max_tokens" => "length".to_string(),
163                "tool_use" => "tool_calls".to_string(),
164                other => other.to_string(),
165            });
166
167        let usage = body.get("usage").map(|u| UsageInfo {
168            prompt_tokens: u["input_tokens"].as_u64().unwrap_or(0),
169            completion_tokens: u["output_tokens"].as_u64().unwrap_or(0),
170            total_tokens: u["input_tokens"].as_u64().unwrap_or(0)
171                + u["output_tokens"].as_u64().unwrap_or(0),
172            reasoning_tokens: None,
173            cache_read_tokens: u.get("cache_read_input_tokens").and_then(|v| v.as_u64()),
174            cache_creation_tokens: u
175                .get("cache_creation_input_tokens")
176                .and_then(|v| v.as_u64()),
177        });
178
179        // Extract tool_use blocks from content array
180        let tool_calls: Vec<Value> = body
181            .get("content")
182            .and_then(|c| c.as_array())
183            .map(|arr| {
184                arr.iter()
185                    .filter(|b| b.get("type").and_then(|t| t.as_str()) == Some("tool_use"))
186                    .cloned()
187                    .collect()
188            })
189            .unwrap_or_default();
190
191        Ok(DriverResponse {
192            content,
193            finish_reason,
194            usage,
195            tool_calls,
196            raw: body.clone(),
197        })
198    }
199
200    fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
201        if data.trim().is_empty() {
202            return Ok(None);
203        }
204
205        let v: Value = serde_json::from_str(data).map_err(|e| {
206            Error::Protocol(ProtocolError::ValidationError(format!(
207                "Failed to parse Anthropic SSE: {}",
208                e
209            )))
210        })?;
211
212        let event_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("");
213
214        match event_type {
215            "content_block_delta" => {
216                if let Some(text) = v.pointer("/delta/text").and_then(|t| t.as_str()) {
217                    if !text.is_empty() {
218                        return Ok(Some(StreamingEvent::PartialContentDelta {
219                            content: text.to_string(),
220                            sequence_id: v.get("index").and_then(|i| i.as_u64()),
221                        }));
222                    }
223                }
224                // Thinking delta support
225                if let Some(thinking) = v.pointer("/delta/thinking").and_then(|t| t.as_str()) {
226                    return Ok(Some(StreamingEvent::ThinkingDelta {
227                        thinking: thinking.to_string(),
228                        tool_consideration: None,
229                    }));
230                }
231                Ok(None)
232            }
233            "message_delta" => {
234                let reason = v.pointer("/delta/stop_reason").and_then(|r| r.as_str());
235                if let Some(r) = reason {
236                    return Ok(Some(StreamingEvent::StreamEnd {
237                        finish_reason: Some(match r {
238                            "end_turn" => "stop".to_string(),
239                            "max_tokens" => "length".to_string(),
240                            other => other.to_string(),
241                        }),
242                    }));
243                }
244                Ok(None)
245            }
246            "message_stop" => Ok(Some(StreamingEvent::StreamEnd {
247                finish_reason: Some("stop".into()),
248            })),
249            "error" => {
250                let error = v.get("error").cloned().unwrap_or(Value::Null);
251                Ok(Some(StreamingEvent::StreamError {
252                    error,
253                    event_id: None,
254                }))
255            }
256            _ => Ok(None),
257        }
258    }
259
260    fn supported_capabilities(&self) -> &[Capability] {
261        &self.capabilities
262    }
263
264    fn is_stream_done(&self, _data: &str) -> bool {
265        // Anthropic signals done via event type, not a sentinel string.
266        // The `event: message_stop` is handled in parse_stream_event.
267        false
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_system_message_extraction() {
277        let msgs = vec![Message::system("You are helpful."), Message::user("Hi")];
278        let (sys, user_msgs) = AnthropicDriver::split_system_messages(&msgs);
279        assert_eq!(sys.as_deref(), Some("You are helpful."));
280        assert_eq!(user_msgs.len(), 1);
281        assert_eq!(user_msgs[0]["role"], "user");
282    }
283
284    #[test]
285    fn test_anthropic_build_request() {
286        let driver = AnthropicDriver::new("anthropic", vec![Capability::Text]);
287        let messages = vec![Message::user("Hello")];
288        let req = driver
289            .build_request(
290                &messages,
291                "claude-sonnet-4-20250514",
292                None,
293                Some(1024),
294                false,
295                None,
296            )
297            .unwrap();
298        assert_eq!(req.body["max_tokens"], 1024);
299        assert_eq!(req.body["model"], "claude-sonnet-4-20250514");
300        assert!(req.headers.contains_key("anthropic-version"));
301    }
302
303    #[test]
304    fn test_anthropic_parse_response() {
305        let driver = AnthropicDriver::new("anthropic", vec![]);
306        let body = serde_json::json!({
307            "content": [{"type": "text", "text": "Hello!"}],
308            "stop_reason": "end_turn",
309            "usage": {"input_tokens": 10, "output_tokens": 5}
310        });
311        let resp = driver.parse_response(&body).unwrap();
312        assert_eq!(resp.content.as_deref(), Some("Hello!"));
313        assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
314        assert_eq!(resp.usage.unwrap().total_tokens, 15);
315    }
316
317    #[test]
318    fn test_anthropic_parse_stream_delta() {
319        let driver = AnthropicDriver::new("anthropic", vec![]);
320        let data =
321            r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}"#;
322        let event = driver.parse_stream_event(data).unwrap();
323        match event {
324            Some(StreamingEvent::PartialContentDelta { content, .. }) => {
325                assert_eq!(content, "Hi");
326            }
327            _ => panic!("Expected PartialContentDelta"),
328        }
329    }
330
331    #[test]
332    fn test_anthropic_stop_reason_normalization() {
333        let driver = AnthropicDriver::new("anthropic", vec![]);
334        let body = serde_json::json!({
335            "content": [{"type": "text", "text": ""}],
336            "stop_reason": "tool_use",
337            "usage": {"input_tokens": 0, "output_tokens": 0}
338        });
339        let resp = driver.parse_response(&body).unwrap();
340        assert_eq!(resp.finish_reason.as_deref(), Some("tool_calls"));
341    }
342}