Skip to main content

matrixcode_core/providers/
anthropic.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use std::sync::Arc;
7use tokio::sync::mpsc;
8
9use crate::constants::{
10    DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS, DEFAULT_READ_TIMEOUT_SECS,
11    THINKING_BUDGET_NEW_MODELS, THINKING_BUDGET_OLD_MODELS,
12    DEFAULT_CONTENT_TIMEOUT_SECS, ANTHROPIC_API_VERSION,
13};
14use crate::models::context_window_for;
15use crate::tools::ToolDefinition;
16
17use super::{
18    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
19    StreamEvent, Usage,
20};
21
22pub struct AnthropicProvider {
23    api_key: String,
24    model: String,
25    base_url: String,
26    client: reqwest::Client,
27    /// Extra headers from config
28    extra_headers: Vec<(String, String)>,
29}
30
31impl AnthropicProvider {
32    pub fn new(api_key: String, model: String, base_url: String) -> Self {
33        Self::with_headers(api_key, model, base_url, None)
34    }
35
36    /// Load proxy from environment variables
37    fn load_proxy_from_env() -> Option<String> {
38        std::env::var("HTTPS_PROXY")
39            .ok()
40            .or_else(|| std::env::var("https_proxy").ok())
41            .or_else(|| std::env::var("HTTP_PROXY").ok())
42            .or_else(|| std::env::var("http_proxy").ok())
43    }
44
45    pub fn with_headers(
46        api_key: String,
47        model: String,
48        base_url: String,
49        extra_headers: Option<std::collections::HashMap<String, String>>,
50    ) -> Self {
51        // Create client with better timeout handling for streaming
52        // - No total timeout (streaming responses can take a long time)
53        // - Connect timeout: DEFAULT_CONNECT_TIMEOUT_SECS seconds
54        // - Read timeout per chunk: DEFAULT_READ_TIMEOUT_SECS seconds (for slow responses between chunks)
55        let mut client_builder = reqwest::Client::builder()
56            .connect_timeout(std::time::Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
57            .read_timeout(std::time::Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
58            .timeout(std::time::Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)); // Total timeout for non-streaming
59
60        // Add proxy from environment if available
61        if let Some(proxy_url) = Self::load_proxy_from_env()
62            && let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
63                log::info!("AnthropicProvider using proxy: {}", proxy_url);
64                client_builder = client_builder.proxy(proxy);
65            }
66
67        let client = client_builder
68            .build()
69            .unwrap_or_else(|_| reqwest::Client::new());
70        let extra_headers: Vec<(String, String)> = extra_headers
71            .map(|h| h.into_iter().collect())
72            .unwrap_or_default();
73        Self {
74            api_key,
75            model,
76            base_url,
77            client,
78            extra_headers,
79        }
80    }
81
82    /// Check if this is the official Anthropic API endpoint.
83    /// Non-official endpoints typically use Bearer auth (OpenAI-compatible).
84    fn is_official_anthropic(&self) -> bool {
85        self.base_url.contains("api.anthropic.com")
86    }
87
88    fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
89        // For non-official APIs (like glm-5/DashScope), filter out thinking blocks
90        // to prevent API from entering extended thinking mode and hanging
91        // This is necessary because some APIs don't properly support thinking blocks
92        let filter_thinking = !self.is_official_anthropic();
93        log::debug!("convert_messages: filter_thinking={}, base_url={}", filter_thinking, self.base_url);
94
95        // Count thinking blocks in original messages
96        let mut thinking_count = 0;
97        for m in messages {
98            if let MessageContent::Blocks(blocks) = &m.content {
99                for b in blocks {
100                    if matches!(b, ContentBlock::Thinking { .. }) {
101                        thinking_count += 1;
102                    }
103                }
104            }
105        }
106        if thinking_count > 0 {
107            log::debug!("convert_messages: Found {} thinking blocks in {} messages, filter_thinking={}", thinking_count, messages.len(), filter_thinking);
108        }
109
110        messages
111            .iter()
112            .filter(|m| m.role != Role::System)
113            .map(|m| {
114                let role = match m.role {
115                    Role::User | Role::Tool => "user",
116                    Role::Assistant => "assistant",
117                    Role::System => unreachable!(),
118                };
119
120                let content = match &m.content {
121                    MessageContent::Text(text) => json!(text),
122                    MessageContent::Blocks(blocks) => {
123                        let converted: Vec<Value> = blocks
124                            .iter()
125                            .filter(|b| {
126                                // Skip thinking blocks for non-official APIs
127                                if filter_thinking && matches!(b, ContentBlock::Thinking { .. }) {
128                                    return false;
129                                }
130                                true
131                            })
132                            .map(|b| match b {
133                                ContentBlock::Text { text } => json!({"type": "text", "text": text}),
134                                ContentBlock::ToolUse { id, name, input } => {
135                                    json!({"type": "tool_use", "id": id, "name": name, "input": input})
136                                }
137                                ContentBlock::ToolResult { tool_use_id, content } => {
138                                    json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
139                                }
140                                ContentBlock::Thinking { thinking, signature } => {
141                                    let mut obj = json!({"type": "thinking", "thinking": thinking});
142                                    // Only send non-empty signature
143                                    if let Some(sig) = signature
144                                        && !sig.is_empty() {
145                                            obj["signature"] = json!(sig);
146                                        }
147                                    obj
148                                }
149                                ContentBlock::ServerToolUse { id, name, input } => {
150                                    json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
151                                }
152                                ContentBlock::WebSearchResult { tool_use_id, content } => {
153                                    json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
154                                }
155                            })
156                            .collect();
157
158                        if converted.is_empty() {
159                            json!([])
160                        } else {
161                            json!(converted)
162                        }
163                    }
164                };
165
166                // Skip messages with empty content
167                if content.is_array() && content.as_array().map(|a| a.is_empty()).unwrap_or(false) {
168                    return json!(null);
169                }
170
171                json!({"role": role, "content": content})
172            })
173            .filter(|v| !v.is_null())
174            .collect()
175    }
176
177    /// Convert tools with caching control for Anthropic prompt caching.
178    fn convert_tools_with_caching(
179        &self,
180        tools: &[ToolDefinition],
181        enable_caching: bool,
182    ) -> Vec<Value> {
183        let mut converted: Vec<Value> = tools
184            .iter()
185            .map(|t| {
186                json!({
187                    "name": t.name,
188                    "description": t.description,
189                    "input_schema": t.parameters,
190                })
191            })
192            .collect();
193
194        // Add cache_control to the last tool definition for tools caching
195        if enable_caching && !converted.is_empty() {
196            let last_idx = converted.len() - 1;
197            if let Some(obj) = converted[last_idx].as_object_mut() {
198                obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
199            }
200        }
201
202        converted
203    }
204
205    /// Build the base JSON body shared by streaming and non-streaming requests.
206    fn build_body(&self, request: &ChatRequest) -> Value {
207        let mut body = json!({
208            "model": self.model,
209            "max_tokens": request.max_tokens,
210            "messages": self.convert_messages(&request.messages),
211        });
212
213        // Add prompt caching for system prompt (Anthropic-specific)
214        if request.enable_caching {
215            if let Some(system) = &request.system {
216                // System prompt caching: add cache_control to enable caching
217                body["system"] = json!([
218                    {
219                        "type": "text",
220                        "text": system,
221                        "cache_control": {"type": "ephemeral"}
222                    }
223                ]);
224            }
225        } else if let Some(system) = &request.system {
226            body["system"] = json!(system);
227        }
228
229        if !request.tools.is_empty() {
230            let tools = self.convert_tools_with_caching(
231                &request.tools,
232                request.enable_caching,
233            );
234            body["tools"] = json!(tools);
235        }
236
237        if !request.server_tools.is_empty() {
238            body["tools"] = json!(
239                body["tools"]
240                    .as_array()
241                    .map(|t| {
242                        let mut tools = t.clone();
243                        for st in &request.server_tools {
244                            tools.push(serde_json::to_value(st).unwrap_or_default());
245                        }
246                        tools
247                    })
248                    .unwrap_or_else(|| request
249                        .server_tools
250                        .iter()
251                        .map(|st| serde_json::to_value(st).unwrap_or_default())
252                        .collect())
253            );
254        }
255
256        // Extended thinking (Anthropic-specific)
257        // Only enable thinking for official Anthropic API - third-party APIs like DashScope
258        // forcibly enable thinking mode which causes hanging issues
259        if request.think && self.is_official_anthropic() {
260            let config = thinking_config(&self.model);
261            log::debug!(
262                "Adding thinking config for model {}: {:?}",
263                self.model,
264                config
265            );
266            body["thinking"] = config;
267        } else if request.think {
268            log::debug!(
269                "Skipping thinking config for non-official API (model: {}, base_url: {})",
270                self.model,
271                self.base_url
272            );
273        }
274
275        body
276    }
277}
278
279/// Models that require the new `adaptive` thinking mode instead of the
280/// legacy `enabled`+`budget_tokens` form. Conservative allow-list: if we
281/// don't recognize the name we default to the legacy shape (which older
282/// models and most third-party gateways understand).
283fn thinking_config(model: &str) -> Value {
284    let m = model.to_lowercase();
285    // New models (2025+) use adaptive thinking
286    let adaptive = m.contains("opus-4")
287        || m.contains("sonnet-4")
288        || m.contains("claude-4")
289        || m.contains("20250")
290        || m.contains("2025");
291    if adaptive {
292        json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_NEW_MODELS})
293    } else {
294         // to prevent hanging on long histories
295        json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_OLD_MODELS})
296    }
297}
298
299#[async_trait]
300impl Provider for AnthropicProvider {
301    fn context_size(&self) -> Option<u32> {
302        context_window_for(&self.model)
303    }
304
305    fn model_name(&self) -> &str {
306        &self.model
307    }
308
309    fn clone_box(&self) -> Box<dyn Provider> {
310        Box::new(Self {
311            api_key: self.api_key.clone(),
312            model: self.model.clone(),
313            base_url: self.base_url.clone(),
314            client: reqwest::Client::new(),
315            extra_headers: self.extra_headers.clone(),
316        })
317    }
318
319    fn clone_arc(&self) -> Arc<dyn Provider> {
320        Arc::new(Self {
321            api_key: self.api_key.clone(),
322            model: self.model.clone(),
323            base_url: self.base_url.clone(),
324            client: reqwest::Client::new(),
325            extra_headers: self.extra_headers.clone(),
326        })
327    }
328
329    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
330        let body = self.build_body(&request);
331
332        let url = format!("{}/v1/messages", self.base_url);
333
334        // Debug: log request
335        crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
336
337        let mut req = self
338            .client
339            .post(&url)
340            .header("User-Agent", "curl/8.0")
341            .json(&body);
342
343        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
344        if self.is_official_anthropic() {
345            req = req
346                .header("x-api-key", &self.api_key)
347                .header("anthropic-version", ANTHROPIC_API_VERSION)
348                .header("anthropic-beta", "prompt-caching-2024-07-31");
349        } else {
350            req = req
351                .header("Authorization", format!("Bearer {}", self.api_key))
352                // Add anthropic-version for third-party APIs to ensure correct protocol behavior
353                .header("anthropic-version", "2023-06-01")
354                // Try enabling prompt caching for third-party APIs (DashScope may support this)
355                .header("anthropic-beta", "prompt-caching-2024-07-31");
356        }
357
358        // Add extra headers from config (all custom headers go here)
359        for (name, value) in &self.extra_headers {
360            req = req.header(name, value);
361        }
362
363        let response = req.send().await?;
364
365        let status = response.status();
366        let response_body: Value = response.json().await?;
367
368        // Debug: log response
369        crate::debug::debug_log().api_response(status.as_u16(), &serde_json::to_string(&response_body).unwrap_or_default());
370
371        if !status.is_success() {
372            let err_msg = response_body["error"]["message"]
373                .as_str()
374                .unwrap_or("unknown error");
375            anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
376        }
377
378        let stop_reason = match response_body["stop_reason"].as_str() {
379            Some("tool_use") => StopReason::ToolUse,
380            Some("max_tokens") => StopReason::MaxTokens,
381            _ => StopReason::EndTurn,
382        };
383
384        let content = response_body["content"]
385            .as_array()
386            .unwrap_or(&vec![])
387            .iter()
388            .filter_map(|block| match block["type"].as_str()? {
389                "text" => Some(ContentBlock::Text {
390                    text: block["text"].as_str()?.to_string(),
391                }),
392                "tool_use" => Some(ContentBlock::ToolUse {
393                    id: block["id"].as_str()?.to_string(),
394                    name: block["name"].as_str()?.to_string(),
395                    input: block["input"].clone(),
396                }),
397                "thinking" => Some(ContentBlock::Thinking {
398                    thinking: block["thinking"].as_str()?.to_string(),
399                    signature: block["signature"].as_str().map(String::from),
400                }),
401                "server_tool_use" => Some(ContentBlock::ServerToolUse {
402                    id: block["id"].as_str()?.to_string(),
403                    name: block["name"].as_str()?.to_string(),
404                    input: block["input"].clone(),
405                }),
406                "web_search_tool_result" => {
407                    let tool_use_id = block["tool_use_id"].as_str()?.to_string();
408                    let content = parse_web_search_content(&block["content"]);
409                    Some(ContentBlock::WebSearchResult {
410                        tool_use_id,
411                        content,
412                    })
413                }
414                _ => None,
415            })
416            .collect();
417
418        Ok(ChatResponse {
419            content,
420            stop_reason,
421            usage: parse_usage(&response_body["usage"]),
422        })
423    }
424
425    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
426        let mut body = self.build_body(&request);
427        body["stream"] = json!(true);
428
429        let url = format!("{}/v1/messages", self.base_url);
430
431        // Debug: log streaming request
432        crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
433
434        let mut req = self
435            .client
436            .post(&url)
437            .header("User-Agent", "curl/8.0")
438            .json(&body);
439
440        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
441        if self.is_official_anthropic() {
442            req = req
443                .header("x-api-key", &self.api_key)
444                .header("anthropic-version", ANTHROPIC_API_VERSION)
445                .header("anthropic-beta", "prompt-caching-2024-07-31");
446        } else {
447            req = req
448                .header("Authorization", format!("Bearer {}", self.api_key))
449                // Add anthropic-version for third-party APIs to ensure correct protocol behavior
450                .header("anthropic-version", "2023-06-01")
451                // Try enabling prompt caching for third-party APIs (DashScope may support this)
452                .header("anthropic-beta", "prompt-caching-2024-07-31");
453        }
454
455        // Add extra headers from config (all custom headers go here)
456        for (name, value) in &self.extra_headers {
457            req = req.header(name, value);
458        }
459
460        let response = req.send().await?;
461
462        if !response.status().is_success() {
463            let status = response.status();
464            let text = response.text().await.unwrap_or_default();
465            anyhow::bail!("Anthropic API error ({}): {}", status, text);
466        }
467
468        let (tx, rx) = mpsc::channel(64);
469        tokio::spawn(async move {
470            let mut stream = response.bytes_stream();
471            let mut buffer = String::new();
472            let mut sent_first_byte = false;
473
474            // In-flight block assembly: index → partial data
475            let mut blocks: Vec<AssembledBlock> = Vec::new();
476            let mut stop_reason = StopReason::EndTurn;
477            let mut usage = Usage::default();
478
479            // Timeout detection: track last meaningful event (non-ping)
480            let mut last_content_time = std::time::Instant::now();
481            const CONTENT_TIMEOUT_SECS: u64 = DEFAULT_CONTENT_TIMEOUT_SECS; // 5 minutes without content = timeout (for slow APIs like DashScope/glm-5)
482
483            while let Some(chunk) = stream.next().await {
484                let chunk = match chunk {
485                    Ok(c) => c,
486                    Err(e) => {
487                        // Log detailed error for debugging
488                        let error_msg = e.to_string();
489                        let is_timeout = error_msg.contains("timeout") || error_msg.contains("timed out");
490                        let is_decode = error_msg.contains("decode") || error_msg.contains("decoding");
491
492                        let msg = if is_timeout {
493                            format!("Stream timeout - the API took too long to respond: {}", error_msg)
494                        } else if is_decode {
495                            format!("Stream decode error - possibly interrupted or corrupted response: {}", error_msg)
496                        } else {
497                            format!("Stream read error: {}", error_msg)
498                        };
499
500                        // Try to finalize any partial content we have
501                        if sent_first_byte && !blocks.is_empty() {
502                            debug!("finalizing partial stream due to error");
503                            let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
504                                std::mem::take(&mut blocks),
505                                stop_reason,
506                                usage,
507                            ))).await;
508                        } else {
509                            let _ = tx.send(StreamEvent::Error(msg)).await;
510                        }
511                        return;
512                    }
513                };
514
515                if !sent_first_byte {
516                    sent_first_byte = true;
517                    let _ = tx.send(StreamEvent::FirstByte).await;
518                }
519
520                buffer.push_str(&String::from_utf8_lossy(&chunk));
521
522                // Check for timeout: if only ping events for too long, force finalize
523                let elapsed = last_content_time.elapsed().as_secs();
524                if elapsed > CONTENT_TIMEOUT_SECS && !blocks.is_empty() {
525                    crate::debug::debug_log().stream_chunk("TIMEOUT_FORCE_FINALIZE",
526                        &format!("elapsed={}s, blocks={}", elapsed, blocks.len()));
527                    let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
528                        std::mem::take(&mut blocks),
529                        stop_reason,
530                        usage,
531                    ))).await;
532                    return;
533                }
534
535                while let Some(frame) = take_next_sse_frame(&mut buffer) {
536                    if handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx, &mut last_content_time)
537                        .await
538                    {
539                        return;
540                    }
541                }
542            }
543
544            if let Some(frame) = take_trailing_sse_frame(&mut buffer)
545                && handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx, &mut last_content_time).await
546            {
547                return;
548            }
549
550            if sent_first_byte {
551                debug!("stream ended without explicit message_stop; finalizing best-effort");
552                let _ = tx
553                    .send(StreamEvent::Done(finalize_incomplete_stream(
554                        std::mem::take(&mut blocks),
555                        stop_reason,
556                        usage,
557                    )))
558                    .await;
559            } else {
560                let _ = tx
561                    .send(StreamEvent::Error(
562                        "stream ended before any events were received".to_string(),
563                    ))
564                    .await;
565            }
566        });
567
568        Ok(rx)
569    }
570}
571
572fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
573    let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
574    let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
575    let (pos, delim_len) = match (lf, crlf) {
576        (Some(a), Some(b)) => {
577            if a.0 <= b.0 {
578                a
579            } else {
580                b
581            }
582        }
583        (Some(a), None) => a,
584        (None, Some(b)) => b,
585        (None, None) => return None,
586    };
587
588    let frame = buffer[..pos].to_string();
589    buffer.drain(..pos + delim_len);
590    Some(frame)
591}
592
593fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
594    let frame = buffer.trim().trim_end_matches('\r').to_string();
595    buffer.clear();
596    if frame.is_empty() { None } else { Some(frame) }
597}
598
599fn extract_sse_data_line(frame: &str) -> Option<String> {
600    for line in frame.lines() {
601        let line = line.trim_end_matches('\r');
602        // Support both "data: " (Anthropic) and "data:" (DashScope)
603        if let Some(rest) = line.strip_prefix("data: ") {
604            return Some(rest.to_string());
605        }
606        if let Some(rest) = line.strip_prefix("data:") {
607            return Some(rest.to_string());
608        }
609    }
610    None
611}
612
613async fn handle_sse_frame(
614    frame: &str,
615    blocks: &mut Vec<AssembledBlock>,
616    stop_reason: &mut StopReason,
617    usage: &mut Usage,
618    tx: &mpsc::Sender<StreamEvent>,
619    last_content_time: &mut std::time::Instant,
620) -> bool {
621    let Some(data_line) = extract_sse_data_line(frame) else {
622        return false;
623    };
624
625    let evt: Value = match serde_json::from_str(&data_line) {
626        Ok(v) => v,
627        Err(_) => return false,
628    };
629
630    handle_sse_event(evt, blocks, stop_reason, usage, tx, last_content_time).await
631}
632
633async fn handle_sse_event(
634    evt: Value,
635    blocks: &mut Vec<AssembledBlock>,
636    stop_reason: &mut StopReason,
637    usage: &mut Usage,
638    tx: &mpsc::Sender<StreamEvent>,
639    last_content_time: &mut std::time::Instant,
640) -> bool {
641    let evt_type = evt["type"].as_str().unwrap_or("");
642
643    // Debug: log all SSE events for diagnosis (with full content for debugging)
644    let evt_json = serde_json::to_string(&evt).unwrap_or_default();
645    crate::debug::debug_log().stream_chunk(evt_type, &evt_json);
646
647    // Log event handling for thinking_delta specifically
648    if evt_type == "content_block_delta" {
649        let delta_type = evt["delta"]["type"].as_str().unwrap_or("");
650        let idx = evt["index"].as_u64().unwrap_or(0) as usize;
651        log::debug!(
652            "content_block_delta: type={}, idx={}, blocks_len={}, has_block={}",
653            delta_type,
654            idx,
655            blocks.len(),
656            idx < blocks.len()
657        );
658    }
659
660    // Update last_content_time for non-ping events
661    if evt_type != "ping" {
662        *last_content_time = std::time::Instant::now();
663    }
664
665    match evt_type {
666        "message_start" => {
667            // Initial usage payload — `input_tokens` is final
668            // (they don't grow during streaming) but
669            // `output_tokens` starts near 0 and is updated by
670            // subsequent `message_delta` events.
671            *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
672            debug!(
673                "message_start: usage_json={}",
674                serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
675            );
676            debug!(
677                "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
678                usage.input_tokens,
679                usage.output_tokens,
680                usage.cache_read_input_tokens,
681                usage.cache_creation_input_tokens
682            );
683            // Send real-time usage update
684            let _ = tx
685                .send(StreamEvent::Usage {
686                    output_tokens: usage.output_tokens,
687                })
688                .await;
689        }
690        "content_block_start" => {
691            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
692            let block = &evt["content_block"];
693            let kind = block["type"].as_str().unwrap_or("");
694            while blocks.len() <= idx {
695                blocks.push(AssembledBlock::default());
696            }
697            match kind {
698                "text" => {
699                    blocks[idx] = AssembledBlock::Text(String::new());
700                }
701                "thinking" => {
702                    blocks[idx] = AssembledBlock::Thinking {
703                        text: String::new(),
704                        signature: None,
705                    };
706                }
707                "tool_use" | "server_tool_use" => {
708                    let id = block["id"].as_str().unwrap_or_default();
709                    let name = block["name"].as_str().unwrap_or_default();
710                    let is_server = kind == "server_tool_use";
711                    blocks[idx] = if is_server {
712                        AssembledBlock::ServerToolUse {
713                            id: id.into(),
714                            name: name.into(),
715                            input_json: String::new(),
716                        }
717                    } else {
718                        AssembledBlock::ToolUse {
719                            id: id.into(),
720                            name: name.into(),
721                            input_json: String::new(),
722                        }
723                    };
724                    let _ = tx
725                        .send(StreamEvent::ToolUseStart {
726                            id: id.into(),
727                            name: name.into(),
728                        })
729                        .await;
730                }
731                "web_search_tool_result" => {
732                    let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
733                    blocks[idx] = AssembledBlock::WebSearchResult {
734                        tool_use_id,
735                        content_json: String::new(),
736                    };
737                }
738                _ => {}
739            }
740        }
741        "content_block_delta" => {
742            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
743            let delta = &evt["delta"];
744            let dt = delta["type"].as_str().unwrap_or("");
745            if idx >= blocks.len() {
746                return false;
747            }
748            match (dt, &mut blocks[idx]) {
749                ("text_delta", AssembledBlock::Text(buf)) => {
750                    if let Some(t) = delta["text"].as_str() {
751                        buf.push_str(t);
752                        let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
753                    }
754                }
755                ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
756                    if let Some(t) = delta["thinking"].as_str() {
757                        text.push_str(t);
758                        log::debug!("Received thinking_delta: {} chars", t.len());
759                        let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
760                    }
761                }
762                ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
763                    if let Some(s) = delta["signature"].as_str() {
764                        signature.get_or_insert_with(String::new).push_str(s);
765                    }
766                }
767                ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
768                    if let Some(p) = delta["partial_json"].as_str() {
769                        input_json.push_str(p);
770                        let _ = tx
771                            .send(StreamEvent::ToolInputDelta {
772                                bytes_so_far: input_json.len(),
773                            })
774                            .await;
775                    }
776                }
777                ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
778                    if let Some(p) = delta["partial_json"].as_str() {
779                        input_json.push_str(p);
780                        let _ = tx
781                            .send(StreamEvent::ToolInputDelta {
782                                bytes_so_far: input_json.len(),
783                            })
784                            .await;
785                    }
786                }
787                _ => {}
788            }
789        }
790        "message_delta" => {
791            if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
792                *stop_reason = match sr {
793                    "tool_use" => StopReason::ToolUse,
794                    "max_tokens" => StopReason::MaxTokens,
795                    _ => StopReason::EndTurn,
796                };
797            }
798            // `usage` on message_delta reflects cumulative
799            // counts for the current message — most notably
800            // the final `output_tokens`.
801            *usage = merge_usage(usage.clone(), &evt["usage"]);
802            debug!(
803                "message_delta: input={}, output={}, cache_read={}, cache_created={}",
804                usage.input_tokens,
805                usage.output_tokens,
806                usage.cache_read_input_tokens,
807                usage.cache_creation_input_tokens
808            );
809            // Send real-time usage update
810            let _ = tx
811                .send(StreamEvent::Usage {
812                    output_tokens: usage.output_tokens,
813                })
814                .await;
815        }
816        "message_stop" => {
817            debug!(
818                "Message completed: stop_reason={}, usage={}",
819                match *stop_reason {
820                    StopReason::EndTurn => "end_turn",
821                    StopReason::ToolUse => "tool_use",
822                    StopReason::MaxTokens => "max_tokens",
823                },
824                usage.output_tokens
825            );
826            debug!(
827                "message_stop final usage: cache_read={}, cache_created={}",
828                usage.cache_read_input_tokens, usage.cache_creation_input_tokens
829            );
830            let _ = tx
831                .send(StreamEvent::Done(finalize_incomplete_stream(
832                    std::mem::take(blocks),
833                    stop_reason.clone(),
834                    usage.clone(),
835                )))
836                .await;
837            return true;
838        }
839        "error" => {
840            let msg = evt["error"]["message"]
841                .as_str()
842                .unwrap_or("unknown stream error")
843                .to_string();
844            let _ = tx.send(StreamEvent::Error(msg)).await;
845            return true;
846        }
847        _ => {}
848    }
849
850    false
851}
852
853fn build_chat_response(
854    blocks: Vec<AssembledBlock>,
855    stop_reason: StopReason,
856    usage: Usage,
857) -> ChatResponse {
858    let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
859    ChatResponse {
860        content,
861        stop_reason,
862        usage,
863    }
864}
865
866fn finalize_incomplete_stream(
867    blocks: Vec<AssembledBlock>,
868    stop_reason: StopReason,
869    usage: Usage,
870) -> ChatResponse {
871    build_chat_response(blocks, stop_reason, usage)
872}
873
874#[derive(Default)]
875enum AssembledBlock {
876    #[default]
877    Empty,
878    Text(String),
879    Thinking {
880        text: String,
881        signature: Option<String>,
882    },
883    ToolUse {
884        id: String,
885        name: String,
886        input_json: String,
887    },
888    ServerToolUse {
889        id: String,
890        name: String,
891        input_json: String,
892    },
893    WebSearchResult {
894        tool_use_id: String,
895        content_json: String,
896    },
897}
898
899impl AssembledBlock {
900    fn finish(self) -> Option<ContentBlock> {
901        match self {
902            AssembledBlock::Empty => None,
903            AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
904            AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
905                thinking: text,
906                signature,
907            }),
908            AssembledBlock::ToolUse {
909                id,
910                name,
911                input_json,
912            } => {
913                let input: Value = if input_json.is_empty() {
914                    json!({})
915                } else {
916                    serde_json::from_str(&input_json).unwrap_or(json!({}))
917                };
918                Some(ContentBlock::ToolUse { id, name, input })
919            }
920            AssembledBlock::ServerToolUse {
921                id,
922                name,
923                input_json,
924            } => {
925                let input: Value = if input_json.is_empty() {
926                    json!({})
927                } else {
928                    serde_json::from_str(&input_json).unwrap_or(json!({}))
929                };
930                Some(ContentBlock::ServerToolUse { id, name, input })
931            }
932            AssembledBlock::WebSearchResult {
933                tool_use_id,
934                content_json,
935            } => {
936                let content: Value = if content_json.is_empty() {
937                    json!({"results": []})
938                } else {
939                    serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
940                };
941                Some(ContentBlock::WebSearchResult {
942                    tool_use_id,
943                    content: parse_web_search_content(&content),
944                })
945            }
946        }
947    }
948}
949
950/// Parse the provider's `usage` blob (non-streaming response) into our
951/// internal `Usage` struct. Missing fields default to 0.
952fn parse_usage(u: &Value) -> Usage {
953    Usage {
954        input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
955        output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
956        cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
957        cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
958    }
959}
960
961/// Merge a fresh usage payload into the accumulated one. Non-zero new values
962/// override prior ones — this matches the streaming protocol where
963/// `message_start` gives input counts and `message_delta` updates output.
964fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
965    let fresh = parse_usage(u);
966    if fresh.input_tokens > 0 {
967        acc.input_tokens = fresh.input_tokens;
968    }
969    if fresh.output_tokens > 0 {
970        acc.output_tokens = fresh.output_tokens;
971    }
972    if fresh.cache_creation_input_tokens > 0 {
973        acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
974    }
975    if fresh.cache_read_input_tokens > 0 {
976        acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
977    }
978    acc
979}
980
981/// Parse web search content from the API response.
982fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
983    let results = value["results"]
984        .as_array()
985        .map(|arr| {
986            arr.iter()
987                .filter_map(|item| {
988                    Some(crate::providers::WebSearchResultItem {
989                        title: item["title"].as_str().map(String::from),
990                        url: item["url"].as_str()?.to_string(),
991                        encrypted_content: item["encrypted_content"].as_str().map(String::from),
992                        snippet: item["snippet"].as_str().map(String::from),
993                    })
994                })
995                .collect()
996        })
997        .unwrap_or_default();
998
999    crate::providers::WebSearchContent { results }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004    use super::*;
1005
1006    #[test]
1007    fn take_next_sse_frame_supports_crlf_delimiter() {
1008        let mut buffer = concat!(
1009            "event: message_start\r\n",
1010            "data: {\"type\":\"message_start\"}\r\n\r\n",
1011            "data: {\"type\":\"message_stop\"}\r\n\r\n"
1012        )
1013        .to_string();
1014
1015        let first = take_next_sse_frame(&mut buffer).expect("first frame");
1016        assert!(first.contains("message_start"));
1017
1018        let second = take_next_sse_frame(&mut buffer).expect("second frame");
1019        assert!(second.contains("message_stop"));
1020        assert!(buffer.is_empty());
1021    }
1022
1023    #[test]
1024    fn take_trailing_sse_frame_returns_unterminated_event() {
1025        let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
1026        let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
1027        assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
1028        assert!(buffer.is_empty());
1029    }
1030
1031    #[test]
1032    fn extract_sse_data_line_supports_optional_space() {
1033        assert_eq!(
1034            extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
1035            Some("{\"k\":1}".to_string())
1036        );
1037        assert_eq!(
1038            extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
1039            Some("{\"k\":2}".to_string())
1040        );
1041    }
1042
1043    #[test]
1044    fn finalize_incomplete_stream_keeps_accumulated_content() {
1045        let response = finalize_incomplete_stream(
1046            vec![AssembledBlock::Text("partial reply".to_string())],
1047            StopReason::EndTurn,
1048            Usage::default(),
1049        );
1050
1051        assert_eq!(response.stop_reason, StopReason::EndTurn);
1052        assert_eq!(response.content.len(), 1);
1053        match &response.content[0] {
1054            ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
1055            other => panic!("unexpected block: {other:?}"),
1056        }
1057    }
1058}