Skip to main content

matrixcode_core/providers/
anthropic.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use tokio::sync::mpsc;
7
8use crate::models::context_window_for;
9use crate::tools::ToolDefinition;
10
11use super::{
12    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
13    StreamEvent, Usage,
14};
15
16pub struct AnthropicProvider {
17    api_key: String,
18    model: String,
19    base_url: String,
20    client: reqwest::Client,
21    /// Extra headers from config
22    extra_headers: Vec<(String, String)>,
23}
24
25impl AnthropicProvider {
26    pub fn new(api_key: String, model: String, base_url: String) -> Self {
27        Self::with_headers(api_key, model, base_url, None)
28    }
29
30    pub fn with_headers(
31        api_key: String,
32        model: String,
33        base_url: String,
34        extra_headers: Option<std::collections::HashMap<String, String>>,
35    ) -> Self {
36        // Create client with better timeout handling for streaming
37        // - No total timeout (streaming responses can take a long time)
38        // - Connect timeout: 10 seconds
39        // - Read timeout per chunk: 60 seconds (for slow responses between chunks)
40        let client = reqwest::Client::builder()
41            .connect_timeout(std::time::Duration::from_secs(10))
42            .read_timeout(std::time::Duration::from_secs(60))
43            .timeout(std::time::Duration::from_secs(300)) // Total timeout for non-streaming
44            .build()
45            .unwrap_or_else(|_| reqwest::Client::new());
46        let extra_headers: Vec<(String, String)> = extra_headers
47            .map(|h| h.into_iter().collect())
48            .unwrap_or_default();
49        Self {
50            api_key,
51            model,
52            base_url,
53            client,
54            extra_headers,
55        }
56    }
57
58    /// Check if this is the official Anthropic API endpoint.
59    /// Non-official endpoints typically use Bearer auth (OpenAI-compatible).
60    fn is_official_anthropic(&self) -> bool {
61        self.base_url.contains("api.anthropic.com")
62    }
63
64    fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
65        // For non-official APIs (like glm-5/DashScope), filter out thinking blocks
66        // to prevent API from entering extended thinking mode and hanging
67        // This is necessary because some APIs don't properly support thinking blocks
68        let filter_thinking = !self.is_official_anthropic();
69        log::debug!("convert_messages: filter_thinking={}, base_url={}", filter_thinking, self.base_url);
70
71        // Count thinking blocks in original messages
72        let mut thinking_count = 0;
73        for m in messages {
74            if let MessageContent::Blocks(blocks) = &m.content {
75                for b in blocks {
76                    if matches!(b, ContentBlock::Thinking { .. }) {
77                        thinking_count += 1;
78                    }
79                }
80            }
81        }
82        if thinking_count > 0 {
83            log::debug!("convert_messages: Found {} thinking blocks in {} messages, filter_thinking={}", thinking_count, messages.len(), filter_thinking);
84        }
85
86        messages
87            .iter()
88            .filter(|m| m.role != Role::System)
89            .map(|m| {
90                let role = match m.role {
91                    Role::User | Role::Tool => "user",
92                    Role::Assistant => "assistant",
93                    Role::System => unreachable!(),
94                };
95
96                let content = match &m.content {
97                    MessageContent::Text(text) => json!(text),
98                    MessageContent::Blocks(blocks) => {
99                        let converted: Vec<Value> = blocks
100                            .iter()
101                            .filter(|b| {
102                                // Skip thinking blocks for non-official APIs
103                                if filter_thinking && matches!(b, ContentBlock::Thinking { .. }) {
104                                    return false;
105                                }
106                                true
107                            })
108                            .map(|b| match b {
109                                ContentBlock::Text { text } => json!({"type": "text", "text": text}),
110                                ContentBlock::ToolUse { id, name, input } => {
111                                    json!({"type": "tool_use", "id": id, "name": name, "input": input})
112                                }
113                                ContentBlock::ToolResult { tool_use_id, content } => {
114                                    json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
115                                }
116                                ContentBlock::Thinking { thinking, signature } => {
117                                    let mut obj = json!({"type": "thinking", "thinking": thinking});
118                                    // Only send non-empty signature
119                                    if let Some(sig) = signature {
120                                        if !sig.is_empty() {
121                                            obj["signature"] = json!(sig);
122                                        }
123                                    }
124                                    obj
125                                }
126                                ContentBlock::ServerToolUse { id, name, input } => {
127                                    json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
128                                }
129                                ContentBlock::WebSearchResult { tool_use_id, content } => {
130                                    json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
131                                }
132                            })
133                            .collect();
134
135                        if converted.is_empty() {
136                            json!([])
137                        } else {
138                            json!(converted)
139                        }
140                    }
141                };
142
143                // Skip messages with empty content
144                if content.is_array() && content.as_array().map(|a| a.is_empty()).unwrap_or(false) {
145                    return json!(null);
146                }
147
148                json!({"role": role, "content": content})
149            })
150            .filter(|v| !v.is_null())
151            .collect()
152    }
153
154    /// Convert tools with caching control for Anthropic prompt caching.
155    fn convert_tools_with_caching(
156        &self,
157        tools: &[ToolDefinition],
158        enable_caching: bool,
159    ) -> Vec<Value> {
160        let mut converted: Vec<Value> = tools
161            .iter()
162            .map(|t| {
163                json!({
164                    "name": t.name,
165                    "description": t.description,
166                    "input_schema": t.parameters,
167                })
168            })
169            .collect();
170
171        // Add cache_control to the last tool definition for tools caching
172        if enable_caching && !converted.is_empty() {
173            let last_idx = converted.len() - 1;
174            if let Some(obj) = converted[last_idx].as_object_mut() {
175                obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
176            }
177        }
178
179        converted
180    }
181
182    /// Build the base JSON body shared by streaming and non-streaming requests.
183    fn build_body(&self, request: &ChatRequest) -> Value {
184        let mut body = json!({
185            "model": self.model,
186            "max_tokens": request.max_tokens,
187            "messages": self.convert_messages(&request.messages),
188        });
189
190        // Add prompt caching for system prompt (Anthropic-specific)
191        if request.enable_caching {
192            if let Some(system) = &request.system {
193                // System prompt caching: add cache_control to enable caching
194                body["system"] = json!([
195                    {
196                        "type": "text",
197                        "text": system,
198                        "cache_control": {"type": "ephemeral"}
199                    }
200                ]);
201            }
202        } else if let Some(system) = &request.system {
203            body["system"] = json!(system);
204        }
205
206        if !request.tools.is_empty() {
207            let tools = self.convert_tools_with_caching(
208                &request.tools,
209                request.enable_caching,
210            );
211            body["tools"] = json!(tools);
212        }
213
214        if !request.server_tools.is_empty() {
215            body["tools"] = json!(
216                body["tools"]
217                    .as_array()
218                    .map(|t| {
219                        let mut tools = t.clone();
220                        for st in &request.server_tools {
221                            tools.push(serde_json::to_value(st).unwrap_or_default());
222                        }
223                        tools
224                    })
225                    .unwrap_or_else(|| request
226                        .server_tools
227                        .iter()
228                        .map(|st| serde_json::to_value(st).unwrap_or_default())
229                        .collect())
230            );
231        }
232
233        // Extended thinking (Anthropic-specific)
234        // Only enable thinking for official Anthropic API - third-party APIs like DashScope
235        // forcibly enable thinking mode which causes hanging issues
236        if request.think && self.is_official_anthropic() {
237            let config = thinking_config(&self.model);
238            log::debug!(
239                "Adding thinking config for model {}: {:?}",
240                self.model,
241                config
242            );
243            body["thinking"] = config;
244        } else if request.think {
245            log::debug!(
246                "Skipping thinking config for non-official API (model: {}, base_url: {})",
247                self.model,
248                self.base_url
249            );
250        }
251
252        body
253    }
254}
255
256/// Models that require the new `adaptive` thinking mode instead of the
257/// legacy `enabled`+`budget_tokens` form. Conservative allow-list: if we
258/// don't recognize the name we default to the legacy shape (which older
259/// models and most third-party gateways understand).
260fn thinking_config(model: &str) -> Value {
261    let m = model.to_lowercase();
262    // New models (2025+) use adaptive thinking
263    let adaptive = m.contains("opus-4")
264        || m.contains("sonnet-4")
265        || m.contains("claude-4")
266        || m.contains("20250")
267        || m.contains("2025");
268    if adaptive {
269        json!({"type": "enabled", "budget_tokens": 10000})
270    } else {
271         // to prevent hanging on long histories
272        json!({"type": "enabled", "budget_tokens": 5000})
273    }
274}
275
276#[async_trait]
277impl Provider for AnthropicProvider {
278    fn context_size(&self) -> Option<u32> {
279        context_window_for(&self.model)
280    }
281
282    fn model_name(&self) -> &str {
283        &self.model
284    }
285
286    fn clone_box(&self) -> Box<dyn Provider> {
287        Box::new(Self {
288            api_key: self.api_key.clone(),
289            model: self.model.clone(),
290            base_url: self.base_url.clone(),
291            client: reqwest::Client::new(),
292            extra_headers: self.extra_headers.clone(),
293        })
294    }
295
296    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
297        let body = self.build_body(&request);
298
299        let url = format!("{}/v1/messages", self.base_url);
300
301        // Debug: log request
302        crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
303
304        let mut req = self
305            .client
306            .post(&url)
307            .header("User-Agent", "curl/8.0")
308            .json(&body);
309
310        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
311        if self.is_official_anthropic() {
312            req = req
313                .header("x-api-key", &self.api_key)
314                .header("anthropic-version", "2025-04-15")
315                .header("anthropic-beta", "prompt-caching-2024-07-31");
316        } else {
317            req = req
318                .header("Authorization", format!("Bearer {}", self.api_key))
319                // Add anthropic-version for third-party APIs to ensure correct protocol behavior
320                .header("anthropic-version", "2023-06-01")
321                // Try enabling prompt caching for third-party APIs (DashScope may support this)
322                .header("anthropic-beta", "prompt-caching-2024-07-31");
323        }
324
325        // Add extra headers from config (all custom headers go here)
326        for (name, value) in &self.extra_headers {
327            req = req.header(name, value);
328        }
329
330        let response = req.send().await?;
331
332        let status = response.status();
333        let response_body: Value = response.json().await?;
334
335        // Debug: log response
336        crate::debug::debug_log().api_response(status.as_u16(), &serde_json::to_string(&response_body).unwrap_or_default());
337
338        if !status.is_success() {
339            let err_msg = response_body["error"]["message"]
340                .as_str()
341                .unwrap_or("unknown error");
342            anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
343        }
344
345        let stop_reason = match response_body["stop_reason"].as_str() {
346            Some("tool_use") => StopReason::ToolUse,
347            Some("max_tokens") => StopReason::MaxTokens,
348            _ => StopReason::EndTurn,
349        };
350
351        let content = response_body["content"]
352            .as_array()
353            .unwrap_or(&vec![])
354            .iter()
355            .filter_map(|block| match block["type"].as_str()? {
356                "text" => Some(ContentBlock::Text {
357                    text: block["text"].as_str()?.to_string(),
358                }),
359                "tool_use" => Some(ContentBlock::ToolUse {
360                    id: block["id"].as_str()?.to_string(),
361                    name: block["name"].as_str()?.to_string(),
362                    input: block["input"].clone(),
363                }),
364                "thinking" => Some(ContentBlock::Thinking {
365                    thinking: block["thinking"].as_str()?.to_string(),
366                    signature: block["signature"].as_str().map(String::from),
367                }),
368                "server_tool_use" => Some(ContentBlock::ServerToolUse {
369                    id: block["id"].as_str()?.to_string(),
370                    name: block["name"].as_str()?.to_string(),
371                    input: block["input"].clone(),
372                }),
373                "web_search_tool_result" => {
374                    let tool_use_id = block["tool_use_id"].as_str()?.to_string();
375                    let content = parse_web_search_content(&block["content"]);
376                    Some(ContentBlock::WebSearchResult {
377                        tool_use_id,
378                        content,
379                    })
380                }
381                _ => None,
382            })
383            .collect();
384
385        Ok(ChatResponse {
386            content,
387            stop_reason,
388            usage: parse_usage(&response_body["usage"]),
389        })
390    }
391
392    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
393        let mut body = self.build_body(&request);
394        body["stream"] = json!(true);
395
396        let url = format!("{}/v1/messages", self.base_url);
397
398        // Debug: log streaming request
399        crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
400
401        let mut req = self
402            .client
403            .post(&url)
404            .header("User-Agent", "curl/8.0")
405            .json(&body);
406
407        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
408        if self.is_official_anthropic() {
409            req = req
410                .header("x-api-key", &self.api_key)
411                .header("anthropic-version", "2025-04-15")
412                .header("anthropic-beta", "prompt-caching-2024-07-31");
413        } else {
414            req = req
415                .header("Authorization", format!("Bearer {}", self.api_key))
416                // Add anthropic-version for third-party APIs to ensure correct protocol behavior
417                .header("anthropic-version", "2023-06-01")
418                // Try enabling prompt caching for third-party APIs (DashScope may support this)
419                .header("anthropic-beta", "prompt-caching-2024-07-31");
420        }
421
422        // Add extra headers from config (all custom headers go here)
423        for (name, value) in &self.extra_headers {
424            req = req.header(name, value);
425        }
426
427        let response = req.send().await?;
428
429        if !response.status().is_success() {
430            let status = response.status();
431            let text = response.text().await.unwrap_or_default();
432            anyhow::bail!("Anthropic API error ({}): {}", status, text);
433        }
434
435        let (tx, rx) = mpsc::channel(64);
436        tokio::spawn(async move {
437            let mut stream = response.bytes_stream();
438            let mut buffer = String::new();
439            let mut sent_first_byte = false;
440
441            // In-flight block assembly: index → partial data
442            let mut blocks: Vec<AssembledBlock> = Vec::new();
443            let mut stop_reason = StopReason::EndTurn;
444            let mut usage = Usage::default();
445
446            // Timeout detection: track last meaningful event (non-ping)
447            let mut last_content_time = std::time::Instant::now();
448            const CONTENT_TIMEOUT_SECS: u64 = 300; // 5 minutes without content = timeout (for slow APIs like DashScope/glm-5)
449
450            while let Some(chunk) = stream.next().await {
451                let chunk = match chunk {
452                    Ok(c) => c,
453                    Err(e) => {
454                        // Log detailed error for debugging
455                        let error_msg = e.to_string();
456                        let is_timeout = error_msg.contains("timeout") || error_msg.contains("timed out");
457                        let is_decode = error_msg.contains("decode") || error_msg.contains("decoding");
458
459                        let msg = if is_timeout {
460                            format!("Stream timeout - the API took too long to respond: {}", error_msg)
461                        } else if is_decode {
462                            format!("Stream decode error - possibly interrupted or corrupted response: {}", error_msg)
463                        } else {
464                            format!("Stream read error: {}", error_msg)
465                        };
466
467                        // Try to finalize any partial content we have
468                        if sent_first_byte && !blocks.is_empty() {
469                            debug!("finalizing partial stream due to error");
470                            let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
471                                std::mem::take(&mut blocks),
472                                stop_reason,
473                                usage,
474                            ))).await;
475                        } else {
476                            let _ = tx.send(StreamEvent::Error(msg)).await;
477                        }
478                        return;
479                    }
480                };
481
482                if !sent_first_byte {
483                    sent_first_byte = true;
484                    let _ = tx.send(StreamEvent::FirstByte).await;
485                }
486
487                buffer.push_str(&String::from_utf8_lossy(&chunk));
488
489                // Check for timeout: if only ping events for too long, force finalize
490                let elapsed = last_content_time.elapsed().as_secs();
491                if elapsed > CONTENT_TIMEOUT_SECS && !blocks.is_empty() {
492                    crate::debug::debug_log().stream_chunk("TIMEOUT_FORCE_FINALIZE",
493                        &format!("elapsed={}s, blocks={}", elapsed, blocks.len()));
494                    let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
495                        std::mem::take(&mut blocks),
496                        stop_reason,
497                        usage,
498                    ))).await;
499                    return;
500                }
501
502                while let Some(frame) = take_next_sse_frame(&mut buffer) {
503                    if handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx, &mut last_content_time)
504                        .await
505                    {
506                        return;
507                    }
508                }
509            }
510
511            if let Some(frame) = take_trailing_sse_frame(&mut buffer)
512                && handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx, &mut last_content_time).await
513            {
514                return;
515            }
516
517            if sent_first_byte {
518                debug!("stream ended without explicit message_stop; finalizing best-effort");
519                let _ = tx
520                    .send(StreamEvent::Done(finalize_incomplete_stream(
521                        std::mem::take(&mut blocks),
522                        stop_reason,
523                        usage,
524                    )))
525                    .await;
526            } else {
527                let _ = tx
528                    .send(StreamEvent::Error(
529                        "stream ended before any events were received".to_string(),
530                    ))
531                    .await;
532            }
533        });
534
535        Ok(rx)
536    }
537}
538
539fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
540    let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
541    let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
542    let (pos, delim_len) = match (lf, crlf) {
543        (Some(a), Some(b)) => {
544            if a.0 <= b.0 {
545                a
546            } else {
547                b
548            }
549        }
550        (Some(a), None) => a,
551        (None, Some(b)) => b,
552        (None, None) => return None,
553    };
554
555    let frame = buffer[..pos].to_string();
556    buffer.drain(..pos + delim_len);
557    Some(frame)
558}
559
560fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
561    let frame = buffer.trim().trim_end_matches('\r').to_string();
562    buffer.clear();
563    if frame.is_empty() { None } else { Some(frame) }
564}
565
566fn extract_sse_data_line(frame: &str) -> Option<String> {
567    for line in frame.lines() {
568        let line = line.trim_end_matches('\r');
569        // Support both "data: " (Anthropic) and "data:" (DashScope)
570        if let Some(rest) = line.strip_prefix("data: ") {
571            return Some(rest.to_string());
572        }
573        if let Some(rest) = line.strip_prefix("data:") {
574            return Some(rest.to_string());
575        }
576    }
577    None
578}
579
580async fn handle_sse_frame(
581    frame: &str,
582    blocks: &mut Vec<AssembledBlock>,
583    stop_reason: &mut StopReason,
584    usage: &mut Usage,
585    tx: &mpsc::Sender<StreamEvent>,
586    last_content_time: &mut std::time::Instant,
587) -> bool {
588    let Some(data_line) = extract_sse_data_line(frame) else {
589        return false;
590    };
591
592    let evt: Value = match serde_json::from_str(&data_line) {
593        Ok(v) => v,
594        Err(_) => return false,
595    };
596
597    handle_sse_event(evt, blocks, stop_reason, usage, tx, last_content_time).await
598}
599
600async fn handle_sse_event(
601    evt: Value,
602    blocks: &mut Vec<AssembledBlock>,
603    stop_reason: &mut StopReason,
604    usage: &mut Usage,
605    tx: &mpsc::Sender<StreamEvent>,
606    last_content_time: &mut std::time::Instant,
607) -> bool {
608    let evt_type = evt["type"].as_str().unwrap_or("");
609
610    // Debug: log all SSE events for diagnosis (with full content for debugging)
611    let evt_json = serde_json::to_string(&evt).unwrap_or_default();
612    crate::debug::debug_log().stream_chunk(evt_type, &evt_json);
613
614    // Log event handling for thinking_delta specifically
615    if evt_type == "content_block_delta" {
616        let delta_type = evt["delta"]["type"].as_str().unwrap_or("");
617        let idx = evt["index"].as_u64().unwrap_or(0) as usize;
618        log::debug!(
619            "content_block_delta: type={}, idx={}, blocks_len={}, has_block={}",
620            delta_type,
621            idx,
622            blocks.len(),
623            idx < blocks.len()
624        );
625    }
626
627    // Update last_content_time for non-ping events
628    if evt_type != "ping" {
629        *last_content_time = std::time::Instant::now();
630    }
631
632    match evt_type {
633        "message_start" => {
634            // Initial usage payload — `input_tokens` is final
635            // (they don't grow during streaming) but
636            // `output_tokens` starts near 0 and is updated by
637            // subsequent `message_delta` events.
638            *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
639            debug!(
640                "message_start: usage_json={}",
641                serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
642            );
643            debug!(
644                "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
645                usage.input_tokens,
646                usage.output_tokens,
647                usage.cache_read_input_tokens,
648                usage.cache_creation_input_tokens
649            );
650            // Send real-time usage update
651            let _ = tx
652                .send(StreamEvent::Usage {
653                    output_tokens: usage.output_tokens,
654                })
655                .await;
656        }
657        "content_block_start" => {
658            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
659            let block = &evt["content_block"];
660            let kind = block["type"].as_str().unwrap_or("");
661            while blocks.len() <= idx {
662                blocks.push(AssembledBlock::default());
663            }
664            match kind {
665                "text" => {
666                    blocks[idx] = AssembledBlock::Text(String::new());
667                }
668                "thinking" => {
669                    blocks[idx] = AssembledBlock::Thinking {
670                        text: String::new(),
671                        signature: None,
672                    };
673                }
674                "tool_use" | "server_tool_use" => {
675                    let id = block["id"].as_str().unwrap_or_default();
676                    let name = block["name"].as_str().unwrap_or_default();
677                    let is_server = kind == "server_tool_use";
678                    blocks[idx] = if is_server {
679                        AssembledBlock::ServerToolUse {
680                            id: id.into(),
681                            name: name.into(),
682                            input_json: String::new(),
683                        }
684                    } else {
685                        AssembledBlock::ToolUse {
686                            id: id.into(),
687                            name: name.into(),
688                            input_json: String::new(),
689                        }
690                    };
691                    let _ = tx
692                        .send(StreamEvent::ToolUseStart {
693                            id: id.into(),
694                            name: name.into(),
695                        })
696                        .await;
697                }
698                "web_search_tool_result" => {
699                    let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
700                    blocks[idx] = AssembledBlock::WebSearchResult {
701                        tool_use_id,
702                        content_json: String::new(),
703                    };
704                }
705                _ => {}
706            }
707        }
708        "content_block_delta" => {
709            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
710            let delta = &evt["delta"];
711            let dt = delta["type"].as_str().unwrap_or("");
712            if idx >= blocks.len() {
713                return false;
714            }
715            match (dt, &mut blocks[idx]) {
716                ("text_delta", AssembledBlock::Text(buf)) => {
717                    if let Some(t) = delta["text"].as_str() {
718                        buf.push_str(t);
719                        let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
720                    }
721                }
722                ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
723                    if let Some(t) = delta["thinking"].as_str() {
724                        text.push_str(t);
725                        log::debug!("Received thinking_delta: {} chars", t.len());
726                        let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
727                    }
728                }
729                ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
730                    if let Some(s) = delta["signature"].as_str() {
731                        signature.get_or_insert_with(String::new).push_str(s);
732                    }
733                }
734                ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
735                    if let Some(p) = delta["partial_json"].as_str() {
736                        input_json.push_str(p);
737                        let _ = tx
738                            .send(StreamEvent::ToolInputDelta {
739                                bytes_so_far: input_json.len(),
740                            })
741                            .await;
742                    }
743                }
744                ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
745                    if let Some(p) = delta["partial_json"].as_str() {
746                        input_json.push_str(p);
747                        let _ = tx
748                            .send(StreamEvent::ToolInputDelta {
749                                bytes_so_far: input_json.len(),
750                            })
751                            .await;
752                    }
753                }
754                _ => {}
755            }
756        }
757        "message_delta" => {
758            if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
759                *stop_reason = match sr {
760                    "tool_use" => StopReason::ToolUse,
761                    "max_tokens" => StopReason::MaxTokens,
762                    _ => StopReason::EndTurn,
763                };
764            }
765            // `usage` on message_delta reflects cumulative
766            // counts for the current message — most notably
767            // the final `output_tokens`.
768            *usage = merge_usage(usage.clone(), &evt["usage"]);
769            debug!(
770                "message_delta: input={}, output={}, cache_read={}, cache_created={}",
771                usage.input_tokens,
772                usage.output_tokens,
773                usage.cache_read_input_tokens,
774                usage.cache_creation_input_tokens
775            );
776            // Send real-time usage update
777            let _ = tx
778                .send(StreamEvent::Usage {
779                    output_tokens: usage.output_tokens,
780                })
781                .await;
782        }
783        "message_stop" => {
784            debug!(
785                "Message completed: stop_reason={}, usage={}",
786                match *stop_reason {
787                    StopReason::EndTurn => "end_turn",
788                    StopReason::ToolUse => "tool_use",
789                    StopReason::MaxTokens => "max_tokens",
790                },
791                usage.output_tokens
792            );
793            debug!(
794                "message_stop final usage: cache_read={}, cache_created={}",
795                usage.cache_read_input_tokens, usage.cache_creation_input_tokens
796            );
797            let _ = tx
798                .send(StreamEvent::Done(finalize_incomplete_stream(
799                    std::mem::take(blocks),
800                    stop_reason.clone(),
801                    usage.clone(),
802                )))
803                .await;
804            return true;
805        }
806        "error" => {
807            let msg = evt["error"]["message"]
808                .as_str()
809                .unwrap_or("unknown stream error")
810                .to_string();
811            let _ = tx.send(StreamEvent::Error(msg)).await;
812            return true;
813        }
814        _ => {}
815    }
816
817    false
818}
819
820fn build_chat_response(
821    blocks: Vec<AssembledBlock>,
822    stop_reason: StopReason,
823    usage: Usage,
824) -> ChatResponse {
825    let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
826    ChatResponse {
827        content,
828        stop_reason,
829        usage,
830    }
831}
832
833fn finalize_incomplete_stream(
834    blocks: Vec<AssembledBlock>,
835    stop_reason: StopReason,
836    usage: Usage,
837) -> ChatResponse {
838    build_chat_response(blocks, stop_reason, usage)
839}
840
841#[derive(Default)]
842enum AssembledBlock {
843    #[default]
844    Empty,
845    Text(String),
846    Thinking {
847        text: String,
848        signature: Option<String>,
849    },
850    ToolUse {
851        id: String,
852        name: String,
853        input_json: String,
854    },
855    ServerToolUse {
856        id: String,
857        name: String,
858        input_json: String,
859    },
860    WebSearchResult {
861        tool_use_id: String,
862        content_json: String,
863    },
864}
865
866impl AssembledBlock {
867    fn finish(self) -> Option<ContentBlock> {
868        match self {
869            AssembledBlock::Empty => None,
870            AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
871            AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
872                thinking: text,
873                signature,
874            }),
875            AssembledBlock::ToolUse {
876                id,
877                name,
878                input_json,
879            } => {
880                let input: Value = if input_json.is_empty() {
881                    json!({})
882                } else {
883                    serde_json::from_str(&input_json).unwrap_or(json!({}))
884                };
885                Some(ContentBlock::ToolUse { id, name, input })
886            }
887            AssembledBlock::ServerToolUse {
888                id,
889                name,
890                input_json,
891            } => {
892                let input: Value = if input_json.is_empty() {
893                    json!({})
894                } else {
895                    serde_json::from_str(&input_json).unwrap_or(json!({}))
896                };
897                Some(ContentBlock::ServerToolUse { id, name, input })
898            }
899            AssembledBlock::WebSearchResult {
900                tool_use_id,
901                content_json,
902            } => {
903                let content: Value = if content_json.is_empty() {
904                    json!({"results": []})
905                } else {
906                    serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
907                };
908                Some(ContentBlock::WebSearchResult {
909                    tool_use_id,
910                    content: parse_web_search_content(&content),
911                })
912            }
913        }
914    }
915}
916
917/// Parse the provider's `usage` blob (non-streaming response) into our
918/// internal `Usage` struct. Missing fields default to 0.
919fn parse_usage(u: &Value) -> Usage {
920    Usage {
921        input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
922        output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
923        cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
924        cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
925    }
926}
927
928/// Merge a fresh usage payload into the accumulated one. Non-zero new values
929/// override prior ones — this matches the streaming protocol where
930/// `message_start` gives input counts and `message_delta` updates output.
931fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
932    let fresh = parse_usage(u);
933    if fresh.input_tokens > 0 {
934        acc.input_tokens = fresh.input_tokens;
935    }
936    if fresh.output_tokens > 0 {
937        acc.output_tokens = fresh.output_tokens;
938    }
939    if fresh.cache_creation_input_tokens > 0 {
940        acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
941    }
942    if fresh.cache_read_input_tokens > 0 {
943        acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
944    }
945    acc
946}
947
948/// Parse web search content from the API response.
949fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
950    let results = value["results"]
951        .as_array()
952        .map(|arr| {
953            arr.iter()
954                .filter_map(|item| {
955                    Some(crate::providers::WebSearchResultItem {
956                        title: item["title"].as_str().map(String::from),
957                        url: item["url"].as_str()?.to_string(),
958                        encrypted_content: item["encrypted_content"].as_str().map(String::from),
959                        snippet: item["snippet"].as_str().map(String::from),
960                    })
961                })
962                .collect()
963        })
964        .unwrap_or_default();
965
966    crate::providers::WebSearchContent { results }
967}
968
969#[cfg(test)]
970mod tests {
971    use super::*;
972
973    #[test]
974    fn take_next_sse_frame_supports_crlf_delimiter() {
975        let mut buffer = concat!(
976            "event: message_start\r\n",
977            "data: {\"type\":\"message_start\"}\r\n\r\n",
978            "data: {\"type\":\"message_stop\"}\r\n\r\n"
979        )
980        .to_string();
981
982        let first = take_next_sse_frame(&mut buffer).expect("first frame");
983        assert!(first.contains("message_start"));
984
985        let second = take_next_sse_frame(&mut buffer).expect("second frame");
986        assert!(second.contains("message_stop"));
987        assert!(buffer.is_empty());
988    }
989
990    #[test]
991    fn take_trailing_sse_frame_returns_unterminated_event() {
992        let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
993        let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
994        assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
995        assert!(buffer.is_empty());
996    }
997
998    #[test]
999    fn extract_sse_data_line_supports_optional_space() {
1000        assert_eq!(
1001            extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
1002            Some("{\"k\":1}".to_string())
1003        );
1004        assert_eq!(
1005            extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
1006            Some("{\"k\":2}".to_string())
1007        );
1008    }
1009
1010    #[test]
1011    fn finalize_incomplete_stream_keeps_accumulated_content() {
1012        let response = finalize_incomplete_stream(
1013            vec![AssembledBlock::Text("partial reply".to_string())],
1014            StopReason::EndTurn,
1015            Usage::default(),
1016        );
1017
1018        assert_eq!(response.stop_reason, StopReason::EndTurn);
1019        assert_eq!(response.content.len(), 1);
1020        match &response.content[0] {
1021            ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
1022            other => panic!("unexpected block: {other:?}"),
1023        }
1024    }
1025}