1use crate::error::LlmError;
2use crate::providers::{LlmProvider, ProviderResponseChunk};
3
4use crate::types::{Message, Tool, Usage};
5use async_stream::stream;
6use futures::{Stream, StreamExt};
7use reqwest::Client;
8use serde_json::Value;
9use std::boxed::Box;
10use std::pin::Pin;
11use std::time::Duration;
12use tracing::{debug, error, info, instrument, trace, warn};
13
14pub struct AnthropicClient {
15 api_key: String,
16 client: Client,
17 base_url: String,
18 model: String,
19 max_tokens: u32,
20}
21
22#[derive(Debug)]
23struct SseEvent {
24 data: String,
25}
26
27impl Clone for AnthropicClient {
28 fn clone(&self) -> Self {
29 Self {
30 api_key: self.api_key.clone(),
31 client: self.client.clone(),
32 base_url: self.base_url.clone(),
33 model: self.model.clone(),
34 max_tokens: self.max_tokens,
35 }
36 }
37}
38
39#[async_trait::async_trait]
40impl LlmProvider for AnthropicClient {
41 async fn send(
42 &self,
43 messages: Vec<Message>,
44 tools: Vec<Tool>,
45 ) -> Result<
46 Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + '_>>,
47 LlmError,
48 > {
49 Ok(self.send(messages, tools).await)
50 }
51
52 fn provider_name(&self) -> &str {
53 "anthropic"
54 }
55
56 fn model_name(&self) -> &str {
57 &self.model
58 }
59
60 fn clone_box(&self) -> Box<dyn LlmProvider> {
61 Box::new(self.clone())
62 }
63}
64
65impl AnthropicClient {
66 pub fn new(
67 api_key: String,
68 base_url: Option<&str>,
69 timeout: u64,
70 model: &str,
71 max_tokens: u32,
72 ) -> Self {
73 let client = Client::builder()
74 .timeout(Duration::from_secs(timeout))
75 .connect_timeout(Duration::from_secs(30))
76 .build()
77 .expect("Failed to build HTTP client");
78
79 Self {
80 api_key,
81 client,
82 base_url: base_url
83 .unwrap_or("https://api.anthropic.com/v1/messages")
84 .to_string(),
85 model: model.to_string(),
86 max_tokens,
87 }
88 }
89
90 #[instrument(skip(self, messages, tools))]
91 pub async fn send(
92 &self,
93 messages: Vec<Message>,
94 tools: Vec<Tool>,
95 ) -> Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + '_>> {
96 let api_key = self.api_key.clone();
97 let base_url = self.base_url.clone();
98 let model = self.model.clone();
99 let max_tokens = self.max_tokens;
100 let messages_cloned = messages.clone();
101 let tools_cloned = tools.clone();
102 let client_clone = self.client.clone();
103
104 Box::pin(stream! {
105 info!("API request: model={}, max_tokens={}", self.model, self.max_tokens);
106
107 let request_body = match build_request_body(&messages_cloned, &tools_cloned, &model, max_tokens) {
108 Ok(body) => body,
109 Err(e) => {
110 error!("API error: {}", e);
111 yield Err(e);
112 return;
113 }
114 };
115
116 for attempt in 0..3 {
117 let delay = Duration::from_secs(2_u64.pow(attempt));
118
119 match do_request(&client_clone, &api_key, &base_url, &request_body).await {
120 Ok(mut stream) => {
121 while let Some(chunk) = stream.next().await {
122 yield chunk;
123 }
124 return;
125 }
126 Err(e) => {
127 if attempt == 2 {
128 error!("API error: {}", e);
129 yield Err(e);
130 return;
131 }
132 warn!("API retry: attempt={}, delay_ms={}", attempt, delay.as_millis());
133 tokio::time::sleep(delay).await;
134 }
135 }
136 }
137 })
138 }
139}
140
141#[instrument(skip_all)]
142#[allow(clippy::type_complexity)]
143async fn do_request(
144 client: &Client,
145 api_key: &str,
146 base_url: &str,
147 request_body: &Value,
148) -> Result<
149 Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + 'static>>,
150 LlmError,
151> {
152 let response = client
153 .post(base_url)
154 .header("x-api-key", api_key)
155 .header("anthropic-version", "2023-06-01")
156 .header("content-type", "application/json")
157 .json(request_body)
158 .send()
159 .await
160 .map_err(|e| LlmError::NetworkError(e.to_string()))?;
161
162 let status = response.status();
163 debug!("API response received: status={}", status.as_u16());
164
165 if status.is_client_error() || status.is_server_error() {
166 let error_text = response
167 .text()
168 .await
169 .unwrap_or_else(|_| "Unknown error".to_string());
170
171 if status.as_u16() == 429 {
172 error!("API error: Rate limited");
173 return Err(LlmError::ApiError(format!("Rate limited: {}", error_text)));
174 }
175 error!("API error: HTTP {}: {}", status, error_text);
176 return Err(LlmError::ApiError(format!(
177 "HTTP {}: {}",
178 status, error_text
179 )));
180 }
181
182 let byte_stream = response.bytes_stream();
183 let stream = parse_sse_stream(byte_stream);
184 Ok(stream)
185}
186
187fn build_request_body(
188 messages: &[Message],
189 tools: &[Tool],
190 model: &str,
191 max_tokens: u32,
192) -> Result<Value, LlmError> {
193 let cache_count = messages
194 .iter()
195 .filter(|m| m.cache_control.is_some())
196 .count();
197 if cache_count > 0 {
198 debug!(
199 "Anthropic request has {} messages with cache_control",
200 cache_count
201 );
202 for m in messages.iter().filter(|m| m.cache_control.is_some()) {
203 if let Some(cc) = &m.cache_control {
204 debug!(
205 " - role={:?}, type={}, ttl={:?}",
206 m.role, cc.cache_type, cc.ttl
207 );
208 }
209 }
210 }
211
212 let mut request = serde_json::json!({
213 "model": model,
214 "max_tokens": max_tokens,
215 "messages": messages,
216 "stream": true
217 });
218
219 if !tools.is_empty() {
220 request["tools"] = serde_json::to_value(tools)
221 .map_err(|e| LlmError::ApiError(format!("Failed to serialize tools: {}", e)))?;
222 }
223
224 Ok(request)
225}
226fn parse_partial_json(json: &str) -> serde_json::Value {
229 if json.trim().is_empty() {
230 return serde_json::json!({});
231 }
232
233 if let Ok(value) = serde_json::from_str::<serde_json::Value>(json) {
235 return value;
236 }
237
238 serde_json::json!({})
241}
242
243fn parse_sse_stream(
244 byte_stream: impl Stream<Item = reqwest::Result<bytes::Bytes>> + Send + Unpin + 'static,
245) -> Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + 'static>> {
246 Box::pin(stream! {
247 let mut buffer = String::new();
248 let mut tool_calls_by_id: std::collections::HashMap<u64, (String, String)> = std::collections::HashMap::new();
249
250 let mut tool_partial_json: std::collections::HashMap<u64, String> = std::collections::HashMap::new();
251
252 let mut lines = byte_stream
253 .map(|chunk| chunk.map_err(|e| LlmError::NetworkError(e.to_string())));
254
255 while let Some(chunk_result) = lines.next().await {
256 let chunk = match chunk_result {
257 Ok(c) => c,
258 Err(e) => {
259 yield Err(e);
260 continue;
261 }
262 };
263
264 let text = String::from_utf8_lossy(&chunk);
265
266 buffer.push_str(&text);
267
268 while let Some(event) = parse_sse_line(&mut buffer) {
269
270 if event.data == "[DONE]" {
271 return;
272 }
273
274 if let Ok(parsed) = serde_json::from_str::<Value>(&event.data) {
275 trace!("SSE: {}", &event.data.chars().take(200).collect::<String>());
276 let chunk_type = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
277
278 match chunk_type {
279 "content_block_delta" => {
280 if let Some(delta) = parsed.get("delta") {
281 if let Some(text) = delta.get("text").and_then(|v| v.as_str()) {
283 yield Ok(ProviderResponseChunk::ContentDelta(text.to_string()));
284 }
285
286 let delta_type = delta.get("type").and_then(|v| v.as_str());
288 if delta_type == Some("input_json_delta") {
289 if let Some(partial_json) = delta.get("partial_json").and_then(|v| v.as_str()) {
290 if let Some(index) = parsed.get("index").and_then(|v| v.as_u64()) {
292 tool_partial_json.entry(index)
294 .or_default()
295 .push_str(partial_json);
296
297 if let Some((id, name)) = tool_calls_by_id.get(&index) {
299 let accumulated = tool_partial_json.get(&index).unwrap();
300 let args = parse_partial_json(accumulated);
301
302 yield Ok(ProviderResponseChunk::ToolCallDelta {
303 id: id.clone(),
304 name: name.clone(),
305 arguments: args,
306 });
307 }
308 }
309 }
310 }
311 }
312 }
313 "content_block_start" => {
314 if let Some(content_block) = parsed.get("content_block") {
315 let block_type = content_block.get("type").and_then(|v| v.as_str());
316 if block_type == Some("tool_use") {
317 let id = content_block.get("id")
318 .and_then(|v| v.as_str())
319 .unwrap_or("")
320 .to_string();
321 let name = content_block.get("name")
322 .and_then(|v| v.as_str())
323 .unwrap_or("")
324 .to_string();
325
326 if let Some(index) = parsed.get("index").and_then(|v| v.as_u64()) {
328 tool_calls_by_id.insert(index, (id.clone(), name.clone()));
329 }
330
331 yield Ok(ProviderResponseChunk::ToolCallDelta {
332 id,
333 name,
334 arguments: serde_json::json!({}),
335 });
336 }
337 }
338 }
339 "content_block_stop" => {
340 }
342 "message_delta" => {
343 if let Some(delta) = parsed.get("delta") {
344 if let Some(stop_reason) = delta.get("stop_reason").and_then(|v| v.as_str()) {
345 debug!("stop_reason: {}", stop_reason);
346 if stop_reason == "end_turn" || stop_reason == "tool_use" {
347 if let Some(usage) = parsed.get("usage") {
348 if let Ok(usage_obj) = serde_json::from_value::<Usage>(usage.clone()) {
349 if usage_obj.cache_read_tokens > 0 || usage_obj.cache_write_tokens > 0 {
350 debug!(
351 "Anthropic cache tokens: read={}, write={}",
352 usage_obj.cache_read_tokens, usage_obj.cache_write_tokens
353 );
354 }
355 yield Ok(ProviderResponseChunk::Done(usage_obj));
356 return;
357 }
358 }
359 }
360 }
361 }
362 }
363 _ => {
364 debug!("Unknown chunk_type: {}", chunk_type);
365 }
366 }
367 }
368 }
369 }
370 })
371}
372
373fn parse_sse_line(buffer: &mut String) -> Option<SseEvent> {
374 loop {
375 let newline_pos = buffer.find('\n')?;
376 let line = buffer[..newline_pos].trim().to_string();
377 *buffer = buffer[newline_pos + 1..].to_string();
378
379 if line.is_empty() || line.starts_with(':') {
381 continue;
382 }
383
384 if line.starts_with("event:") {
386 continue;
387 }
388
389 if let Some(data_pos) = line.find("data: ") {
391 let data = line[data_pos + 6..].trim();
392 return Some(SseEvent {
393 data: data.to_string(),
394 });
395 }
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use mockito::Server;
403
404 #[tokio::test]
405 async fn test_streaming() {
406 let mut server = Server::new_async().await;
407 let mock = server
408 .mock("POST", "/v1/messages")
409 .with_status(200)
410 .with_header("content-type", "text/event-stream")
411 .with_chunked_body(|w| {
412 w.write_all(b"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}\n\n")?;
413 w.write_all(b"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\" world\"}}\n\n")?;
414 w.write_all(b"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":10,\"output_tokens\":5}}\n\n")?;
415 Ok::<(), std::io::Error>(())
416 })
417 .create_async()
418 .await;
419
420 let client = AnthropicClient::new(
421 "test-key".to_string(),
422 None,
423 300,
424 "claude-3-5-sonnet-20241022",
425 4096,
426 );
427 let messages = vec![Message {
428 role: crate::types::Role::User,
429 content: Some("Hello".to_string()),
430 tool_calls: None,
431 tool_call_id: None,
432 cache_control: None,
433 }];
434
435 let base_url = format!("{}/v1/messages", server.url());
436 let client_with_url = AnthropicClient {
437 api_key: "test-key".to_string(),
438 client: client.client,
439 base_url,
440 model: "claude-3-5-sonnet-20241022".to_string(),
441 max_tokens: 4096,
442 };
443
444 let stream = client_with_url.send(messages, vec![]).await;
445 let chunks: Vec<_> = stream.collect().await;
446 assert!(chunks.len() >= 3);
447
448 mock.assert_async().await;
449 }
450
451 #[tokio::test]
452 async fn test_retry_on_429() {
453 let mut server = Server::new_async().await;
454 let mock = server
455 .mock("POST", "/v1/messages")
456 .with_status(429)
457 .with_header("content-type", "application/json")
458 .with_body(r#"{"error":{"type":"rate_limit_error","message":"Rate limited"}}"#)
459 .expect(2)
460 .create_async()
461 .await;
462
463 let success_mock = server
464 .mock("POST", "/v1/messages")
465 .with_status(200)
466 .with_header("content-type", "text/event-stream")
467 .with_chunked_body(|w| {
468 w.write_all(b"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}\n\n")?;
469 w.write_all(b"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":10,\"output_tokens\":5}}\n\n")?;
470 Ok::<(), std::io::Error>(())
471 })
472 .expect(1)
473 .create_async()
474 .await;
475
476 let client = AnthropicClient::new(
477 "test-key".to_string(),
478 None,
479 300,
480 "claude-3-5-sonnet-20241022",
481 4096,
482 );
483 let messages = vec![Message {
484 role: crate::types::Role::User,
485 content: Some("Hello".to_string()),
486 tool_calls: None,
487 tool_call_id: None,
488 cache_control: None,
489 }];
490
491 let base_url = format!("{}/v1/messages", server.url());
492 let client_with_url = AnthropicClient {
493 api_key: "test-key".to_string(),
494 client: client.client,
495 base_url,
496 model: "claude-3-5-sonnet-20241022".to_string(),
497 max_tokens: 4096,
498 };
499
500 let stream = client_with_url.send(messages, vec![]).await;
501 let chunks: Vec<_> = stream.collect().await;
502 assert!(!chunks.is_empty());
503
504 mock.assert_async().await;
505 success_mock.assert_async().await;
506 }
507
508 #[tokio::test]
509 async fn test_timeout() {
510 let mut server = Server::new_async().await;
511 let _mock = server
512 .mock("POST", "/v1/messages")
513 .with_status(200)
514 .with_header("content-type", "text/event-stream")
515 .with_chunked_body(|w| {
516 std::thread::sleep(std::time::Duration::from_millis(500));
518 w.write_all(
519 b"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}\n\n",
520 )?;
521 Ok::<(), std::io::Error>(())
522 })
523 .create_async()
524 .await;
525
526 let client = AnthropicClient::new(
527 "test-key".to_string(),
528 None,
529 300,
530 "claude-3-5-sonnet-20241022",
531 4096,
532 );
533 let messages = vec![Message {
534 role: crate::types::Role::User,
535 content: Some("Hello".to_string()),
536 tool_calls: None,
537 tool_call_id: None,
538 cache_control: None,
539 }];
540
541 let base_url = format!("{}/v1/messages", server.url());
542 let client_with_url = AnthropicClient {
543 api_key: "test-key".to_string(),
544 client: client.client,
545 base_url,
546 model: "claude-3-5-sonnet-20241022".to_string(),
547 max_tokens: 4096,
548 };
549
550 let stream = client_with_url.send(messages, vec![]).await;
552 let chunks: Vec<_> = stream.collect().await;
553 assert!(!chunks.is_empty());
554 }
555
556 #[tokio::test]
557 async fn test_tool_call_streaming() {
558 let mut server = Server::new_async().await;
559 let mock = server
560 .mock("POST", "/v1/messages")
561 .with_status(200)
562 .with_header("content-type", "text/event-stream")
563 .with_chunked_body(|w| {
564 w.write_all(b"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"test_tool\"}}\n\n")?;
565 w.write_all(b"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"partial_json\":\"{\\\"arg\\\":\\\"value\\\"}\"}}\n\n")?;
566 w.write_all(b"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"input_tokens\":15,\"output_tokens\":20}}\n\n")?;
567 Ok::<(), std::io::Error>(())
568 })
569 .create_async()
570 .await;
571
572 let client = AnthropicClient::new(
573 "test-key".to_string(),
574 None,
575 300,
576 "claude-3-5-sonnet-20241022",
577 4096,
578 );
579 let messages = vec![Message {
580 role: crate::types::Role::User,
581 content: Some("Use test_tool".to_string()),
582 tool_calls: None,
583 tool_call_id: None,
584 cache_control: None,
585 }];
586
587 let tools = vec![Tool {
588 tool_type: "function".to_string(),
589 function: crate::types::ToolFunction {
590 name: "test_tool".to_string(),
591 description: "A test tool".to_string(),
592 parameters: serde_json::json!({"type": "object"}),
593 },
594 }];
595
596 let base_url = format!("{}/v1/messages", server.url());
597 let client_with_url = AnthropicClient {
598 api_key: "test-key".to_string(),
599 client: client.client,
600 base_url,
601 model: "claude-3-5-sonnet-20241022".to_string(),
602 max_tokens: 4096,
603 };
604
605 let stream = client_with_url.send(messages, tools).await;
606 let chunks: Vec<_> = stream.collect().await;
607 assert!(!chunks.is_empty());
608
609 mock.assert_async().await;
610 }
611
612 #[test]
613 fn test_parse_sse_line() {
614 let mut buffer = String::from("data: {\"type\":\"test\"}\n\nother data");
615 let event = parse_sse_line(&mut buffer);
616 assert!(event.is_some());
617 assert_eq!(event.unwrap().data, "{\"type\":\"test\"}");
618 assert_eq!(buffer, "\nother data");
619 }
620
621 #[test]
622 fn test_parse_sse_line_empty() {
623 let mut buffer = String::from("\n\ndata: test");
624 let event = parse_sse_line(&mut buffer);
625 assert!(event.is_none());
626 assert_eq!(buffer, "data: test");
627 }
628
629 #[test]
630 fn test_parse_sse_line_comment() {
631 let mut buffer = String::from(": comment\n\ndata: test");
632 let event = parse_sse_line(&mut buffer);
633 assert!(event.is_none());
634 }
635
636 #[test]
637 fn test_parse_sse_line_zai_format() {
638 let mut buffer = String::from("event: content_block_start\ndata: {\"type\":\"test\"}\n\n");
639 let event = parse_sse_line(&mut buffer);
640 assert!(event.is_some());
641 assert_eq!(event.unwrap().data, "{\"type\":\"test\"}");
642 }
643}