1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use std::sync::Arc;
7use tokio::sync::mpsc;
8
9use crate::constants::{
10 DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS, DEFAULT_READ_TIMEOUT_SECS,
11 THINKING_BUDGET_NEW_MODELS, THINKING_BUDGET_OLD_MODELS,
12 DEFAULT_CONTENT_TIMEOUT_SECS, ANTHROPIC_API_VERSION,
13};
14use crate::models::context_window_for;
15use crate::tools::ToolDefinition;
16
17use super::{
18 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
19 StreamEvent, Usage,
20};
21
22pub struct AnthropicProvider {
23 api_key: String,
24 model: String,
25 base_url: String,
26 client: reqwest::Client,
27 extra_headers: Vec<(String, String)>,
29}
30
31impl AnthropicProvider {
32 pub fn new(api_key: String, model: String, base_url: String) -> Self {
33 Self::with_headers(api_key, model, base_url, None)
34 }
35
36 fn load_proxy_from_env() -> Option<String> {
38 std::env::var("HTTPS_PROXY")
39 .ok()
40 .or_else(|| std::env::var("https_proxy").ok())
41 .or_else(|| std::env::var("HTTP_PROXY").ok())
42 .or_else(|| std::env::var("http_proxy").ok())
43 }
44
45 pub fn with_headers(
46 api_key: String,
47 model: String,
48 base_url: String,
49 extra_headers: Option<std::collections::HashMap<String, String>>,
50 ) -> Self {
51 let mut client_builder = reqwest::Client::builder()
56 .connect_timeout(std::time::Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
57 .read_timeout(std::time::Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
58 .timeout(std::time::Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)); if let Some(proxy_url) = Self::load_proxy_from_env()
62 && let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
63 log::info!("AnthropicProvider using proxy: {}", proxy_url);
64 client_builder = client_builder.proxy(proxy);
65 }
66
67 let client = client_builder
68 .build()
69 .unwrap_or_else(|_| reqwest::Client::new());
70 let extra_headers: Vec<(String, String)> = extra_headers
71 .map(|h| h.into_iter().collect())
72 .unwrap_or_default();
73 Self {
74 api_key,
75 model,
76 base_url,
77 client,
78 extra_headers,
79 }
80 }
81
82 fn is_official_anthropic(&self) -> bool {
85 self.base_url.contains("api.anthropic.com")
86 }
87
88 fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
89 let filter_thinking = !self.is_official_anthropic();
93 log::debug!("convert_messages: filter_thinking={}, base_url={}", filter_thinking, self.base_url);
94
95 let mut thinking_count = 0;
97 for m in messages {
98 if let MessageContent::Blocks(blocks) = &m.content {
99 for b in blocks {
100 if matches!(b, ContentBlock::Thinking { .. }) {
101 thinking_count += 1;
102 }
103 }
104 }
105 }
106 if thinking_count > 0 {
107 log::debug!("convert_messages: Found {} thinking blocks in {} messages, filter_thinking={}", thinking_count, messages.len(), filter_thinking);
108 }
109
110 messages
111 .iter()
112 .filter(|m| m.role != Role::System)
113 .map(|m| {
114 let role = match m.role {
115 Role::User | Role::Tool => "user",
116 Role::Assistant => "assistant",
117 Role::System => unreachable!(),
118 };
119
120 let content = match &m.content {
121 MessageContent::Text(text) => json!(text),
122 MessageContent::Blocks(blocks) => {
123 let converted: Vec<Value> = blocks
124 .iter()
125 .filter(|b| {
126 if filter_thinking && matches!(b, ContentBlock::Thinking { .. }) {
128 return false;
129 }
130 true
131 })
132 .map(|b| match b {
133 ContentBlock::Text { text } => json!({"type": "text", "text": text}),
134 ContentBlock::ToolUse { id, name, input } => {
135 json!({"type": "tool_use", "id": id, "name": name, "input": input})
136 }
137 ContentBlock::ToolResult { tool_use_id, content } => {
138 json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
139 }
140 ContentBlock::Thinking { thinking, signature } => {
141 let mut obj = json!({"type": "thinking", "thinking": thinking});
142 if let Some(sig) = signature
144 && !sig.is_empty() {
145 obj["signature"] = json!(sig);
146 }
147 obj
148 }
149 ContentBlock::ServerToolUse { id, name, input } => {
150 json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
151 }
152 ContentBlock::WebSearchResult { tool_use_id, content } => {
153 json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
154 }
155 })
156 .collect();
157
158 if converted.is_empty() {
159 json!([])
160 } else {
161 json!(converted)
162 }
163 }
164 };
165
166 if content.is_array() && content.as_array().map(|a| a.is_empty()).unwrap_or(false) {
168 return json!(null);
169 }
170
171 json!({"role": role, "content": content})
172 })
173 .filter(|v| !v.is_null())
174 .collect()
175 }
176
177 fn convert_tools_with_caching(
179 &self,
180 tools: &[ToolDefinition],
181 enable_caching: bool,
182 ) -> Vec<Value> {
183 let mut converted: Vec<Value> = tools
184 .iter()
185 .map(|t| {
186 json!({
187 "name": t.name,
188 "description": t.description,
189 "input_schema": t.parameters,
190 })
191 })
192 .collect();
193
194 if enable_caching && !converted.is_empty() {
196 let last_idx = converted.len() - 1;
197 if let Some(obj) = converted[last_idx].as_object_mut() {
198 obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
199 }
200 }
201
202 converted
203 }
204
205 fn build_body(&self, request: &ChatRequest) -> Value {
207 let mut body = json!({
208 "model": self.model,
209 "max_tokens": request.max_tokens,
210 "messages": self.convert_messages(&request.messages),
211 });
212
213 if request.enable_caching {
215 if let Some(system) = &request.system {
216 body["system"] = json!([
218 {
219 "type": "text",
220 "text": system,
221 "cache_control": {"type": "ephemeral"}
222 }
223 ]);
224 }
225 } else if let Some(system) = &request.system {
226 body["system"] = json!(system);
227 }
228
229 if !request.tools.is_empty() {
230 let tools = self.convert_tools_with_caching(
231 &request.tools,
232 request.enable_caching,
233 );
234 body["tools"] = json!(tools);
235 }
236
237 if !request.server_tools.is_empty() {
238 body["tools"] = json!(
239 body["tools"]
240 .as_array()
241 .map(|t| {
242 let mut tools = t.clone();
243 for st in &request.server_tools {
244 tools.push(serde_json::to_value(st).unwrap_or_default());
245 }
246 tools
247 })
248 .unwrap_or_else(|| request
249 .server_tools
250 .iter()
251 .map(|st| serde_json::to_value(st).unwrap_or_default())
252 .collect())
253 );
254 }
255
256 if request.think && self.is_official_anthropic() {
260 let config = thinking_config(&self.model);
261 log::debug!(
262 "Adding thinking config for model {}: {:?}",
263 self.model,
264 config
265 );
266 body["thinking"] = config;
267 } else if request.think {
268 log::debug!(
269 "Skipping thinking config for non-official API (model: {}, base_url: {})",
270 self.model,
271 self.base_url
272 );
273 }
274
275 body
276 }
277}
278
279fn thinking_config(model: &str) -> Value {
284 let m = model.to_lowercase();
285 let adaptive = m.contains("opus-4")
287 || m.contains("sonnet-4")
288 || m.contains("claude-4")
289 || m.contains("20250")
290 || m.contains("2025");
291 if adaptive {
292 json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_NEW_MODELS})
293 } else {
294 json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_OLD_MODELS})
296 }
297}
298
299#[async_trait]
300impl Provider for AnthropicProvider {
301 fn context_size(&self) -> Option<u32> {
302 context_window_for(&self.model)
303 }
304
305 fn model_name(&self) -> &str {
306 &self.model
307 }
308
309 fn clone_box(&self) -> Box<dyn Provider> {
310 Box::new(Self {
311 api_key: self.api_key.clone(),
312 model: self.model.clone(),
313 base_url: self.base_url.clone(),
314 client: reqwest::Client::new(),
315 extra_headers: self.extra_headers.clone(),
316 })
317 }
318
319 fn clone_arc(&self) -> Arc<dyn Provider> {
320 Arc::new(Self {
321 api_key: self.api_key.clone(),
322 model: self.model.clone(),
323 base_url: self.base_url.clone(),
324 client: reqwest::Client::new(),
325 extra_headers: self.extra_headers.clone(),
326 })
327 }
328
329 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
330 let body = self.build_body(&request);
331
332 let url = format!("{}/v1/messages", self.base_url);
333
334 crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
336
337 let mut req = self
338 .client
339 .post(&url)
340 .header("User-Agent", "curl/8.0")
341 .json(&body);
342
343 if self.is_official_anthropic() {
345 req = req
346 .header("x-api-key", &self.api_key)
347 .header("anthropic-version", ANTHROPIC_API_VERSION)
348 .header("anthropic-beta", "prompt-caching-2024-07-31");
349 } else {
350 req = req
351 .header("Authorization", format!("Bearer {}", self.api_key))
352 .header("anthropic-version", "2023-06-01")
354 .header("anthropic-beta", "prompt-caching-2024-07-31");
356 }
357
358 for (name, value) in &self.extra_headers {
360 req = req.header(name, value);
361 }
362
363 let response = req.send().await?;
364
365 let status = response.status();
366 let response_body: Value = response.json().await?;
367
368 crate::debug::debug_log().api_response(status.as_u16(), &serde_json::to_string(&response_body).unwrap_or_default());
370
371 if !status.is_success() {
372 let err_msg = response_body["error"]["message"]
373 .as_str()
374 .unwrap_or("unknown error");
375 anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
376 }
377
378 let stop_reason = match response_body["stop_reason"].as_str() {
379 Some("tool_use") => StopReason::ToolUse,
380 Some("max_tokens") => StopReason::MaxTokens,
381 _ => StopReason::EndTurn,
382 };
383
384 let content = response_body["content"]
385 .as_array()
386 .unwrap_or(&vec![])
387 .iter()
388 .filter_map(|block| match block["type"].as_str()? {
389 "text" => Some(ContentBlock::Text {
390 text: block["text"].as_str()?.to_string(),
391 }),
392 "tool_use" => Some(ContentBlock::ToolUse {
393 id: block["id"].as_str()?.to_string(),
394 name: block["name"].as_str()?.to_string(),
395 input: block["input"].clone(),
396 }),
397 "thinking" => Some(ContentBlock::Thinking {
398 thinking: block["thinking"].as_str()?.to_string(),
399 signature: block["signature"].as_str().map(String::from),
400 }),
401 "server_tool_use" => Some(ContentBlock::ServerToolUse {
402 id: block["id"].as_str()?.to_string(),
403 name: block["name"].as_str()?.to_string(),
404 input: block["input"].clone(),
405 }),
406 "web_search_tool_result" => {
407 let tool_use_id = block["tool_use_id"].as_str()?.to_string();
408 let content = parse_web_search_content(&block["content"]);
409 Some(ContentBlock::WebSearchResult {
410 tool_use_id,
411 content,
412 })
413 }
414 _ => None,
415 })
416 .collect();
417
418 Ok(ChatResponse {
419 content,
420 stop_reason,
421 usage: parse_usage(&response_body["usage"]),
422 })
423 }
424
425 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
426 let mut body = self.build_body(&request);
427 body["stream"] = json!(true);
428
429 let url = format!("{}/v1/messages", self.base_url);
430
431 crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
433
434 let mut req = self
435 .client
436 .post(&url)
437 .header("User-Agent", "curl/8.0")
438 .json(&body);
439
440 if self.is_official_anthropic() {
442 req = req
443 .header("x-api-key", &self.api_key)
444 .header("anthropic-version", ANTHROPIC_API_VERSION)
445 .header("anthropic-beta", "prompt-caching-2024-07-31");
446 } else {
447 req = req
448 .header("Authorization", format!("Bearer {}", self.api_key))
449 .header("anthropic-version", "2023-06-01")
451 .header("anthropic-beta", "prompt-caching-2024-07-31");
453 }
454
455 for (name, value) in &self.extra_headers {
457 req = req.header(name, value);
458 }
459
460 let response = req.send().await?;
461
462 if !response.status().is_success() {
463 let status = response.status();
464 let text = response.text().await.unwrap_or_default();
465 anyhow::bail!("Anthropic API error ({}): {}", status, text);
466 }
467
468 let (tx, rx) = mpsc::channel(64);
469 tokio::spawn(async move {
470 let mut stream = response.bytes_stream();
471 let mut buffer = String::new();
472 let mut sent_first_byte = false;
473
474 let mut blocks: Vec<AssembledBlock> = Vec::new();
476 let mut stop_reason = StopReason::EndTurn;
477 let mut usage = Usage::default();
478
479 let mut last_content_time = std::time::Instant::now();
481 const CONTENT_TIMEOUT_SECS: u64 = DEFAULT_CONTENT_TIMEOUT_SECS; while let Some(chunk) = stream.next().await {
484 let chunk = match chunk {
485 Ok(c) => c,
486 Err(e) => {
487 let error_msg = e.to_string();
489 let is_timeout = error_msg.contains("timeout") || error_msg.contains("timed out");
490 let is_decode = error_msg.contains("decode") || error_msg.contains("decoding");
491
492 let msg = if is_timeout {
493 format!("Stream timeout - the API took too long to respond: {}", error_msg)
494 } else if is_decode {
495 format!("Stream decode error - possibly interrupted or corrupted response: {}", error_msg)
496 } else {
497 format!("Stream read error: {}", error_msg)
498 };
499
500 if sent_first_byte && !blocks.is_empty() {
502 debug!("finalizing partial stream due to error");
503 let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
504 std::mem::take(&mut blocks),
505 stop_reason,
506 usage,
507 ))).await;
508 } else {
509 let _ = tx.send(StreamEvent::Error(msg)).await;
510 }
511 return;
512 }
513 };
514
515 if !sent_first_byte {
516 sent_first_byte = true;
517 let _ = tx.send(StreamEvent::FirstByte).await;
518 }
519
520 buffer.push_str(&String::from_utf8_lossy(&chunk));
521
522 let elapsed = last_content_time.elapsed().as_secs();
524 if elapsed > CONTENT_TIMEOUT_SECS && !blocks.is_empty() {
525 crate::debug::debug_log().stream_chunk("TIMEOUT_FORCE_FINALIZE",
526 &format!("elapsed={}s, blocks={}", elapsed, blocks.len()));
527 let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
528 std::mem::take(&mut blocks),
529 stop_reason,
530 usage,
531 ))).await;
532 return;
533 }
534
535 while let Some(frame) = take_next_sse_frame(&mut buffer) {
536 if handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx, &mut last_content_time)
537 .await
538 {
539 return;
540 }
541 }
542 }
543
544 if let Some(frame) = take_trailing_sse_frame(&mut buffer)
545 && handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx, &mut last_content_time).await
546 {
547 return;
548 }
549
550 if sent_first_byte {
551 debug!("stream ended without explicit message_stop; finalizing best-effort");
552 let _ = tx
553 .send(StreamEvent::Done(finalize_incomplete_stream(
554 std::mem::take(&mut blocks),
555 stop_reason,
556 usage,
557 )))
558 .await;
559 } else {
560 let _ = tx
561 .send(StreamEvent::Error(
562 "stream ended before any events were received".to_string(),
563 ))
564 .await;
565 }
566 });
567
568 Ok(rx)
569 }
570}
571
572fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
573 let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
574 let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
575 let (pos, delim_len) = match (lf, crlf) {
576 (Some(a), Some(b)) => {
577 if a.0 <= b.0 {
578 a
579 } else {
580 b
581 }
582 }
583 (Some(a), None) => a,
584 (None, Some(b)) => b,
585 (None, None) => return None,
586 };
587
588 let frame = buffer[..pos].to_string();
589 buffer.drain(..pos + delim_len);
590 Some(frame)
591}
592
593fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
594 let frame = buffer.trim().trim_end_matches('\r').to_string();
595 buffer.clear();
596 if frame.is_empty() { None } else { Some(frame) }
597}
598
599fn extract_sse_data_line(frame: &str) -> Option<String> {
600 for line in frame.lines() {
601 let line = line.trim_end_matches('\r');
602 if let Some(rest) = line.strip_prefix("data: ") {
604 return Some(rest.to_string());
605 }
606 if let Some(rest) = line.strip_prefix("data:") {
607 return Some(rest.to_string());
608 }
609 }
610 None
611}
612
613async fn handle_sse_frame(
614 frame: &str,
615 blocks: &mut Vec<AssembledBlock>,
616 stop_reason: &mut StopReason,
617 usage: &mut Usage,
618 tx: &mpsc::Sender<StreamEvent>,
619 last_content_time: &mut std::time::Instant,
620) -> bool {
621 let Some(data_line) = extract_sse_data_line(frame) else {
622 return false;
623 };
624
625 let evt: Value = match serde_json::from_str(&data_line) {
626 Ok(v) => v,
627 Err(_) => return false,
628 };
629
630 handle_sse_event(evt, blocks, stop_reason, usage, tx, last_content_time).await
631}
632
633async fn handle_sse_event(
634 evt: Value,
635 blocks: &mut Vec<AssembledBlock>,
636 stop_reason: &mut StopReason,
637 usage: &mut Usage,
638 tx: &mpsc::Sender<StreamEvent>,
639 last_content_time: &mut std::time::Instant,
640) -> bool {
641 let evt_type = evt["type"].as_str().unwrap_or("");
642
643 let evt_json = serde_json::to_string(&evt).unwrap_or_default();
645 crate::debug::debug_log().stream_chunk(evt_type, &evt_json);
646
647 if evt_type == "content_block_delta" {
649 let delta_type = evt["delta"]["type"].as_str().unwrap_or("");
650 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
651 log::debug!(
652 "content_block_delta: type={}, idx={}, blocks_len={}, has_block={}",
653 delta_type,
654 idx,
655 blocks.len(),
656 idx < blocks.len()
657 );
658 }
659
660 if evt_type != "ping" {
662 *last_content_time = std::time::Instant::now();
663 }
664
665 match evt_type {
666 "message_start" => {
667 *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
672 debug!(
673 "message_start: usage_json={}",
674 serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
675 );
676 debug!(
677 "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
678 usage.input_tokens,
679 usage.output_tokens,
680 usage.cache_read_input_tokens,
681 usage.cache_creation_input_tokens
682 );
683 let _ = tx
685 .send(StreamEvent::Usage {
686 output_tokens: usage.output_tokens,
687 })
688 .await;
689 }
690 "content_block_start" => {
691 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
692 let block = &evt["content_block"];
693 let kind = block["type"].as_str().unwrap_or("");
694 while blocks.len() <= idx {
695 blocks.push(AssembledBlock::default());
696 }
697 match kind {
698 "text" => {
699 blocks[idx] = AssembledBlock::Text(String::new());
700 }
701 "thinking" => {
702 blocks[idx] = AssembledBlock::Thinking {
703 text: String::new(),
704 signature: None,
705 };
706 }
707 "tool_use" | "server_tool_use" => {
708 let id = block["id"].as_str().unwrap_or_default();
709 let name = block["name"].as_str().unwrap_or_default();
710 let is_server = kind == "server_tool_use";
711 blocks[idx] = if is_server {
712 AssembledBlock::ServerToolUse {
713 id: id.into(),
714 name: name.into(),
715 input_json: String::new(),
716 }
717 } else {
718 AssembledBlock::ToolUse {
719 id: id.into(),
720 name: name.into(),
721 input_json: String::new(),
722 }
723 };
724 let _ = tx
725 .send(StreamEvent::ToolUseStart {
726 id: id.into(),
727 name: name.into(),
728 })
729 .await;
730 }
731 "web_search_tool_result" => {
732 let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
733 blocks[idx] = AssembledBlock::WebSearchResult {
734 tool_use_id,
735 content_json: String::new(),
736 };
737 }
738 _ => {}
739 }
740 }
741 "content_block_delta" => {
742 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
743 let delta = &evt["delta"];
744 let dt = delta["type"].as_str().unwrap_or("");
745 if idx >= blocks.len() {
746 return false;
747 }
748 match (dt, &mut blocks[idx]) {
749 ("text_delta", AssembledBlock::Text(buf)) => {
750 if let Some(t) = delta["text"].as_str() {
751 buf.push_str(t);
752 let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
753 }
754 }
755 ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
756 if let Some(t) = delta["thinking"].as_str() {
757 text.push_str(t);
758 log::debug!("Received thinking_delta: {} chars", t.len());
759 let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
760 }
761 }
762 ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
763 if let Some(s) = delta["signature"].as_str() {
764 signature.get_or_insert_with(String::new).push_str(s);
765 }
766 }
767 ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
768 if let Some(p) = delta["partial_json"].as_str() {
769 input_json.push_str(p);
770 let _ = tx
771 .send(StreamEvent::ToolInputDelta {
772 bytes_so_far: input_json.len(),
773 })
774 .await;
775 }
776 }
777 ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
778 if let Some(p) = delta["partial_json"].as_str() {
779 input_json.push_str(p);
780 let _ = tx
781 .send(StreamEvent::ToolInputDelta {
782 bytes_so_far: input_json.len(),
783 })
784 .await;
785 }
786 }
787 _ => {}
788 }
789 }
790 "message_delta" => {
791 if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
792 *stop_reason = match sr {
793 "tool_use" => StopReason::ToolUse,
794 "max_tokens" => StopReason::MaxTokens,
795 _ => StopReason::EndTurn,
796 };
797 }
798 *usage = merge_usage(usage.clone(), &evt["usage"]);
802 debug!(
803 "message_delta: input={}, output={}, cache_read={}, cache_created={}",
804 usage.input_tokens,
805 usage.output_tokens,
806 usage.cache_read_input_tokens,
807 usage.cache_creation_input_tokens
808 );
809 let _ = tx
811 .send(StreamEvent::Usage {
812 output_tokens: usage.output_tokens,
813 })
814 .await;
815 }
816 "message_stop" => {
817 debug!(
818 "Message completed: stop_reason={}, usage={}",
819 match *stop_reason {
820 StopReason::EndTurn => "end_turn",
821 StopReason::ToolUse => "tool_use",
822 StopReason::MaxTokens => "max_tokens",
823 },
824 usage.output_tokens
825 );
826 debug!(
827 "message_stop final usage: cache_read={}, cache_created={}",
828 usage.cache_read_input_tokens, usage.cache_creation_input_tokens
829 );
830 let _ = tx
831 .send(StreamEvent::Done(finalize_incomplete_stream(
832 std::mem::take(blocks),
833 stop_reason.clone(),
834 usage.clone(),
835 )))
836 .await;
837 return true;
838 }
839 "error" => {
840 let msg = evt["error"]["message"]
841 .as_str()
842 .unwrap_or("unknown stream error")
843 .to_string();
844 let _ = tx.send(StreamEvent::Error(msg)).await;
845 return true;
846 }
847 _ => {}
848 }
849
850 false
851}
852
853fn build_chat_response(
854 blocks: Vec<AssembledBlock>,
855 stop_reason: StopReason,
856 usage: Usage,
857) -> ChatResponse {
858 let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
859 ChatResponse {
860 content,
861 stop_reason,
862 usage,
863 }
864}
865
866fn finalize_incomplete_stream(
867 blocks: Vec<AssembledBlock>,
868 stop_reason: StopReason,
869 usage: Usage,
870) -> ChatResponse {
871 build_chat_response(blocks, stop_reason, usage)
872}
873
874#[derive(Default)]
875enum AssembledBlock {
876 #[default]
877 Empty,
878 Text(String),
879 Thinking {
880 text: String,
881 signature: Option<String>,
882 },
883 ToolUse {
884 id: String,
885 name: String,
886 input_json: String,
887 },
888 ServerToolUse {
889 id: String,
890 name: String,
891 input_json: String,
892 },
893 WebSearchResult {
894 tool_use_id: String,
895 content_json: String,
896 },
897}
898
899impl AssembledBlock {
900 fn finish(self) -> Option<ContentBlock> {
901 match self {
902 AssembledBlock::Empty => None,
903 AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
904 AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
905 thinking: text,
906 signature,
907 }),
908 AssembledBlock::ToolUse {
909 id,
910 name,
911 input_json,
912 } => {
913 let input: Value = if input_json.is_empty() {
914 json!({})
915 } else {
916 serde_json::from_str(&input_json).unwrap_or(json!({}))
917 };
918 Some(ContentBlock::ToolUse { id, name, input })
919 }
920 AssembledBlock::ServerToolUse {
921 id,
922 name,
923 input_json,
924 } => {
925 let input: Value = if input_json.is_empty() {
926 json!({})
927 } else {
928 serde_json::from_str(&input_json).unwrap_or(json!({}))
929 };
930 Some(ContentBlock::ServerToolUse { id, name, input })
931 }
932 AssembledBlock::WebSearchResult {
933 tool_use_id,
934 content_json,
935 } => {
936 let content: Value = if content_json.is_empty() {
937 json!({"results": []})
938 } else {
939 serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
940 };
941 Some(ContentBlock::WebSearchResult {
942 tool_use_id,
943 content: parse_web_search_content(&content),
944 })
945 }
946 }
947 }
948}
949
950fn parse_usage(u: &Value) -> Usage {
953 Usage {
954 input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
955 output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
956 cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
957 cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
958 }
959}
960
961fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
965 let fresh = parse_usage(u);
966 if fresh.input_tokens > 0 {
967 acc.input_tokens = fresh.input_tokens;
968 }
969 if fresh.output_tokens > 0 {
970 acc.output_tokens = fresh.output_tokens;
971 }
972 if fresh.cache_creation_input_tokens > 0 {
973 acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
974 }
975 if fresh.cache_read_input_tokens > 0 {
976 acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
977 }
978 acc
979}
980
981fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
983 let results = value["results"]
984 .as_array()
985 .map(|arr| {
986 arr.iter()
987 .filter_map(|item| {
988 Some(crate::providers::WebSearchResultItem {
989 title: item["title"].as_str().map(String::from),
990 url: item["url"].as_str()?.to_string(),
991 encrypted_content: item["encrypted_content"].as_str().map(String::from),
992 snippet: item["snippet"].as_str().map(String::from),
993 })
994 })
995 .collect()
996 })
997 .unwrap_or_default();
998
999 crate::providers::WebSearchContent { results }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004 use super::*;
1005
1006 #[test]
1007 fn take_next_sse_frame_supports_crlf_delimiter() {
1008 let mut buffer = concat!(
1009 "event: message_start\r\n",
1010 "data: {\"type\":\"message_start\"}\r\n\r\n",
1011 "data: {\"type\":\"message_stop\"}\r\n\r\n"
1012 )
1013 .to_string();
1014
1015 let first = take_next_sse_frame(&mut buffer).expect("first frame");
1016 assert!(first.contains("message_start"));
1017
1018 let second = take_next_sse_frame(&mut buffer).expect("second frame");
1019 assert!(second.contains("message_stop"));
1020 assert!(buffer.is_empty());
1021 }
1022
1023 #[test]
1024 fn take_trailing_sse_frame_returns_unterminated_event() {
1025 let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
1026 let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
1027 assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
1028 assert!(buffer.is_empty());
1029 }
1030
1031 #[test]
1032 fn extract_sse_data_line_supports_optional_space() {
1033 assert_eq!(
1034 extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
1035 Some("{\"k\":1}".to_string())
1036 );
1037 assert_eq!(
1038 extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
1039 Some("{\"k\":2}".to_string())
1040 );
1041 }
1042
1043 #[test]
1044 fn finalize_incomplete_stream_keeps_accumulated_content() {
1045 let response = finalize_incomplete_stream(
1046 vec![AssembledBlock::Text("partial reply".to_string())],
1047 StopReason::EndTurn,
1048 Usage::default(),
1049 );
1050
1051 assert_eq!(response.stop_reason, StopReason::EndTurn);
1052 assert_eq!(response.content.len(), 1);
1053 match &response.content[0] {
1054 ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
1055 other => panic!("unexpected block: {other:?}"),
1056 }
1057 }
1058}