Skip to main content

matrixcode_core/providers/
anthropic.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use tokio::sync::mpsc;
7
8use crate::models::context_window_for;
9use crate::tools::ToolDefinition;
10
11use super::{
12    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
13    StreamEvent, Usage,
14};
15
16pub struct AnthropicProvider {
17    api_key: String,
18    model: String,
19    base_url: String,
20    client: reqwest::Client,
21    /// Extra headers from config
22    extra_headers: Vec<(String, String)>,
23}
24
25impl AnthropicProvider {
26    pub fn new(api_key: String, model: String, base_url: String) -> Self {
27        Self::with_headers(api_key, model, base_url, None)
28    }
29
30    pub fn with_headers(
31        api_key: String,
32        model: String,
33        base_url: String,
34        extra_headers: Option<std::collections::HashMap<String, String>>,
35    ) -> Self {
36        // Create client with better timeout handling for streaming
37        // - No total timeout (streaming responses can take a long time)
38        // - Connect timeout: 10 seconds
39        // - Read timeout per chunk: 60 seconds (for slow responses between chunks)
40        let client = reqwest::Client::builder()
41            .connect_timeout(std::time::Duration::from_secs(10))
42            .read_timeout(std::time::Duration::from_secs(60))
43            .timeout(std::time::Duration::from_secs(300)) // Total timeout for non-streaming
44            .build()
45            .unwrap_or_else(|_| reqwest::Client::new());
46        let extra_headers: Vec<(String, String)> = extra_headers
47            .map(|h| h.into_iter().collect())
48            .unwrap_or_default();
49        Self {
50            api_key,
51            model,
52            base_url,
53            client,
54            extra_headers,
55        }
56    }
57
58    /// Check if this is the official Anthropic API endpoint.
59    /// Non-official endpoints typically use Bearer auth (OpenAI-compatible).
60    fn is_official_anthropic(&self) -> bool {
61        self.base_url.contains("api.anthropic.com")
62    }
63
64    fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
65        // For non-official APIs (like glm-5/DashScope), filter out thinking blocks
66        // to prevent API from entering extended thinking mode and hanging
67        // This is necessary because some APIs don't properly support thinking blocks
68        let filter_thinking = !self.is_official_anthropic();
69        log::debug!("convert_messages: filter_thinking={}, base_url={}", filter_thinking, self.base_url);
70
71        // Count thinking blocks in original messages
72        let mut thinking_count = 0;
73        for m in messages {
74            if let MessageContent::Blocks(blocks) = &m.content {
75                for b in blocks {
76                    if matches!(b, ContentBlock::Thinking { .. }) {
77                        thinking_count += 1;
78                    }
79                }
80            }
81        }
82        if thinking_count > 0 {
83            log::debug!("convert_messages: Found {} thinking blocks in {} messages, filter_thinking={}", thinking_count, messages.len(), filter_thinking);
84        }
85
86        messages
87            .iter()
88            .filter(|m| m.role != Role::System)
89            .map(|m| {
90                let role = match m.role {
91                    Role::User | Role::Tool => "user",
92                    Role::Assistant => "assistant",
93                    Role::System => unreachable!(),
94                };
95
96                let content = match &m.content {
97                    MessageContent::Text(text) => json!(text),
98                    MessageContent::Blocks(blocks) => {
99                        let converted: Vec<Value> = blocks
100                            .iter()
101                            .filter(|b| {
102                                // Skip thinking blocks for non-official APIs
103                                if filter_thinking && matches!(b, ContentBlock::Thinking { .. }) {
104                                    return false;
105                                }
106                                true
107                            })
108                            .map(|b| match b {
109                                ContentBlock::Text { text } => json!({"type": "text", "text": text}),
110                                ContentBlock::ToolUse { id, name, input } => {
111                                    json!({"type": "tool_use", "id": id, "name": name, "input": input})
112                                }
113                                ContentBlock::ToolResult { tool_use_id, content } => {
114                                    json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
115                                }
116                                ContentBlock::Thinking { thinking, signature } => {
117                                    let mut obj = json!({"type": "thinking", "thinking": thinking});
118                                    // Only send non-empty signature
119                                    if let Some(sig) = signature {
120                                        if !sig.is_empty() {
121                                            obj["signature"] = json!(sig);
122                                        }
123                                    }
124                                    obj
125                                }
126                                ContentBlock::ServerToolUse { id, name, input } => {
127                                    json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
128                                }
129                                ContentBlock::WebSearchResult { tool_use_id, content } => {
130                                    json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
131                                }
132                            })
133                            .collect();
134
135                        if converted.is_empty() {
136                            json!([])
137                        } else {
138                            json!(converted)
139                        }
140                    }
141                };
142
143                // Skip messages with empty content
144                if content.is_array() && content.as_array().map(|a| a.is_empty()).unwrap_or(false) {
145                    return json!(null);
146                }
147
148                json!({"role": role, "content": content})
149            })
150            .filter(|v| !v.is_null())
151            .collect()
152    }
153
154    /// Convert tools with caching control for Anthropic prompt caching.
155    fn convert_tools_with_caching(
156        &self,
157        tools: &[ToolDefinition],
158        enable_caching: bool,
159    ) -> Vec<Value> {
160        let mut converted: Vec<Value> = tools
161            .iter()
162            .map(|t| {
163                json!({
164                    "name": t.name,
165                    "description": t.description,
166                    "input_schema": t.parameters,
167                })
168            })
169            .collect();
170
171        // Add cache_control to the last tool definition for tools caching
172        if enable_caching && !converted.is_empty() {
173            let last_idx = converted.len() - 1;
174            if let Some(obj) = converted[last_idx].as_object_mut() {
175                obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
176            }
177        }
178
179        converted
180    }
181
182    /// Build the base JSON body shared by streaming and non-streaming requests.
183    fn build_body(&self, request: &ChatRequest) -> Value {
184        let mut body = json!({
185            "model": self.model,
186            "max_tokens": request.max_tokens,
187            "messages": self.convert_messages(&request.messages),
188        });
189
190        // Add prompt caching for system prompt (Anthropic-specific)
191        if request.enable_caching {
192            if let Some(system) = &request.system {
193                // System prompt caching: add cache_control to enable caching
194                body["system"] = json!([
195                    {
196                        "type": "text",
197                        "text": system,
198                        "cache_control": {"type": "ephemeral"}
199                    }
200                ]);
201            }
202        } else if let Some(system) = &request.system {
203            body["system"] = json!(system);
204        }
205
206        if !request.tools.is_empty() {
207            let tools = self.convert_tools_with_caching(
208                &request.tools,
209                request.enable_caching,
210            );
211            body["tools"] = json!(tools);
212        }
213
214        if !request.server_tools.is_empty() {
215            body["tools"] = json!(
216                body["tools"]
217                    .as_array()
218                    .map(|t| {
219                        let mut tools = t.clone();
220                        for st in &request.server_tools {
221                            tools.push(serde_json::to_value(st).unwrap_or_default());
222                        }
223                        tools
224                    })
225                    .unwrap_or_else(|| request
226                        .server_tools
227                        .iter()
228                        .map(|st| serde_json::to_value(st).unwrap_or_default())
229                        .collect())
230            );
231        }
232
233        // Extended thinking (Anthropic-specific)
234        // Only enable thinking for official Anthropic API - third-party APIs like DashScope
235        // forcibly enable thinking mode which causes hanging issues
236        if request.think && self.is_official_anthropic() {
237            let config = thinking_config(&self.model);
238            log::debug!(
239                "Adding thinking config for model {}: {:?}",
240                self.model,
241                config
242            );
243            body["thinking"] = config;
244        } else if request.think {
245            log::debug!(
246                "Skipping thinking config for non-official API (model: {}, base_url: {})",
247                self.model,
248                self.base_url
249            );
250        }
251
252        body
253    }
254}
255
256/// Models that require the new `adaptive` thinking mode instead of the
257/// legacy `enabled`+`budget_tokens` form. Conservative allow-list: if we
258/// don't recognize the name we default to the legacy shape (which older
259/// models and most third-party gateways understand).
260fn thinking_config(model: &str) -> Value {
261    let m = model.to_lowercase();
262    // New models (2025+) use adaptive thinking
263    let adaptive = m.contains("opus-4")
264        || m.contains("sonnet-4")
265        || m.contains("claude-4")
266        || m.contains("20250")
267        || m.contains("2025");
268    if adaptive {
269        json!({"type": "enabled", "budget_tokens": 10000})
270    } else {
271         // to prevent hanging on long histories
272        json!({"type": "enabled", "budget_tokens": 5000})
273    }
274}
275
276#[async_trait]
277impl Provider for AnthropicProvider {
278    fn context_size(&self) -> Option<u32> {
279        context_window_for(&self.model)
280    }
281
282    fn clone_box(&self) -> Box<dyn Provider> {
283        Box::new(Self {
284            api_key: self.api_key.clone(),
285            model: self.model.clone(),
286            base_url: self.base_url.clone(),
287            client: reqwest::Client::new(),
288            extra_headers: self.extra_headers.clone(),
289        })
290    }
291
292    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
293        let body = self.build_body(&request);
294
295        let url = format!("{}/v1/messages", self.base_url);
296
297        // Debug: log request
298        crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
299
300        let mut req = self
301            .client
302            .post(&url)
303            .header("User-Agent", "curl/8.0")
304            .json(&body);
305
306        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
307        if self.is_official_anthropic() {
308            req = req
309                .header("x-api-key", &self.api_key)
310                .header("anthropic-version", "2025-04-15")
311                .header("anthropic-beta", "prompt-caching-2024-07-31");
312        } else {
313            req = req
314                .header("Authorization", format!("Bearer {}", self.api_key))
315                // Add anthropic-version for third-party APIs to ensure correct protocol behavior
316                .header("anthropic-version", "2023-06-01")
317                // Try enabling prompt caching for third-party APIs (DashScope may support this)
318                .header("anthropic-beta", "prompt-caching-2024-07-31");
319        }
320
321        // Add extra headers from config (all custom headers go here)
322        for (name, value) in &self.extra_headers {
323            req = req.header(name, value);
324        }
325
326        let response = req.send().await?;
327
328        let status = response.status();
329        let response_body: Value = response.json().await?;
330
331        // Debug: log response
332        crate::debug::debug_log().api_response(status.as_u16(), &serde_json::to_string(&response_body).unwrap_or_default());
333
334        if !status.is_success() {
335            let err_msg = response_body["error"]["message"]
336                .as_str()
337                .unwrap_or("unknown error");
338            anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
339        }
340
341        let stop_reason = match response_body["stop_reason"].as_str() {
342            Some("tool_use") => StopReason::ToolUse,
343            Some("max_tokens") => StopReason::MaxTokens,
344            _ => StopReason::EndTurn,
345        };
346
347        let content = response_body["content"]
348            .as_array()
349            .unwrap_or(&vec![])
350            .iter()
351            .filter_map(|block| match block["type"].as_str()? {
352                "text" => Some(ContentBlock::Text {
353                    text: block["text"].as_str()?.to_string(),
354                }),
355                "tool_use" => Some(ContentBlock::ToolUse {
356                    id: block["id"].as_str()?.to_string(),
357                    name: block["name"].as_str()?.to_string(),
358                    input: block["input"].clone(),
359                }),
360                "thinking" => Some(ContentBlock::Thinking {
361                    thinking: block["thinking"].as_str()?.to_string(),
362                    signature: block["signature"].as_str().map(String::from),
363                }),
364                "server_tool_use" => Some(ContentBlock::ServerToolUse {
365                    id: block["id"].as_str()?.to_string(),
366                    name: block["name"].as_str()?.to_string(),
367                    input: block["input"].clone(),
368                }),
369                "web_search_tool_result" => {
370                    let tool_use_id = block["tool_use_id"].as_str()?.to_string();
371                    let content = parse_web_search_content(&block["content"]);
372                    Some(ContentBlock::WebSearchResult {
373                        tool_use_id,
374                        content,
375                    })
376                }
377                _ => None,
378            })
379            .collect();
380
381        Ok(ChatResponse {
382            content,
383            stop_reason,
384            usage: parse_usage(&response_body["usage"]),
385        })
386    }
387
388    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
389        let mut body = self.build_body(&request);
390        body["stream"] = json!(true);
391
392        let url = format!("{}/v1/messages", self.base_url);
393
394        // Debug: log streaming request
395        crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
396
397        let mut req = self
398            .client
399            .post(&url)
400            .header("User-Agent", "curl/8.0")
401            .json(&body);
402
403        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
404        if self.is_official_anthropic() {
405            req = req
406                .header("x-api-key", &self.api_key)
407                .header("anthropic-version", "2025-04-15")
408                .header("anthropic-beta", "prompt-caching-2024-07-31");
409        } else {
410            req = req
411                .header("Authorization", format!("Bearer {}", self.api_key))
412                // Add anthropic-version for third-party APIs to ensure correct protocol behavior
413                .header("anthropic-version", "2023-06-01")
414                // Try enabling prompt caching for third-party APIs (DashScope may support this)
415                .header("anthropic-beta", "prompt-caching-2024-07-31");
416        }
417
418        // Add extra headers from config (all custom headers go here)
419        for (name, value) in &self.extra_headers {
420            req = req.header(name, value);
421        }
422
423        let response = req.send().await?;
424
425        if !response.status().is_success() {
426            let status = response.status();
427            let text = response.text().await.unwrap_or_default();
428            anyhow::bail!("Anthropic API error ({}): {}", status, text);
429        }
430
431        let (tx, rx) = mpsc::channel(64);
432        tokio::spawn(async move {
433            let mut stream = response.bytes_stream();
434            let mut buffer = String::new();
435            let mut sent_first_byte = false;
436
437            // In-flight block assembly: index → partial data
438            let mut blocks: Vec<AssembledBlock> = Vec::new();
439            let mut stop_reason = StopReason::EndTurn;
440            let mut usage = Usage::default();
441
442            // Timeout detection: track last meaningful event (non-ping)
443            let mut last_content_time = std::time::Instant::now();
444            const CONTENT_TIMEOUT_SECS: u64 = 120; // 120 seconds without content = timeout (increased for slow APIs like DashScope/glm-5)
445
446            while let Some(chunk) = stream.next().await {
447                let chunk = match chunk {
448                    Ok(c) => c,
449                    Err(e) => {
450                        // Log detailed error for debugging
451                        let error_msg = e.to_string();
452                        let is_timeout = error_msg.contains("timeout") || error_msg.contains("timed out");
453                        let is_decode = error_msg.contains("decode") || error_msg.contains("decoding");
454
455                        let msg = if is_timeout {
456                            format!("Stream timeout - the API took too long to respond: {}", error_msg)
457                        } else if is_decode {
458                            format!("Stream decode error - possibly interrupted or corrupted response: {}", error_msg)
459                        } else {
460                            format!("Stream read error: {}", error_msg)
461                        };
462
463                        // Try to finalize any partial content we have
464                        if sent_first_byte && !blocks.is_empty() {
465                            debug!("finalizing partial stream due to error");
466                            let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
467                                std::mem::take(&mut blocks),
468                                stop_reason,
469                                usage,
470                            ))).await;
471                        } else {
472                            let _ = tx.send(StreamEvent::Error(msg)).await;
473                        }
474                        return;
475                    }
476                };
477
478                if !sent_first_byte {
479                    sent_first_byte = true;
480                    let _ = tx.send(StreamEvent::FirstByte).await;
481                }
482
483                buffer.push_str(&String::from_utf8_lossy(&chunk));
484
485                // Check for timeout: if only ping events for too long, force finalize
486                let elapsed = last_content_time.elapsed().as_secs();
487                if elapsed > CONTENT_TIMEOUT_SECS && !blocks.is_empty() {
488                    crate::debug::debug_log().stream_chunk("TIMEOUT_FORCE_FINALIZE",
489                        &format!("elapsed={}s, blocks={}", elapsed, blocks.len()));
490                    let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
491                        std::mem::take(&mut blocks),
492                        stop_reason,
493                        usage,
494                    ))).await;
495                    return;
496                }
497
498                while let Some(frame) = take_next_sse_frame(&mut buffer) {
499                    if handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx, &mut last_content_time)
500                        .await
501                    {
502                        return;
503                    }
504                }
505            }
506
507            if let Some(frame) = take_trailing_sse_frame(&mut buffer)
508                && handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx, &mut last_content_time).await
509            {
510                return;
511            }
512
513            if sent_first_byte {
514                debug!("stream ended without explicit message_stop; finalizing best-effort");
515                let _ = tx
516                    .send(StreamEvent::Done(finalize_incomplete_stream(
517                        std::mem::take(&mut blocks),
518                        stop_reason,
519                        usage,
520                    )))
521                    .await;
522            } else {
523                let _ = tx
524                    .send(StreamEvent::Error(
525                        "stream ended before any events were received".to_string(),
526                    ))
527                    .await;
528            }
529        });
530
531        Ok(rx)
532    }
533}
534
535fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
536    let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
537    let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
538    let (pos, delim_len) = match (lf, crlf) {
539        (Some(a), Some(b)) => {
540            if a.0 <= b.0 {
541                a
542            } else {
543                b
544            }
545        }
546        (Some(a), None) => a,
547        (None, Some(b)) => b,
548        (None, None) => return None,
549    };
550
551    let frame = buffer[..pos].to_string();
552    buffer.drain(..pos + delim_len);
553    Some(frame)
554}
555
556fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
557    let frame = buffer.trim().trim_end_matches('\r').to_string();
558    buffer.clear();
559    if frame.is_empty() { None } else { Some(frame) }
560}
561
562fn extract_sse_data_line(frame: &str) -> Option<String> {
563    for line in frame.lines() {
564        let line = line.trim_end_matches('\r');
565        // Support both "data: " (Anthropic) and "data:" (DashScope)
566        if let Some(rest) = line.strip_prefix("data: ") {
567            return Some(rest.to_string());
568        }
569        if let Some(rest) = line.strip_prefix("data:") {
570            return Some(rest.to_string());
571        }
572    }
573    None
574}
575
576async fn handle_sse_frame(
577    frame: &str,
578    blocks: &mut Vec<AssembledBlock>,
579    stop_reason: &mut StopReason,
580    usage: &mut Usage,
581    tx: &mpsc::Sender<StreamEvent>,
582    last_content_time: &mut std::time::Instant,
583) -> bool {
584    let Some(data_line) = extract_sse_data_line(frame) else {
585        return false;
586    };
587
588    let evt: Value = match serde_json::from_str(&data_line) {
589        Ok(v) => v,
590        Err(_) => return false,
591    };
592
593    handle_sse_event(evt, blocks, stop_reason, usage, tx, last_content_time).await
594}
595
596async fn handle_sse_event(
597    evt: Value,
598    blocks: &mut Vec<AssembledBlock>,
599    stop_reason: &mut StopReason,
600    usage: &mut Usage,
601    tx: &mpsc::Sender<StreamEvent>,
602    last_content_time: &mut std::time::Instant,
603) -> bool {
604    let evt_type = evt["type"].as_str().unwrap_or("");
605
606    // Debug: log all SSE events for diagnosis (with full content for debugging)
607    let evt_json = serde_json::to_string(&evt).unwrap_or_default();
608    crate::debug::debug_log().stream_chunk(evt_type, &evt_json);
609
610    // Log event handling for thinking_delta specifically
611    if evt_type == "content_block_delta" {
612        let delta_type = evt["delta"]["type"].as_str().unwrap_or("");
613        let idx = evt["index"].as_u64().unwrap_or(0) as usize;
614        log::debug!(
615            "content_block_delta: type={}, idx={}, blocks_len={}, has_block={}",
616            delta_type,
617            idx,
618            blocks.len(),
619            idx < blocks.len()
620        );
621    }
622
623    // Update last_content_time for non-ping events
624    if evt_type != "ping" {
625        *last_content_time = std::time::Instant::now();
626    }
627
628    match evt_type {
629        "message_start" => {
630            // Initial usage payload — `input_tokens` is final
631            // (they don't grow during streaming) but
632            // `output_tokens` starts near 0 and is updated by
633            // subsequent `message_delta` events.
634            *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
635            debug!(
636                "message_start: usage_json={}",
637                serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
638            );
639            debug!(
640                "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
641                usage.input_tokens,
642                usage.output_tokens,
643                usage.cache_read_input_tokens,
644                usage.cache_creation_input_tokens
645            );
646            // Send real-time usage update
647            let _ = tx
648                .send(StreamEvent::Usage {
649                    output_tokens: usage.output_tokens,
650                })
651                .await;
652        }
653        "content_block_start" => {
654            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
655            let block = &evt["content_block"];
656            let kind = block["type"].as_str().unwrap_or("");
657            while blocks.len() <= idx {
658                blocks.push(AssembledBlock::default());
659            }
660            match kind {
661                "text" => {
662                    blocks[idx] = AssembledBlock::Text(String::new());
663                }
664                "thinking" => {
665                    blocks[idx] = AssembledBlock::Thinking {
666                        text: String::new(),
667                        signature: None,
668                    };
669                }
670                "tool_use" | "server_tool_use" => {
671                    let id = block["id"].as_str().unwrap_or_default();
672                    let name = block["name"].as_str().unwrap_or_default();
673                    let is_server = kind == "server_tool_use";
674                    blocks[idx] = if is_server {
675                        AssembledBlock::ServerToolUse {
676                            id: id.into(),
677                            name: name.into(),
678                            input_json: String::new(),
679                        }
680                    } else {
681                        AssembledBlock::ToolUse {
682                            id: id.into(),
683                            name: name.into(),
684                            input_json: String::new(),
685                        }
686                    };
687                    let _ = tx
688                        .send(StreamEvent::ToolUseStart {
689                            id: id.into(),
690                            name: name.into(),
691                        })
692                        .await;
693                }
694                "web_search_tool_result" => {
695                    let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
696                    blocks[idx] = AssembledBlock::WebSearchResult {
697                        tool_use_id,
698                        content_json: String::new(),
699                    };
700                }
701                _ => {}
702            }
703        }
704        "content_block_delta" => {
705            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
706            let delta = &evt["delta"];
707            let dt = delta["type"].as_str().unwrap_or("");
708            if idx >= blocks.len() {
709                return false;
710            }
711            match (dt, &mut blocks[idx]) {
712                ("text_delta", AssembledBlock::Text(buf)) => {
713                    if let Some(t) = delta["text"].as_str() {
714                        buf.push_str(t);
715                        let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
716                    }
717                }
718                ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
719                    if let Some(t) = delta["thinking"].as_str() {
720                        text.push_str(t);
721                        log::debug!("Received thinking_delta: {} chars", t.len());
722                        let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
723                    }
724                }
725                ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
726                    if let Some(s) = delta["signature"].as_str() {
727                        signature.get_or_insert_with(String::new).push_str(s);
728                    }
729                }
730                ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
731                    if let Some(p) = delta["partial_json"].as_str() {
732                        input_json.push_str(p);
733                        let _ = tx
734                            .send(StreamEvent::ToolInputDelta {
735                                bytes_so_far: input_json.len(),
736                            })
737                            .await;
738                    }
739                }
740                ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
741                    if let Some(p) = delta["partial_json"].as_str() {
742                        input_json.push_str(p);
743                        let _ = tx
744                            .send(StreamEvent::ToolInputDelta {
745                                bytes_so_far: input_json.len(),
746                            })
747                            .await;
748                    }
749                }
750                _ => {}
751            }
752        }
753        "message_delta" => {
754            if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
755                *stop_reason = match sr {
756                    "tool_use" => StopReason::ToolUse,
757                    "max_tokens" => StopReason::MaxTokens,
758                    _ => StopReason::EndTurn,
759                };
760            }
761            // `usage` on message_delta reflects cumulative
762            // counts for the current message — most notably
763            // the final `output_tokens`.
764            *usage = merge_usage(usage.clone(), &evt["usage"]);
765            debug!(
766                "message_delta: input={}, output={}, cache_read={}, cache_created={}",
767                usage.input_tokens,
768                usage.output_tokens,
769                usage.cache_read_input_tokens,
770                usage.cache_creation_input_tokens
771            );
772            // Send real-time usage update
773            let _ = tx
774                .send(StreamEvent::Usage {
775                    output_tokens: usage.output_tokens,
776                })
777                .await;
778        }
779        "message_stop" => {
780            debug!(
781                "Message completed: stop_reason={}, usage={}",
782                match *stop_reason {
783                    StopReason::EndTurn => "end_turn",
784                    StopReason::ToolUse => "tool_use",
785                    StopReason::MaxTokens => "max_tokens",
786                },
787                usage.output_tokens
788            );
789            debug!(
790                "message_stop final usage: cache_read={}, cache_created={}",
791                usage.cache_read_input_tokens, usage.cache_creation_input_tokens
792            );
793            let _ = tx
794                .send(StreamEvent::Done(finalize_incomplete_stream(
795                    std::mem::take(blocks),
796                    stop_reason.clone(),
797                    usage.clone(),
798                )))
799                .await;
800            return true;
801        }
802        "error" => {
803            let msg = evt["error"]["message"]
804                .as_str()
805                .unwrap_or("unknown stream error")
806                .to_string();
807            let _ = tx.send(StreamEvent::Error(msg)).await;
808            return true;
809        }
810        _ => {}
811    }
812
813    false
814}
815
816fn build_chat_response(
817    blocks: Vec<AssembledBlock>,
818    stop_reason: StopReason,
819    usage: Usage,
820) -> ChatResponse {
821    let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
822    ChatResponse {
823        content,
824        stop_reason,
825        usage,
826    }
827}
828
829fn finalize_incomplete_stream(
830    blocks: Vec<AssembledBlock>,
831    stop_reason: StopReason,
832    usage: Usage,
833) -> ChatResponse {
834    build_chat_response(blocks, stop_reason, usage)
835}
836
837#[derive(Default)]
838enum AssembledBlock {
839    #[default]
840    Empty,
841    Text(String),
842    Thinking {
843        text: String,
844        signature: Option<String>,
845    },
846    ToolUse {
847        id: String,
848        name: String,
849        input_json: String,
850    },
851    ServerToolUse {
852        id: String,
853        name: String,
854        input_json: String,
855    },
856    WebSearchResult {
857        tool_use_id: String,
858        content_json: String,
859    },
860}
861
862impl AssembledBlock {
863    fn finish(self) -> Option<ContentBlock> {
864        match self {
865            AssembledBlock::Empty => None,
866            AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
867            AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
868                thinking: text,
869                signature,
870            }),
871            AssembledBlock::ToolUse {
872                id,
873                name,
874                input_json,
875            } => {
876                let input: Value = if input_json.is_empty() {
877                    json!({})
878                } else {
879                    serde_json::from_str(&input_json).unwrap_or(json!({}))
880                };
881                Some(ContentBlock::ToolUse { id, name, input })
882            }
883            AssembledBlock::ServerToolUse {
884                id,
885                name,
886                input_json,
887            } => {
888                let input: Value = if input_json.is_empty() {
889                    json!({})
890                } else {
891                    serde_json::from_str(&input_json).unwrap_or(json!({}))
892                };
893                Some(ContentBlock::ServerToolUse { id, name, input })
894            }
895            AssembledBlock::WebSearchResult {
896                tool_use_id,
897                content_json,
898            } => {
899                let content: Value = if content_json.is_empty() {
900                    json!({"results": []})
901                } else {
902                    serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
903                };
904                Some(ContentBlock::WebSearchResult {
905                    tool_use_id,
906                    content: parse_web_search_content(&content),
907                })
908            }
909        }
910    }
911}
912
913/// Parse the provider's `usage` blob (non-streaming response) into our
914/// internal `Usage` struct. Missing fields default to 0.
915fn parse_usage(u: &Value) -> Usage {
916    Usage {
917        input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
918        output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
919        cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
920        cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
921    }
922}
923
924/// Merge a fresh usage payload into the accumulated one. Non-zero new values
925/// override prior ones — this matches the streaming protocol where
926/// `message_start` gives input counts and `message_delta` updates output.
927fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
928    let fresh = parse_usage(u);
929    if fresh.input_tokens > 0 {
930        acc.input_tokens = fresh.input_tokens;
931    }
932    if fresh.output_tokens > 0 {
933        acc.output_tokens = fresh.output_tokens;
934    }
935    if fresh.cache_creation_input_tokens > 0 {
936        acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
937    }
938    if fresh.cache_read_input_tokens > 0 {
939        acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
940    }
941    acc
942}
943
944/// Parse web search content from the API response.
945fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
946    let results = value["results"]
947        .as_array()
948        .map(|arr| {
949            arr.iter()
950                .filter_map(|item| {
951                    Some(crate::providers::WebSearchResultItem {
952                        title: item["title"].as_str().map(String::from),
953                        url: item["url"].as_str()?.to_string(),
954                        encrypted_content: item["encrypted_content"].as_str().map(String::from),
955                        snippet: item["snippet"].as_str().map(String::from),
956                    })
957                })
958                .collect()
959        })
960        .unwrap_or_default();
961
962    crate::providers::WebSearchContent { results }
963}
964
965#[cfg(test)]
966mod tests {
967    use super::*;
968
969    #[test]
970    fn take_next_sse_frame_supports_crlf_delimiter() {
971        let mut buffer = concat!(
972            "event: message_start\r\n",
973            "data: {\"type\":\"message_start\"}\r\n\r\n",
974            "data: {\"type\":\"message_stop\"}\r\n\r\n"
975        )
976        .to_string();
977
978        let first = take_next_sse_frame(&mut buffer).expect("first frame");
979        assert!(first.contains("message_start"));
980
981        let second = take_next_sse_frame(&mut buffer).expect("second frame");
982        assert!(second.contains("message_stop"));
983        assert!(buffer.is_empty());
984    }
985
986    #[test]
987    fn take_trailing_sse_frame_returns_unterminated_event() {
988        let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
989        let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
990        assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
991        assert!(buffer.is_empty());
992    }
993
994    #[test]
995    fn extract_sse_data_line_supports_optional_space() {
996        assert_eq!(
997            extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
998            Some("{\"k\":1}".to_string())
999        );
1000        assert_eq!(
1001            extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
1002            Some("{\"k\":2}".to_string())
1003        );
1004    }
1005
1006    #[test]
1007    fn finalize_incomplete_stream_keeps_accumulated_content() {
1008        let response = finalize_incomplete_stream(
1009            vec![AssembledBlock::Text("partial reply".to_string())],
1010            StopReason::EndTurn,
1011            Usage::default(),
1012        );
1013
1014        assert_eq!(response.stop_reason, StopReason::EndTurn);
1015        assert_eq!(response.content.len(), 1);
1016        match &response.content[0] {
1017            ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
1018            other => panic!("unexpected block: {other:?}"),
1019        }
1020    }
1021}