Skip to main content

matrixcode_core/providers/
anthropic.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use std::sync::Arc;
7use tokio::sync::mpsc;
8
9use crate::constants::{
10    ANTHROPIC_API_VERSION, DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_CONTENT_TIMEOUT_SECS,
11    DEFAULT_READ_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS, THINKING_BUDGET_NEW_MODELS,
12    THINKING_BUDGET_OLD_MODELS,
13};
14use crate::models::context_window_for;
15use crate::tools::ToolDefinition;
16
17use super::{
18    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
19    StreamEvent, Usage,
20};
21
22pub struct AnthropicProvider {
23    api_key: String,
24    model: String,
25    base_url: String,
26    client: reqwest::Client,
27    /// Extra headers from config
28    extra_headers: Vec<(String, String)>,
29}
30
31impl AnthropicProvider {
32    pub fn new(api_key: String, model: String, base_url: String) -> Self {
33        Self::with_headers(api_key, model, base_url, None)
34    }
35
36    /// Load proxy from environment variables
37    fn load_proxy_from_env() -> Option<String> {
38        std::env::var("HTTPS_PROXY")
39            .ok()
40            .or_else(|| std::env::var("https_proxy").ok())
41            .or_else(|| std::env::var("HTTP_PROXY").ok())
42            .or_else(|| std::env::var("http_proxy").ok())
43    }
44
45    pub fn with_headers(
46        api_key: String,
47        model: String,
48        base_url: String,
49        extra_headers: Option<std::collections::HashMap<String, String>>,
50    ) -> Self {
51        // Create client with better timeout handling for streaming
52        // - No total timeout (streaming responses can take a long time)
53        // - Connect timeout: DEFAULT_CONNECT_TIMEOUT_SECS seconds
54        // - Read timeout per chunk: DEFAULT_READ_TIMEOUT_SECS seconds (for slow responses between chunks)
55        let mut client_builder = reqwest::Client::builder()
56            .connect_timeout(std::time::Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
57            .read_timeout(std::time::Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
58            .timeout(std::time::Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)); // Total timeout for non-streaming
59
60        // Add proxy from environment if available
61        if let Some(proxy_url) = Self::load_proxy_from_env()
62            && let Ok(proxy) = reqwest::Proxy::all(&proxy_url)
63        {
64            log::info!("AnthropicProvider using proxy: {}", proxy_url);
65            client_builder = client_builder.proxy(proxy);
66        }
67
68        let client = client_builder
69            .build()
70            .unwrap_or_else(|_| reqwest::Client::new());
71        let extra_headers: Vec<(String, String)> = extra_headers
72            .map(|h| h.into_iter().collect())
73            .unwrap_or_default();
74        Self {
75            api_key,
76            model,
77            base_url,
78            client,
79            extra_headers,
80        }
81    }
82
83    /// Check if this is the official Anthropic API endpoint.
84    /// Non-official endpoints typically use Bearer auth (OpenAI-compatible).
85    fn is_official_anthropic(&self) -> bool {
86        self.base_url.contains("api.anthropic.com")
87    }
88
89    fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
90        // For non-official APIs (like glm-5/DashScope), filter out thinking blocks
91        // to prevent API from entering extended thinking mode and hanging
92        // This is necessary because some APIs don't properly support thinking blocks
93        let filter_thinking = !self.is_official_anthropic();
94        log::debug!(
95            "convert_messages: filter_thinking={}, base_url={}",
96            filter_thinking,
97            self.base_url
98        );
99
100        // Count thinking blocks in original messages
101        let mut thinking_count = 0;
102        for m in messages {
103            if let MessageContent::Blocks(blocks) = &m.content {
104                for b in blocks {
105                    if matches!(b, ContentBlock::Thinking { .. }) {
106                        thinking_count += 1;
107                    }
108                }
109            }
110        }
111        if thinking_count > 0 {
112            log::debug!(
113                "convert_messages: Found {} thinking blocks in {} messages, filter_thinking={}",
114                thinking_count,
115                messages.len(),
116                filter_thinking
117            );
118        }
119
120        messages
121            .iter()
122            .filter(|m| m.role != Role::System)
123            .map(|m| {
124                let role = match m.role {
125                    Role::User | Role::Tool => "user",
126                    Role::Assistant => "assistant",
127                    Role::System => unreachable!(),
128                };
129
130                let content = match &m.content {
131                    MessageContent::Text(text) => json!(text),
132                    MessageContent::Blocks(blocks) => {
133                        let converted: Vec<Value> = blocks
134                            .iter()
135                            .filter(|b| {
136                                // Skip thinking blocks for non-official APIs
137                                if filter_thinking && matches!(b, ContentBlock::Thinking { .. }) {
138                                    return false;
139                                }
140                                true
141                            })
142                            .map(|b| match b {
143                                ContentBlock::Text { text } => json!({"type": "text", "text": text}),
144                                ContentBlock::ToolUse { id, name, input } => {
145                                    json!({"type": "tool_use", "id": id, "name": name, "input": input})
146                                }
147                                ContentBlock::ToolResult { tool_use_id, content } => {
148                                    json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
149                                }
150                                ContentBlock::Thinking { thinking, signature } => {
151                                    let mut obj = json!({"type": "thinking", "thinking": thinking});
152                                    // Only send non-empty signature
153                                    if let Some(sig) = signature
154                                        && !sig.is_empty() {
155                                            obj["signature"] = json!(sig);
156                                        }
157                                    obj
158                                }
159                                ContentBlock::ServerToolUse { id, name, input } => {
160                                    json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
161                                }
162                                ContentBlock::WebSearchResult { tool_use_id, content } => {
163                                    json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
164                                }
165                            })
166                            .collect();
167
168                        if converted.is_empty() {
169                            json!([])
170                        } else {
171                            json!(converted)
172                        }
173                    }
174                };
175
176                // Skip messages with empty content
177                if content.is_array() && content.as_array().map(|a| a.is_empty()).unwrap_or(false) {
178                    return json!(null);
179                }
180
181                json!({"role": role, "content": content})
182            })
183            .filter(|v| !v.is_null())
184            .collect()
185    }
186
187    /// Convert tools with caching control for Anthropic prompt caching.
188    fn convert_tools_with_caching(
189        &self,
190        tools: &[ToolDefinition],
191        enable_caching: bool,
192    ) -> Vec<Value> {
193        let mut converted: Vec<Value> = tools
194            .iter()
195            .map(|t| {
196                json!({
197                    "name": t.name,
198                    "description": t.description,
199                    "input_schema": t.parameters,
200                })
201            })
202            .collect();
203
204        // Add cache_control to the last tool definition for tools caching
205        if enable_caching && !converted.is_empty() {
206            let last_idx = converted.len() - 1;
207            if let Some(obj) = converted[last_idx].as_object_mut() {
208                obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
209            }
210        }
211
212        converted
213    }
214
215    /// Build the base JSON body shared by streaming and non-streaming requests.
216    fn build_body(&self, request: &ChatRequest) -> Value {
217        let mut body = json!({
218            "model": self.model,
219            "max_tokens": request.max_tokens,
220            "messages": self.convert_messages(&request.messages),
221        });
222
223        // Add prompt caching for system prompt (Anthropic-specific)
224        if request.enable_caching {
225            if let Some(system) = &request.system {
226                // System prompt caching: add cache_control to enable caching
227                body["system"] = json!([
228                    {
229                        "type": "text",
230                        "text": system,
231                        "cache_control": {"type": "ephemeral"}
232                    }
233                ]);
234            }
235        } else if let Some(system) = &request.system {
236            body["system"] = json!(system);
237        }
238
239        if !request.tools.is_empty() {
240            let tools = self.convert_tools_with_caching(&request.tools, request.enable_caching);
241            body["tools"] = json!(tools);
242        }
243
244        if !request.server_tools.is_empty() {
245            body["tools"] = json!(
246                body["tools"]
247                    .as_array()
248                    .map(|t| {
249                        let mut tools = t.clone();
250                        for st in &request.server_tools {
251                            tools.push(serde_json::to_value(st).unwrap_or_default());
252                        }
253                        tools
254                    })
255                    .unwrap_or_else(|| request
256                        .server_tools
257                        .iter()
258                        .map(|st| serde_json::to_value(st).unwrap_or_default())
259                        .collect())
260            );
261        }
262
263        // Extended thinking (Anthropic-specific)
264        // Only enable thinking for official Anthropic API - third-party APIs like DashScope
265        // forcibly enable thinking mode which causes hanging issues
266        if request.think && self.is_official_anthropic() {
267            let config = thinking_config(&self.model);
268            log::debug!(
269                "Adding thinking config for model {}: {:?}",
270                self.model,
271                config
272            );
273            body["thinking"] = config;
274        } else if request.think {
275            log::debug!(
276                "Skipping thinking config for non-official API (model: {}, base_url: {})",
277                self.model,
278                self.base_url
279            );
280        }
281
282        body
283    }
284}
285
286/// Models that require the new `adaptive` thinking mode instead of the
287/// legacy `enabled`+`budget_tokens` form. Conservative allow-list: if we
288/// don't recognize the name we default to the legacy shape (which older
289/// models and most third-party gateways understand).
290fn thinking_config(model: &str) -> Value {
291    let m = model.to_lowercase();
292    // New models (2025+) use adaptive thinking
293    let adaptive = m.contains("opus-4")
294        || m.contains("sonnet-4")
295        || m.contains("claude-4")
296        || m.contains("20250")
297        || m.contains("2025");
298    if adaptive {
299        json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_NEW_MODELS})
300    } else {
301        // to prevent hanging on long histories
302        json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_OLD_MODELS})
303    }
304}
305
306#[async_trait]
307impl Provider for AnthropicProvider {
308    fn context_size(&self) -> Option<u32> {
309        context_window_for(&self.model)
310    }
311
312    fn model_name(&self) -> &str {
313        &self.model
314    }
315
316    fn clone_box(&self) -> Box<dyn Provider> {
317        Box::new(Self {
318            api_key: self.api_key.clone(),
319            model: self.model.clone(),
320            base_url: self.base_url.clone(),
321            client: reqwest::Client::new(),
322            extra_headers: self.extra_headers.clone(),
323        })
324    }
325
326    fn clone_arc(&self) -> Arc<dyn Provider> {
327        Arc::new(Self {
328            api_key: self.api_key.clone(),
329            model: self.model.clone(),
330            base_url: self.base_url.clone(),
331            client: reqwest::Client::new(),
332            extra_headers: self.extra_headers.clone(),
333        })
334    }
335
336    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
337        let body = self.build_body(&request);
338
339        let url = format!("{}/v1/messages", self.base_url);
340
341        // Debug: log request
342        crate::debug::debug_log()
343            .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
344
345        let mut req = self
346            .client
347            .post(&url)
348            .header("User-Agent", "curl/8.0")
349            .json(&body);
350
351        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
352        if self.is_official_anthropic() {
353            req = req
354                .header("x-api-key", &self.api_key)
355                .header("anthropic-version", ANTHROPIC_API_VERSION)
356                .header("anthropic-beta", "prompt-caching-2024-07-31");
357        } else {
358            req = req
359                .header("Authorization", format!("Bearer {}", self.api_key))
360                // Add anthropic-version for third-party APIs to ensure correct protocol behavior
361                .header("anthropic-version", "2023-06-01")
362                // Try enabling prompt caching for third-party APIs (DashScope may support this)
363                .header("anthropic-beta", "prompt-caching-2024-07-31");
364        }
365
366        // Add extra headers from config (all custom headers go here)
367        for (name, value) in &self.extra_headers {
368            req = req.header(name, value);
369        }
370
371        let response = req.send().await?;
372
373        let status = response.status();
374        let response_body: Value = response.json().await?;
375
376        // Debug: log response
377        crate::debug::debug_log().api_response(
378            status.as_u16(),
379            &serde_json::to_string(&response_body).unwrap_or_default(),
380        );
381
382        if !status.is_success() {
383            let err_msg = response_body["error"]["message"]
384                .as_str()
385                .unwrap_or("unknown error");
386            anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
387        }
388
389        let stop_reason = match response_body["stop_reason"].as_str() {
390            Some("tool_use") => StopReason::ToolUse,
391            Some("max_tokens") => StopReason::MaxTokens,
392            _ => StopReason::EndTurn,
393        };
394
395        let content = response_body["content"]
396            .as_array()
397            .unwrap_or(&vec![])
398            .iter()
399            .filter_map(|block| match block["type"].as_str()? {
400                "text" => Some(ContentBlock::Text {
401                    text: block["text"].as_str()?.to_string(),
402                }),
403                "tool_use" => Some(ContentBlock::ToolUse {
404                    id: block["id"].as_str()?.to_string(),
405                    name: block["name"].as_str()?.to_string(),
406                    input: block["input"].clone(),
407                }),
408                "thinking" => Some(ContentBlock::Thinking {
409                    thinking: block["thinking"].as_str()?.to_string(),
410                    signature: block["signature"].as_str().map(String::from),
411                }),
412                "server_tool_use" => Some(ContentBlock::ServerToolUse {
413                    id: block["id"].as_str()?.to_string(),
414                    name: block["name"].as_str()?.to_string(),
415                    input: block["input"].clone(),
416                }),
417                "web_search_tool_result" => {
418                    let tool_use_id = block["tool_use_id"].as_str()?.to_string();
419                    let content = parse_web_search_content(&block["content"]);
420                    Some(ContentBlock::WebSearchResult {
421                        tool_use_id,
422                        content,
423                    })
424                }
425                _ => None,
426            })
427            .collect();
428
429        Ok(ChatResponse {
430            content,
431            stop_reason,
432            usage: parse_usage(&response_body["usage"]),
433        })
434    }
435
436    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
437        let mut body = self.build_body(&request);
438        body["stream"] = json!(true);
439
440        let url = format!("{}/v1/messages", self.base_url);
441
442        // Debug: log streaming request
443        crate::debug::debug_log()
444            .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
445
446        let mut req = self
447            .client
448            .post(&url)
449            .header("User-Agent", "curl/8.0")
450            .json(&body);
451
452        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
453        if self.is_official_anthropic() {
454            req = req
455                .header("x-api-key", &self.api_key)
456                .header("anthropic-version", ANTHROPIC_API_VERSION)
457                .header("anthropic-beta", "prompt-caching-2024-07-31");
458        } else {
459            req = req
460                .header("Authorization", format!("Bearer {}", self.api_key))
461                // Add anthropic-version for third-party APIs to ensure correct protocol behavior
462                .header("anthropic-version", "2023-06-01")
463                // Try enabling prompt caching for third-party APIs (DashScope may support this)
464                .header("anthropic-beta", "prompt-caching-2024-07-31");
465        }
466
467        // Add extra headers from config (all custom headers go here)
468        for (name, value) in &self.extra_headers {
469            req = req.header(name, value);
470        }
471
472        let response = req.send().await?;
473
474        if !response.status().is_success() {
475            let status = response.status();
476            let text = response.text().await.unwrap_or_default();
477            anyhow::bail!("Anthropic API error ({}): {}", status, text);
478        }
479
480        let (tx, rx) = mpsc::channel(64);
481        tokio::spawn(async move {
482            let mut stream = response.bytes_stream();
483            let mut buffer = String::new();
484            let mut sent_first_byte = false;
485
486            // In-flight block assembly: index → partial data
487            let mut blocks: Vec<AssembledBlock> = Vec::new();
488            let mut stop_reason = StopReason::EndTurn;
489            let mut usage = Usage::default();
490
491            // Timeout detection: track last meaningful event (non-ping)
492            let mut last_content_time = std::time::Instant::now();
493            const CONTENT_TIMEOUT_SECS: u64 = DEFAULT_CONTENT_TIMEOUT_SECS; // 5 minutes without content = timeout (for slow APIs like DashScope/glm-5)
494
495            while let Some(chunk) = stream.next().await {
496                let chunk = match chunk {
497                    Ok(c) => c,
498                    Err(e) => {
499                        // Log detailed error for debugging
500                        let error_msg = e.to_string();
501                        let is_timeout =
502                            error_msg.contains("timeout") || error_msg.contains("timed out");
503                        let is_decode =
504                            error_msg.contains("decode") || error_msg.contains("decoding");
505
506                        let msg = if is_timeout {
507                            format!(
508                                "Stream timeout - the API took too long to respond: {}",
509                                error_msg
510                            )
511                        } else if is_decode {
512                            format!(
513                                "Stream decode error - possibly interrupted or corrupted response: {}",
514                                error_msg
515                            )
516                        } else {
517                            format!("Stream read error: {}", error_msg)
518                        };
519
520                        // Try to finalize any partial content we have
521                        if sent_first_byte && !blocks.is_empty() {
522                            debug!("finalizing partial stream due to error");
523                            let _ = tx
524                                .send(StreamEvent::Done(finalize_incomplete_stream(
525                                    std::mem::take(&mut blocks),
526                                    stop_reason,
527                                    usage,
528                                )))
529                                .await;
530                        } else {
531                            let _ = tx.send(StreamEvent::Error(msg)).await;
532                        }
533                        return;
534                    }
535                };
536
537                if !sent_first_byte {
538                    sent_first_byte = true;
539                    let _ = tx.send(StreamEvent::FirstByte).await;
540                }
541
542                buffer.push_str(&String::from_utf8_lossy(&chunk));
543
544                // Check for timeout: if only ping events for too long, force finalize
545                let elapsed = last_content_time.elapsed().as_secs();
546                if elapsed > CONTENT_TIMEOUT_SECS && !blocks.is_empty() {
547                    crate::debug::debug_log().stream_chunk(
548                        "TIMEOUT_FORCE_FINALIZE",
549                        &format!("elapsed={}s, blocks={}", elapsed, blocks.len()),
550                    );
551                    let _ = tx
552                        .send(StreamEvent::Done(finalize_incomplete_stream(
553                            std::mem::take(&mut blocks),
554                            stop_reason,
555                            usage,
556                        )))
557                        .await;
558                    return;
559                }
560
561                while let Some(frame) = take_next_sse_frame(&mut buffer) {
562                    if handle_sse_frame(
563                        &frame,
564                        &mut blocks,
565                        &mut stop_reason,
566                        &mut usage,
567                        &tx,
568                        &mut last_content_time,
569                    )
570                    .await
571                    {
572                        return;
573                    }
574                }
575            }
576
577            if let Some(frame) = take_trailing_sse_frame(&mut buffer)
578                && handle_sse_frame(
579                    &frame,
580                    &mut blocks,
581                    &mut stop_reason,
582                    &mut usage,
583                    &tx,
584                    &mut last_content_time,
585                )
586                .await
587            {
588                return;
589            }
590
591            if sent_first_byte {
592                debug!("stream ended without explicit message_stop; finalizing best-effort");
593                let _ = tx
594                    .send(StreamEvent::Done(finalize_incomplete_stream(
595                        std::mem::take(&mut blocks),
596                        stop_reason,
597                        usage,
598                    )))
599                    .await;
600            } else {
601                let _ = tx
602                    .send(StreamEvent::Error(
603                        "stream ended before any events were received".to_string(),
604                    ))
605                    .await;
606            }
607        });
608
609        Ok(rx)
610    }
611}
612
613fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
614    let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
615    let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
616    let (pos, delim_len) = match (lf, crlf) {
617        (Some(a), Some(b)) => {
618            if a.0 <= b.0 {
619                a
620            } else {
621                b
622            }
623        }
624        (Some(a), None) => a,
625        (None, Some(b)) => b,
626        (None, None) => return None,
627    };
628
629    let frame = buffer[..pos].to_string();
630    buffer.drain(..pos + delim_len);
631    Some(frame)
632}
633
634fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
635    let frame = buffer.trim().trim_end_matches('\r').to_string();
636    buffer.clear();
637    if frame.is_empty() { None } else { Some(frame) }
638}
639
640fn extract_sse_data_line(frame: &str) -> Option<String> {
641    for line in frame.lines() {
642        let line = line.trim_end_matches('\r');
643        // Support both "data: " (Anthropic) and "data:" (DashScope)
644        if let Some(rest) = line.strip_prefix("data: ") {
645            return Some(rest.to_string());
646        }
647        if let Some(rest) = line.strip_prefix("data:") {
648            return Some(rest.to_string());
649        }
650    }
651    None
652}
653
654async fn handle_sse_frame(
655    frame: &str,
656    blocks: &mut Vec<AssembledBlock>,
657    stop_reason: &mut StopReason,
658    usage: &mut Usage,
659    tx: &mpsc::Sender<StreamEvent>,
660    last_content_time: &mut std::time::Instant,
661) -> bool {
662    let Some(data_line) = extract_sse_data_line(frame) else {
663        return false;
664    };
665
666    let evt: Value = match serde_json::from_str(&data_line) {
667        Ok(v) => v,
668        Err(_) => return false,
669    };
670
671    handle_sse_event(evt, blocks, stop_reason, usage, tx, last_content_time).await
672}
673
674async fn handle_sse_event(
675    evt: Value,
676    blocks: &mut Vec<AssembledBlock>,
677    stop_reason: &mut StopReason,
678    usage: &mut Usage,
679    tx: &mpsc::Sender<StreamEvent>,
680    last_content_time: &mut std::time::Instant,
681) -> bool {
682    let evt_type = evt["type"].as_str().unwrap_or("");
683
684    // Debug: log all SSE events for diagnosis (with full content for debugging)
685    let evt_json = serde_json::to_string(&evt).unwrap_or_default();
686    crate::debug::debug_log().stream_chunk(evt_type, &evt_json);
687
688    // Log event handling for thinking_delta specifically
689    if evt_type == "content_block_delta" {
690        let delta_type = evt["delta"]["type"].as_str().unwrap_or("");
691        let idx = evt["index"].as_u64().unwrap_or(0) as usize;
692        log::debug!(
693            "content_block_delta: type={}, idx={}, blocks_len={}, has_block={}",
694            delta_type,
695            idx,
696            blocks.len(),
697            idx < blocks.len()
698        );
699    }
700
701    // Update last_content_time for non-ping events
702    if evt_type != "ping" {
703        *last_content_time = std::time::Instant::now();
704    }
705
706    match evt_type {
707        "message_start" => {
708            // Initial usage payload — `input_tokens` is final
709            // (they don't grow during streaming) but
710            // `output_tokens` starts near 0 and is updated by
711            // subsequent `message_delta` events.
712            *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
713            debug!(
714                "message_start: usage_json={}",
715                serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
716            );
717            debug!(
718                "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
719                usage.input_tokens,
720                usage.output_tokens,
721                usage.cache_read_input_tokens,
722                usage.cache_creation_input_tokens
723            );
724            // Send real-time usage update
725            let _ = tx
726                .send(StreamEvent::Usage {
727                    output_tokens: usage.output_tokens,
728                })
729                .await;
730        }
731        "content_block_start" => {
732            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
733            let block = &evt["content_block"];
734            let kind = block["type"].as_str().unwrap_or("");
735            while blocks.len() <= idx {
736                blocks.push(AssembledBlock::default());
737            }
738            match kind {
739                "text" => {
740                    blocks[idx] = AssembledBlock::Text(String::new());
741                }
742                "thinking" => {
743                    blocks[idx] = AssembledBlock::Thinking {
744                        text: String::new(),
745                        signature: None,
746                    };
747                }
748                "tool_use" | "server_tool_use" => {
749                    let id = block["id"].as_str().unwrap_or_default();
750                    let name = block["name"].as_str().unwrap_or_default();
751                    let is_server = kind == "server_tool_use";
752                    blocks[idx] = if is_server {
753                        AssembledBlock::ServerToolUse {
754                            id: id.into(),
755                            name: name.into(),
756                            input_json: String::new(),
757                        }
758                    } else {
759                        AssembledBlock::ToolUse {
760                            id: id.into(),
761                            name: name.into(),
762                            input_json: String::new(),
763                        }
764                    };
765                    let _ = tx
766                        .send(StreamEvent::ToolUseStart {
767                            id: id.into(),
768                            name: name.into(),
769                        })
770                        .await;
771                }
772                "web_search_tool_result" => {
773                    let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
774                    blocks[idx] = AssembledBlock::WebSearchResult {
775                        tool_use_id,
776                        content_json: String::new(),
777                    };
778                }
779                _ => {}
780            }
781        }
782        "content_block_delta" => {
783            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
784            let delta = &evt["delta"];
785            let dt = delta["type"].as_str().unwrap_or("");
786            if idx >= blocks.len() {
787                return false;
788            }
789            match (dt, &mut blocks[idx]) {
790                ("text_delta", AssembledBlock::Text(buf)) => {
791                    if let Some(t) = delta["text"].as_str() {
792                        buf.push_str(t);
793                        let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
794                    }
795                }
796                ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
797                    if let Some(t) = delta["thinking"].as_str() {
798                        text.push_str(t);
799                        log::debug!("Received thinking_delta: {} chars", t.len());
800                        let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
801                    }
802                }
803                ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
804                    if let Some(s) = delta["signature"].as_str() {
805                        signature.get_or_insert_with(String::new).push_str(s);
806                    }
807                }
808                ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
809                    if let Some(p) = delta["partial_json"].as_str() {
810                        input_json.push_str(p);
811                        let _ = tx
812                            .send(StreamEvent::ToolInputDelta {
813                                bytes_so_far: input_json.len(),
814                            })
815                            .await;
816                    }
817                }
818                ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
819                    if let Some(p) = delta["partial_json"].as_str() {
820                        input_json.push_str(p);
821                        let _ = tx
822                            .send(StreamEvent::ToolInputDelta {
823                                bytes_so_far: input_json.len(),
824                            })
825                            .await;
826                    }
827                }
828                _ => {}
829            }
830        }
831        "content_block_stop" => {
832            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
833            if let Some(AssembledBlock::ToolUse {
834                id,
835                name,
836                input_json,
837            }) = blocks.get(idx)
838            {
839                let input: Value = if input_json.is_empty() {
840                    json!({})
841                } else {
842                    serde_json::from_str(input_json).unwrap_or(json!({}))
843                };
844                let _ = tx
845                    .send(StreamEvent::ToolInputComplete {
846                        id: id.clone(),
847                        name: name.clone(),
848                        input,
849                    })
850                    .await;
851            }
852        }
853        "message_delta" => {
854            if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
855                *stop_reason = match sr {
856                    "tool_use" => StopReason::ToolUse,
857                    "max_tokens" => StopReason::MaxTokens,
858                    _ => StopReason::EndTurn,
859                };
860            }
861            // `usage` on message_delta reflects cumulative
862            // counts for the current message — most notably
863            // the final `output_tokens`.
864            *usage = merge_usage(usage.clone(), &evt["usage"]);
865            debug!(
866                "message_delta: input={}, output={}, cache_read={}, cache_created={}",
867                usage.input_tokens,
868                usage.output_tokens,
869                usage.cache_read_input_tokens,
870                usage.cache_creation_input_tokens
871            );
872            // Send real-time usage update
873            let _ = tx
874                .send(StreamEvent::Usage {
875                    output_tokens: usage.output_tokens,
876                })
877                .await;
878        }
879        "message_stop" => {
880            debug!(
881                "Message completed: stop_reason={}, usage={}",
882                match *stop_reason {
883                    StopReason::EndTurn => "end_turn",
884                    StopReason::ToolUse => "tool_use",
885                    StopReason::MaxTokens => "max_tokens",
886                },
887                usage.output_tokens
888            );
889            debug!(
890                "message_stop final usage: cache_read={}, cache_created={}",
891                usage.cache_read_input_tokens, usage.cache_creation_input_tokens
892            );
893            let _ = tx
894                .send(StreamEvent::Done(finalize_incomplete_stream(
895                    std::mem::take(blocks),
896                    stop_reason.clone(),
897                    usage.clone(),
898                )))
899                .await;
900            return true;
901        }
902        "error" => {
903            let msg = evt["error"]["message"]
904                .as_str()
905                .unwrap_or("unknown stream error")
906                .to_string();
907            let _ = tx.send(StreamEvent::Error(msg)).await;
908            return true;
909        }
910        _ => {}
911    }
912
913    false
914}
915
916fn build_chat_response(
917    blocks: Vec<AssembledBlock>,
918    stop_reason: StopReason,
919    usage: Usage,
920) -> ChatResponse {
921    let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
922    ChatResponse {
923        content,
924        stop_reason,
925        usage,
926    }
927}
928
929fn finalize_incomplete_stream(
930    blocks: Vec<AssembledBlock>,
931    stop_reason: StopReason,
932    usage: Usage,
933) -> ChatResponse {
934    build_chat_response(blocks, stop_reason, usage)
935}
936
937#[derive(Default)]
938enum AssembledBlock {
939    #[default]
940    Empty,
941    Text(String),
942    Thinking {
943        text: String,
944        signature: Option<String>,
945    },
946    ToolUse {
947        id: String,
948        name: String,
949        input_json: String,
950    },
951    ServerToolUse {
952        id: String,
953        name: String,
954        input_json: String,
955    },
956    WebSearchResult {
957        tool_use_id: String,
958        content_json: String,
959    },
960}
961
962impl AssembledBlock {
963    fn finish(self) -> Option<ContentBlock> {
964        match self {
965            AssembledBlock::Empty => None,
966            AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
967            AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
968                thinking: text,
969                signature,
970            }),
971            AssembledBlock::ToolUse {
972                id,
973                name,
974                input_json,
975            } => {
976                let input: Value = if input_json.is_empty() {
977                    json!({})
978                } else {
979                    serde_json::from_str(&input_json).unwrap_or(json!({}))
980                };
981                Some(ContentBlock::ToolUse { id, name, input })
982            }
983            AssembledBlock::ServerToolUse {
984                id,
985                name,
986                input_json,
987            } => {
988                let input: Value = if input_json.is_empty() {
989                    json!({})
990                } else {
991                    serde_json::from_str(&input_json).unwrap_or(json!({}))
992                };
993                Some(ContentBlock::ServerToolUse { id, name, input })
994            }
995            AssembledBlock::WebSearchResult {
996                tool_use_id,
997                content_json,
998            } => {
999                let content: Value = if content_json.is_empty() {
1000                    json!({"results": []})
1001                } else {
1002                    serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
1003                };
1004                Some(ContentBlock::WebSearchResult {
1005                    tool_use_id,
1006                    content: parse_web_search_content(&content),
1007                })
1008            }
1009        }
1010    }
1011}
1012
1013/// Parse the provider's `usage` blob (non-streaming response) into our
1014/// internal `Usage` struct. Missing fields default to 0.
1015fn parse_usage(u: &Value) -> Usage {
1016    Usage {
1017        input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
1018        output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
1019        cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
1020        cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
1021    }
1022}
1023
1024/// Merge a fresh usage payload into the accumulated one. Non-zero new values
1025/// override prior ones — this matches the streaming protocol where
1026/// `message_start` gives input counts and `message_delta` updates output.
1027fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
1028    let fresh = parse_usage(u);
1029    if fresh.input_tokens > 0 {
1030        acc.input_tokens = fresh.input_tokens;
1031    }
1032    if fresh.output_tokens > 0 {
1033        acc.output_tokens = fresh.output_tokens;
1034    }
1035    if fresh.cache_creation_input_tokens > 0 {
1036        acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
1037    }
1038    if fresh.cache_read_input_tokens > 0 {
1039        acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
1040    }
1041    acc
1042}
1043
1044/// Parse web search content from the API response.
1045fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
1046    let results = value["results"]
1047        .as_array()
1048        .map(|arr| {
1049            arr.iter()
1050                .filter_map(|item| {
1051                    Some(crate::providers::WebSearchResultItem {
1052                        title: item["title"].as_str().map(String::from),
1053                        url: item["url"].as_str()?.to_string(),
1054                        encrypted_content: item["encrypted_content"].as_str().map(String::from),
1055                        snippet: item["snippet"].as_str().map(String::from),
1056                    })
1057                })
1058                .collect()
1059        })
1060        .unwrap_or_default();
1061
1062    crate::providers::WebSearchContent { results }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067    use super::*;
1068
1069    #[test]
1070    fn take_next_sse_frame_supports_crlf_delimiter() {
1071        let mut buffer = concat!(
1072            "event: message_start\r\n",
1073            "data: {\"type\":\"message_start\"}\r\n\r\n",
1074            "data: {\"type\":\"message_stop\"}\r\n\r\n"
1075        )
1076        .to_string();
1077
1078        let first = take_next_sse_frame(&mut buffer).expect("first frame");
1079        assert!(first.contains("message_start"));
1080
1081        let second = take_next_sse_frame(&mut buffer).expect("second frame");
1082        assert!(second.contains("message_stop"));
1083        assert!(buffer.is_empty());
1084    }
1085
1086    #[test]
1087    fn take_trailing_sse_frame_returns_unterminated_event() {
1088        let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
1089        let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
1090        assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
1091        assert!(buffer.is_empty());
1092    }
1093
1094    #[test]
1095    fn extract_sse_data_line_supports_optional_space() {
1096        assert_eq!(
1097            extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
1098            Some("{\"k\":1}".to_string())
1099        );
1100        assert_eq!(
1101            extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
1102            Some("{\"k\":2}".to_string())
1103        );
1104    }
1105
1106    #[test]
1107    fn finalize_incomplete_stream_keeps_accumulated_content() {
1108        let response = finalize_incomplete_stream(
1109            vec![AssembledBlock::Text("partial reply".to_string())],
1110            StopReason::EndTurn,
1111            Usage::default(),
1112        );
1113
1114        assert_eq!(response.stop_reason, StopReason::EndTurn);
1115        assert_eq!(response.content.len(), 1);
1116        match &response.content[0] {
1117            ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
1118            other => panic!("unexpected block: {other:?}"),
1119        }
1120    }
1121}