1use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14use crate::error::Error;
15use crate::protocol::v2::capabilities::Capability;
16use crate::protocol::v2::manifest::ApiStyle;
17use crate::protocol::ProtocolError;
18use crate::types::events::StreamingEvent;
19use crate::types::message::{Message, MessageContent, MessageRole};
20
21use super::{DriverRequest, DriverResponse, ProviderDriver, UsageInfo};
22
23const DEFAULT_MAX_TOKENS: u32 = 4096;
24
25#[derive(Debug)]
27pub struct AnthropicDriver {
28 provider_id: String,
29 capabilities: Vec<Capability>,
30}
31
32impl AnthropicDriver {
33 pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
34 Self {
35 provider_id: provider_id.into(),
36 capabilities,
37 }
38 }
39
40 fn split_system_messages(messages: &[Message]) -> (Option<String>, Vec<Value>) {
43 let mut system_parts: Vec<String> = Vec::new();
44 let mut user_messages: Vec<Value> = Vec::new();
45
46 for m in messages {
47 match m.role {
48 MessageRole::System => {
49 if let MessageContent::Text(ref s) = m.content {
50 system_parts.push(s.clone());
51 }
52 }
53 MessageRole::Tool => {
54 if let (Some(ref id), MessageContent::Text(ref s)) =
56 (&m.tool_call_id, &m.content)
57 {
58 user_messages.push(serde_json::json!({
59 "role": "user",
60 "content": [{ "type": "tool_result", "tool_use_id": id, "content": s }],
61 }));
62 }
63 }
64 _ => {
65 let role = match m.role {
66 MessageRole::User => "user",
67 MessageRole::Assistant => "assistant",
68 MessageRole::System => unreachable!(),
69 MessageRole::Tool => unreachable!(),
70 };
71 let content = match &m.content {
72 MessageContent::Text(s) => {
73 serde_json::json!([{ "type": "text", "text": s }])
74 }
75 MessageContent::Blocks(_) => {
76 serde_json::to_value(&m.content).unwrap_or(Value::Null)
77 }
78 };
79 user_messages.push(serde_json::json!({
80 "role": role,
81 "content": content,
82 }));
83 }
84 }
85 }
86
87 let system = if system_parts.is_empty() {
88 None
89 } else {
90 Some(system_parts.join("\n\n"))
91 };
92
93 (system, user_messages)
94 }
95}
96
97#[async_trait]
98impl ProviderDriver for AnthropicDriver {
99 fn provider_id(&self) -> &str {
100 &self.provider_id
101 }
102
103 fn api_style(&self) -> ApiStyle {
104 ApiStyle::AnthropicMessages
105 }
106
107 fn build_request(
108 &self,
109 messages: &[Message],
110 model: &str,
111 temperature: Option<f64>,
112 max_tokens: Option<u32>,
113 stream: bool,
114 extra: Option<&Value>,
115 ) -> Result<DriverRequest, Error> {
116 let (system, msgs) = Self::split_system_messages(messages);
117
118 let mut body = serde_json::json!({
119 "model": model,
120 "messages": msgs,
121 "max_tokens": max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
122 "stream": stream,
123 });
124
125 if let Some(sys) = system {
126 body["system"] = Value::String(sys);
127 }
128 if let Some(t) = temperature {
129 body["temperature"] = serde_json::json!(t);
130 }
131 if let Some(Value::Object(map)) = extra {
132 for (k, v) in map {
133 body[k] = v.clone();
134 }
135 }
136
137 let mut headers = HashMap::new();
138 headers.insert("anthropic-version".into(), "2023-06-01".into());
139
140 Ok(DriverRequest {
141 url: String::new(),
142 method: "POST".into(),
143 headers,
144 body,
145 stream,
146 })
147 }
148
149 fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
150 let content = body
152 .pointer("/content/0/text")
153 .and_then(|v| v.as_str())
154 .map(String::from);
155
156 let finish_reason = body
158 .get("stop_reason")
159 .and_then(|v| v.as_str())
160 .map(|r| match r {
161 "end_turn" => "stop".to_string(),
162 "max_tokens" => "length".to_string(),
163 "tool_use" => "tool_calls".to_string(),
164 other => other.to_string(),
165 });
166
167 let usage = body.get("usage").map(|u| UsageInfo {
168 prompt_tokens: u["input_tokens"].as_u64().unwrap_or(0),
169 completion_tokens: u["output_tokens"].as_u64().unwrap_or(0),
170 total_tokens: u["input_tokens"].as_u64().unwrap_or(0)
171 + u["output_tokens"].as_u64().unwrap_or(0),
172 reasoning_tokens: None,
173 cache_read_tokens: u.get("cache_read_input_tokens").and_then(|v| v.as_u64()),
174 cache_creation_tokens: u
175 .get("cache_creation_input_tokens")
176 .and_then(|v| v.as_u64()),
177 });
178
179 let tool_calls: Vec<Value> = body
181 .get("content")
182 .and_then(|c| c.as_array())
183 .map(|arr| {
184 arr.iter()
185 .filter(|b| b.get("type").and_then(|t| t.as_str()) == Some("tool_use"))
186 .cloned()
187 .collect()
188 })
189 .unwrap_or_default();
190
191 Ok(DriverResponse {
192 content,
193 finish_reason,
194 usage,
195 tool_calls,
196 raw: body.clone(),
197 })
198 }
199
200 fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
201 if data.trim().is_empty() {
202 return Ok(None);
203 }
204
205 let v: Value = serde_json::from_str(data).map_err(|e| {
206 Error::Protocol(ProtocolError::ValidationError(format!(
207 "Failed to parse Anthropic SSE: {}",
208 e
209 )))
210 })?;
211
212 let event_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("");
213
214 match event_type {
215 "content_block_delta" => {
216 if let Some(text) = v.pointer("/delta/text").and_then(|t| t.as_str()) {
217 if !text.is_empty() {
218 return Ok(Some(StreamingEvent::PartialContentDelta {
219 content: text.to_string(),
220 sequence_id: v.get("index").and_then(|i| i.as_u64()),
221 }));
222 }
223 }
224 if let Some(thinking) = v.pointer("/delta/thinking").and_then(|t| t.as_str()) {
226 return Ok(Some(StreamingEvent::ThinkingDelta {
227 thinking: thinking.to_string(),
228 tool_consideration: None,
229 }));
230 }
231 Ok(None)
232 }
233 "message_delta" => {
234 let reason = v.pointer("/delta/stop_reason").and_then(|r| r.as_str());
235 if let Some(r) = reason {
236 return Ok(Some(StreamingEvent::StreamEnd {
237 finish_reason: Some(match r {
238 "end_turn" => "stop".to_string(),
239 "max_tokens" => "length".to_string(),
240 other => other.to_string(),
241 }),
242 }));
243 }
244 Ok(None)
245 }
246 "message_stop" => Ok(Some(StreamingEvent::StreamEnd {
247 finish_reason: Some("stop".into()),
248 })),
249 "error" => {
250 let error = v.get("error").cloned().unwrap_or(Value::Null);
251 Ok(Some(StreamingEvent::StreamError {
252 error,
253 event_id: None,
254 }))
255 }
256 _ => Ok(None),
257 }
258 }
259
260 fn supported_capabilities(&self) -> &[Capability] {
261 &self.capabilities
262 }
263
264 fn is_stream_done(&self, _data: &str) -> bool {
265 false
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_system_message_extraction() {
277 let msgs = vec![Message::system("You are helpful."), Message::user("Hi")];
278 let (sys, user_msgs) = AnthropicDriver::split_system_messages(&msgs);
279 assert_eq!(sys.as_deref(), Some("You are helpful."));
280 assert_eq!(user_msgs.len(), 1);
281 assert_eq!(user_msgs[0]["role"], "user");
282 }
283
284 #[test]
285 fn test_anthropic_build_request() {
286 let driver = AnthropicDriver::new("anthropic", vec![Capability::Text]);
287 let messages = vec![Message::user("Hello")];
288 let req = driver
289 .build_request(
290 &messages,
291 "claude-sonnet-4-20250514",
292 None,
293 Some(1024),
294 false,
295 None,
296 )
297 .unwrap();
298 assert_eq!(req.body["max_tokens"], 1024);
299 assert_eq!(req.body["model"], "claude-sonnet-4-20250514");
300 assert!(req.headers.contains_key("anthropic-version"));
301 }
302
303 #[test]
304 fn test_anthropic_parse_response() {
305 let driver = AnthropicDriver::new("anthropic", vec![]);
306 let body = serde_json::json!({
307 "content": [{"type": "text", "text": "Hello!"}],
308 "stop_reason": "end_turn",
309 "usage": {"input_tokens": 10, "output_tokens": 5}
310 });
311 let resp = driver.parse_response(&body).unwrap();
312 assert_eq!(resp.content.as_deref(), Some("Hello!"));
313 assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
314 assert_eq!(resp.usage.unwrap().total_tokens, 15);
315 }
316
317 #[test]
318 fn test_anthropic_parse_stream_delta() {
319 let driver = AnthropicDriver::new("anthropic", vec![]);
320 let data =
321 r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}"#;
322 let event = driver.parse_stream_event(data).unwrap();
323 match event {
324 Some(StreamingEvent::PartialContentDelta { content, .. }) => {
325 assert_eq!(content, "Hi");
326 }
327 _ => panic!("Expected PartialContentDelta"),
328 }
329 }
330
331 #[test]
332 fn test_anthropic_stop_reason_normalization() {
333 let driver = AnthropicDriver::new("anthropic", vec![]);
334 let body = serde_json::json!({
335 "content": [{"type": "text", "text": ""}],
336 "stop_reason": "tool_use",
337 "usage": {"input_tokens": 0, "output_tokens": 0}
338 });
339 let resp = driver.parse_response(&body).unwrap();
340 assert_eq!(resp.finish_reason.as_deref(), Some("tool_calls"));
341 }
342}