1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use std::sync::Arc;
7use tokio::sync::mpsc;
8
9use crate::constants::{
10 ANTHROPIC_API_VERSION, DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_CONTENT_TIMEOUT_SECS,
11 DEFAULT_READ_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS, THINKING_BUDGET_NEW_MODELS,
12 THINKING_BUDGET_OLD_MODELS,
13};
14use crate::models::context_window_for;
15use crate::tools::ToolDefinition;
16
17use super::{
18 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
19 StreamEvent, Usage,
20};
21
22pub struct AnthropicProvider {
23 api_key: String,
24 model: String,
25 base_url: String,
26 client: reqwest::Client,
27 extra_headers: Vec<(String, String)>,
29}
30
31impl AnthropicProvider {
32 pub fn new(api_key: String, model: String, base_url: String) -> Self {
33 Self::with_headers(api_key, model, base_url, None)
34 }
35
36 fn load_proxy_from_env() -> Option<String> {
38 std::env::var("HTTPS_PROXY")
39 .ok()
40 .or_else(|| std::env::var("https_proxy").ok())
41 .or_else(|| std::env::var("HTTP_PROXY").ok())
42 .or_else(|| std::env::var("http_proxy").ok())
43 }
44
45 pub fn with_headers(
46 api_key: String,
47 model: String,
48 base_url: String,
49 extra_headers: Option<std::collections::HashMap<String, String>>,
50 ) -> Self {
51 let mut client_builder = reqwest::Client::builder()
56 .connect_timeout(std::time::Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
57 .read_timeout(std::time::Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
58 .timeout(std::time::Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)); if let Some(proxy_url) = Self::load_proxy_from_env()
62 && let Ok(proxy) = reqwest::Proxy::all(&proxy_url)
63 {
64 log::info!("AnthropicProvider using proxy: {}", proxy_url);
65 client_builder = client_builder.proxy(proxy);
66 }
67
68 let client = client_builder
69 .build()
70 .unwrap_or_else(|_| reqwest::Client::new());
71 let extra_headers: Vec<(String, String)> = extra_headers
72 .map(|h| h.into_iter().collect())
73 .unwrap_or_default();
74 Self {
75 api_key,
76 model,
77 base_url,
78 client,
79 extra_headers,
80 }
81 }
82
83 fn is_official_anthropic(&self) -> bool {
86 self.base_url.contains("api.anthropic.com")
87 }
88
89 fn filter_trailing_thinking(messages: Vec<Value>) -> Vec<Value> {
94 if messages.is_empty() {
95 return messages;
96 }
97
98 let last_idx = messages.len() - 1;
100 let last_msg = &messages[last_idx];
101
102 if last_msg.get("role").and_then(|r| r.as_str()) != Some("assistant") {
104 return messages;
105 }
106
107 let content = last_msg.get("content");
108 if let Some(content_array) = content.and_then(|c| c.as_array()) {
109 if content_array.is_empty() {
110 return messages;
111 }
112
113 let last_block = content_array.last();
115 if let Some(block) = last_block {
116 if block.get("type").and_then(|t| t.as_str()) == Some("thinking") {
117 let mut last_valid_idx: i64 = content_array.len() as i64 - 1;
120 while last_valid_idx >= 0 {
121 let b = &content_array[last_valid_idx as usize];
122 if b.get("type").and_then(|t| t.as_str()) != Some("thinking") {
123 break;
124 }
125 last_valid_idx -= 1;
126 }
127
128 let filtered_content: Vec<Value> = if last_valid_idx < 0 {
130 vec![json!({"type": "text", "text": "[No message content]"})]
131 } else {
132 content_array[..=last_valid_idx as usize].to_vec()
133 };
134
135 let mut result = messages;
137 result[last_idx] = json!({
138 "role": "assistant",
139 "content": filtered_content
140 });
141 return result;
142 }
143 }
144 }
145
146 messages
147 }
148
149 fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
150 log::debug!(
155 "convert_messages: keeping thinking blocks for continuity, base_url={}",
156 self.base_url
157 );
158
159 let converted: Vec<Value> = messages
160 .iter()
161 .filter(|m| m.role != Role::System)
162 .map(|m| {
163 let role = match m.role {
164 Role::User | Role::Tool => "user",
165 Role::Assistant => "assistant",
166 Role::System => unreachable!(),
167 };
168
169 let content = match &m.content {
170 MessageContent::Text(text) => json!(text),
171 MessageContent::Blocks(blocks) => {
172 let converted: Vec<Value> = blocks
173 .iter()
174 .map(|b| match b {
175 ContentBlock::Thinking { thinking, signature } => {
177 let mut obj = json!({"type": "thinking", "thinking": thinking});
178 if let Some(sig) = signature {
179 if !sig.is_empty() {
180 obj["signature"] = json!(sig);
181 }
182 }
183 obj
184 }
185 ContentBlock::Text { text } => {
186 json!({"type": "text", "text": text})
187 }
188 ContentBlock::ToolUse { id, name, input } => {
189 json!({"type": "tool_use", "id": id, "name": name, "input": input})
190 }
191 ContentBlock::ToolResult { tool_use_id, content } => {
192 json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
193 }
194 ContentBlock::ServerToolUse { id, name, input } => {
195 json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
196 }
197 ContentBlock::ServerToolResult { tool_use_id, content } => {
198 json!({"type": "server_tool_result", "tool_use_id": tool_use_id, "content": content})
199 }
200 ContentBlock::WebSearchResult { tool_use_id, content } => {
201 json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
202 }
203 })
204 .collect();
205
206 if converted.is_empty() {
207 json!([])
208 } else {
209 json!(converted)
210 }
211 }
212 };
213
214 if content.is_array() && content.as_array().map(|a| a.is_empty()).unwrap_or(false) {
216 return json!(null);
217 }
218
219 json!({"role": role, "content": content})
220 })
221 .filter(|v| !v.is_null())
222 .collect();
223
224 Self::filter_trailing_thinking(converted)
226 }
227
228 fn convert_tools_with_caching(
231 &self,
232 tools: &[ToolDefinition],
233 enable_caching: bool,
234 ) -> Vec<Value> {
235 let mut converted: Vec<Value> = tools
236 .iter()
237 .map(|t| {
238 json!({
239 "name": t.name,
240 "description": t.description,
241 "input_schema": t.parameters,
242 })
243 })
244 .collect();
245
246 if enable_caching && self.is_official_anthropic() && !converted.is_empty() {
249 let last_idx = converted.len() - 1;
250 if let Some(obj) = converted[last_idx].as_object_mut() {
251 obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
252 }
253 }
254
255 converted
256 }
257
258 fn build_body(&self, request: &ChatRequest) -> Value {
260 let mut body = json!({
261 "model": self.model,
262 "max_tokens": request.max_tokens,
263 "messages": self.convert_messages(&request.messages),
264 });
265
266 if request.enable_caching && self.is_official_anthropic() {
269 if let Some(system) = &request.system {
270 body["system"] = json!([
272 {
273 "type": "text",
274 "text": system,
275 "cache_control": {"type": "ephemeral"}
276 }
277 ]);
278 }
279 } else if let Some(system) = &request.system {
280 body["system"] = json!(system);
281 }
282
283 if !request.tools.is_empty() {
284 let tools = self.convert_tools_with_caching(&request.tools, request.enable_caching);
285 body["tools"] = json!(tools);
286 }
287
288 if !request.server_tools.is_empty() {
289 body["tools"] = json!(
290 body["tools"]
291 .as_array()
292 .map(|t| {
293 let mut tools = t.clone();
294 for st in &request.server_tools {
295 tools.push(serde_json::to_value(st).unwrap_or_default());
296 }
297 tools
298 })
299 .unwrap_or_else(|| request
300 .server_tools
301 .iter()
302 .map(|st| serde_json::to_value(st).unwrap_or_default())
303 .collect())
304 );
305 }
306
307 if request.think && self.is_official_anthropic() {
311 let config = thinking_config(&self.model);
312 log::debug!(
313 "Adding thinking config for model {}: {:?}",
314 self.model,
315 config
316 );
317 body["thinking"] = config;
318 } else if request.think {
319 log::debug!(
320 "Skipping thinking config for non-official API (model: {}, base_url: {})",
321 self.model,
322 self.base_url
323 );
324 }
325
326 body
327 }
328}
329
330fn thinking_config(model: &str) -> Value {
335 let m = model.to_lowercase();
336 let adaptive = m.contains("opus-4")
338 || m.contains("sonnet-4")
339 || m.contains("claude-4")
340 || m.contains("20250")
341 || m.contains("2025");
342 if adaptive {
343 json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_NEW_MODELS})
344 } else {
345 json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_OLD_MODELS})
347 }
348}
349
350#[async_trait]
351impl Provider for AnthropicProvider {
352 fn context_size(&self) -> Option<u32> {
353 context_window_for(&self.model)
354 }
355
356 fn model_name(&self) -> &str {
357 &self.model
358 }
359
360 fn clone_box(&self) -> Box<dyn Provider> {
361 Box::new(Self {
362 api_key: self.api_key.clone(),
363 model: self.model.clone(),
364 base_url: self.base_url.clone(),
365 client: reqwest::Client::new(),
366 extra_headers: self.extra_headers.clone(),
367 })
368 }
369
370 fn clone_arc(&self) -> Arc<dyn Provider> {
371 Arc::new(Self {
372 api_key: self.api_key.clone(),
373 model: self.model.clone(),
374 base_url: self.base_url.clone(),
375 client: reqwest::Client::new(),
376 extra_headers: self.extra_headers.clone(),
377 })
378 }
379
380 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
381 let body = self.build_body(&request);
382
383 let url = format!("{}/v1/messages", self.base_url);
384
385 crate::debug::debug_log()
387 .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
388
389 let mut req = self
390 .client
391 .post(&url)
392 .header("User-Agent", "curl/8.0")
393 .json(&body);
394
395 if self.is_official_anthropic() {
397 req = req
398 .header("x-api-key", &self.api_key)
399 .header("anthropic-version", ANTHROPIC_API_VERSION)
400 .header("anthropic-beta", "prompt-caching-2024-07-31");
401 } else {
402 req = req
403 .header("Authorization", format!("Bearer {}", self.api_key))
404 .header("anthropic-version", "2023-06-01")
406 .header("anthropic-beta", "prompt-caching-2024-07-31");
408 }
409
410 for (name, value) in &self.extra_headers {
412 req = req.header(name, value);
413 }
414
415 let response = req.send().await?;
416
417 let status = response.status();
418 let response_body: Value = response.json().await?;
419
420 crate::debug::debug_log().api_response(
422 status.as_u16(),
423 &serde_json::to_string(&response_body).unwrap_or_default(),
424 );
425
426 if !status.is_success() {
427 let err_msg = response_body["error"]["message"]
428 .as_str()
429 .unwrap_or("unknown error");
430 anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
431 }
432
433 let stop_reason = match response_body["stop_reason"].as_str() {
434 Some("tool_use") => StopReason::ToolUse,
435 Some("max_tokens") => StopReason::MaxTokens,
436 _ => StopReason::EndTurn,
437 };
438
439 let content = response_body["content"]
440 .as_array()
441 .unwrap_or(&vec![])
442 .iter()
443 .filter_map(|block| match block["type"].as_str()? {
444 "text" => Some(ContentBlock::Text {
445 text: block["text"].as_str()?.to_string(),
446 }),
447 "tool_use" => Some(ContentBlock::ToolUse {
448 id: block["id"].as_str()?.to_string(),
449 name: block["name"].as_str()?.to_string(),
450 input: block["input"].clone(),
451 }),
452 "thinking" => Some(ContentBlock::Thinking {
453 thinking: block["thinking"].as_str()?.to_string(),
454 signature: block["signature"].as_str().map(String::from),
455 }),
456 "server_tool_use" => Some(ContentBlock::ServerToolUse {
457 id: block["id"].as_str()?.to_string(),
458 name: block["name"].as_str()?.to_string(),
459 input: block["input"].clone(),
460 }),
461 "web_search_tool_result" => {
462 let tool_use_id = block["tool_use_id"].as_str()?.to_string();
463 let content = parse_web_search_content(&block["content"]);
464 Some(ContentBlock::WebSearchResult {
465 tool_use_id,
466 content,
467 })
468 }
469 _ => None,
470 })
471 .collect();
472
473 Ok(ChatResponse {
474 content,
475 stop_reason,
476 usage: parse_usage(&response_body["usage"]),
477 })
478 }
479
480 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
481 let mut body = self.build_body(&request);
482 body["stream"] = json!(true);
483
484 let url = format!("{}/v1/messages", self.base_url);
485
486 crate::debug::debug_log()
488 .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
489
490 let mut req = self
491 .client
492 .post(&url)
493 .header("User-Agent", "curl/8.0")
494 .json(&body);
495
496 if self.is_official_anthropic() {
498 req = req
499 .header("x-api-key", &self.api_key)
500 .header("anthropic-version", ANTHROPIC_API_VERSION)
501 .header("anthropic-beta", "prompt-caching-2024-07-31");
502 } else {
503 req = req
504 .header("Authorization", format!("Bearer {}", self.api_key))
505 .header("anthropic-version", "2023-06-01")
507 .header("anthropic-beta", "prompt-caching-2024-07-31");
509 }
510
511 for (name, value) in &self.extra_headers {
513 req = req.header(name, value);
514 }
515
516 let response = req.send().await?;
517
518 if !response.status().is_success() {
519 let status = response.status();
520 let text = response.text().await.unwrap_or_default();
521 anyhow::bail!("Anthropic API error ({}): {}", status, text);
522 }
523
524 let (tx, rx) = mpsc::channel(64);
525 tokio::spawn(async move {
526 let mut stream = response.bytes_stream();
527 let mut buffer = String::new();
528 let mut sent_first_byte = false;
529
530 let mut blocks: Vec<AssembledBlock> = Vec::new();
532 let mut stop_reason = StopReason::EndTurn;
533 let mut usage = Usage::default();
534
535 let mut last_content_time = std::time::Instant::now();
537 const CONTENT_TIMEOUT_SECS: u64 = DEFAULT_CONTENT_TIMEOUT_SECS; while let Some(chunk) = stream.next().await {
540 let chunk = match chunk {
541 Ok(c) => c,
542 Err(e) => {
543 let error_msg = e.to_string();
545 let is_timeout =
546 error_msg.contains("timeout") || error_msg.contains("timed out");
547 let is_decode =
548 error_msg.contains("decode") || error_msg.contains("decoding");
549
550 let msg = if is_timeout {
551 format!(
552 "Stream timeout - the API took too long to respond: {}",
553 error_msg
554 )
555 } else if is_decode {
556 format!(
557 "Stream decode error - possibly interrupted or corrupted response: {}",
558 error_msg
559 )
560 } else {
561 format!("Stream read error: {}", error_msg)
562 };
563
564 if sent_first_byte && !blocks.is_empty() {
566 debug!("finalizing partial stream due to error");
567 let _ = tx
568 .send(StreamEvent::Done(finalize_incomplete_stream(
569 std::mem::take(&mut blocks),
570 stop_reason,
571 usage,
572 )))
573 .await;
574 } else {
575 let _ = tx.send(StreamEvent::Error(msg)).await;
576 }
577 return;
578 }
579 };
580
581 if !sent_first_byte {
582 sent_first_byte = true;
583 let _ = tx.send(StreamEvent::FirstByte).await;
584 }
585
586 buffer.push_str(&String::from_utf8_lossy(&chunk));
587
588 let elapsed = last_content_time.elapsed().as_secs();
590 if elapsed > CONTENT_TIMEOUT_SECS && !blocks.is_empty() {
591 crate::debug::debug_log().stream_chunk(
592 "TIMEOUT_FORCE_FINALIZE",
593 &format!("elapsed={}s, blocks={}", elapsed, blocks.len()),
594 );
595 let _ = tx
596 .send(StreamEvent::Done(finalize_incomplete_stream(
597 std::mem::take(&mut blocks),
598 stop_reason,
599 usage,
600 )))
601 .await;
602 return;
603 }
604
605 while let Some(frame) = take_next_sse_frame(&mut buffer) {
606 if handle_sse_frame(
607 &frame,
608 &mut blocks,
609 &mut stop_reason,
610 &mut usage,
611 &tx,
612 &mut last_content_time,
613 )
614 .await
615 {
616 return;
617 }
618 }
619 }
620
621 if let Some(frame) = take_trailing_sse_frame(&mut buffer)
622 && handle_sse_frame(
623 &frame,
624 &mut blocks,
625 &mut stop_reason,
626 &mut usage,
627 &tx,
628 &mut last_content_time,
629 )
630 .await
631 {
632 return;
633 }
634
635 if sent_first_byte {
636 debug!("stream ended without explicit message_stop; finalizing best-effort");
637 let _ = tx
638 .send(StreamEvent::Done(finalize_incomplete_stream(
639 std::mem::take(&mut blocks),
640 stop_reason,
641 usage,
642 )))
643 .await;
644 } else {
645 let _ = tx
646 .send(StreamEvent::Error(
647 "stream ended before any events were received".to_string(),
648 ))
649 .await;
650 }
651 });
652
653 Ok(rx)
654 }
655}
656
657fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
658 let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
659 let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
660 let (pos, delim_len) = match (lf, crlf) {
661 (Some(a), Some(b)) => {
662 if a.0 <= b.0 {
663 a
664 } else {
665 b
666 }
667 }
668 (Some(a), None) => a,
669 (None, Some(b)) => b,
670 (None, None) => return None,
671 };
672
673 let frame = buffer[..pos].to_string();
674 buffer.drain(..pos + delim_len);
675 Some(frame)
676}
677
678fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
679 let frame = buffer.trim().trim_end_matches('\r').to_string();
680 buffer.clear();
681 if frame.is_empty() { None } else { Some(frame) }
682}
683
684fn extract_sse_data_line(frame: &str) -> Option<String> {
685 for line in frame.lines() {
686 let line = line.trim_end_matches('\r');
687 if let Some(rest) = line.strip_prefix("data: ") {
689 return Some(rest.to_string());
690 }
691 if let Some(rest) = line.strip_prefix("data:") {
692 return Some(rest.to_string());
693 }
694 }
695 None
696}
697
698async fn handle_sse_frame(
699 frame: &str,
700 blocks: &mut Vec<AssembledBlock>,
701 stop_reason: &mut StopReason,
702 usage: &mut Usage,
703 tx: &mpsc::Sender<StreamEvent>,
704 last_content_time: &mut std::time::Instant,
705) -> bool {
706 let Some(data_line) = extract_sse_data_line(frame) else {
707 return false;
708 };
709
710 let evt: Value = match serde_json::from_str(&data_line) {
711 Ok(v) => v,
712 Err(_) => return false,
713 };
714
715 handle_sse_event(evt, blocks, stop_reason, usage, tx, last_content_time).await
716}
717
718async fn handle_sse_event(
719 evt: Value,
720 blocks: &mut Vec<AssembledBlock>,
721 stop_reason: &mut StopReason,
722 usage: &mut Usage,
723 tx: &mpsc::Sender<StreamEvent>,
724 last_content_time: &mut std::time::Instant,
725) -> bool {
726 let evt_type = evt["type"].as_str().unwrap_or("");
727
728 let evt_json = serde_json::to_string(&evt).unwrap_or_default();
730 crate::debug::debug_log().stream_chunk(evt_type, &evt_json);
731
732 if evt_type == "content_block_delta" {
734 let delta_type = evt["delta"]["type"].as_str().unwrap_or("");
735 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
736 log::debug!(
737 "content_block_delta: type={}, idx={}, blocks_len={}, has_block={}",
738 delta_type,
739 idx,
740 blocks.len(),
741 idx < blocks.len()
742 );
743 }
744
745 if evt_type != "ping" {
747 *last_content_time = std::time::Instant::now();
748 }
749
750 match evt_type {
751 "message_start" => {
752 *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
757 debug!(
758 "message_start: usage_json={}",
759 serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
760 );
761 debug!(
762 "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
763 usage.input_tokens,
764 usage.output_tokens,
765 usage.cache_read_input_tokens,
766 usage.cache_creation_input_tokens
767 );
768 let _ = tx
770 .send(StreamEvent::Usage {
771 output_tokens: usage.output_tokens,
772 })
773 .await;
774 }
775 "content_block_start" => {
776 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
777 let block = &evt["content_block"];
778 let kind = block["type"].as_str().unwrap_or("");
779 while blocks.len() <= idx {
780 blocks.push(AssembledBlock::default());
781 }
782 match kind {
783 "text" => {
784 blocks[idx] = AssembledBlock::Text(String::new());
785 }
786 "thinking" => {
787 blocks[idx] = AssembledBlock::Thinking {
788 text: String::new(),
789 signature: None,
790 };
791 }
792 "tool_use" | "server_tool_use" => {
793 let id = block["id"].as_str().unwrap_or_default();
794 let name = block["name"].as_str().unwrap_or_default();
795 let is_server = kind == "server_tool_use";
796 blocks[idx] = if is_server {
797 AssembledBlock::ServerToolUse {
798 id: id.into(),
799 name: name.into(),
800 input_json: String::new(),
801 }
802 } else {
803 AssembledBlock::ToolUse {
804 id: id.into(),
805 name: name.into(),
806 input_json: String::new(),
807 }
808 };
809 let _ = tx
810 .send(StreamEvent::ToolUseStart {
811 id: id.into(),
812 name: name.into(),
813 })
814 .await;
815 }
816 "web_search_tool_result" => {
817 let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
818 blocks[idx] = AssembledBlock::WebSearchResult {
819 tool_use_id,
820 content_json: String::new(),
821 };
822 }
823 _ => {}
824 }
825 }
826 "content_block_delta" => {
827 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
828 let delta = &evt["delta"];
829 let dt = delta["type"].as_str().unwrap_or("");
830 if idx >= blocks.len() {
831 return false;
832 }
833 match (dt, &mut blocks[idx]) {
834 ("text_delta", AssembledBlock::Text(buf)) => {
835 if let Some(t) = delta["text"].as_str() {
836 buf.push_str(t);
837 let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
838 }
839 }
840 ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
841 if let Some(t) = delta["thinking"].as_str() {
842 text.push_str(t);
843 log::debug!("Received thinking_delta: {} chars", t.len());
844 let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
845 }
846 }
847 ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
848 if let Some(s) = delta["signature"].as_str() {
849 signature.get_or_insert_with(String::new).push_str(s);
850 }
851 }
852 ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
853 if let Some(p) = delta["partial_json"].as_str() {
854 input_json.push_str(p);
855 let _ = tx
856 .send(StreamEvent::ToolInputDelta {
857 bytes_so_far: input_json.len(),
858 })
859 .await;
860 }
861 }
862 ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
863 if let Some(p) = delta["partial_json"].as_str() {
864 input_json.push_str(p);
865 let _ = tx
866 .send(StreamEvent::ToolInputDelta {
867 bytes_so_far: input_json.len(),
868 })
869 .await;
870 }
871 }
872 _ => {}
873 }
874 }
875 "content_block_stop" => {
876 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
877 if let Some(AssembledBlock::ToolUse {
878 id,
879 name,
880 input_json,
881 }) = blocks.get(idx)
882 {
883 let input: Value = if input_json.is_empty() {
884 json!({})
885 } else {
886 serde_json::from_str(input_json).unwrap_or(json!({}))
887 };
888 let _ = tx
889 .send(StreamEvent::ToolInputComplete {
890 id: id.clone(),
891 name: name.clone(),
892 input,
893 })
894 .await;
895 }
896 }
897 "message_delta" => {
898 if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
899 *stop_reason = match sr {
900 "tool_use" => StopReason::ToolUse,
901 "max_tokens" => StopReason::MaxTokens,
902 _ => StopReason::EndTurn,
903 };
904 }
905 *usage = merge_usage(usage.clone(), &evt["usage"]);
909 debug!(
910 "message_delta: input={}, output={}, cache_read={}, cache_created={}",
911 usage.input_tokens,
912 usage.output_tokens,
913 usage.cache_read_input_tokens,
914 usage.cache_creation_input_tokens
915 );
916 let _ = tx
918 .send(StreamEvent::Usage {
919 output_tokens: usage.output_tokens,
920 })
921 .await;
922 }
923 "message_stop" => {
924 debug!(
925 "Message completed: stop_reason={}, usage={}",
926 match *stop_reason {
927 StopReason::EndTurn => "end_turn",
928 StopReason::ToolUse => "tool_use",
929 StopReason::MaxTokens => "max_tokens",
930 },
931 usage.output_tokens
932 );
933 debug!(
934 "message_stop final usage: cache_read={}, cache_created={}",
935 usage.cache_read_input_tokens, usage.cache_creation_input_tokens
936 );
937 let _ = tx
938 .send(StreamEvent::Done(finalize_incomplete_stream(
939 std::mem::take(blocks),
940 stop_reason.clone(),
941 usage.clone(),
942 )))
943 .await;
944 return true;
945 }
946 "error" => {
947 let msg = evt["error"]["message"]
948 .as_str()
949 .unwrap_or("unknown stream error")
950 .to_string();
951 let _ = tx.send(StreamEvent::Error(msg)).await;
952 return true;
953 }
954 _ => {}
955 }
956
957 false
958}
959
960fn build_chat_response(
961 blocks: Vec<AssembledBlock>,
962 stop_reason: StopReason,
963 usage: Usage,
964) -> ChatResponse {
965 let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
966 ChatResponse {
967 content,
968 stop_reason,
969 usage,
970 }
971}
972
973fn finalize_incomplete_stream(
974 blocks: Vec<AssembledBlock>,
975 stop_reason: StopReason,
976 usage: Usage,
977) -> ChatResponse {
978 build_chat_response(blocks, stop_reason, usage)
979}
980
981#[derive(Default)]
982enum AssembledBlock {
983 #[default]
984 Empty,
985 Text(String),
986 Thinking {
987 text: String,
988 signature: Option<String>,
989 },
990 ToolUse {
991 id: String,
992 name: String,
993 input_json: String,
994 },
995 ServerToolUse {
996 id: String,
997 name: String,
998 input_json: String,
999 },
1000 WebSearchResult {
1001 tool_use_id: String,
1002 content_json: String,
1003 },
1004}
1005
1006impl AssembledBlock {
1007 fn finish(self) -> Option<ContentBlock> {
1008 match self {
1009 AssembledBlock::Empty => None,
1010 AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
1011 AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
1012 thinking: text,
1013 signature,
1014 }),
1015 AssembledBlock::ToolUse {
1016 id,
1017 name,
1018 input_json,
1019 } => {
1020 let input: Value = if input_json.is_empty() {
1021 json!({})
1022 } else {
1023 serde_json::from_str(&input_json).unwrap_or(json!({}))
1024 };
1025 Some(ContentBlock::ToolUse { id, name, input })
1026 }
1027 AssembledBlock::ServerToolUse {
1028 id,
1029 name,
1030 input_json,
1031 } => {
1032 let input: Value = if input_json.is_empty() {
1033 json!({})
1034 } else {
1035 serde_json::from_str(&input_json).unwrap_or(json!({}))
1036 };
1037 Some(ContentBlock::ServerToolUse { id, name, input })
1038 }
1039 AssembledBlock::WebSearchResult {
1040 tool_use_id,
1041 content_json,
1042 } => {
1043 let content: Value = if content_json.is_empty() {
1044 json!({"results": []})
1045 } else {
1046 serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
1047 };
1048 Some(ContentBlock::WebSearchResult {
1049 tool_use_id,
1050 content: parse_web_search_content(&content),
1051 })
1052 }
1053 }
1054 }
1055}
1056
1057fn parse_usage(u: &Value) -> Usage {
1060 Usage {
1061 input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
1062 output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
1063 cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
1064 cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
1065 }
1066}
1067
1068fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
1072 let fresh = parse_usage(u);
1073 if fresh.input_tokens > 0 {
1074 acc.input_tokens = fresh.input_tokens;
1075 }
1076 if fresh.output_tokens > 0 {
1077 acc.output_tokens = fresh.output_tokens;
1078 }
1079 if fresh.cache_creation_input_tokens > 0 {
1080 acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
1081 }
1082 if fresh.cache_read_input_tokens > 0 {
1083 acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
1084 }
1085 acc
1086}
1087
1088fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
1090 let results = value["results"]
1091 .as_array()
1092 .map(|arr| {
1093 arr.iter()
1094 .filter_map(|item| {
1095 Some(crate::providers::WebSearchResultItem {
1096 title: item["title"].as_str().map(String::from),
1097 url: item["url"].as_str()?.to_string(),
1098 encrypted_content: item["encrypted_content"].as_str().map(String::from),
1099 snippet: item["snippet"].as_str().map(String::from),
1100 })
1101 })
1102 .collect()
1103 })
1104 .unwrap_or_default();
1105
1106 crate::providers::WebSearchContent { results }
1107}
1108
1109#[cfg(test)]
1110mod tests {
1111 use super::*;
1112
1113 #[test]
1114 fn take_next_sse_frame_supports_crlf_delimiter() {
1115 let mut buffer = concat!(
1116 "event: message_start\r\n",
1117 "data: {\"type\":\"message_start\"}\r\n\r\n",
1118 "data: {\"type\":\"message_stop\"}\r\n\r\n"
1119 )
1120 .to_string();
1121
1122 let first = take_next_sse_frame(&mut buffer).expect("first frame");
1123 assert!(first.contains("message_start"));
1124
1125 let second = take_next_sse_frame(&mut buffer).expect("second frame");
1126 assert!(second.contains("message_stop"));
1127 assert!(buffer.is_empty());
1128 }
1129
1130 #[test]
1131 fn take_trailing_sse_frame_returns_unterminated_event() {
1132 let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
1133 let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
1134 assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
1135 assert!(buffer.is_empty());
1136 }
1137
1138 #[test]
1139 fn extract_sse_data_line_supports_optional_space() {
1140 assert_eq!(
1141 extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
1142 Some("{\"k\":1}".to_string())
1143 );
1144 assert_eq!(
1145 extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
1146 Some("{\"k\":2}".to_string())
1147 );
1148 }
1149
1150 #[test]
1151 fn finalize_incomplete_stream_keeps_accumulated_content() {
1152 let response = finalize_incomplete_stream(
1153 vec![AssembledBlock::Text("partial reply".to_string())],
1154 StopReason::EndTurn,
1155 Usage::default(),
1156 );
1157
1158 assert_eq!(response.stop_reason, StopReason::EndTurn);
1159 assert_eq!(response.content.len(), 1);
1160 match &response.content[0] {
1161 ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
1162 other => panic!("unexpected block: {other:?}"),
1163 }
1164 }
1165}