1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use std::sync::Arc;
7use tokio::sync::mpsc;
8
9use crate::constants::{
10 ANTHROPIC_API_VERSION, DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_CONTENT_TIMEOUT_SECS,
11 DEFAULT_READ_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS, THINKING_BUDGET_NEW_MODELS,
12 THINKING_BUDGET_OLD_MODELS,
13};
14use crate::models::context_window_for;
15use crate::tools::ToolDefinition;
16
17use super::{
18 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
19 StreamEvent, Usage,
20};
21
22pub struct AnthropicProvider {
23 api_key: String,
24 model: String,
25 base_url: String,
26 client: reqwest::Client,
27 extra_headers: Vec<(String, String)>,
29}
30
31impl AnthropicProvider {
32 pub fn new(api_key: String, model: String, base_url: String) -> Self {
33 Self::with_headers(api_key, model, base_url, None)
34 }
35
36 fn load_proxy_from_env() -> Option<String> {
38 std::env::var("HTTPS_PROXY")
39 .ok()
40 .or_else(|| std::env::var("https_proxy").ok())
41 .or_else(|| std::env::var("HTTP_PROXY").ok())
42 .or_else(|| std::env::var("http_proxy").ok())
43 }
44
45 pub fn with_headers(
46 api_key: String,
47 model: String,
48 base_url: String,
49 extra_headers: Option<std::collections::HashMap<String, String>>,
50 ) -> Self {
51 let mut client_builder = reqwest::Client::builder()
56 .connect_timeout(std::time::Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
57 .read_timeout(std::time::Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
58 .timeout(std::time::Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)); if let Some(proxy_url) = Self::load_proxy_from_env()
62 && let Ok(proxy) = reqwest::Proxy::all(&proxy_url)
63 {
64 log::info!("AnthropicProvider using proxy: {}", proxy_url);
65 client_builder = client_builder.proxy(proxy);
66 }
67
68 let client = client_builder
69 .build()
70 .unwrap_or_else(|_| reqwest::Client::new());
71 let extra_headers: Vec<(String, String)> = extra_headers
72 .map(|h| h.into_iter().collect())
73 .unwrap_or_default();
74 Self {
75 api_key,
76 model,
77 base_url,
78 client,
79 extra_headers,
80 }
81 }
82
83 fn is_official_anthropic(&self) -> bool {
86 self.base_url.contains("api.anthropic.com")
87 }
88
89 fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
90 log::debug!(
97 "convert_messages: always filtering thinking blocks, base_url={}",
98 self.base_url
99 );
100
101 let mut thinking_count = 0;
103 for m in messages {
104 if let MessageContent::Blocks(blocks) = &m.content {
105 for b in blocks {
106 if matches!(b, ContentBlock::Thinking { .. }) {
107 thinking_count += 1;
108 }
109 }
110 }
111 }
112 if thinking_count > 0 {
113 log::debug!(
114 "convert_messages: Filtering {} thinking blocks from {} messages",
115 thinking_count,
116 messages.len()
117 );
118 }
119
120 messages
121 .iter()
122 .filter(|m| m.role != Role::System)
123 .map(|m| {
124 let role = match m.role {
125 Role::User | Role::Tool => "user",
126 Role::Assistant => "assistant",
127 Role::System => unreachable!(),
128 };
129
130 let content = match &m.content {
131 MessageContent::Text(text) => json!(text),
132 MessageContent::Blocks(blocks) => {
133 let converted: Vec<Value> = blocks
134 .iter()
135 .filter(|b| {
136 if matches!(b, ContentBlock::Thinking { .. }) {
138 return false;
139 }
140 true
141 })
142 .map(|b| match b {
143 ContentBlock::Text { text } => json!({"type": "text", "text": text}),
144 ContentBlock::ToolUse { id, name, input } => {
145 json!({"type": "tool_use", "id": id, "name": name, "input": input})
146 }
147 ContentBlock::ToolResult { tool_use_id, content } => {
148 json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
149 }
150 ContentBlock::Thinking { thinking, signature } => {
151 let mut obj = json!({"type": "thinking", "thinking": thinking});
152 if let Some(sig) = signature
154 && !sig.is_empty() {
155 obj["signature"] = json!(sig);
156 }
157 obj
158 }
159 ContentBlock::ServerToolUse { id, name, input } => {
160 json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
161 }
162 ContentBlock::ServerToolResult { tool_use_id, content } => {
163 json!({"type": "server_tool_result", "tool_use_id": tool_use_id, "content": content})
164 }
165 ContentBlock::WebSearchResult { tool_use_id, content } => {
166 json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
167 }
168 })
169 .collect();
170
171 if converted.is_empty() {
172 json!([])
173 } else {
174 json!(converted)
175 }
176 }
177 };
178
179 if content.is_array() && content.as_array().map(|a| a.is_empty()).unwrap_or(false) {
181 return json!(null);
182 }
183
184 json!({"role": role, "content": content})
185 })
186 .filter(|v| !v.is_null())
187 .collect()
188 }
189
190 fn convert_tools_with_caching(
193 &self,
194 tools: &[ToolDefinition],
195 enable_caching: bool,
196 ) -> Vec<Value> {
197 let mut converted: Vec<Value> = tools
198 .iter()
199 .map(|t| {
200 json!({
201 "name": t.name,
202 "description": t.description,
203 "input_schema": t.parameters,
204 })
205 })
206 .collect();
207
208 if enable_caching && self.is_official_anthropic() && !converted.is_empty() {
211 let last_idx = converted.len() - 1;
212 if let Some(obj) = converted[last_idx].as_object_mut() {
213 obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
214 }
215 }
216
217 converted
218 }
219
220 fn build_body(&self, request: &ChatRequest) -> Value {
222 let mut body = json!({
223 "model": self.model,
224 "max_tokens": request.max_tokens,
225 "messages": self.convert_messages(&request.messages),
226 });
227
228 if request.enable_caching && self.is_official_anthropic() {
231 if let Some(system) = &request.system {
232 body["system"] = json!([
234 {
235 "type": "text",
236 "text": system,
237 "cache_control": {"type": "ephemeral"}
238 }
239 ]);
240 }
241 } else if let Some(system) = &request.system {
242 body["system"] = json!(system);
243 }
244
245 if !request.tools.is_empty() {
246 let tools = self.convert_tools_with_caching(&request.tools, request.enable_caching);
247 body["tools"] = json!(tools);
248 }
249
250 if !request.server_tools.is_empty() {
251 body["tools"] = json!(
252 body["tools"]
253 .as_array()
254 .map(|t| {
255 let mut tools = t.clone();
256 for st in &request.server_tools {
257 tools.push(serde_json::to_value(st).unwrap_or_default());
258 }
259 tools
260 })
261 .unwrap_or_else(|| request
262 .server_tools
263 .iter()
264 .map(|st| serde_json::to_value(st).unwrap_or_default())
265 .collect())
266 );
267 }
268
269 if request.think && self.is_official_anthropic() {
273 let config = thinking_config(&self.model);
274 log::debug!(
275 "Adding thinking config for model {}: {:?}",
276 self.model,
277 config
278 );
279 body["thinking"] = config;
280 } else if request.think {
281 log::debug!(
282 "Skipping thinking config for non-official API (model: {}, base_url: {})",
283 self.model,
284 self.base_url
285 );
286 }
287
288 body
289 }
290}
291
292fn thinking_config(model: &str) -> Value {
297 let m = model.to_lowercase();
298 let adaptive = m.contains("opus-4")
300 || m.contains("sonnet-4")
301 || m.contains("claude-4")
302 || m.contains("20250")
303 || m.contains("2025");
304 if adaptive {
305 json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_NEW_MODELS})
306 } else {
307 json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_OLD_MODELS})
309 }
310}
311
312#[async_trait]
313impl Provider for AnthropicProvider {
314 fn context_size(&self) -> Option<u32> {
315 context_window_for(&self.model)
316 }
317
318 fn model_name(&self) -> &str {
319 &self.model
320 }
321
322 fn clone_box(&self) -> Box<dyn Provider> {
323 Box::new(Self {
324 api_key: self.api_key.clone(),
325 model: self.model.clone(),
326 base_url: self.base_url.clone(),
327 client: reqwest::Client::new(),
328 extra_headers: self.extra_headers.clone(),
329 })
330 }
331
332 fn clone_arc(&self) -> Arc<dyn Provider> {
333 Arc::new(Self {
334 api_key: self.api_key.clone(),
335 model: self.model.clone(),
336 base_url: self.base_url.clone(),
337 client: reqwest::Client::new(),
338 extra_headers: self.extra_headers.clone(),
339 })
340 }
341
342 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
343 let body = self.build_body(&request);
344
345 let url = format!("{}/v1/messages", self.base_url);
346
347 crate::debug::debug_log()
349 .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
350
351 let mut req = self
352 .client
353 .post(&url)
354 .header("User-Agent", "curl/8.0")
355 .json(&body);
356
357 if self.is_official_anthropic() {
359 req = req
360 .header("x-api-key", &self.api_key)
361 .header("anthropic-version", ANTHROPIC_API_VERSION)
362 .header("anthropic-beta", "prompt-caching-2024-07-31");
363 } else {
364 req = req
365 .header("Authorization", format!("Bearer {}", self.api_key))
366 .header("anthropic-version", "2023-06-01")
368 .header("anthropic-beta", "prompt-caching-2024-07-31");
370 }
371
372 for (name, value) in &self.extra_headers {
374 req = req.header(name, value);
375 }
376
377 let response = req.send().await?;
378
379 let status = response.status();
380 let response_body: Value = response.json().await?;
381
382 crate::debug::debug_log().api_response(
384 status.as_u16(),
385 &serde_json::to_string(&response_body).unwrap_or_default(),
386 );
387
388 if !status.is_success() {
389 let err_msg = response_body["error"]["message"]
390 .as_str()
391 .unwrap_or("unknown error");
392 anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
393 }
394
395 let stop_reason = match response_body["stop_reason"].as_str() {
396 Some("tool_use") => StopReason::ToolUse,
397 Some("max_tokens") => StopReason::MaxTokens,
398 _ => StopReason::EndTurn,
399 };
400
401 let content = response_body["content"]
402 .as_array()
403 .unwrap_or(&vec![])
404 .iter()
405 .filter_map(|block| match block["type"].as_str()? {
406 "text" => Some(ContentBlock::Text {
407 text: block["text"].as_str()?.to_string(),
408 }),
409 "tool_use" => Some(ContentBlock::ToolUse {
410 id: block["id"].as_str()?.to_string(),
411 name: block["name"].as_str()?.to_string(),
412 input: block["input"].clone(),
413 }),
414 "thinking" => Some(ContentBlock::Thinking {
415 thinking: block["thinking"].as_str()?.to_string(),
416 signature: block["signature"].as_str().map(String::from),
417 }),
418 "server_tool_use" => Some(ContentBlock::ServerToolUse {
419 id: block["id"].as_str()?.to_string(),
420 name: block["name"].as_str()?.to_string(),
421 input: block["input"].clone(),
422 }),
423 "web_search_tool_result" => {
424 let tool_use_id = block["tool_use_id"].as_str()?.to_string();
425 let content = parse_web_search_content(&block["content"]);
426 Some(ContentBlock::WebSearchResult {
427 tool_use_id,
428 content,
429 })
430 }
431 _ => None,
432 })
433 .collect();
434
435 Ok(ChatResponse {
436 content,
437 stop_reason,
438 usage: parse_usage(&response_body["usage"]),
439 })
440 }
441
442 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
443 let mut body = self.build_body(&request);
444 body["stream"] = json!(true);
445
446 let url = format!("{}/v1/messages", self.base_url);
447
448 crate::debug::debug_log()
450 .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
451
452 let mut req = self
453 .client
454 .post(&url)
455 .header("User-Agent", "curl/8.0")
456 .json(&body);
457
458 if self.is_official_anthropic() {
460 req = req
461 .header("x-api-key", &self.api_key)
462 .header("anthropic-version", ANTHROPIC_API_VERSION)
463 .header("anthropic-beta", "prompt-caching-2024-07-31");
464 } else {
465 req = req
466 .header("Authorization", format!("Bearer {}", self.api_key))
467 .header("anthropic-version", "2023-06-01")
469 .header("anthropic-beta", "prompt-caching-2024-07-31");
471 }
472
473 for (name, value) in &self.extra_headers {
475 req = req.header(name, value);
476 }
477
478 let response = req.send().await?;
479
480 if !response.status().is_success() {
481 let status = response.status();
482 let text = response.text().await.unwrap_or_default();
483 anyhow::bail!("Anthropic API error ({}): {}", status, text);
484 }
485
486 let (tx, rx) = mpsc::channel(64);
487 tokio::spawn(async move {
488 let mut stream = response.bytes_stream();
489 let mut buffer = String::new();
490 let mut sent_first_byte = false;
491
492 let mut blocks: Vec<AssembledBlock> = Vec::new();
494 let mut stop_reason = StopReason::EndTurn;
495 let mut usage = Usage::default();
496
497 let mut last_content_time = std::time::Instant::now();
499 const CONTENT_TIMEOUT_SECS: u64 = DEFAULT_CONTENT_TIMEOUT_SECS; while let Some(chunk) = stream.next().await {
502 let chunk = match chunk {
503 Ok(c) => c,
504 Err(e) => {
505 let error_msg = e.to_string();
507 let is_timeout =
508 error_msg.contains("timeout") || error_msg.contains("timed out");
509 let is_decode =
510 error_msg.contains("decode") || error_msg.contains("decoding");
511
512 let msg = if is_timeout {
513 format!(
514 "Stream timeout - the API took too long to respond: {}",
515 error_msg
516 )
517 } else if is_decode {
518 format!(
519 "Stream decode error - possibly interrupted or corrupted response: {}",
520 error_msg
521 )
522 } else {
523 format!("Stream read error: {}", error_msg)
524 };
525
526 if sent_first_byte && !blocks.is_empty() {
528 debug!("finalizing partial stream due to error");
529 let _ = tx
530 .send(StreamEvent::Done(finalize_incomplete_stream(
531 std::mem::take(&mut blocks),
532 stop_reason,
533 usage,
534 )))
535 .await;
536 } else {
537 let _ = tx.send(StreamEvent::Error(msg)).await;
538 }
539 return;
540 }
541 };
542
543 if !sent_first_byte {
544 sent_first_byte = true;
545 let _ = tx.send(StreamEvent::FirstByte).await;
546 }
547
548 buffer.push_str(&String::from_utf8_lossy(&chunk));
549
550 let elapsed = last_content_time.elapsed().as_secs();
552 if elapsed > CONTENT_TIMEOUT_SECS && !blocks.is_empty() {
553 crate::debug::debug_log().stream_chunk(
554 "TIMEOUT_FORCE_FINALIZE",
555 &format!("elapsed={}s, blocks={}", elapsed, blocks.len()),
556 );
557 let _ = tx
558 .send(StreamEvent::Done(finalize_incomplete_stream(
559 std::mem::take(&mut blocks),
560 stop_reason,
561 usage,
562 )))
563 .await;
564 return;
565 }
566
567 while let Some(frame) = take_next_sse_frame(&mut buffer) {
568 if handle_sse_frame(
569 &frame,
570 &mut blocks,
571 &mut stop_reason,
572 &mut usage,
573 &tx,
574 &mut last_content_time,
575 )
576 .await
577 {
578 return;
579 }
580 }
581 }
582
583 if let Some(frame) = take_trailing_sse_frame(&mut buffer)
584 && handle_sse_frame(
585 &frame,
586 &mut blocks,
587 &mut stop_reason,
588 &mut usage,
589 &tx,
590 &mut last_content_time,
591 )
592 .await
593 {
594 return;
595 }
596
597 if sent_first_byte {
598 debug!("stream ended without explicit message_stop; finalizing best-effort");
599 let _ = tx
600 .send(StreamEvent::Done(finalize_incomplete_stream(
601 std::mem::take(&mut blocks),
602 stop_reason,
603 usage,
604 )))
605 .await;
606 } else {
607 let _ = tx
608 .send(StreamEvent::Error(
609 "stream ended before any events were received".to_string(),
610 ))
611 .await;
612 }
613 });
614
615 Ok(rx)
616 }
617}
618
619fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
620 let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
621 let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
622 let (pos, delim_len) = match (lf, crlf) {
623 (Some(a), Some(b)) => {
624 if a.0 <= b.0 {
625 a
626 } else {
627 b
628 }
629 }
630 (Some(a), None) => a,
631 (None, Some(b)) => b,
632 (None, None) => return None,
633 };
634
635 let frame = buffer[..pos].to_string();
636 buffer.drain(..pos + delim_len);
637 Some(frame)
638}
639
640fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
641 let frame = buffer.trim().trim_end_matches('\r').to_string();
642 buffer.clear();
643 if frame.is_empty() { None } else { Some(frame) }
644}
645
646fn extract_sse_data_line(frame: &str) -> Option<String> {
647 for line in frame.lines() {
648 let line = line.trim_end_matches('\r');
649 if let Some(rest) = line.strip_prefix("data: ") {
651 return Some(rest.to_string());
652 }
653 if let Some(rest) = line.strip_prefix("data:") {
654 return Some(rest.to_string());
655 }
656 }
657 None
658}
659
660async fn handle_sse_frame(
661 frame: &str,
662 blocks: &mut Vec<AssembledBlock>,
663 stop_reason: &mut StopReason,
664 usage: &mut Usage,
665 tx: &mpsc::Sender<StreamEvent>,
666 last_content_time: &mut std::time::Instant,
667) -> bool {
668 let Some(data_line) = extract_sse_data_line(frame) else {
669 return false;
670 };
671
672 let evt: Value = match serde_json::from_str(&data_line) {
673 Ok(v) => v,
674 Err(_) => return false,
675 };
676
677 handle_sse_event(evt, blocks, stop_reason, usage, tx, last_content_time).await
678}
679
680async fn handle_sse_event(
681 evt: Value,
682 blocks: &mut Vec<AssembledBlock>,
683 stop_reason: &mut StopReason,
684 usage: &mut Usage,
685 tx: &mpsc::Sender<StreamEvent>,
686 last_content_time: &mut std::time::Instant,
687) -> bool {
688 let evt_type = evt["type"].as_str().unwrap_or("");
689
690 let evt_json = serde_json::to_string(&evt).unwrap_or_default();
692 crate::debug::debug_log().stream_chunk(evt_type, &evt_json);
693
694 if evt_type == "content_block_delta" {
696 let delta_type = evt["delta"]["type"].as_str().unwrap_or("");
697 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
698 log::debug!(
699 "content_block_delta: type={}, idx={}, blocks_len={}, has_block={}",
700 delta_type,
701 idx,
702 blocks.len(),
703 idx < blocks.len()
704 );
705 }
706
707 if evt_type != "ping" {
709 *last_content_time = std::time::Instant::now();
710 }
711
712 match evt_type {
713 "message_start" => {
714 *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
719 debug!(
720 "message_start: usage_json={}",
721 serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
722 );
723 debug!(
724 "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
725 usage.input_tokens,
726 usage.output_tokens,
727 usage.cache_read_input_tokens,
728 usage.cache_creation_input_tokens
729 );
730 let _ = tx
732 .send(StreamEvent::Usage {
733 output_tokens: usage.output_tokens,
734 })
735 .await;
736 }
737 "content_block_start" => {
738 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
739 let block = &evt["content_block"];
740 let kind = block["type"].as_str().unwrap_or("");
741 while blocks.len() <= idx {
742 blocks.push(AssembledBlock::default());
743 }
744 match kind {
745 "text" => {
746 blocks[idx] = AssembledBlock::Text(String::new());
747 }
748 "thinking" => {
749 blocks[idx] = AssembledBlock::Thinking {
750 text: String::new(),
751 signature: None,
752 };
753 }
754 "tool_use" | "server_tool_use" => {
755 let id = block["id"].as_str().unwrap_or_default();
756 let name = block["name"].as_str().unwrap_or_default();
757 let is_server = kind == "server_tool_use";
758 blocks[idx] = if is_server {
759 AssembledBlock::ServerToolUse {
760 id: id.into(),
761 name: name.into(),
762 input_json: String::new(),
763 }
764 } else {
765 AssembledBlock::ToolUse {
766 id: id.into(),
767 name: name.into(),
768 input_json: String::new(),
769 }
770 };
771 let _ = tx
772 .send(StreamEvent::ToolUseStart {
773 id: id.into(),
774 name: name.into(),
775 })
776 .await;
777 }
778 "web_search_tool_result" => {
779 let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
780 blocks[idx] = AssembledBlock::WebSearchResult {
781 tool_use_id,
782 content_json: String::new(),
783 };
784 }
785 _ => {}
786 }
787 }
788 "content_block_delta" => {
789 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
790 let delta = &evt["delta"];
791 let dt = delta["type"].as_str().unwrap_or("");
792 if idx >= blocks.len() {
793 return false;
794 }
795 match (dt, &mut blocks[idx]) {
796 ("text_delta", AssembledBlock::Text(buf)) => {
797 if let Some(t) = delta["text"].as_str() {
798 buf.push_str(t);
799 let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
800 }
801 }
802 ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
803 if let Some(t) = delta["thinking"].as_str() {
804 text.push_str(t);
805 log::debug!("Received thinking_delta: {} chars", t.len());
806 let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
807 }
808 }
809 ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
810 if let Some(s) = delta["signature"].as_str() {
811 signature.get_or_insert_with(String::new).push_str(s);
812 }
813 }
814 ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
815 if let Some(p) = delta["partial_json"].as_str() {
816 input_json.push_str(p);
817 let _ = tx
818 .send(StreamEvent::ToolInputDelta {
819 bytes_so_far: input_json.len(),
820 })
821 .await;
822 }
823 }
824 ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
825 if let Some(p) = delta["partial_json"].as_str() {
826 input_json.push_str(p);
827 let _ = tx
828 .send(StreamEvent::ToolInputDelta {
829 bytes_so_far: input_json.len(),
830 })
831 .await;
832 }
833 }
834 _ => {}
835 }
836 }
837 "content_block_stop" => {
838 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
839 if let Some(AssembledBlock::ToolUse {
840 id,
841 name,
842 input_json,
843 }) = blocks.get(idx)
844 {
845 let input: Value = if input_json.is_empty() {
846 json!({})
847 } else {
848 serde_json::from_str(input_json).unwrap_or(json!({}))
849 };
850 let _ = tx
851 .send(StreamEvent::ToolInputComplete {
852 id: id.clone(),
853 name: name.clone(),
854 input,
855 })
856 .await;
857 }
858 }
859 "message_delta" => {
860 if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
861 *stop_reason = match sr {
862 "tool_use" => StopReason::ToolUse,
863 "max_tokens" => StopReason::MaxTokens,
864 _ => StopReason::EndTurn,
865 };
866 }
867 *usage = merge_usage(usage.clone(), &evt["usage"]);
871 debug!(
872 "message_delta: input={}, output={}, cache_read={}, cache_created={}",
873 usage.input_tokens,
874 usage.output_tokens,
875 usage.cache_read_input_tokens,
876 usage.cache_creation_input_tokens
877 );
878 let _ = tx
880 .send(StreamEvent::Usage {
881 output_tokens: usage.output_tokens,
882 })
883 .await;
884 }
885 "message_stop" => {
886 debug!(
887 "Message completed: stop_reason={}, usage={}",
888 match *stop_reason {
889 StopReason::EndTurn => "end_turn",
890 StopReason::ToolUse => "tool_use",
891 StopReason::MaxTokens => "max_tokens",
892 },
893 usage.output_tokens
894 );
895 debug!(
896 "message_stop final usage: cache_read={}, cache_created={}",
897 usage.cache_read_input_tokens, usage.cache_creation_input_tokens
898 );
899 let _ = tx
900 .send(StreamEvent::Done(finalize_incomplete_stream(
901 std::mem::take(blocks),
902 stop_reason.clone(),
903 usage.clone(),
904 )))
905 .await;
906 return true;
907 }
908 "error" => {
909 let msg = evt["error"]["message"]
910 .as_str()
911 .unwrap_or("unknown stream error")
912 .to_string();
913 let _ = tx.send(StreamEvent::Error(msg)).await;
914 return true;
915 }
916 _ => {}
917 }
918
919 false
920}
921
922fn build_chat_response(
923 blocks: Vec<AssembledBlock>,
924 stop_reason: StopReason,
925 usage: Usage,
926) -> ChatResponse {
927 let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
928 ChatResponse {
929 content,
930 stop_reason,
931 usage,
932 }
933}
934
935fn finalize_incomplete_stream(
936 blocks: Vec<AssembledBlock>,
937 stop_reason: StopReason,
938 usage: Usage,
939) -> ChatResponse {
940 build_chat_response(blocks, stop_reason, usage)
941}
942
943#[derive(Default)]
944enum AssembledBlock {
945 #[default]
946 Empty,
947 Text(String),
948 Thinking {
949 text: String,
950 signature: Option<String>,
951 },
952 ToolUse {
953 id: String,
954 name: String,
955 input_json: String,
956 },
957 ServerToolUse {
958 id: String,
959 name: String,
960 input_json: String,
961 },
962 WebSearchResult {
963 tool_use_id: String,
964 content_json: String,
965 },
966}
967
968impl AssembledBlock {
969 fn finish(self) -> Option<ContentBlock> {
970 match self {
971 AssembledBlock::Empty => None,
972 AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
973 AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
974 thinking: text,
975 signature,
976 }),
977 AssembledBlock::ToolUse {
978 id,
979 name,
980 input_json,
981 } => {
982 let input: Value = if input_json.is_empty() {
983 json!({})
984 } else {
985 serde_json::from_str(&input_json).unwrap_or(json!({}))
986 };
987 Some(ContentBlock::ToolUse { id, name, input })
988 }
989 AssembledBlock::ServerToolUse {
990 id,
991 name,
992 input_json,
993 } => {
994 let input: Value = if input_json.is_empty() {
995 json!({})
996 } else {
997 serde_json::from_str(&input_json).unwrap_or(json!({}))
998 };
999 Some(ContentBlock::ServerToolUse { id, name, input })
1000 }
1001 AssembledBlock::WebSearchResult {
1002 tool_use_id,
1003 content_json,
1004 } => {
1005 let content: Value = if content_json.is_empty() {
1006 json!({"results": []})
1007 } else {
1008 serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
1009 };
1010 Some(ContentBlock::WebSearchResult {
1011 tool_use_id,
1012 content: parse_web_search_content(&content),
1013 })
1014 }
1015 }
1016 }
1017}
1018
1019fn parse_usage(u: &Value) -> Usage {
1022 Usage {
1023 input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
1024 output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
1025 cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
1026 cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
1027 }
1028}
1029
1030fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
1034 let fresh = parse_usage(u);
1035 if fresh.input_tokens > 0 {
1036 acc.input_tokens = fresh.input_tokens;
1037 }
1038 if fresh.output_tokens > 0 {
1039 acc.output_tokens = fresh.output_tokens;
1040 }
1041 if fresh.cache_creation_input_tokens > 0 {
1042 acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
1043 }
1044 if fresh.cache_read_input_tokens > 0 {
1045 acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
1046 }
1047 acc
1048}
1049
1050fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
1052 let results = value["results"]
1053 .as_array()
1054 .map(|arr| {
1055 arr.iter()
1056 .filter_map(|item| {
1057 Some(crate::providers::WebSearchResultItem {
1058 title: item["title"].as_str().map(String::from),
1059 url: item["url"].as_str()?.to_string(),
1060 encrypted_content: item["encrypted_content"].as_str().map(String::from),
1061 snippet: item["snippet"].as_str().map(String::from),
1062 })
1063 })
1064 .collect()
1065 })
1066 .unwrap_or_default();
1067
1068 crate::providers::WebSearchContent { results }
1069}
1070
1071#[cfg(test)]
1072mod tests {
1073 use super::*;
1074
1075 #[test]
1076 fn take_next_sse_frame_supports_crlf_delimiter() {
1077 let mut buffer = concat!(
1078 "event: message_start\r\n",
1079 "data: {\"type\":\"message_start\"}\r\n\r\n",
1080 "data: {\"type\":\"message_stop\"}\r\n\r\n"
1081 )
1082 .to_string();
1083
1084 let first = take_next_sse_frame(&mut buffer).expect("first frame");
1085 assert!(first.contains("message_start"));
1086
1087 let second = take_next_sse_frame(&mut buffer).expect("second frame");
1088 assert!(second.contains("message_stop"));
1089 assert!(buffer.is_empty());
1090 }
1091
1092 #[test]
1093 fn take_trailing_sse_frame_returns_unterminated_event() {
1094 let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
1095 let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
1096 assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
1097 assert!(buffer.is_empty());
1098 }
1099
1100 #[test]
1101 fn extract_sse_data_line_supports_optional_space() {
1102 assert_eq!(
1103 extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
1104 Some("{\"k\":1}".to_string())
1105 );
1106 assert_eq!(
1107 extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
1108 Some("{\"k\":2}".to_string())
1109 );
1110 }
1111
1112 #[test]
1113 fn finalize_incomplete_stream_keeps_accumulated_content() {
1114 let response = finalize_incomplete_stream(
1115 vec![AssembledBlock::Text("partial reply".to_string())],
1116 StopReason::EndTurn,
1117 Usage::default(),
1118 );
1119
1120 assert_eq!(response.stop_reason, StopReason::EndTurn);
1121 assert_eq!(response.content.len(), 1);
1122 match &response.content[0] {
1123 ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
1124 other => panic!("unexpected block: {other:?}"),
1125 }
1126 }
1127}