Skip to main content

limit_llm/
client.rs

1use crate::error::LlmError;
2use crate::providers::{LlmProvider, ProviderResponseChunk};
3
4use crate::types::{Message, Tool, Usage};
5use async_stream::stream;
6use futures::{Stream, StreamExt};
7use reqwest::Client;
8use serde_json::Value;
9use std::boxed::Box;
10use std::pin::Pin;
11use std::time::Duration;
12use tracing::{debug, error, info, instrument, trace, warn};
13
14pub struct AnthropicClient {
15    api_key: String,
16    client: Client,
17    base_url: String,
18    model: String,
19    max_tokens: u32,
20}
21
22#[derive(Debug)]
23struct SseEvent {
24    data: String,
25}
26
27impl Clone for AnthropicClient {
28    fn clone(&self) -> Self {
29        Self {
30            api_key: self.api_key.clone(),
31            client: self.client.clone(),
32            base_url: self.base_url.clone(),
33            model: self.model.clone(),
34            max_tokens: self.max_tokens,
35        }
36    }
37}
38
39#[async_trait::async_trait]
40impl LlmProvider for AnthropicClient {
41    async fn send(
42        &self,
43        messages: Vec<Message>,
44        tools: Vec<Tool>,
45    ) -> Result<
46        Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + '_>>,
47        LlmError,
48    > {
49        Ok(self.send(messages, tools).await)
50    }
51
52    fn provider_name(&self) -> &str {
53        "anthropic"
54    }
55
56    fn model_name(&self) -> &str {
57        &self.model
58    }
59
60    fn clone_box(&self) -> Box<dyn LlmProvider> {
61        Box::new(self.clone())
62    }
63}
64
65impl AnthropicClient {
66    pub fn new(
67        api_key: String,
68        base_url: Option<&str>,
69        timeout: u64,
70        model: &str,
71        max_tokens: u32,
72    ) -> Self {
73        let client = Client::builder()
74            .timeout(Duration::from_secs(timeout))
75            .connect_timeout(Duration::from_secs(30))
76            .build()
77            .expect("Failed to build HTTP client");
78
79        Self {
80            api_key,
81            client,
82            base_url: base_url
83                .unwrap_or("https://api.anthropic.com/v1/messages")
84                .to_string(),
85            model: model.to_string(),
86            max_tokens,
87        }
88    }
89
90    #[instrument(skip(self, messages, tools))]
91    pub async fn send(
92        &self,
93        messages: Vec<Message>,
94        tools: Vec<Tool>,
95    ) -> Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + '_>> {
96        let api_key = self.api_key.clone();
97        let base_url = self.base_url.clone();
98        let model = self.model.clone();
99        let max_tokens = self.max_tokens;
100        let messages_cloned = messages.clone();
101        let tools_cloned = tools.clone();
102        let client_clone = self.client.clone();
103
104        Box::pin(stream! {
105            info!("API request: model={}, max_tokens={}", self.model, self.max_tokens);
106
107            let request_body = match build_request_body(&messages_cloned, &tools_cloned, &model, max_tokens) {
108                Ok(body) => body,
109                Err(e) => {
110                    error!("API error: {}", e);
111                    yield Err(e);
112                    return;
113                }
114            };
115
116            for attempt in 0..3 {
117                let delay = Duration::from_secs(2_u64.pow(attempt));
118
119                match do_request(&client_clone, &api_key, &base_url, &request_body).await {
120                    Ok(mut stream) => {
121                        while let Some(chunk) = stream.next().await {
122                            yield chunk;
123                        }
124                        return;
125                    }
126                    Err(e) => {
127                        if attempt == 2 {
128                            error!("API error: {}", e);
129                            yield Err(e);
130                            return;
131                        }
132                        warn!("API retry: attempt={}, delay_ms={}", attempt, delay.as_millis());
133                        tokio::time::sleep(delay).await;
134                    }
135                }
136            }
137        })
138    }
139}
140
141#[instrument(skip_all)]
142#[allow(clippy::type_complexity)]
143async fn do_request(
144    client: &Client,
145    api_key: &str,
146    base_url: &str,
147    request_body: &Value,
148) -> Result<
149    Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + 'static>>,
150    LlmError,
151> {
152    let response = client
153        .post(base_url)
154        .header("x-api-key", api_key)
155        .header("anthropic-version", "2023-06-01")
156        .header("content-type", "application/json")
157        .json(request_body)
158        .send()
159        .await
160        .map_err(|e| LlmError::NetworkError(e.to_string()))?;
161
162    let status = response.status();
163    debug!("API response received: status={}", status.as_u16());
164
165    if status.is_client_error() || status.is_server_error() {
166        let error_text = response
167            .text()
168            .await
169            .unwrap_or_else(|_| "Unknown error".to_string());
170
171        if status.as_u16() == 429 {
172            error!("API error: Rate limited");
173            return Err(LlmError::ApiError(format!("Rate limited: {}", error_text)));
174        }
175        error!("API error: HTTP {}: {}", status, error_text);
176        return Err(LlmError::ApiError(format!(
177            "HTTP {}: {}",
178            status, error_text
179        )));
180    }
181
182    let byte_stream = response.bytes_stream();
183    let stream = parse_sse_stream(byte_stream);
184    Ok(stream)
185}
186
187fn build_request_body(
188    messages: &[Message],
189    tools: &[Tool],
190    model: &str,
191    max_tokens: u32,
192) -> Result<Value, LlmError> {
193    let cache_count = messages
194        .iter()
195        .filter(|m| m.cache_control.is_some())
196        .count();
197    if cache_count > 0 {
198        debug!(
199            "Anthropic request has {} messages with cache_control",
200            cache_count
201        );
202        for m in messages.iter().filter(|m| m.cache_control.is_some()) {
203            if let Some(cc) = &m.cache_control {
204                debug!(
205                    "  - role={:?}, type={}, ttl={:?}",
206                    m.role, cc.cache_type, cc.ttl
207                );
208            }
209        }
210    }
211
212    let mut request = serde_json::json!({
213        "model": model,
214        "max_tokens": max_tokens,
215        "messages": messages,
216        "stream": true
217    });
218
219    if !tools.is_empty() {
220        request["tools"] = serde_json::to_value(tools)
221            .map_err(|e| LlmError::ApiError(format!("Failed to serialize tools: {}", e)))?;
222    }
223
224    Ok(request)
225}
226/// Parse potentially incomplete JSON during streaming.
227/// Returns empty object if parsing fails.
228fn parse_partial_json(json: &str) -> serde_json::Value {
229    if json.trim().is_empty() {
230        return serde_json::json!({});
231    }
232
233    // Try standard parsing first
234    if let Ok(value) = serde_json::from_str::<serde_json::Value>(json) {
235        return value;
236    }
237
238    // If parsing fails, return empty object
239    // (In future, could use partial-json crate for better handling)
240    serde_json::json!({})
241}
242
243fn parse_sse_stream(
244    byte_stream: impl Stream<Item = reqwest::Result<bytes::Bytes>> + Send + Unpin + 'static,
245) -> Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + 'static>> {
246    Box::pin(stream! {
247            let mut buffer = String::new();
248            let mut tool_calls_by_id: std::collections::HashMap<u64, (String, String)> = std::collections::HashMap::new();
249
250            let mut tool_partial_json: std::collections::HashMap<u64, String> = std::collections::HashMap::new();
251
252            let mut lines = byte_stream
253                .map(|chunk| chunk.map_err(|e| LlmError::NetworkError(e.to_string())));
254
255            while let Some(chunk_result) = lines.next().await {
256                let chunk = match chunk_result {
257                    Ok(c) => c,
258                    Err(e) => {
259                        yield Err(e);
260                        continue;
261                    }
262                };
263
264                let text = String::from_utf8_lossy(&chunk);
265
266                buffer.push_str(&text);
267
268                while let Some(event) = parse_sse_line(&mut buffer) {
269
270                    if event.data == "[DONE]" {
271                        return;
272                    }
273
274                    if let Ok(parsed) = serde_json::from_str::<Value>(&event.data) {
275                        trace!("SSE: {}", &event.data.chars().take(200).collect::<String>());
276                        let chunk_type = parsed.get("type").and_then(|v| v.as_str()).unwrap_or("");
277
278                        match chunk_type {
279                            "content_block_delta" => {
280                                if let Some(delta) = parsed.get("delta") {
281                                    // Handle text deltas
282                                    if let Some(text) = delta.get("text").and_then(|v| v.as_str()) {
283                                        yield Ok(ProviderResponseChunk::ContentDelta(text.to_string()));
284                                    }
285
286                                    // Handle tool argument deltas (input_json_delta)
287                                    let delta_type = delta.get("type").and_then(|v| v.as_str());
288                                    if delta_type == Some("input_json_delta") {
289                                        if let Some(partial_json) = delta.get("partial_json").and_then(|v| v.as_str()) {
290                                            // Get tool call index
291                                            if let Some(index) = parsed.get("index").and_then(|v| v.as_u64()) {
292                                                // Accumulate partial JSON
293    tool_partial_json.entry(index)
294                                                    .or_default()
295                                                    .push_str(partial_json);
296
297                                                // Look up tool call metadata and parse accumulated JSON
298                                                if let Some((id, name)) = tool_calls_by_id.get(&index) {
299                                                    let accumulated = tool_partial_json.get(&index).unwrap();
300                                                    let args = parse_partial_json(accumulated);
301
302                                                    yield Ok(ProviderResponseChunk::ToolCallDelta {
303                                                        id: id.clone(),
304                                                        name: name.clone(),
305                                                        arguments: args,
306                                                    });
307                                                }
308                                            }
309                                        }
310                                    }
311                                }
312                            }
313                            "content_block_start" => {
314                                if let Some(content_block) = parsed.get("content_block") {
315                                    let block_type = content_block.get("type").and_then(|v| v.as_str());
316                                    if block_type == Some("tool_use") {
317                                        let id = content_block.get("id")
318                                            .and_then(|v| v.as_str())
319                                            .unwrap_or("")
320                                            .to_string();
321                                        let name = content_block.get("name")
322                                            .and_then(|v| v.as_str())
323                                            .unwrap_or("")
324                                            .to_string();
325
326                                        // Track tool call by index
327                                        if let Some(index) = parsed.get("index").and_then(|v| v.as_u64()) {
328                                            tool_calls_by_id.insert(index, (id.clone(), name.clone()));
329                                        }
330
331                                        yield Ok(ProviderResponseChunk::ToolCallDelta {
332                                            id,
333                                            name,
334                                            arguments: serde_json::json!({}),
335                                        });
336                                    }
337                                }
338                            }
339                            "content_block_stop" => {
340                                // Tool call completed
341                            }
342                            "message_delta" => {
343                                if let Some(delta) = parsed.get("delta") {
344                                    if let Some(stop_reason) = delta.get("stop_reason").and_then(|v| v.as_str()) {
345                                        debug!("stop_reason: {}", stop_reason);
346                                        if stop_reason == "end_turn" || stop_reason == "tool_use" {
347                                            if let Some(usage) = parsed.get("usage") {
348                                                if let Ok(usage_obj) = serde_json::from_value::<Usage>(usage.clone()) {
349                                                    if usage_obj.cache_read_tokens > 0 || usage_obj.cache_write_tokens > 0 {
350                                                        debug!(
351                                                            "Anthropic cache tokens: read={}, write={}",
352                                                            usage_obj.cache_read_tokens, usage_obj.cache_write_tokens
353                                                        );
354                                                    }
355                                                    yield Ok(ProviderResponseChunk::Done(usage_obj));
356                                                    return;
357                                                }
358                                            }
359                                        }
360                                    }
361                                }
362                            }
363                            _ => {
364                                debug!("Unknown chunk_type: {}", chunk_type);
365                            }
366                        }
367                    }
368                }
369            }
370        })
371}
372
373fn parse_sse_line(buffer: &mut String) -> Option<SseEvent> {
374    loop {
375        let newline_pos = buffer.find('\n')?;
376        let line = buffer[..newline_pos].trim().to_string();
377        *buffer = buffer[newline_pos + 1..].to_string();
378
379        // Skip empty lines and comments
380        if line.is_empty() || line.starts_with(':') {
381            continue;
382        }
383
384        // Skip event: lines (we only care about data)
385        if line.starts_with("event:") {
386            continue;
387        }
388
389        // Parse data: lines
390        if let Some(data_pos) = line.find("data: ") {
391            let data = line[data_pos + 6..].trim();
392            return Some(SseEvent {
393                data: data.to_string(),
394            });
395        }
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use mockito::Server;
403
404    #[tokio::test]
405    async fn test_streaming() {
406        let mut server = Server::new_async().await;
407        let mock = server
408            .mock("POST", "/v1/messages")
409            .with_status(200)
410            .with_header("content-type", "text/event-stream")
411            .with_chunked_body(|w| {
412                w.write_all(b"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}\n\n")?;
413                w.write_all(b"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\" world\"}}\n\n")?;
414                w.write_all(b"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":10,\"output_tokens\":5}}\n\n")?;
415                Ok::<(), std::io::Error>(())
416            })
417            .create_async()
418            .await;
419
420        let client = AnthropicClient::new(
421            "test-key".to_string(),
422            None,
423            300,
424            "claude-3-5-sonnet-20241022",
425            4096,
426        );
427        let messages = vec![Message {
428            role: crate::types::Role::User,
429            content: Some("Hello".to_string()),
430            tool_calls: None,
431            tool_call_id: None,
432            cache_control: None,
433        }];
434
435        let base_url = format!("{}/v1/messages", server.url());
436        let client_with_url = AnthropicClient {
437            api_key: "test-key".to_string(),
438            client: client.client,
439            base_url,
440            model: "claude-3-5-sonnet-20241022".to_string(),
441            max_tokens: 4096,
442        };
443
444        let stream = client_with_url.send(messages, vec![]).await;
445        let chunks: Vec<_> = stream.collect().await;
446        assert!(chunks.len() >= 3);
447
448        mock.assert_async().await;
449    }
450
451    #[tokio::test]
452    async fn test_retry_on_429() {
453        let mut server = Server::new_async().await;
454        let mock = server
455            .mock("POST", "/v1/messages")
456            .with_status(429)
457            .with_header("content-type", "application/json")
458            .with_body(r#"{"error":{"type":"rate_limit_error","message":"Rate limited"}}"#)
459            .expect(2)
460            .create_async()
461            .await;
462
463        let success_mock = server
464            .mock("POST", "/v1/messages")
465            .with_status(200)
466            .with_header("content-type", "text/event-stream")
467            .with_chunked_body(|w| {
468                w.write_all(b"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}\n\n")?;
469                w.write_all(b"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":10,\"output_tokens\":5}}\n\n")?;
470                Ok::<(), std::io::Error>(())
471            })
472            .expect(1)
473            .create_async()
474            .await;
475
476        let client = AnthropicClient::new(
477            "test-key".to_string(),
478            None,
479            300,
480            "claude-3-5-sonnet-20241022",
481            4096,
482        );
483        let messages = vec![Message {
484            role: crate::types::Role::User,
485            content: Some("Hello".to_string()),
486            tool_calls: None,
487            tool_call_id: None,
488            cache_control: None,
489        }];
490
491        let base_url = format!("{}/v1/messages", server.url());
492        let client_with_url = AnthropicClient {
493            api_key: "test-key".to_string(),
494            client: client.client,
495            base_url,
496            model: "claude-3-5-sonnet-20241022".to_string(),
497            max_tokens: 4096,
498        };
499
500        let stream = client_with_url.send(messages, vec![]).await;
501        let chunks: Vec<_> = stream.collect().await;
502        assert!(!chunks.is_empty());
503
504        mock.assert_async().await;
505        success_mock.assert_async().await;
506    }
507
508    #[tokio::test]
509    async fn test_timeout() {
510        let mut server = Server::new_async().await;
511        let _mock = server
512            .mock("POST", "/v1/messages")
513            .with_status(200)
514            .with_header("content-type", "text/event-stream")
515            .with_chunked_body(|w| {
516                // Sleep to simulate slow response
517                std::thread::sleep(std::time::Duration::from_millis(500));
518                w.write_all(
519                    b"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}\n\n",
520                )?;
521                Ok::<(), std::io::Error>(())
522            })
523            .create_async()
524            .await;
525
526        let client = AnthropicClient::new(
527            "test-key".to_string(),
528            None,
529            300,
530            "claude-3-5-sonnet-20241022",
531            4096,
532        );
533        let messages = vec![Message {
534            role: crate::types::Role::User,
535            content: Some("Hello".to_string()),
536            tool_calls: None,
537            tool_call_id: None,
538            cache_control: None,
539        }];
540
541        let base_url = format!("{}/v1/messages", server.url());
542        let client_with_url = AnthropicClient {
543            api_key: "test-key".to_string(),
544            client: client.client,
545            base_url,
546            model: "claude-3-5-sonnet-20241022".to_string(),
547            max_tokens: 4096,
548        };
549
550        // The test should pass since timeout is 300s
551        let stream = client_with_url.send(messages, vec![]).await;
552        let chunks: Vec<_> = stream.collect().await;
553        assert!(!chunks.is_empty());
554    }
555
556    #[tokio::test]
557    async fn test_tool_call_streaming() {
558        let mut server = Server::new_async().await;
559        let mock = server
560            .mock("POST", "/v1/messages")
561            .with_status(200)
562            .with_header("content-type", "text/event-stream")
563            .with_chunked_body(|w| {
564                w.write_all(b"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"test_tool\"}}\n\n")?;
565                w.write_all(b"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"partial_json\":\"{\\\"arg\\\":\\\"value\\\"}\"}}\n\n")?;
566                w.write_all(b"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"input_tokens\":15,\"output_tokens\":20}}\n\n")?;
567                Ok::<(), std::io::Error>(())
568            })
569            .create_async()
570            .await;
571
572        let client = AnthropicClient::new(
573            "test-key".to_string(),
574            None,
575            300,
576            "claude-3-5-sonnet-20241022",
577            4096,
578        );
579        let messages = vec![Message {
580            role: crate::types::Role::User,
581            content: Some("Use test_tool".to_string()),
582            tool_calls: None,
583            tool_call_id: None,
584            cache_control: None,
585        }];
586
587        let tools = vec![Tool {
588            tool_type: "function".to_string(),
589            function: crate::types::ToolFunction {
590                name: "test_tool".to_string(),
591                description: "A test tool".to_string(),
592                parameters: serde_json::json!({"type": "object"}),
593            },
594        }];
595
596        let base_url = format!("{}/v1/messages", server.url());
597        let client_with_url = AnthropicClient {
598            api_key: "test-key".to_string(),
599            client: client.client,
600            base_url,
601            model: "claude-3-5-sonnet-20241022".to_string(),
602            max_tokens: 4096,
603        };
604
605        let stream = client_with_url.send(messages, tools).await;
606        let chunks: Vec<_> = stream.collect().await;
607        assert!(!chunks.is_empty());
608
609        mock.assert_async().await;
610    }
611
612    #[test]
613    fn test_parse_sse_line() {
614        let mut buffer = String::from("data: {\"type\":\"test\"}\n\nother data");
615        let event = parse_sse_line(&mut buffer);
616        assert!(event.is_some());
617        assert_eq!(event.unwrap().data, "{\"type\":\"test\"}");
618        assert_eq!(buffer, "\nother data");
619    }
620
621    #[test]
622    fn test_parse_sse_line_empty() {
623        let mut buffer = String::from("\n\ndata: test");
624        let event = parse_sse_line(&mut buffer);
625        assert!(event.is_none());
626        assert_eq!(buffer, "data: test");
627    }
628
629    #[test]
630    fn test_parse_sse_line_comment() {
631        let mut buffer = String::from(": comment\n\ndata: test");
632        let event = parse_sse_line(&mut buffer);
633        assert!(event.is_none());
634    }
635
636    #[test]
637    fn test_parse_sse_line_zai_format() {
638        let mut buffer = String::from("event: content_block_start\ndata: {\"type\":\"test\"}\n\n");
639        let event = parse_sse_line(&mut buffer);
640        assert!(event.is_some());
641        assert_eq!(event.unwrap().data, "{\"type\":\"test\"}");
642    }
643}