Skip to main content

matrixcode_core/providers/
anthropic.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use std::sync::Arc;
7use tokio::sync::mpsc;
8
9use crate::constants::{
10    ANTHROPIC_API_VERSION, DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_CONTENT_TIMEOUT_SECS,
11    DEFAULT_READ_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS, THINKING_BUDGET_NEW_MODELS,
12    THINKING_BUDGET_OLD_MODELS,
13};
14use crate::models::context_window_for;
15use crate::tools::ToolDefinition;
16
17use super::{
18    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
19    StreamEvent, Usage,
20};
21
22pub struct AnthropicProvider {
23    api_key: String,
24    model: String,
25    base_url: String,
26    client: reqwest::Client,
27    /// Extra headers from config
28    extra_headers: Vec<(String, String)>,
29}
30
31impl AnthropicProvider {
32    pub fn new(api_key: String, model: String, base_url: String) -> Self {
33        Self::with_headers(api_key, model, base_url, None)
34    }
35
36    /// Load proxy from environment variables
37    fn load_proxy_from_env() -> Option<String> {
38        std::env::var("HTTPS_PROXY")
39            .ok()
40            .or_else(|| std::env::var("https_proxy").ok())
41            .or_else(|| std::env::var("HTTP_PROXY").ok())
42            .or_else(|| std::env::var("http_proxy").ok())
43    }
44
45    pub fn with_headers(
46        api_key: String,
47        model: String,
48        base_url: String,
49        extra_headers: Option<std::collections::HashMap<String, String>>,
50    ) -> Self {
51        // Create client with better timeout handling for streaming
52        // - No total timeout (streaming responses can take a long time)
53        // - Connect timeout: DEFAULT_CONNECT_TIMEOUT_SECS seconds
54        // - Read timeout per chunk: DEFAULT_READ_TIMEOUT_SECS seconds (for slow responses between chunks)
55        let mut client_builder = reqwest::Client::builder()
56            .connect_timeout(std::time::Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
57            .read_timeout(std::time::Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
58            .timeout(std::time::Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)); // Total timeout for non-streaming
59
60        // Add proxy from environment if available
61        if let Some(proxy_url) = Self::load_proxy_from_env()
62            && let Ok(proxy) = reqwest::Proxy::all(&proxy_url)
63        {
64            log::info!("AnthropicProvider using proxy: {}", proxy_url);
65            client_builder = client_builder.proxy(proxy);
66        }
67
68        let client = client_builder
69            .build()
70            .unwrap_or_else(|_| reqwest::Client::new());
71        let extra_headers: Vec<(String, String)> = extra_headers
72            .map(|h| h.into_iter().collect())
73            .unwrap_or_default();
74        Self {
75            api_key,
76            model,
77            base_url,
78            client,
79            extra_headers,
80        }
81    }
82
83    /// Check if this is the official Anthropic API endpoint.
84    /// Non-official endpoints typically use Bearer auth (OpenAI-compatible).
85    fn is_official_anthropic(&self) -> bool {
86        self.base_url.contains("api.anthropic.com")
87    }
88
89    fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
90        // Always filter out thinking blocks from message history.
91        // Thinking blocks should NOT be sent back to the API because:
92        // 1. They consume significant tokens without adding value
93        // 2. They can cause the model to repeat similar thinking patterns
94        // 3. Anthropic's thinking is meant for user visibility, not for context
95        // Note: The current turn's thinking is handled separately by the API
96        log::debug!(
97            "convert_messages: always filtering thinking blocks, base_url={}",
98            self.base_url
99        );
100
101        // Count thinking blocks in original messages
102        let mut thinking_count = 0;
103        for m in messages {
104            if let MessageContent::Blocks(blocks) = &m.content {
105                for b in blocks {
106                    if matches!(b, ContentBlock::Thinking { .. }) {
107                        thinking_count += 1;
108                    }
109                }
110            }
111        }
112        if thinking_count > 0 {
113            log::debug!(
114                "convert_messages: Filtering {} thinking blocks from {} messages",
115                thinking_count,
116                messages.len()
117            );
118        }
119
120        messages
121            .iter()
122            .filter(|m| m.role != Role::System)
123            .map(|m| {
124                let role = match m.role {
125                    Role::User | Role::Tool => "user",
126                    Role::Assistant => "assistant",
127                    Role::System => unreachable!(),
128                };
129
130                let content = match &m.content {
131                    MessageContent::Text(text) => json!(text),
132                    MessageContent::Blocks(blocks) => {
133                        let converted: Vec<Value> = blocks
134                            .iter()
135                            .filter(|b| {
136                                // Always skip thinking blocks - they should not be sent back
137                                if matches!(b, ContentBlock::Thinking { .. }) {
138                                    return false;
139                                }
140                                true
141                            })
142                            .map(|b| match b {
143                                ContentBlock::Text { text } => json!({"type": "text", "text": text}),
144                                ContentBlock::ToolUse { id, name, input } => {
145                                    json!({"type": "tool_use", "id": id, "name": name, "input": input})
146                                }
147                                ContentBlock::ToolResult { tool_use_id, content } => {
148                                    json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
149                                }
150                                ContentBlock::Thinking { thinking, signature } => {
151                                    let mut obj = json!({"type": "thinking", "thinking": thinking});
152                                    // Only send non-empty signature
153                                    if let Some(sig) = signature
154                                        && !sig.is_empty() {
155                                            obj["signature"] = json!(sig);
156                                        }
157                                    obj
158                                }
159                                ContentBlock::ServerToolUse { id, name, input } => {
160                                    json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
161                                }
162                                ContentBlock::ServerToolResult { tool_use_id, content } => {
163                                    json!({"type": "server_tool_result", "tool_use_id": tool_use_id, "content": content})
164                                }
165                                ContentBlock::WebSearchResult { tool_use_id, content } => {
166                                    json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
167                                }
168                            })
169                            .collect();
170
171                        if converted.is_empty() {
172                            json!([])
173                        } else {
174                            json!(converted)
175                        }
176                    }
177                };
178
179                // Skip messages with empty content
180                if content.is_array() && content.as_array().map(|a| a.is_empty()).unwrap_or(false) {
181                    return json!(null);
182                }
183
184                json!({"role": role, "content": content})
185            })
186            .filter(|v| !v.is_null())
187            .collect()
188    }
189
190    /// Convert tools with caching control for Anthropic prompt caching.
191    fn convert_tools_with_caching(
192        &self,
193        tools: &[ToolDefinition],
194        enable_caching: bool,
195    ) -> Vec<Value> {
196        let mut converted: Vec<Value> = tools
197            .iter()
198            .map(|t| {
199                json!({
200                    "name": t.name,
201                    "description": t.description,
202                    "input_schema": t.parameters,
203                })
204            })
205            .collect();
206
207        // Add cache_control to the last tool definition for tools caching
208        if enable_caching && !converted.is_empty() {
209            let last_idx = converted.len() - 1;
210            if let Some(obj) = converted[last_idx].as_object_mut() {
211                obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
212            }
213        }
214
215        converted
216    }
217
218    /// Build the base JSON body shared by streaming and non-streaming requests.
219    fn build_body(&self, request: &ChatRequest) -> Value {
220        let mut body = json!({
221            "model": self.model,
222            "max_tokens": request.max_tokens,
223            "messages": self.convert_messages(&request.messages),
224        });
225
226        // Add prompt caching for system prompt (Anthropic-specific)
227        if request.enable_caching {
228            if let Some(system) = &request.system {
229                // System prompt caching: add cache_control to enable caching
230                body["system"] = json!([
231                    {
232                        "type": "text",
233                        "text": system,
234                        "cache_control": {"type": "ephemeral"}
235                    }
236                ]);
237            }
238        } else if let Some(system) = &request.system {
239            body["system"] = json!(system);
240        }
241
242        if !request.tools.is_empty() {
243            let tools = self.convert_tools_with_caching(&request.tools, request.enable_caching);
244            body["tools"] = json!(tools);
245        }
246
247        if !request.server_tools.is_empty() {
248            body["tools"] = json!(
249                body["tools"]
250                    .as_array()
251                    .map(|t| {
252                        let mut tools = t.clone();
253                        for st in &request.server_tools {
254                            tools.push(serde_json::to_value(st).unwrap_or_default());
255                        }
256                        tools
257                    })
258                    .unwrap_or_else(|| request
259                        .server_tools
260                        .iter()
261                        .map(|st| serde_json::to_value(st).unwrap_or_default())
262                        .collect())
263            );
264        }
265
266        // Extended thinking (Anthropic-specific)
267        // Only enable thinking for official Anthropic API - third-party APIs like DashScope
268        // forcibly enable thinking mode which causes hanging issues
269        if request.think && self.is_official_anthropic() {
270            let config = thinking_config(&self.model);
271            log::debug!(
272                "Adding thinking config for model {}: {:?}",
273                self.model,
274                config
275            );
276            body["thinking"] = config;
277        } else if request.think {
278            log::debug!(
279                "Skipping thinking config for non-official API (model: {}, base_url: {})",
280                self.model,
281                self.base_url
282            );
283        }
284
285        body
286    }
287}
288
289/// Models that require the new `adaptive` thinking mode instead of the
290/// legacy `enabled`+`budget_tokens` form. Conservative allow-list: if we
291/// don't recognize the name we default to the legacy shape (which older
292/// models and most third-party gateways understand).
293fn thinking_config(model: &str) -> Value {
294    let m = model.to_lowercase();
295    // New models (2025+) use adaptive thinking
296    let adaptive = m.contains("opus-4")
297        || m.contains("sonnet-4")
298        || m.contains("claude-4")
299        || m.contains("20250")
300        || m.contains("2025");
301    if adaptive {
302        json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_NEW_MODELS})
303    } else {
304        // to prevent hanging on long histories
305        json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_OLD_MODELS})
306    }
307}
308
309#[async_trait]
310impl Provider for AnthropicProvider {
311    fn context_size(&self) -> Option<u32> {
312        context_window_for(&self.model)
313    }
314
315    fn model_name(&self) -> &str {
316        &self.model
317    }
318
319    fn clone_box(&self) -> Box<dyn Provider> {
320        Box::new(Self {
321            api_key: self.api_key.clone(),
322            model: self.model.clone(),
323            base_url: self.base_url.clone(),
324            client: reqwest::Client::new(),
325            extra_headers: self.extra_headers.clone(),
326        })
327    }
328
329    fn clone_arc(&self) -> Arc<dyn Provider> {
330        Arc::new(Self {
331            api_key: self.api_key.clone(),
332            model: self.model.clone(),
333            base_url: self.base_url.clone(),
334            client: reqwest::Client::new(),
335            extra_headers: self.extra_headers.clone(),
336        })
337    }
338
339    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
340        let body = self.build_body(&request);
341
342        let url = format!("{}/v1/messages", self.base_url);
343
344        // Debug: log request
345        crate::debug::debug_log()
346            .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
347
348        let mut req = self
349            .client
350            .post(&url)
351            .header("User-Agent", "curl/8.0")
352            .json(&body);
353
354        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
355        if self.is_official_anthropic() {
356            req = req
357                .header("x-api-key", &self.api_key)
358                .header("anthropic-version", ANTHROPIC_API_VERSION)
359                .header("anthropic-beta", "prompt-caching-2024-07-31");
360        } else {
361            req = req
362                .header("Authorization", format!("Bearer {}", self.api_key))
363                // Add anthropic-version for third-party APIs to ensure correct protocol behavior
364                .header("anthropic-version", "2023-06-01")
365                // Try enabling prompt caching for third-party APIs (DashScope may support this)
366                .header("anthropic-beta", "prompt-caching-2024-07-31");
367        }
368
369        // Add extra headers from config (all custom headers go here)
370        for (name, value) in &self.extra_headers {
371            req = req.header(name, value);
372        }
373
374        let response = req.send().await?;
375
376        let status = response.status();
377        let response_body: Value = response.json().await?;
378
379        // Debug: log response
380        crate::debug::debug_log().api_response(
381            status.as_u16(),
382            &serde_json::to_string(&response_body).unwrap_or_default(),
383        );
384
385        if !status.is_success() {
386            let err_msg = response_body["error"]["message"]
387                .as_str()
388                .unwrap_or("unknown error");
389            anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
390        }
391
392        let stop_reason = match response_body["stop_reason"].as_str() {
393            Some("tool_use") => StopReason::ToolUse,
394            Some("max_tokens") => StopReason::MaxTokens,
395            _ => StopReason::EndTurn,
396        };
397
398        let content = response_body["content"]
399            .as_array()
400            .unwrap_or(&vec![])
401            .iter()
402            .filter_map(|block| match block["type"].as_str()? {
403                "text" => Some(ContentBlock::Text {
404                    text: block["text"].as_str()?.to_string(),
405                }),
406                "tool_use" => Some(ContentBlock::ToolUse {
407                    id: block["id"].as_str()?.to_string(),
408                    name: block["name"].as_str()?.to_string(),
409                    input: block["input"].clone(),
410                }),
411                "thinking" => Some(ContentBlock::Thinking {
412                    thinking: block["thinking"].as_str()?.to_string(),
413                    signature: block["signature"].as_str().map(String::from),
414                }),
415                "server_tool_use" => Some(ContentBlock::ServerToolUse {
416                    id: block["id"].as_str()?.to_string(),
417                    name: block["name"].as_str()?.to_string(),
418                    input: block["input"].clone(),
419                }),
420                "web_search_tool_result" => {
421                    let tool_use_id = block["tool_use_id"].as_str()?.to_string();
422                    let content = parse_web_search_content(&block["content"]);
423                    Some(ContentBlock::WebSearchResult {
424                        tool_use_id,
425                        content,
426                    })
427                }
428                _ => None,
429            })
430            .collect();
431
432        Ok(ChatResponse {
433            content,
434            stop_reason,
435            usage: parse_usage(&response_body["usage"]),
436        })
437    }
438
439    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
440        let mut body = self.build_body(&request);
441        body["stream"] = json!(true);
442
443        let url = format!("{}/v1/messages", self.base_url);
444
445        // Debug: log streaming request
446        crate::debug::debug_log()
447            .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
448
449        let mut req = self
450            .client
451            .post(&url)
452            .header("User-Agent", "curl/8.0")
453            .json(&body);
454
455        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
456        if self.is_official_anthropic() {
457            req = req
458                .header("x-api-key", &self.api_key)
459                .header("anthropic-version", ANTHROPIC_API_VERSION)
460                .header("anthropic-beta", "prompt-caching-2024-07-31");
461        } else {
462            req = req
463                .header("Authorization", format!("Bearer {}", self.api_key))
464                // Add anthropic-version for third-party APIs to ensure correct protocol behavior
465                .header("anthropic-version", "2023-06-01")
466                // Try enabling prompt caching for third-party APIs (DashScope may support this)
467                .header("anthropic-beta", "prompt-caching-2024-07-31");
468        }
469
470        // Add extra headers from config (all custom headers go here)
471        for (name, value) in &self.extra_headers {
472            req = req.header(name, value);
473        }
474
475        let response = req.send().await?;
476
477        if !response.status().is_success() {
478            let status = response.status();
479            let text = response.text().await.unwrap_or_default();
480            anyhow::bail!("Anthropic API error ({}): {}", status, text);
481        }
482
483        let (tx, rx) = mpsc::channel(64);
484        tokio::spawn(async move {
485            let mut stream = response.bytes_stream();
486            let mut buffer = String::new();
487            let mut sent_first_byte = false;
488
489            // In-flight block assembly: index → partial data
490            let mut blocks: Vec<AssembledBlock> = Vec::new();
491            let mut stop_reason = StopReason::EndTurn;
492            let mut usage = Usage::default();
493
494            // Timeout detection: track last meaningful event (non-ping)
495            let mut last_content_time = std::time::Instant::now();
496            const CONTENT_TIMEOUT_SECS: u64 = DEFAULT_CONTENT_TIMEOUT_SECS; // 5 minutes without content = timeout (for slow APIs like DashScope/glm-5)
497
498            while let Some(chunk) = stream.next().await {
499                let chunk = match chunk {
500                    Ok(c) => c,
501                    Err(e) => {
502                        // Log detailed error for debugging
503                        let error_msg = e.to_string();
504                        let is_timeout =
505                            error_msg.contains("timeout") || error_msg.contains("timed out");
506                        let is_decode =
507                            error_msg.contains("decode") || error_msg.contains("decoding");
508
509                        let msg = if is_timeout {
510                            format!(
511                                "Stream timeout - the API took too long to respond: {}",
512                                error_msg
513                            )
514                        } else if is_decode {
515                            format!(
516                                "Stream decode error - possibly interrupted or corrupted response: {}",
517                                error_msg
518                            )
519                        } else {
520                            format!("Stream read error: {}", error_msg)
521                        };
522
523                        // Try to finalize any partial content we have
524                        if sent_first_byte && !blocks.is_empty() {
525                            debug!("finalizing partial stream due to error");
526                            let _ = tx
527                                .send(StreamEvent::Done(finalize_incomplete_stream(
528                                    std::mem::take(&mut blocks),
529                                    stop_reason,
530                                    usage,
531                                )))
532                                .await;
533                        } else {
534                            let _ = tx.send(StreamEvent::Error(msg)).await;
535                        }
536                        return;
537                    }
538                };
539
540                if !sent_first_byte {
541                    sent_first_byte = true;
542                    let _ = tx.send(StreamEvent::FirstByte).await;
543                }
544
545                buffer.push_str(&String::from_utf8_lossy(&chunk));
546
547                // Check for timeout: if only ping events for too long, force finalize
548                let elapsed = last_content_time.elapsed().as_secs();
549                if elapsed > CONTENT_TIMEOUT_SECS && !blocks.is_empty() {
550                    crate::debug::debug_log().stream_chunk(
551                        "TIMEOUT_FORCE_FINALIZE",
552                        &format!("elapsed={}s, blocks={}", elapsed, blocks.len()),
553                    );
554                    let _ = tx
555                        .send(StreamEvent::Done(finalize_incomplete_stream(
556                            std::mem::take(&mut blocks),
557                            stop_reason,
558                            usage,
559                        )))
560                        .await;
561                    return;
562                }
563
564                while let Some(frame) = take_next_sse_frame(&mut buffer) {
565                    if handle_sse_frame(
566                        &frame,
567                        &mut blocks,
568                        &mut stop_reason,
569                        &mut usage,
570                        &tx,
571                        &mut last_content_time,
572                    )
573                    .await
574                    {
575                        return;
576                    }
577                }
578            }
579
580            if let Some(frame) = take_trailing_sse_frame(&mut buffer)
581                && handle_sse_frame(
582                    &frame,
583                    &mut blocks,
584                    &mut stop_reason,
585                    &mut usage,
586                    &tx,
587                    &mut last_content_time,
588                )
589                .await
590            {
591                return;
592            }
593
594            if sent_first_byte {
595                debug!("stream ended without explicit message_stop; finalizing best-effort");
596                let _ = tx
597                    .send(StreamEvent::Done(finalize_incomplete_stream(
598                        std::mem::take(&mut blocks),
599                        stop_reason,
600                        usage,
601                    )))
602                    .await;
603            } else {
604                let _ = tx
605                    .send(StreamEvent::Error(
606                        "stream ended before any events were received".to_string(),
607                    ))
608                    .await;
609            }
610        });
611
612        Ok(rx)
613    }
614}
615
616fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
617    let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
618    let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
619    let (pos, delim_len) = match (lf, crlf) {
620        (Some(a), Some(b)) => {
621            if a.0 <= b.0 {
622                a
623            } else {
624                b
625            }
626        }
627        (Some(a), None) => a,
628        (None, Some(b)) => b,
629        (None, None) => return None,
630    };
631
632    let frame = buffer[..pos].to_string();
633    buffer.drain(..pos + delim_len);
634    Some(frame)
635}
636
637fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
638    let frame = buffer.trim().trim_end_matches('\r').to_string();
639    buffer.clear();
640    if frame.is_empty() { None } else { Some(frame) }
641}
642
643fn extract_sse_data_line(frame: &str) -> Option<String> {
644    for line in frame.lines() {
645        let line = line.trim_end_matches('\r');
646        // Support both "data: " (Anthropic) and "data:" (DashScope)
647        if let Some(rest) = line.strip_prefix("data: ") {
648            return Some(rest.to_string());
649        }
650        if let Some(rest) = line.strip_prefix("data:") {
651            return Some(rest.to_string());
652        }
653    }
654    None
655}
656
657async fn handle_sse_frame(
658    frame: &str,
659    blocks: &mut Vec<AssembledBlock>,
660    stop_reason: &mut StopReason,
661    usage: &mut Usage,
662    tx: &mpsc::Sender<StreamEvent>,
663    last_content_time: &mut std::time::Instant,
664) -> bool {
665    let Some(data_line) = extract_sse_data_line(frame) else {
666        return false;
667    };
668
669    let evt: Value = match serde_json::from_str(&data_line) {
670        Ok(v) => v,
671        Err(_) => return false,
672    };
673
674    handle_sse_event(evt, blocks, stop_reason, usage, tx, last_content_time).await
675}
676
677async fn handle_sse_event(
678    evt: Value,
679    blocks: &mut Vec<AssembledBlock>,
680    stop_reason: &mut StopReason,
681    usage: &mut Usage,
682    tx: &mpsc::Sender<StreamEvent>,
683    last_content_time: &mut std::time::Instant,
684) -> bool {
685    let evt_type = evt["type"].as_str().unwrap_or("");
686
687    // Debug: log all SSE events for diagnosis (with full content for debugging)
688    let evt_json = serde_json::to_string(&evt).unwrap_or_default();
689    crate::debug::debug_log().stream_chunk(evt_type, &evt_json);
690
691    // Log event handling for thinking_delta specifically
692    if evt_type == "content_block_delta" {
693        let delta_type = evt["delta"]["type"].as_str().unwrap_or("");
694        let idx = evt["index"].as_u64().unwrap_or(0) as usize;
695        log::debug!(
696            "content_block_delta: type={}, idx={}, blocks_len={}, has_block={}",
697            delta_type,
698            idx,
699            blocks.len(),
700            idx < blocks.len()
701        );
702    }
703
704    // Update last_content_time for non-ping events
705    if evt_type != "ping" {
706        *last_content_time = std::time::Instant::now();
707    }
708
709    match evt_type {
710        "message_start" => {
711            // Initial usage payload — `input_tokens` is final
712            // (they don't grow during streaming) but
713            // `output_tokens` starts near 0 and is updated by
714            // subsequent `message_delta` events.
715            *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
716            debug!(
717                "message_start: usage_json={}",
718                serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
719            );
720            debug!(
721                "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
722                usage.input_tokens,
723                usage.output_tokens,
724                usage.cache_read_input_tokens,
725                usage.cache_creation_input_tokens
726            );
727            // Send real-time usage update
728            let _ = tx
729                .send(StreamEvent::Usage {
730                    output_tokens: usage.output_tokens,
731                })
732                .await;
733        }
734        "content_block_start" => {
735            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
736            let block = &evt["content_block"];
737            let kind = block["type"].as_str().unwrap_or("");
738            while blocks.len() <= idx {
739                blocks.push(AssembledBlock::default());
740            }
741            match kind {
742                "text" => {
743                    blocks[idx] = AssembledBlock::Text(String::new());
744                }
745                "thinking" => {
746                    blocks[idx] = AssembledBlock::Thinking {
747                        text: String::new(),
748                        signature: None,
749                    };
750                }
751                "tool_use" | "server_tool_use" => {
752                    let id = block["id"].as_str().unwrap_or_default();
753                    let name = block["name"].as_str().unwrap_or_default();
754                    let is_server = kind == "server_tool_use";
755                    blocks[idx] = if is_server {
756                        AssembledBlock::ServerToolUse {
757                            id: id.into(),
758                            name: name.into(),
759                            input_json: String::new(),
760                        }
761                    } else {
762                        AssembledBlock::ToolUse {
763                            id: id.into(),
764                            name: name.into(),
765                            input_json: String::new(),
766                        }
767                    };
768                    let _ = tx
769                        .send(StreamEvent::ToolUseStart {
770                            id: id.into(),
771                            name: name.into(),
772                        })
773                        .await;
774                }
775                "web_search_tool_result" => {
776                    let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
777                    blocks[idx] = AssembledBlock::WebSearchResult {
778                        tool_use_id,
779                        content_json: String::new(),
780                    };
781                }
782                _ => {}
783            }
784        }
785        "content_block_delta" => {
786            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
787            let delta = &evt["delta"];
788            let dt = delta["type"].as_str().unwrap_or("");
789            if idx >= blocks.len() {
790                return false;
791            }
792            match (dt, &mut blocks[idx]) {
793                ("text_delta", AssembledBlock::Text(buf)) => {
794                    if let Some(t) = delta["text"].as_str() {
795                        buf.push_str(t);
796                        let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
797                    }
798                }
799                ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
800                    if let Some(t) = delta["thinking"].as_str() {
801                        text.push_str(t);
802                        log::debug!("Received thinking_delta: {} chars", t.len());
803                        let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
804                    }
805                }
806                ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
807                    if let Some(s) = delta["signature"].as_str() {
808                        signature.get_or_insert_with(String::new).push_str(s);
809                    }
810                }
811                ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
812                    if let Some(p) = delta["partial_json"].as_str() {
813                        input_json.push_str(p);
814                        let _ = tx
815                            .send(StreamEvent::ToolInputDelta {
816                                bytes_so_far: input_json.len(),
817                            })
818                            .await;
819                    }
820                }
821                ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
822                    if let Some(p) = delta["partial_json"].as_str() {
823                        input_json.push_str(p);
824                        let _ = tx
825                            .send(StreamEvent::ToolInputDelta {
826                                bytes_so_far: input_json.len(),
827                            })
828                            .await;
829                    }
830                }
831                _ => {}
832            }
833        }
834        "content_block_stop" => {
835            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
836            if let Some(AssembledBlock::ToolUse {
837                id,
838                name,
839                input_json,
840            }) = blocks.get(idx)
841            {
842                let input: Value = if input_json.is_empty() {
843                    json!({})
844                } else {
845                    serde_json::from_str(input_json).unwrap_or(json!({}))
846                };
847                let _ = tx
848                    .send(StreamEvent::ToolInputComplete {
849                        id: id.clone(),
850                        name: name.clone(),
851                        input,
852                    })
853                    .await;
854            }
855        }
856        "message_delta" => {
857            if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
858                *stop_reason = match sr {
859                    "tool_use" => StopReason::ToolUse,
860                    "max_tokens" => StopReason::MaxTokens,
861                    _ => StopReason::EndTurn,
862                };
863            }
864            // `usage` on message_delta reflects cumulative
865            // counts for the current message — most notably
866            // the final `output_tokens`.
867            *usage = merge_usage(usage.clone(), &evt["usage"]);
868            debug!(
869                "message_delta: input={}, output={}, cache_read={}, cache_created={}",
870                usage.input_tokens,
871                usage.output_tokens,
872                usage.cache_read_input_tokens,
873                usage.cache_creation_input_tokens
874            );
875            // Send real-time usage update
876            let _ = tx
877                .send(StreamEvent::Usage {
878                    output_tokens: usage.output_tokens,
879                })
880                .await;
881        }
882        "message_stop" => {
883            debug!(
884                "Message completed: stop_reason={}, usage={}",
885                match *stop_reason {
886                    StopReason::EndTurn => "end_turn",
887                    StopReason::ToolUse => "tool_use",
888                    StopReason::MaxTokens => "max_tokens",
889                },
890                usage.output_tokens
891            );
892            debug!(
893                "message_stop final usage: cache_read={}, cache_created={}",
894                usage.cache_read_input_tokens, usage.cache_creation_input_tokens
895            );
896            let _ = tx
897                .send(StreamEvent::Done(finalize_incomplete_stream(
898                    std::mem::take(blocks),
899                    stop_reason.clone(),
900                    usage.clone(),
901                )))
902                .await;
903            return true;
904        }
905        "error" => {
906            let msg = evt["error"]["message"]
907                .as_str()
908                .unwrap_or("unknown stream error")
909                .to_string();
910            let _ = tx.send(StreamEvent::Error(msg)).await;
911            return true;
912        }
913        _ => {}
914    }
915
916    false
917}
918
919fn build_chat_response(
920    blocks: Vec<AssembledBlock>,
921    stop_reason: StopReason,
922    usage: Usage,
923) -> ChatResponse {
924    let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
925    ChatResponse {
926        content,
927        stop_reason,
928        usage,
929    }
930}
931
932fn finalize_incomplete_stream(
933    blocks: Vec<AssembledBlock>,
934    stop_reason: StopReason,
935    usage: Usage,
936) -> ChatResponse {
937    build_chat_response(blocks, stop_reason, usage)
938}
939
940#[derive(Default)]
941enum AssembledBlock {
942    #[default]
943    Empty,
944    Text(String),
945    Thinking {
946        text: String,
947        signature: Option<String>,
948    },
949    ToolUse {
950        id: String,
951        name: String,
952        input_json: String,
953    },
954    ServerToolUse {
955        id: String,
956        name: String,
957        input_json: String,
958    },
959    WebSearchResult {
960        tool_use_id: String,
961        content_json: String,
962    },
963}
964
965impl AssembledBlock {
966    fn finish(self) -> Option<ContentBlock> {
967        match self {
968            AssembledBlock::Empty => None,
969            AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
970            AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
971                thinking: text,
972                signature,
973            }),
974            AssembledBlock::ToolUse {
975                id,
976                name,
977                input_json,
978            } => {
979                let input: Value = if input_json.is_empty() {
980                    json!({})
981                } else {
982                    serde_json::from_str(&input_json).unwrap_or(json!({}))
983                };
984                Some(ContentBlock::ToolUse { id, name, input })
985            }
986            AssembledBlock::ServerToolUse {
987                id,
988                name,
989                input_json,
990            } => {
991                let input: Value = if input_json.is_empty() {
992                    json!({})
993                } else {
994                    serde_json::from_str(&input_json).unwrap_or(json!({}))
995                };
996                Some(ContentBlock::ServerToolUse { id, name, input })
997            }
998            AssembledBlock::WebSearchResult {
999                tool_use_id,
1000                content_json,
1001            } => {
1002                let content: Value = if content_json.is_empty() {
1003                    json!({"results": []})
1004                } else {
1005                    serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
1006                };
1007                Some(ContentBlock::WebSearchResult {
1008                    tool_use_id,
1009                    content: parse_web_search_content(&content),
1010                })
1011            }
1012        }
1013    }
1014}
1015
1016/// Parse the provider's `usage` blob (non-streaming response) into our
1017/// internal `Usage` struct. Missing fields default to 0.
1018fn parse_usage(u: &Value) -> Usage {
1019    Usage {
1020        input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
1021        output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
1022        cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
1023        cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
1024    }
1025}
1026
1027/// Merge a fresh usage payload into the accumulated one. Non-zero new values
1028/// override prior ones — this matches the streaming protocol where
1029/// `message_start` gives input counts and `message_delta` updates output.
1030fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
1031    let fresh = parse_usage(u);
1032    if fresh.input_tokens > 0 {
1033        acc.input_tokens = fresh.input_tokens;
1034    }
1035    if fresh.output_tokens > 0 {
1036        acc.output_tokens = fresh.output_tokens;
1037    }
1038    if fresh.cache_creation_input_tokens > 0 {
1039        acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
1040    }
1041    if fresh.cache_read_input_tokens > 0 {
1042        acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
1043    }
1044    acc
1045}
1046
1047/// Parse web search content from the API response.
1048fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
1049    let results = value["results"]
1050        .as_array()
1051        .map(|arr| {
1052            arr.iter()
1053                .filter_map(|item| {
1054                    Some(crate::providers::WebSearchResultItem {
1055                        title: item["title"].as_str().map(String::from),
1056                        url: item["url"].as_str()?.to_string(),
1057                        encrypted_content: item["encrypted_content"].as_str().map(String::from),
1058                        snippet: item["snippet"].as_str().map(String::from),
1059                    })
1060                })
1061                .collect()
1062        })
1063        .unwrap_or_default();
1064
1065    crate::providers::WebSearchContent { results }
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070    use super::*;
1071
1072    #[test]
1073    fn take_next_sse_frame_supports_crlf_delimiter() {
1074        let mut buffer = concat!(
1075            "event: message_start\r\n",
1076            "data: {\"type\":\"message_start\"}\r\n\r\n",
1077            "data: {\"type\":\"message_stop\"}\r\n\r\n"
1078        )
1079        .to_string();
1080
1081        let first = take_next_sse_frame(&mut buffer).expect("first frame");
1082        assert!(first.contains("message_start"));
1083
1084        let second = take_next_sse_frame(&mut buffer).expect("second frame");
1085        assert!(second.contains("message_stop"));
1086        assert!(buffer.is_empty());
1087    }
1088
1089    #[test]
1090    fn take_trailing_sse_frame_returns_unterminated_event() {
1091        let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
1092        let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
1093        assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
1094        assert!(buffer.is_empty());
1095    }
1096
1097    #[test]
1098    fn extract_sse_data_line_supports_optional_space() {
1099        assert_eq!(
1100            extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
1101            Some("{\"k\":1}".to_string())
1102        );
1103        assert_eq!(
1104            extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
1105            Some("{\"k\":2}".to_string())
1106        );
1107    }
1108
1109    #[test]
1110    fn finalize_incomplete_stream_keeps_accumulated_content() {
1111        let response = finalize_incomplete_stream(
1112            vec![AssembledBlock::Text("partial reply".to_string())],
1113            StopReason::EndTurn,
1114            Usage::default(),
1115        );
1116
1117        assert_eq!(response.stop_reason, StopReason::EndTurn);
1118        assert_eq!(response.content.len(), 1);
1119        match &response.content[0] {
1120            ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
1121            other => panic!("unexpected block: {other:?}"),
1122        }
1123    }
1124}