Skip to main content

matrixcode_core/providers/
anthropic.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use tokio::sync::mpsc;
7
8use crate::models::context_window_for;
9use crate::tools::ToolDefinition;
10
11use super::{
12    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
13    StreamEvent, Usage,
14};
15
16pub struct AnthropicProvider {
17    api_key: String,
18    model: String,
19    base_url: String,
20    client: reqwest::Client,
21    /// Extra headers from config
22    extra_headers: Vec<(String, String)>,
23}
24
25impl AnthropicProvider {
26    pub fn new(api_key: String, model: String, base_url: String) -> Self {
27        Self::with_headers(api_key, model, base_url, None)
28    }
29
30    pub fn with_headers(
31        api_key: String,
32        model: String,
33        base_url: String,
34        extra_headers: Option<std::collections::HashMap<String, String>>,
35    ) -> Self {
36        // Create client with better timeout handling for streaming
37        // - No total timeout (streaming responses can take a long time)
38        // - Connect timeout: 10 seconds
39        // - Read timeout per chunk: 60 seconds (for slow responses between chunks)
40        let client = reqwest::Client::builder()
41            .connect_timeout(std::time::Duration::from_secs(10))
42            .read_timeout(std::time::Duration::from_secs(60))
43            .timeout(std::time::Duration::from_secs(300)) // Total timeout for non-streaming
44            .build()
45            .unwrap_or_else(|_| reqwest::Client::new());
46        let extra_headers: Vec<(String, String)> = extra_headers
47            .map(|h| h.into_iter().collect())
48            .unwrap_or_default();
49        Self {
50            api_key,
51            model,
52            base_url,
53            client,
54            extra_headers,
55        }
56    }
57
58    /// Check if this is the official Anthropic API endpoint.
59    /// Non-official endpoints typically use Bearer auth (OpenAI-compatible).
60    fn is_official_anthropic(&self) -> bool {
61        self.base_url.contains("api.anthropic.com")
62    }
63
64    fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
65        messages
66            .iter()
67            .filter(|m| m.role != Role::System)
68            .map(|m| {
69                let role = match m.role {
70                    Role::User | Role::Tool => "user",
71                    Role::Assistant => "assistant",
72                    Role::System => unreachable!(),
73                };
74
75                let content = match &m.content {
76                    MessageContent::Text(text) => json!(text),
77                    MessageContent::Blocks(blocks) => {
78                        let converted: Vec<Value> = blocks
79                            .iter()
80                            .map(|b| match b {
81                                ContentBlock::Text { text } => json!({"type": "text", "text": text}),
82                                ContentBlock::ToolUse { id, name, input } => {
83                                    json!({"type": "tool_use", "id": id, "name": name, "input": input})
84                                }
85                                ContentBlock::ToolResult { tool_use_id, content } => {
86                                    json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
87                                }
88                                ContentBlock::Thinking { thinking, signature } => {
89                                    let mut obj = json!({"type": "thinking", "thinking": thinking});
90                                    if let Some(sig) = signature {
91                                        obj["signature"] = json!(sig);
92                                    }
93                                    obj
94                                }
95                                ContentBlock::ServerToolUse { id, name, input } => {
96                                    json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
97                                }
98                                ContentBlock::WebSearchResult { tool_use_id, content } => {
99                                    json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
100                                }
101                            })
102                            .collect();
103                        json!(converted)
104                    }
105                };
106
107                json!({"role": role, "content": content})
108            })
109            .collect()
110    }
111
112    /// Convert tools with caching control for Anthropic prompt caching.
113    fn convert_tools_with_caching(
114        &self,
115        tools: &[ToolDefinition],
116        enable_caching: bool,
117    ) -> Vec<Value> {
118        let mut converted: Vec<Value> = tools
119            .iter()
120            .map(|t| {
121                json!({
122                    "name": t.name,
123                    "description": t.description,
124                    "input_schema": t.parameters,
125                })
126            })
127            .collect();
128
129        // Add cache_control to the last tool definition for tools caching
130        if enable_caching && !converted.is_empty() {
131            let last_idx = converted.len() - 1;
132            if let Some(obj) = converted[last_idx].as_object_mut() {
133                obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
134            }
135        }
136
137        converted
138    }
139
140    /// Build the base JSON body shared by streaming and non-streaming requests.
141    fn build_body(&self, request: &ChatRequest) -> Value {
142        let mut body = json!({
143            "model": self.model,
144            "max_tokens": request.max_tokens,
145            "messages": self.convert_messages(&request.messages),
146        });
147
148        // Add prompt caching for system prompt (Anthropic-specific)
149        if request.enable_caching {
150            if let Some(system) = &request.system {
151                // System prompt caching: add cache_control to enable caching
152                body["system"] = json!([
153                    {
154                        "type": "text",
155                        "text": system,
156                        "cache_control": {"type": "ephemeral"}
157                    }
158                ]);
159            }
160        } else if let Some(system) = &request.system {
161            body["system"] = json!(system);
162        }
163
164        if !request.tools.is_empty() {
165            let tools = self.convert_tools_with_caching(
166                &request.tools,
167                request.enable_caching,
168            );
169            body["tools"] = json!(tools);
170        }
171
172        if !request.server_tools.is_empty() {
173            body["tools"] = json!(
174                body["tools"]
175                    .as_array()
176                    .map(|t| {
177                        let mut tools = t.clone();
178                        for st in &request.server_tools {
179                            tools.push(serde_json::to_value(st).unwrap_or_default());
180                        }
181                        tools
182                    })
183                    .unwrap_or_else(|| request
184                        .server_tools
185                        .iter()
186                        .map(|st| serde_json::to_value(st).unwrap_or_default())
187                        .collect())
188            );
189        }
190
191        // Extended thinking (Anthropic-specific)
192        if request.think {
193            let config = thinking_config(&self.model);
194            log::debug!(
195                "Adding thinking config for model {}: {:?}",
196                self.model,
197                config
198            );
199            body["thinking"] = config;
200        }
201
202        body
203    }
204}
205
206/// Models that require the new `adaptive` thinking mode instead of the
207/// legacy `enabled`+`budget_tokens` form. Conservative allow-list: if we
208/// don't recognize the name we default to the legacy shape (which older
209/// models and most third-party gateways understand).
210fn thinking_config(model: &str) -> Value {
211    let m = model.to_lowercase();
212    // New models (2025+) use adaptive thinking
213    let adaptive = m.contains("opus-4")
214        || m.contains("sonnet-4")
215        || m.contains("claude-4")
216        || m.contains("20250")
217        || m.contains("2025");
218    if adaptive {
219        json!({"type": "enabled", "budget_tokens": 10000})
220    } else {
221        json!({"type": "enabled", "budget_tokens": 5000})
222    }
223}
224
225#[async_trait]
226impl Provider for AnthropicProvider {
227    fn context_size(&self) -> Option<u32> {
228        context_window_for(&self.model)
229    }
230
231    fn clone_box(&self) -> Box<dyn Provider> {
232        Box::new(Self {
233            api_key: self.api_key.clone(),
234            model: self.model.clone(),
235            base_url: self.base_url.clone(),
236            client: reqwest::Client::new(),
237            extra_headers: self.extra_headers.clone(),
238        })
239    }
240
241    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
242        let body = self.build_body(&request);
243
244        let url = format!("{}/v1/messages", self.base_url);
245        let mut req = self
246            .client
247            .post(&url)
248            .header("User-Agent", "curl/8.0")
249            .json(&body);
250
251        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
252        if self.is_official_anthropic() {
253            req = req
254                .header("x-api-key", &self.api_key)
255                .header("anthropic-version", "2025-04-15")
256                .header("anthropic-beta", "prompt-caching-2024-07-31");
257        } else {
258            req = req.header("Authorization", format!("Bearer {}", self.api_key));
259        }
260
261        // Add extra headers from config (all custom headers go here)
262        for (name, value) in &self.extra_headers {
263            req = req.header(name, value);
264        }
265
266        let response = req.send().await?;
267
268        let status = response.status();
269        let response_body: Value = response.json().await?;
270
271        if !status.is_success() {
272            let err_msg = response_body["error"]["message"]
273                .as_str()
274                .unwrap_or("unknown error");
275            anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
276        }
277
278        let stop_reason = match response_body["stop_reason"].as_str() {
279            Some("tool_use") => StopReason::ToolUse,
280            Some("max_tokens") => StopReason::MaxTokens,
281            _ => StopReason::EndTurn,
282        };
283
284        let content = response_body["content"]
285            .as_array()
286            .unwrap_or(&vec![])
287            .iter()
288            .filter_map(|block| match block["type"].as_str()? {
289                "text" => Some(ContentBlock::Text {
290                    text: block["text"].as_str()?.to_string(),
291                }),
292                "tool_use" => Some(ContentBlock::ToolUse {
293                    id: block["id"].as_str()?.to_string(),
294                    name: block["name"].as_str()?.to_string(),
295                    input: block["input"].clone(),
296                }),
297                "thinking" => Some(ContentBlock::Thinking {
298                    thinking: block["thinking"].as_str()?.to_string(),
299                    signature: block["signature"].as_str().map(String::from),
300                }),
301                "server_tool_use" => Some(ContentBlock::ServerToolUse {
302                    id: block["id"].as_str()?.to_string(),
303                    name: block["name"].as_str()?.to_string(),
304                    input: block["input"].clone(),
305                }),
306                "web_search_tool_result" => {
307                    let tool_use_id = block["tool_use_id"].as_str()?.to_string();
308                    let content = parse_web_search_content(&block["content"]);
309                    Some(ContentBlock::WebSearchResult {
310                        tool_use_id,
311                        content,
312                    })
313                }
314                _ => None,
315            })
316            .collect();
317
318        Ok(ChatResponse {
319            content,
320            stop_reason,
321            usage: parse_usage(&response_body["usage"]),
322        })
323    }
324
325    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
326        let mut body = self.build_body(&request);
327        body["stream"] = json!(true);
328
329        let url = format!("{}/v1/messages", self.base_url);
330        let mut req = self
331            .client
332            .post(&url)
333            .header("User-Agent", "curl/8.0")
334            .json(&body);
335
336        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
337        if self.is_official_anthropic() {
338            req = req
339                .header("x-api-key", &self.api_key)
340                .header("anthropic-version", "2025-04-15")
341                .header("anthropic-beta", "prompt-caching-2024-07-31");
342        } else {
343            req = req.header("Authorization", format!("Bearer {}", self.api_key));
344        }
345
346        // Add extra headers from config (all custom headers go here)
347        for (name, value) in &self.extra_headers {
348            req = req.header(name, value);
349        }
350
351        let response = req.send().await?;
352
353        if !response.status().is_success() {
354            let status = response.status();
355            let text = response.text().await.unwrap_or_default();
356            anyhow::bail!("Anthropic API error ({}): {}", status, text);
357        }
358
359        let (tx, rx) = mpsc::channel(64);
360        tokio::spawn(async move {
361            let mut stream = response.bytes_stream();
362            let mut buffer = String::new();
363            let mut sent_first_byte = false;
364
365            // In-flight block assembly: index → partial data
366            let mut blocks: Vec<AssembledBlock> = Vec::new();
367            let mut stop_reason = StopReason::EndTurn;
368            let mut usage = Usage::default();
369
370            while let Some(chunk) = stream.next().await {
371                let chunk = match chunk {
372                    Ok(c) => c,
373                    Err(e) => {
374                        // Log detailed error for debugging
375                        let error_msg = e.to_string();
376                        let is_timeout = error_msg.contains("timeout") || error_msg.contains("timed out");
377                        let is_decode = error_msg.contains("decode") || error_msg.contains("decoding");
378
379                        let msg = if is_timeout {
380                            format!("Stream timeout - the API took too long to respond: {}", error_msg)
381                        } else if is_decode {
382                            format!("Stream decode error - possibly interrupted or corrupted response: {}", error_msg)
383                        } else {
384                            format!("Stream read error: {}", error_msg)
385                        };
386
387                        // Try to finalize any partial content we have
388                        if sent_first_byte && !blocks.is_empty() {
389                            debug!("finalizing partial stream due to error");
390                            let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
391                                std::mem::take(&mut blocks),
392                                stop_reason,
393                                usage,
394                            ))).await;
395                        } else {
396                            let _ = tx.send(StreamEvent::Error(msg)).await;
397                        }
398                        return;
399                    }
400                };
401
402                if !sent_first_byte {
403                    sent_first_byte = true;
404                    let _ = tx.send(StreamEvent::FirstByte).await;
405                }
406
407                buffer.push_str(&String::from_utf8_lossy(&chunk));
408
409                while let Some(frame) = take_next_sse_frame(&mut buffer) {
410                    if handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx)
411                        .await
412                    {
413                        return;
414                    }
415                }
416            }
417
418            if let Some(frame) = take_trailing_sse_frame(&mut buffer)
419                && handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx).await
420            {
421                return;
422            }
423
424            if sent_first_byte {
425                debug!("stream ended without explicit message_stop; finalizing best-effort");
426                let _ = tx
427                    .send(StreamEvent::Done(finalize_incomplete_stream(
428                        std::mem::take(&mut blocks),
429                        stop_reason,
430                        usage,
431                    )))
432                    .await;
433            } else {
434                let _ = tx
435                    .send(StreamEvent::Error(
436                        "stream ended before any events were received".to_string(),
437                    ))
438                    .await;
439            }
440        });
441
442        Ok(rx)
443    }
444}
445
446fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
447    let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
448    let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
449    let (pos, delim_len) = match (lf, crlf) {
450        (Some(a), Some(b)) => {
451            if a.0 <= b.0 {
452                a
453            } else {
454                b
455            }
456        }
457        (Some(a), None) => a,
458        (None, Some(b)) => b,
459        (None, None) => return None,
460    };
461
462    let frame = buffer[..pos].to_string();
463    buffer.drain(..pos + delim_len);
464    Some(frame)
465}
466
467fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
468    let frame = buffer.trim().trim_end_matches('\r').to_string();
469    buffer.clear();
470    if frame.is_empty() { None } else { Some(frame) }
471}
472
473fn extract_sse_data_line(frame: &str) -> Option<String> {
474    for line in frame.lines() {
475        let line = line.trim_end_matches('\r');
476        // Support both "data: " (Anthropic) and "data:" (DashScope)
477        if let Some(rest) = line.strip_prefix("data: ") {
478            return Some(rest.to_string());
479        }
480        if let Some(rest) = line.strip_prefix("data:") {
481            return Some(rest.to_string());
482        }
483    }
484    None
485}
486
487async fn handle_sse_frame(
488    frame: &str,
489    blocks: &mut Vec<AssembledBlock>,
490    stop_reason: &mut StopReason,
491    usage: &mut Usage,
492    tx: &mpsc::Sender<StreamEvent>,
493) -> bool {
494    let Some(data_line) = extract_sse_data_line(frame) else {
495        return false;
496    };
497
498    let evt: Value = match serde_json::from_str(&data_line) {
499        Ok(v) => v,
500        Err(_) => return false,
501    };
502
503    handle_sse_event(evt, blocks, stop_reason, usage, tx).await
504}
505
506async fn handle_sse_event(
507    evt: Value,
508    blocks: &mut Vec<AssembledBlock>,
509    stop_reason: &mut StopReason,
510    usage: &mut Usage,
511    tx: &mpsc::Sender<StreamEvent>,
512) -> bool {
513    match evt["type"].as_str().unwrap_or("") {
514        "message_start" => {
515            // Initial usage payload — `input_tokens` is final
516            // (they don't grow during streaming) but
517            // `output_tokens` starts near 0 and is updated by
518            // subsequent `message_delta` events.
519            *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
520            debug!(
521                "message_start: usage_json={}",
522                serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
523            );
524            debug!(
525                "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
526                usage.input_tokens,
527                usage.output_tokens,
528                usage.cache_read_input_tokens,
529                usage.cache_creation_input_tokens
530            );
531            // Send real-time usage update
532            let _ = tx
533                .send(StreamEvent::Usage {
534                    output_tokens: usage.output_tokens,
535                })
536                .await;
537        }
538        "content_block_start" => {
539            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
540            let block = &evt["content_block"];
541            let kind = block["type"].as_str().unwrap_or("");
542            while blocks.len() <= idx {
543                blocks.push(AssembledBlock::default());
544            }
545            match kind {
546                "text" => {
547                    blocks[idx] = AssembledBlock::Text(String::new());
548                }
549                "thinking" => {
550                    blocks[idx] = AssembledBlock::Thinking {
551                        text: String::new(),
552                        signature: None,
553                    };
554                }
555                "tool_use" | "server_tool_use" => {
556                    let id = block["id"].as_str().unwrap_or_default();
557                    let name = block["name"].as_str().unwrap_or_default();
558                    let is_server = kind == "server_tool_use";
559                    blocks[idx] = if is_server {
560                        AssembledBlock::ServerToolUse {
561                            id: id.into(),
562                            name: name.into(),
563                            input_json: String::new(),
564                        }
565                    } else {
566                        AssembledBlock::ToolUse {
567                            id: id.into(),
568                            name: name.into(),
569                            input_json: String::new(),
570                        }
571                    };
572                    let _ = tx
573                        .send(StreamEvent::ToolUseStart {
574                            id: id.into(),
575                            name: name.into(),
576                        })
577                        .await;
578                }
579                "web_search_tool_result" => {
580                    let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
581                    blocks[idx] = AssembledBlock::WebSearchResult {
582                        tool_use_id,
583                        content_json: String::new(),
584                    };
585                }
586                _ => {}
587            }
588        }
589        "content_block_delta" => {
590            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
591            let delta = &evt["delta"];
592            let dt = delta["type"].as_str().unwrap_or("");
593            if idx >= blocks.len() {
594                return false;
595            }
596            match (dt, &mut blocks[idx]) {
597                ("text_delta", AssembledBlock::Text(buf)) => {
598                    if let Some(t) = delta["text"].as_str() {
599                        buf.push_str(t);
600                        let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
601                    }
602                }
603                ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
604                    if let Some(t) = delta["thinking"].as_str() {
605                        text.push_str(t);
606                        log::debug!("Received thinking_delta: {} chars", t.len());
607                        let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
608                    }
609                }
610                ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
611                    if let Some(s) = delta["signature"].as_str() {
612                        signature.get_or_insert_with(String::new).push_str(s);
613                    }
614                }
615                ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
616                    if let Some(p) = delta["partial_json"].as_str() {
617                        input_json.push_str(p);
618                        let _ = tx
619                            .send(StreamEvent::ToolInputDelta {
620                                bytes_so_far: input_json.len(),
621                            })
622                            .await;
623                    }
624                }
625                ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
626                    if let Some(p) = delta["partial_json"].as_str() {
627                        input_json.push_str(p);
628                        let _ = tx
629                            .send(StreamEvent::ToolInputDelta {
630                                bytes_so_far: input_json.len(),
631                            })
632                            .await;
633                    }
634                }
635                _ => {}
636            }
637        }
638        "message_delta" => {
639            if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
640                *stop_reason = match sr {
641                    "tool_use" => StopReason::ToolUse,
642                    "max_tokens" => StopReason::MaxTokens,
643                    _ => StopReason::EndTurn,
644                };
645            }
646            // `usage` on message_delta reflects cumulative
647            // counts for the current message — most notably
648            // the final `output_tokens`.
649            *usage = merge_usage(usage.clone(), &evt["usage"]);
650            debug!(
651                "message_delta: input={}, output={}, cache_read={}, cache_created={}",
652                usage.input_tokens,
653                usage.output_tokens,
654                usage.cache_read_input_tokens,
655                usage.cache_creation_input_tokens
656            );
657            // Send real-time usage update
658            let _ = tx
659                .send(StreamEvent::Usage {
660                    output_tokens: usage.output_tokens,
661                })
662                .await;
663        }
664        "message_stop" => {
665            debug!(
666                "Message completed: stop_reason={}, usage={}",
667                match *stop_reason {
668                    StopReason::EndTurn => "end_turn",
669                    StopReason::ToolUse => "tool_use",
670                    StopReason::MaxTokens => "max_tokens",
671                },
672                usage.output_tokens
673            );
674            debug!(
675                "message_stop final usage: cache_read={}, cache_created={}",
676                usage.cache_read_input_tokens, usage.cache_creation_input_tokens
677            );
678            let _ = tx
679                .send(StreamEvent::Done(finalize_incomplete_stream(
680                    std::mem::take(blocks),
681                    stop_reason.clone(),
682                    usage.clone(),
683                )))
684                .await;
685            return true;
686        }
687        "error" => {
688            let msg = evt["error"]["message"]
689                .as_str()
690                .unwrap_or("unknown stream error")
691                .to_string();
692            let _ = tx.send(StreamEvent::Error(msg)).await;
693            return true;
694        }
695        _ => {}
696    }
697
698    false
699}
700
701fn build_chat_response(
702    blocks: Vec<AssembledBlock>,
703    stop_reason: StopReason,
704    usage: Usage,
705) -> ChatResponse {
706    let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
707    ChatResponse {
708        content,
709        stop_reason,
710        usage,
711    }
712}
713
714fn finalize_incomplete_stream(
715    blocks: Vec<AssembledBlock>,
716    stop_reason: StopReason,
717    usage: Usage,
718) -> ChatResponse {
719    build_chat_response(blocks, stop_reason, usage)
720}
721
722#[derive(Default)]
723enum AssembledBlock {
724    #[default]
725    Empty,
726    Text(String),
727    Thinking {
728        text: String,
729        signature: Option<String>,
730    },
731    ToolUse {
732        id: String,
733        name: String,
734        input_json: String,
735    },
736    ServerToolUse {
737        id: String,
738        name: String,
739        input_json: String,
740    },
741    WebSearchResult {
742        tool_use_id: String,
743        content_json: String,
744    },
745}
746
747impl AssembledBlock {
748    fn finish(self) -> Option<ContentBlock> {
749        match self {
750            AssembledBlock::Empty => None,
751            AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
752            AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
753                thinking: text,
754                signature,
755            }),
756            AssembledBlock::ToolUse {
757                id,
758                name,
759                input_json,
760            } => {
761                let input: Value = if input_json.is_empty() {
762                    json!({})
763                } else {
764                    serde_json::from_str(&input_json).unwrap_or(json!({}))
765                };
766                Some(ContentBlock::ToolUse { id, name, input })
767            }
768            AssembledBlock::ServerToolUse {
769                id,
770                name,
771                input_json,
772            } => {
773                let input: Value = if input_json.is_empty() {
774                    json!({})
775                } else {
776                    serde_json::from_str(&input_json).unwrap_or(json!({}))
777                };
778                Some(ContentBlock::ServerToolUse { id, name, input })
779            }
780            AssembledBlock::WebSearchResult {
781                tool_use_id,
782                content_json,
783            } => {
784                let content: Value = if content_json.is_empty() {
785                    json!({"results": []})
786                } else {
787                    serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
788                };
789                Some(ContentBlock::WebSearchResult {
790                    tool_use_id,
791                    content: parse_web_search_content(&content),
792                })
793            }
794        }
795    }
796}
797
798/// Parse the provider's `usage` blob (non-streaming response) into our
799/// internal `Usage` struct. Missing fields default to 0.
800fn parse_usage(u: &Value) -> Usage {
801    Usage {
802        input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
803        output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
804        cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
805        cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
806    }
807}
808
809/// Merge a fresh usage payload into the accumulated one. Non-zero new values
810/// override prior ones — this matches the streaming protocol where
811/// `message_start` gives input counts and `message_delta` updates output.
812fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
813    let fresh = parse_usage(u);
814    if fresh.input_tokens > 0 {
815        acc.input_tokens = fresh.input_tokens;
816    }
817    if fresh.output_tokens > 0 {
818        acc.output_tokens = fresh.output_tokens;
819    }
820    if fresh.cache_creation_input_tokens > 0 {
821        acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
822    }
823    if fresh.cache_read_input_tokens > 0 {
824        acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
825    }
826    acc
827}
828
829/// Parse web search content from the API response.
830fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
831    let results = value["results"]
832        .as_array()
833        .map(|arr| {
834            arr.iter()
835                .filter_map(|item| {
836                    Some(crate::providers::WebSearchResultItem {
837                        title: item["title"].as_str().map(String::from),
838                        url: item["url"].as_str()?.to_string(),
839                        encrypted_content: item["encrypted_content"].as_str().map(String::from),
840                        snippet: item["snippet"].as_str().map(String::from),
841                    })
842                })
843                .collect()
844        })
845        .unwrap_or_default();
846
847    crate::providers::WebSearchContent { results }
848}
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853
854    #[test]
855    fn take_next_sse_frame_supports_crlf_delimiter() {
856        let mut buffer = concat!(
857            "event: message_start\r\n",
858            "data: {\"type\":\"message_start\"}\r\n\r\n",
859            "data: {\"type\":\"message_stop\"}\r\n\r\n"
860        )
861        .to_string();
862
863        let first = take_next_sse_frame(&mut buffer).expect("first frame");
864        assert!(first.contains("message_start"));
865
866        let second = take_next_sse_frame(&mut buffer).expect("second frame");
867        assert!(second.contains("message_stop"));
868        assert!(buffer.is_empty());
869    }
870
871    #[test]
872    fn take_trailing_sse_frame_returns_unterminated_event() {
873        let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
874        let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
875        assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
876        assert!(buffer.is_empty());
877    }
878
879    #[test]
880    fn extract_sse_data_line_supports_optional_space() {
881        assert_eq!(
882            extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
883            Some("{\"k\":1}".to_string())
884        );
885        assert_eq!(
886            extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
887            Some("{\"k\":2}".to_string())
888        );
889    }
890
891    #[test]
892    fn finalize_incomplete_stream_keeps_accumulated_content() {
893        let response = finalize_incomplete_stream(
894            vec![AssembledBlock::Text("partial reply".to_string())],
895            StopReason::EndTurn,
896            Usage::default(),
897        );
898
899        assert_eq!(response.stop_reason, StopReason::EndTurn);
900        assert_eq!(response.content.len(), 1);
901        match &response.content[0] {
902            ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
903            other => panic!("unexpected block: {other:?}"),
904        }
905    }
906}