Skip to main content

matrixcode_core/providers/
anthropic.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use tokio::sync::mpsc;
7
8use crate::models::context_window_for;
9use crate::tools::ToolDefinition;
10
11use super::{
12    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
13    StreamEvent, Usage,
14};
15
16pub struct AnthropicProvider {
17    api_key: String,
18    model: String,
19    base_url: String,
20    client: reqwest::Client,
21    /// Extra headers from config
22    extra_headers: Vec<(String, String)>,
23}
24
25impl AnthropicProvider {
26    pub fn new(api_key: String, model: String, base_url: String) -> Self {
27        Self::with_headers(api_key, model, base_url, None)
28    }
29
30    pub fn with_headers(
31        api_key: String,
32        model: String,
33        base_url: String,
34        extra_headers: Option<std::collections::HashMap<String, String>>,
35    ) -> Self {
36        let client = reqwest::Client::builder()
37            .timeout(std::time::Duration::from_secs(120))
38            .connect_timeout(std::time::Duration::from_secs(10))
39            .build()
40            .unwrap_or_else(|_| reqwest::Client::new());
41        let extra_headers: Vec<(String, String)> = extra_headers
42            .map(|h| h.into_iter().collect())
43            .unwrap_or_default();
44        Self {
45            api_key,
46            model,
47            base_url,
48            client,
49            extra_headers,
50        }
51    }
52
53    /// Check if this is the official Anthropic API endpoint.
54    /// Non-official endpoints typically use Bearer auth (OpenAI-compatible).
55    fn is_official_anthropic(&self) -> bool {
56        self.base_url.contains("api.anthropic.com")
57    }
58
59    fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
60        messages
61            .iter()
62            .filter(|m| m.role != Role::System)
63            .map(|m| {
64                let role = match m.role {
65                    Role::User | Role::Tool => "user",
66                    Role::Assistant => "assistant",
67                    Role::System => unreachable!(),
68                };
69
70                let content = match &m.content {
71                    MessageContent::Text(text) => json!(text),
72                    MessageContent::Blocks(blocks) => {
73                        let converted: Vec<Value> = blocks
74                            .iter()
75                            .map(|b| match b {
76                                ContentBlock::Text { text } => json!({"type": "text", "text": text}),
77                                ContentBlock::ToolUse { id, name, input } => {
78                                    json!({"type": "tool_use", "id": id, "name": name, "input": input})
79                                }
80                                ContentBlock::ToolResult { tool_use_id, content } => {
81                                    json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
82                                }
83                                ContentBlock::Thinking { thinking, signature } => {
84                                    let mut obj = json!({"type": "thinking", "thinking": thinking});
85                                    if let Some(sig) = signature {
86                                        obj["signature"] = json!(sig);
87                                    }
88                                    obj
89                                }
90                                ContentBlock::ServerToolUse { id, name, input } => {
91                                    json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
92                                }
93                                ContentBlock::WebSearchResult { tool_use_id, content } => {
94                                    json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
95                                }
96                            })
97                            .collect();
98                        json!(converted)
99                    }
100                };
101
102                json!({"role": role, "content": content})
103            })
104            .collect()
105    }
106
107    /// Convert tools with caching control for Anthropic prompt caching.
108    fn convert_tools_with_caching(
109        &self,
110        tools: &[ToolDefinition],
111        enable_caching: bool,
112    ) -> Vec<Value> {
113        let mut converted: Vec<Value> = tools
114            .iter()
115            .map(|t| {
116                json!({
117                    "name": t.name,
118                    "description": t.description,
119                    "input_schema": t.parameters,
120                })
121            })
122            .collect();
123
124        // Add cache_control to the last tool definition for tools caching
125        if enable_caching && !converted.is_empty() {
126            let last_idx = converted.len() - 1;
127            if let Some(obj) = converted[last_idx].as_object_mut() {
128                obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
129            }
130        }
131
132        converted
133    }
134
135    /// Build the base JSON body shared by streaming and non-streaming requests.
136    fn build_body(&self, request: &ChatRequest) -> Value {
137        let mut body = json!({
138            "model": self.model,
139            "max_tokens": request.max_tokens,
140            "messages": self.convert_messages(&request.messages),
141        });
142
143        // Add prompt caching for system prompt (Anthropic-specific)
144        if request.enable_caching {
145            if let Some(system) = &request.system {
146                // System prompt caching: add cache_control to enable caching
147                body["system"] = json!([
148                    {
149                        "type": "text",
150                        "text": system,
151                        "cache_control": {"type": "ephemeral"}
152                    }
153                ]);
154            }
155        } else if let Some(system) = &request.system {
156            body["system"] = json!(system);
157        }
158
159        if !request.tools.is_empty() {
160            let tools = self.convert_tools_with_caching(
161                &request.tools,
162                request.enable_caching,
163            );
164            body["tools"] = json!(tools);
165        }
166
167        if !request.server_tools.is_empty() {
168            body["tools"] = json!(
169                body["tools"]
170                    .as_array()
171                    .map(|t| {
172                        let mut tools = t.clone();
173                        for st in &request.server_tools {
174                            tools.push(serde_json::to_value(st).unwrap_or_default());
175                        }
176                        tools
177                    })
178                    .unwrap_or_else(|| request
179                        .server_tools
180                        .iter()
181                        .map(|st| serde_json::to_value(st).unwrap_or_default())
182                        .collect())
183            );
184        }
185
186        // Extended thinking (Anthropic-specific)
187        if request.think {
188            let config = thinking_config(&self.model);
189            log::debug!(
190                "Adding thinking config for model {}: {:?}",
191                self.model,
192                config
193            );
194            body["thinking"] = config;
195        }
196
197        body
198    }
199}
200
201/// Models that require the new `adaptive` thinking mode instead of the
202/// legacy `enabled`+`budget_tokens` form. Conservative allow-list: if we
203/// don't recognize the name we default to the legacy shape (which older
204/// models and most third-party gateways understand).
205fn thinking_config(model: &str) -> Value {
206    let m = model.to_lowercase();
207    // New models (2025+) use adaptive thinking
208    let adaptive = m.contains("opus-4")
209        || m.contains("sonnet-4")
210        || m.contains("claude-4")
211        || m.contains("20250")
212        || m.contains("2025");
213    if adaptive {
214        json!({"type": "enabled", "budget_tokens": 10000})
215    } else {
216        json!({"type": "enabled", "budget_tokens": 5000})
217    }
218}
219
220#[async_trait]
221impl Provider for AnthropicProvider {
222    fn context_size(&self) -> Option<u32> {
223        context_window_for(&self.model)
224    }
225
226    fn clone_box(&self) -> Box<dyn Provider> {
227        Box::new(Self {
228            api_key: self.api_key.clone(),
229            model: self.model.clone(),
230            base_url: self.base_url.clone(),
231            client: reqwest::Client::new(),
232            extra_headers: self.extra_headers.clone(),
233        })
234    }
235
236    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
237        let body = self.build_body(&request);
238
239        let url = format!("{}/v1/messages", self.base_url);
240        let mut req = self
241            .client
242            .post(&url)
243            .header("User-Agent", "curl/8.0")
244            .json(&body);
245
246        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
247        if self.is_official_anthropic() {
248            req = req
249                .header("x-api-key", &self.api_key)
250                .header("anthropic-version", "2025-04-15")
251                .header("anthropic-beta", "prompt-caching-2024-07-31");
252        } else {
253            req = req.header("Authorization", format!("Bearer {}", self.api_key));
254        }
255
256        // Add extra headers from config (all custom headers go here)
257        for (name, value) in &self.extra_headers {
258            req = req.header(name, value);
259        }
260
261        let response = req.send().await?;
262
263        let status = response.status();
264        let response_body: Value = response.json().await?;
265
266        if !status.is_success() {
267            let err_msg = response_body["error"]["message"]
268                .as_str()
269                .unwrap_or("unknown error");
270            anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
271        }
272
273        let stop_reason = match response_body["stop_reason"].as_str() {
274            Some("tool_use") => StopReason::ToolUse,
275            Some("max_tokens") => StopReason::MaxTokens,
276            _ => StopReason::EndTurn,
277        };
278
279        let content = response_body["content"]
280            .as_array()
281            .unwrap_or(&vec![])
282            .iter()
283            .filter_map(|block| match block["type"].as_str()? {
284                "text" => Some(ContentBlock::Text {
285                    text: block["text"].as_str()?.to_string(),
286                }),
287                "tool_use" => Some(ContentBlock::ToolUse {
288                    id: block["id"].as_str()?.to_string(),
289                    name: block["name"].as_str()?.to_string(),
290                    input: block["input"].clone(),
291                }),
292                "thinking" => Some(ContentBlock::Thinking {
293                    thinking: block["thinking"].as_str()?.to_string(),
294                    signature: block["signature"].as_str().map(String::from),
295                }),
296                "server_tool_use" => Some(ContentBlock::ServerToolUse {
297                    id: block["id"].as_str()?.to_string(),
298                    name: block["name"].as_str()?.to_string(),
299                    input: block["input"].clone(),
300                }),
301                "web_search_tool_result" => {
302                    let tool_use_id = block["tool_use_id"].as_str()?.to_string();
303                    let content = parse_web_search_content(&block["content"]);
304                    Some(ContentBlock::WebSearchResult {
305                        tool_use_id,
306                        content,
307                    })
308                }
309                _ => None,
310            })
311            .collect();
312
313        Ok(ChatResponse {
314            content,
315            stop_reason,
316            usage: parse_usage(&response_body["usage"]),
317        })
318    }
319
320    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
321        let mut body = self.build_body(&request);
322        body["stream"] = json!(true);
323
324        let url = format!("{}/v1/messages", self.base_url);
325        let mut req = self
326            .client
327            .post(&url)
328            .header("User-Agent", "curl/8.0")
329            .json(&body);
330
331        // Auth: official Anthropic API uses x-api-key, others use Bearer (OpenAI-compatible)
332        if self.is_official_anthropic() {
333            req = req
334                .header("x-api-key", &self.api_key)
335                .header("anthropic-version", "2025-04-15")
336                .header("anthropic-beta", "prompt-caching-2024-07-31");
337        } else {
338            req = req.header("Authorization", format!("Bearer {}", self.api_key));
339        }
340
341        // Add extra headers from config (all custom headers go here)
342        for (name, value) in &self.extra_headers {
343            req = req.header(name, value);
344        }
345
346        let response = req.send().await?;
347
348        if !response.status().is_success() {
349            let status = response.status();
350            let text = response.text().await.unwrap_or_default();
351            anyhow::bail!("Anthropic API error ({}): {}", status, text);
352        }
353
354        let (tx, rx) = mpsc::channel(64);
355        tokio::spawn(async move {
356            let mut stream = response.bytes_stream();
357            let mut buffer = String::new();
358            let mut sent_first_byte = false;
359
360            // In-flight block assembly: index → partial data
361            let mut blocks: Vec<AssembledBlock> = Vec::new();
362            let mut stop_reason = StopReason::EndTurn;
363            let mut usage = Usage::default();
364
365            while let Some(chunk) = stream.next().await {
366                let chunk = match chunk {
367                    Ok(c) => c,
368                    Err(e) => {
369                        let _ = tx
370                            .send(StreamEvent::Error(format!("stream read error: {}", e)))
371                            .await;
372                        return;
373                    }
374                };
375
376                if !sent_first_byte {
377                    sent_first_byte = true;
378                    let _ = tx.send(StreamEvent::FirstByte).await;
379                }
380
381                buffer.push_str(&String::from_utf8_lossy(&chunk));
382
383                while let Some(frame) = take_next_sse_frame(&mut buffer) {
384                    if handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx)
385                        .await
386                    {
387                        return;
388                    }
389                }
390            }
391
392            if let Some(frame) = take_trailing_sse_frame(&mut buffer)
393                && handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx).await
394            {
395                return;
396            }
397
398            if sent_first_byte {
399                debug!("stream ended without explicit message_stop; finalizing best-effort");
400                let _ = tx
401                    .send(StreamEvent::Done(finalize_incomplete_stream(
402                        std::mem::take(&mut blocks),
403                        stop_reason,
404                        usage,
405                    )))
406                    .await;
407            } else {
408                let _ = tx
409                    .send(StreamEvent::Error(
410                        "stream ended before any events were received".to_string(),
411                    ))
412                    .await;
413            }
414        });
415
416        Ok(rx)
417    }
418}
419
420fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
421    let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
422    let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
423    let (pos, delim_len) = match (lf, crlf) {
424        (Some(a), Some(b)) => {
425            if a.0 <= b.0 {
426                a
427            } else {
428                b
429            }
430        }
431        (Some(a), None) => a,
432        (None, Some(b)) => b,
433        (None, None) => return None,
434    };
435
436    let frame = buffer[..pos].to_string();
437    buffer.drain(..pos + delim_len);
438    Some(frame)
439}
440
441fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
442    let frame = buffer.trim().trim_end_matches('\r').to_string();
443    buffer.clear();
444    if frame.is_empty() { None } else { Some(frame) }
445}
446
447fn extract_sse_data_line(frame: &str) -> Option<String> {
448    for line in frame.lines() {
449        let line = line.trim_end_matches('\r');
450        // Support both "data: " (Anthropic) and "data:" (DashScope)
451        if let Some(rest) = line.strip_prefix("data: ") {
452            return Some(rest.to_string());
453        }
454        if let Some(rest) = line.strip_prefix("data:") {
455            return Some(rest.to_string());
456        }
457    }
458    None
459}
460
461async fn handle_sse_frame(
462    frame: &str,
463    blocks: &mut Vec<AssembledBlock>,
464    stop_reason: &mut StopReason,
465    usage: &mut Usage,
466    tx: &mpsc::Sender<StreamEvent>,
467) -> bool {
468    let Some(data_line) = extract_sse_data_line(frame) else {
469        return false;
470    };
471
472    let evt: Value = match serde_json::from_str(&data_line) {
473        Ok(v) => v,
474        Err(_) => return false,
475    };
476
477    handle_sse_event(evt, blocks, stop_reason, usage, tx).await
478}
479
480async fn handle_sse_event(
481    evt: Value,
482    blocks: &mut Vec<AssembledBlock>,
483    stop_reason: &mut StopReason,
484    usage: &mut Usage,
485    tx: &mpsc::Sender<StreamEvent>,
486) -> bool {
487    match evt["type"].as_str().unwrap_or("") {
488        "message_start" => {
489            // Initial usage payload — `input_tokens` is final
490            // (they don't grow during streaming) but
491            // `output_tokens` starts near 0 and is updated by
492            // subsequent `message_delta` events.
493            *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
494            debug!(
495                "message_start: usage_json={}",
496                serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
497            );
498            debug!(
499                "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
500                usage.input_tokens,
501                usage.output_tokens,
502                usage.cache_read_input_tokens,
503                usage.cache_creation_input_tokens
504            );
505            // Send real-time usage update
506            let _ = tx
507                .send(StreamEvent::Usage {
508                    output_tokens: usage.output_tokens,
509                })
510                .await;
511        }
512        "content_block_start" => {
513            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
514            let block = &evt["content_block"];
515            let kind = block["type"].as_str().unwrap_or("");
516            while blocks.len() <= idx {
517                blocks.push(AssembledBlock::default());
518            }
519            match kind {
520                "text" => {
521                    blocks[idx] = AssembledBlock::Text(String::new());
522                }
523                "thinking" => {
524                    blocks[idx] = AssembledBlock::Thinking {
525                        text: String::new(),
526                        signature: None,
527                    };
528                }
529                "tool_use" | "server_tool_use" => {
530                    let id = block["id"].as_str().unwrap_or_default();
531                    let name = block["name"].as_str().unwrap_or_default();
532                    let is_server = kind == "server_tool_use";
533                    blocks[idx] = if is_server {
534                        AssembledBlock::ServerToolUse {
535                            id: id.into(),
536                            name: name.into(),
537                            input_json: String::new(),
538                        }
539                    } else {
540                        AssembledBlock::ToolUse {
541                            id: id.into(),
542                            name: name.into(),
543                            input_json: String::new(),
544                        }
545                    };
546                    let _ = tx
547                        .send(StreamEvent::ToolUseStart {
548                            id: id.into(),
549                            name: name.into(),
550                        })
551                        .await;
552                }
553                "web_search_tool_result" => {
554                    let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
555                    blocks[idx] = AssembledBlock::WebSearchResult {
556                        tool_use_id,
557                        content_json: String::new(),
558                    };
559                }
560                _ => {}
561            }
562        }
563        "content_block_delta" => {
564            let idx = evt["index"].as_u64().unwrap_or(0) as usize;
565            let delta = &evt["delta"];
566            let dt = delta["type"].as_str().unwrap_or("");
567            if idx >= blocks.len() {
568                return false;
569            }
570            match (dt, &mut blocks[idx]) {
571                ("text_delta", AssembledBlock::Text(buf)) => {
572                    if let Some(t) = delta["text"].as_str() {
573                        buf.push_str(t);
574                        let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
575                    }
576                }
577                ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
578                    if let Some(t) = delta["thinking"].as_str() {
579                        text.push_str(t);
580                        log::debug!("Received thinking_delta: {} chars", t.len());
581                        let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
582                    }
583                }
584                ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
585                    if let Some(s) = delta["signature"].as_str() {
586                        signature.get_or_insert_with(String::new).push_str(s);
587                    }
588                }
589                ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
590                    if let Some(p) = delta["partial_json"].as_str() {
591                        input_json.push_str(p);
592                        let _ = tx
593                            .send(StreamEvent::ToolInputDelta {
594                                bytes_so_far: input_json.len(),
595                            })
596                            .await;
597                    }
598                }
599                ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
600                    if let Some(p) = delta["partial_json"].as_str() {
601                        input_json.push_str(p);
602                        let _ = tx
603                            .send(StreamEvent::ToolInputDelta {
604                                bytes_so_far: input_json.len(),
605                            })
606                            .await;
607                    }
608                }
609                _ => {}
610            }
611        }
612        "message_delta" => {
613            if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
614                *stop_reason = match sr {
615                    "tool_use" => StopReason::ToolUse,
616                    "max_tokens" => StopReason::MaxTokens,
617                    _ => StopReason::EndTurn,
618                };
619            }
620            // `usage` on message_delta reflects cumulative
621            // counts for the current message — most notably
622            // the final `output_tokens`.
623            *usage = merge_usage(usage.clone(), &evt["usage"]);
624            debug!(
625                "message_delta: input={}, output={}, cache_read={}, cache_created={}",
626                usage.input_tokens,
627                usage.output_tokens,
628                usage.cache_read_input_tokens,
629                usage.cache_creation_input_tokens
630            );
631            // Send real-time usage update
632            let _ = tx
633                .send(StreamEvent::Usage {
634                    output_tokens: usage.output_tokens,
635                })
636                .await;
637        }
638        "message_stop" => {
639            debug!(
640                "Message completed: stop_reason={}, usage={}",
641                match *stop_reason {
642                    StopReason::EndTurn => "end_turn",
643                    StopReason::ToolUse => "tool_use",
644                    StopReason::MaxTokens => "max_tokens",
645                },
646                usage.output_tokens
647            );
648            debug!(
649                "message_stop final usage: cache_read={}, cache_created={}",
650                usage.cache_read_input_tokens, usage.cache_creation_input_tokens
651            );
652            let _ = tx
653                .send(StreamEvent::Done(finalize_incomplete_stream(
654                    std::mem::take(blocks),
655                    stop_reason.clone(),
656                    usage.clone(),
657                )))
658                .await;
659            return true;
660        }
661        "error" => {
662            let msg = evt["error"]["message"]
663                .as_str()
664                .unwrap_or("unknown stream error")
665                .to_string();
666            let _ = tx.send(StreamEvent::Error(msg)).await;
667            return true;
668        }
669        _ => {}
670    }
671
672    false
673}
674
675fn build_chat_response(
676    blocks: Vec<AssembledBlock>,
677    stop_reason: StopReason,
678    usage: Usage,
679) -> ChatResponse {
680    let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
681    ChatResponse {
682        content,
683        stop_reason,
684        usage,
685    }
686}
687
688fn finalize_incomplete_stream(
689    blocks: Vec<AssembledBlock>,
690    stop_reason: StopReason,
691    usage: Usage,
692) -> ChatResponse {
693    build_chat_response(blocks, stop_reason, usage)
694}
695
696#[derive(Default)]
697enum AssembledBlock {
698    #[default]
699    Empty,
700    Text(String),
701    Thinking {
702        text: String,
703        signature: Option<String>,
704    },
705    ToolUse {
706        id: String,
707        name: String,
708        input_json: String,
709    },
710    ServerToolUse {
711        id: String,
712        name: String,
713        input_json: String,
714    },
715    WebSearchResult {
716        tool_use_id: String,
717        content_json: String,
718    },
719}
720
721impl AssembledBlock {
722    fn finish(self) -> Option<ContentBlock> {
723        match self {
724            AssembledBlock::Empty => None,
725            AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
726            AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
727                thinking: text,
728                signature,
729            }),
730            AssembledBlock::ToolUse {
731                id,
732                name,
733                input_json,
734            } => {
735                let input: Value = if input_json.is_empty() {
736                    json!({})
737                } else {
738                    serde_json::from_str(&input_json).unwrap_or(json!({}))
739                };
740                Some(ContentBlock::ToolUse { id, name, input })
741            }
742            AssembledBlock::ServerToolUse {
743                id,
744                name,
745                input_json,
746            } => {
747                let input: Value = if input_json.is_empty() {
748                    json!({})
749                } else {
750                    serde_json::from_str(&input_json).unwrap_or(json!({}))
751                };
752                Some(ContentBlock::ServerToolUse { id, name, input })
753            }
754            AssembledBlock::WebSearchResult {
755                tool_use_id,
756                content_json,
757            } => {
758                let content: Value = if content_json.is_empty() {
759                    json!({"results": []})
760                } else {
761                    serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
762                };
763                Some(ContentBlock::WebSearchResult {
764                    tool_use_id,
765                    content: parse_web_search_content(&content),
766                })
767            }
768        }
769    }
770}
771
772/// Parse the provider's `usage` blob (non-streaming response) into our
773/// internal `Usage` struct. Missing fields default to 0.
774fn parse_usage(u: &Value) -> Usage {
775    Usage {
776        input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
777        output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
778        cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
779        cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
780    }
781}
782
783/// Merge a fresh usage payload into the accumulated one. Non-zero new values
784/// override prior ones — this matches the streaming protocol where
785/// `message_start` gives input counts and `message_delta` updates output.
786fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
787    let fresh = parse_usage(u);
788    if fresh.input_tokens > 0 {
789        acc.input_tokens = fresh.input_tokens;
790    }
791    if fresh.output_tokens > 0 {
792        acc.output_tokens = fresh.output_tokens;
793    }
794    if fresh.cache_creation_input_tokens > 0 {
795        acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
796    }
797    if fresh.cache_read_input_tokens > 0 {
798        acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
799    }
800    acc
801}
802
803/// Parse web search content from the API response.
804fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
805    let results = value["results"]
806        .as_array()
807        .map(|arr| {
808            arr.iter()
809                .filter_map(|item| {
810                    Some(crate::providers::WebSearchResultItem {
811                        title: item["title"].as_str().map(String::from),
812                        url: item["url"].as_str()?.to_string(),
813                        encrypted_content: item["encrypted_content"].as_str().map(String::from),
814                        snippet: item["snippet"].as_str().map(String::from),
815                    })
816                })
817                .collect()
818        })
819        .unwrap_or_default();
820
821    crate::providers::WebSearchContent { results }
822}
823
824#[cfg(test)]
825mod tests {
826    use super::*;
827
828    #[test]
829    fn take_next_sse_frame_supports_crlf_delimiter() {
830        let mut buffer = concat!(
831            "event: message_start\r\n",
832            "data: {\"type\":\"message_start\"}\r\n\r\n",
833            "data: {\"type\":\"message_stop\"}\r\n\r\n"
834        )
835        .to_string();
836
837        let first = take_next_sse_frame(&mut buffer).expect("first frame");
838        assert!(first.contains("message_start"));
839
840        let second = take_next_sse_frame(&mut buffer).expect("second frame");
841        assert!(second.contains("message_stop"));
842        assert!(buffer.is_empty());
843    }
844
845    #[test]
846    fn take_trailing_sse_frame_returns_unterminated_event() {
847        let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
848        let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
849        assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
850        assert!(buffer.is_empty());
851    }
852
853    #[test]
854    fn extract_sse_data_line_supports_optional_space() {
855        assert_eq!(
856            extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
857            Some("{\"k\":1}".to_string())
858        );
859        assert_eq!(
860            extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
861            Some("{\"k\":2}".to_string())
862        );
863    }
864
865    #[test]
866    fn finalize_incomplete_stream_keeps_accumulated_content() {
867        let response = finalize_incomplete_stream(
868            vec![AssembledBlock::Text("partial reply".to_string())],
869            StopReason::EndTurn,
870            Usage::default(),
871        );
872
873        assert_eq!(response.stop_reason, StopReason::EndTurn);
874        assert_eq!(response.content.len(), 1);
875        match &response.content[0] {
876            ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
877            other => panic!("unexpected block: {other:?}"),
878        }
879    }
880}