1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use std::sync::Arc;
7use tokio::sync::mpsc;
8
9use crate::constants::{
10 ANTHROPIC_API_VERSION, DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_CONTENT_TIMEOUT_SECS,
11 DEFAULT_READ_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS, THINKING_BUDGET_NEW_MODELS,
12 THINKING_BUDGET_OLD_MODELS,
13};
14use crate::models::context_window_for;
15use crate::tools::ToolDefinition;
16
17use super::{
18 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
19 StreamEvent, Usage,
20};
21
22pub struct AnthropicProvider {
23 api_key: String,
24 model: String,
25 base_url: String,
26 client: reqwest::Client,
27 extra_headers: Vec<(String, String)>,
29}
30
31impl AnthropicProvider {
32 pub fn new(api_key: String, model: String, base_url: String) -> Self {
33 Self::with_headers(api_key, model, base_url, None)
34 }
35
36 fn load_proxy_from_env() -> Option<String> {
38 std::env::var("HTTPS_PROXY")
39 .ok()
40 .or_else(|| std::env::var("https_proxy").ok())
41 .or_else(|| std::env::var("HTTP_PROXY").ok())
42 .or_else(|| std::env::var("http_proxy").ok())
43 }
44
45 pub fn with_headers(
46 api_key: String,
47 model: String,
48 base_url: String,
49 extra_headers: Option<std::collections::HashMap<String, String>>,
50 ) -> Self {
51 let mut client_builder = reqwest::Client::builder()
56 .connect_timeout(std::time::Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
57 .read_timeout(std::time::Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
58 .timeout(std::time::Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)); if let Some(proxy_url) = Self::load_proxy_from_env()
62 && let Ok(proxy) = reqwest::Proxy::all(&proxy_url)
63 {
64 log::info!("AnthropicProvider using proxy: {}", proxy_url);
65 client_builder = client_builder.proxy(proxy);
66 }
67
68 let client = client_builder
69 .build()
70 .unwrap_or_else(|_| reqwest::Client::new());
71 let extra_headers: Vec<(String, String)> = extra_headers
72 .map(|h| h.into_iter().collect())
73 .unwrap_or_default();
74 Self {
75 api_key,
76 model,
77 base_url,
78 client,
79 extra_headers,
80 }
81 }
82
83 fn is_official_anthropic(&self) -> bool {
86 self.base_url.contains("api.anthropic.com")
87 }
88
89 fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
90 log::debug!(
97 "convert_messages: always filtering thinking blocks, base_url={}",
98 self.base_url
99 );
100
101 let mut thinking_count = 0;
103 for m in messages {
104 if let MessageContent::Blocks(blocks) = &m.content {
105 for b in blocks {
106 if matches!(b, ContentBlock::Thinking { .. }) {
107 thinking_count += 1;
108 }
109 }
110 }
111 }
112 if thinking_count > 0 {
113 log::debug!(
114 "convert_messages: Filtering {} thinking blocks from {} messages",
115 thinking_count,
116 messages.len()
117 );
118 }
119
120 messages
121 .iter()
122 .filter(|m| m.role != Role::System)
123 .map(|m| {
124 let role = match m.role {
125 Role::User | Role::Tool => "user",
126 Role::Assistant => "assistant",
127 Role::System => unreachable!(),
128 };
129
130 let content = match &m.content {
131 MessageContent::Text(text) => json!(text),
132 MessageContent::Blocks(blocks) => {
133 let converted: Vec<Value> = blocks
134 .iter()
135 .filter(|b| {
136 if matches!(b, ContentBlock::Thinking { .. }) {
138 return false;
139 }
140 true
141 })
142 .map(|b| match b {
143 ContentBlock::Text { text } => json!({"type": "text", "text": text}),
144 ContentBlock::ToolUse { id, name, input } => {
145 json!({"type": "tool_use", "id": id, "name": name, "input": input})
146 }
147 ContentBlock::ToolResult { tool_use_id, content } => {
148 json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
149 }
150 ContentBlock::Thinking { thinking, signature } => {
151 let mut obj = json!({"type": "thinking", "thinking": thinking});
152 if let Some(sig) = signature
154 && !sig.is_empty() {
155 obj["signature"] = json!(sig);
156 }
157 obj
158 }
159 ContentBlock::ServerToolUse { id, name, input } => {
160 json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
161 }
162 ContentBlock::ServerToolResult { tool_use_id, content } => {
163 json!({"type": "server_tool_result", "tool_use_id": tool_use_id, "content": content})
164 }
165 ContentBlock::WebSearchResult { tool_use_id, content } => {
166 json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
167 }
168 })
169 .collect();
170
171 if converted.is_empty() {
172 json!([])
173 } else {
174 json!(converted)
175 }
176 }
177 };
178
179 if content.is_array() && content.as_array().map(|a| a.is_empty()).unwrap_or(false) {
181 return json!(null);
182 }
183
184 json!({"role": role, "content": content})
185 })
186 .filter(|v| !v.is_null())
187 .collect()
188 }
189
190 fn convert_tools_with_caching(
192 &self,
193 tools: &[ToolDefinition],
194 enable_caching: bool,
195 ) -> Vec<Value> {
196 let mut converted: Vec<Value> = tools
197 .iter()
198 .map(|t| {
199 json!({
200 "name": t.name,
201 "description": t.description,
202 "input_schema": t.parameters,
203 })
204 })
205 .collect();
206
207 if enable_caching && !converted.is_empty() {
209 let last_idx = converted.len() - 1;
210 if let Some(obj) = converted[last_idx].as_object_mut() {
211 obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
212 }
213 }
214
215 converted
216 }
217
218 fn build_body(&self, request: &ChatRequest) -> Value {
220 let mut body = json!({
221 "model": self.model,
222 "max_tokens": request.max_tokens,
223 "messages": self.convert_messages(&request.messages),
224 });
225
226 if request.enable_caching {
228 if let Some(system) = &request.system {
229 body["system"] = json!([
231 {
232 "type": "text",
233 "text": system,
234 "cache_control": {"type": "ephemeral"}
235 }
236 ]);
237 }
238 } else if let Some(system) = &request.system {
239 body["system"] = json!(system);
240 }
241
242 if !request.tools.is_empty() {
243 let tools = self.convert_tools_with_caching(&request.tools, request.enable_caching);
244 body["tools"] = json!(tools);
245 }
246
247 if !request.server_tools.is_empty() {
248 body["tools"] = json!(
249 body["tools"]
250 .as_array()
251 .map(|t| {
252 let mut tools = t.clone();
253 for st in &request.server_tools {
254 tools.push(serde_json::to_value(st).unwrap_or_default());
255 }
256 tools
257 })
258 .unwrap_or_else(|| request
259 .server_tools
260 .iter()
261 .map(|st| serde_json::to_value(st).unwrap_or_default())
262 .collect())
263 );
264 }
265
266 if request.think && self.is_official_anthropic() {
270 let config = thinking_config(&self.model);
271 log::debug!(
272 "Adding thinking config for model {}: {:?}",
273 self.model,
274 config
275 );
276 body["thinking"] = config;
277 } else if request.think {
278 log::debug!(
279 "Skipping thinking config for non-official API (model: {}, base_url: {})",
280 self.model,
281 self.base_url
282 );
283 }
284
285 body
286 }
287}
288
289fn thinking_config(model: &str) -> Value {
294 let m = model.to_lowercase();
295 let adaptive = m.contains("opus-4")
297 || m.contains("sonnet-4")
298 || m.contains("claude-4")
299 || m.contains("20250")
300 || m.contains("2025");
301 if adaptive {
302 json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_NEW_MODELS})
303 } else {
304 json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_OLD_MODELS})
306 }
307}
308
309#[async_trait]
310impl Provider for AnthropicProvider {
311 fn context_size(&self) -> Option<u32> {
312 context_window_for(&self.model)
313 }
314
315 fn model_name(&self) -> &str {
316 &self.model
317 }
318
319 fn clone_box(&self) -> Box<dyn Provider> {
320 Box::new(Self {
321 api_key: self.api_key.clone(),
322 model: self.model.clone(),
323 base_url: self.base_url.clone(),
324 client: reqwest::Client::new(),
325 extra_headers: self.extra_headers.clone(),
326 })
327 }
328
329 fn clone_arc(&self) -> Arc<dyn Provider> {
330 Arc::new(Self {
331 api_key: self.api_key.clone(),
332 model: self.model.clone(),
333 base_url: self.base_url.clone(),
334 client: reqwest::Client::new(),
335 extra_headers: self.extra_headers.clone(),
336 })
337 }
338
339 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
340 let body = self.build_body(&request);
341
342 let url = format!("{}/v1/messages", self.base_url);
343
344 crate::debug::debug_log()
346 .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
347
348 let mut req = self
349 .client
350 .post(&url)
351 .header("User-Agent", "curl/8.0")
352 .json(&body);
353
354 if self.is_official_anthropic() {
356 req = req
357 .header("x-api-key", &self.api_key)
358 .header("anthropic-version", ANTHROPIC_API_VERSION)
359 .header("anthropic-beta", "prompt-caching-2024-07-31");
360 } else {
361 req = req
362 .header("Authorization", format!("Bearer {}", self.api_key))
363 .header("anthropic-version", "2023-06-01")
365 .header("anthropic-beta", "prompt-caching-2024-07-31");
367 }
368
369 for (name, value) in &self.extra_headers {
371 req = req.header(name, value);
372 }
373
374 let response = req.send().await?;
375
376 let status = response.status();
377 let response_body: Value = response.json().await?;
378
379 crate::debug::debug_log().api_response(
381 status.as_u16(),
382 &serde_json::to_string(&response_body).unwrap_or_default(),
383 );
384
385 if !status.is_success() {
386 let err_msg = response_body["error"]["message"]
387 .as_str()
388 .unwrap_or("unknown error");
389 anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
390 }
391
392 let stop_reason = match response_body["stop_reason"].as_str() {
393 Some("tool_use") => StopReason::ToolUse,
394 Some("max_tokens") => StopReason::MaxTokens,
395 _ => StopReason::EndTurn,
396 };
397
398 let content = response_body["content"]
399 .as_array()
400 .unwrap_or(&vec![])
401 .iter()
402 .filter_map(|block| match block["type"].as_str()? {
403 "text" => Some(ContentBlock::Text {
404 text: block["text"].as_str()?.to_string(),
405 }),
406 "tool_use" => Some(ContentBlock::ToolUse {
407 id: block["id"].as_str()?.to_string(),
408 name: block["name"].as_str()?.to_string(),
409 input: block["input"].clone(),
410 }),
411 "thinking" => Some(ContentBlock::Thinking {
412 thinking: block["thinking"].as_str()?.to_string(),
413 signature: block["signature"].as_str().map(String::from),
414 }),
415 "server_tool_use" => Some(ContentBlock::ServerToolUse {
416 id: block["id"].as_str()?.to_string(),
417 name: block["name"].as_str()?.to_string(),
418 input: block["input"].clone(),
419 }),
420 "web_search_tool_result" => {
421 let tool_use_id = block["tool_use_id"].as_str()?.to_string();
422 let content = parse_web_search_content(&block["content"]);
423 Some(ContentBlock::WebSearchResult {
424 tool_use_id,
425 content,
426 })
427 }
428 _ => None,
429 })
430 .collect();
431
432 Ok(ChatResponse {
433 content,
434 stop_reason,
435 usage: parse_usage(&response_body["usage"]),
436 })
437 }
438
439 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
440 let mut body = self.build_body(&request);
441 body["stream"] = json!(true);
442
443 let url = format!("{}/v1/messages", self.base_url);
444
445 crate::debug::debug_log()
447 .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
448
449 let mut req = self
450 .client
451 .post(&url)
452 .header("User-Agent", "curl/8.0")
453 .json(&body);
454
455 if self.is_official_anthropic() {
457 req = req
458 .header("x-api-key", &self.api_key)
459 .header("anthropic-version", ANTHROPIC_API_VERSION)
460 .header("anthropic-beta", "prompt-caching-2024-07-31");
461 } else {
462 req = req
463 .header("Authorization", format!("Bearer {}", self.api_key))
464 .header("anthropic-version", "2023-06-01")
466 .header("anthropic-beta", "prompt-caching-2024-07-31");
468 }
469
470 for (name, value) in &self.extra_headers {
472 req = req.header(name, value);
473 }
474
475 let response = req.send().await?;
476
477 if !response.status().is_success() {
478 let status = response.status();
479 let text = response.text().await.unwrap_or_default();
480 anyhow::bail!("Anthropic API error ({}): {}", status, text);
481 }
482
483 let (tx, rx) = mpsc::channel(64);
484 tokio::spawn(async move {
485 let mut stream = response.bytes_stream();
486 let mut buffer = String::new();
487 let mut sent_first_byte = false;
488
489 let mut blocks: Vec<AssembledBlock> = Vec::new();
491 let mut stop_reason = StopReason::EndTurn;
492 let mut usage = Usage::default();
493
494 let mut last_content_time = std::time::Instant::now();
496 const CONTENT_TIMEOUT_SECS: u64 = DEFAULT_CONTENT_TIMEOUT_SECS; while let Some(chunk) = stream.next().await {
499 let chunk = match chunk {
500 Ok(c) => c,
501 Err(e) => {
502 let error_msg = e.to_string();
504 let is_timeout =
505 error_msg.contains("timeout") || error_msg.contains("timed out");
506 let is_decode =
507 error_msg.contains("decode") || error_msg.contains("decoding");
508
509 let msg = if is_timeout {
510 format!(
511 "Stream timeout - the API took too long to respond: {}",
512 error_msg
513 )
514 } else if is_decode {
515 format!(
516 "Stream decode error - possibly interrupted or corrupted response: {}",
517 error_msg
518 )
519 } else {
520 format!("Stream read error: {}", error_msg)
521 };
522
523 if sent_first_byte && !blocks.is_empty() {
525 debug!("finalizing partial stream due to error");
526 let _ = tx
527 .send(StreamEvent::Done(finalize_incomplete_stream(
528 std::mem::take(&mut blocks),
529 stop_reason,
530 usage,
531 )))
532 .await;
533 } else {
534 let _ = tx.send(StreamEvent::Error(msg)).await;
535 }
536 return;
537 }
538 };
539
540 if !sent_first_byte {
541 sent_first_byte = true;
542 let _ = tx.send(StreamEvent::FirstByte).await;
543 }
544
545 buffer.push_str(&String::from_utf8_lossy(&chunk));
546
547 let elapsed = last_content_time.elapsed().as_secs();
549 if elapsed > CONTENT_TIMEOUT_SECS && !blocks.is_empty() {
550 crate::debug::debug_log().stream_chunk(
551 "TIMEOUT_FORCE_FINALIZE",
552 &format!("elapsed={}s, blocks={}", elapsed, blocks.len()),
553 );
554 let _ = tx
555 .send(StreamEvent::Done(finalize_incomplete_stream(
556 std::mem::take(&mut blocks),
557 stop_reason,
558 usage,
559 )))
560 .await;
561 return;
562 }
563
564 while let Some(frame) = take_next_sse_frame(&mut buffer) {
565 if handle_sse_frame(
566 &frame,
567 &mut blocks,
568 &mut stop_reason,
569 &mut usage,
570 &tx,
571 &mut last_content_time,
572 )
573 .await
574 {
575 return;
576 }
577 }
578 }
579
580 if let Some(frame) = take_trailing_sse_frame(&mut buffer)
581 && handle_sse_frame(
582 &frame,
583 &mut blocks,
584 &mut stop_reason,
585 &mut usage,
586 &tx,
587 &mut last_content_time,
588 )
589 .await
590 {
591 return;
592 }
593
594 if sent_first_byte {
595 debug!("stream ended without explicit message_stop; finalizing best-effort");
596 let _ = tx
597 .send(StreamEvent::Done(finalize_incomplete_stream(
598 std::mem::take(&mut blocks),
599 stop_reason,
600 usage,
601 )))
602 .await;
603 } else {
604 let _ = tx
605 .send(StreamEvent::Error(
606 "stream ended before any events were received".to_string(),
607 ))
608 .await;
609 }
610 });
611
612 Ok(rx)
613 }
614}
615
616fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
617 let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
618 let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
619 let (pos, delim_len) = match (lf, crlf) {
620 (Some(a), Some(b)) => {
621 if a.0 <= b.0 {
622 a
623 } else {
624 b
625 }
626 }
627 (Some(a), None) => a,
628 (None, Some(b)) => b,
629 (None, None) => return None,
630 };
631
632 let frame = buffer[..pos].to_string();
633 buffer.drain(..pos + delim_len);
634 Some(frame)
635}
636
637fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
638 let frame = buffer.trim().trim_end_matches('\r').to_string();
639 buffer.clear();
640 if frame.is_empty() { None } else { Some(frame) }
641}
642
643fn extract_sse_data_line(frame: &str) -> Option<String> {
644 for line in frame.lines() {
645 let line = line.trim_end_matches('\r');
646 if let Some(rest) = line.strip_prefix("data: ") {
648 return Some(rest.to_string());
649 }
650 if let Some(rest) = line.strip_prefix("data:") {
651 return Some(rest.to_string());
652 }
653 }
654 None
655}
656
657async fn handle_sse_frame(
658 frame: &str,
659 blocks: &mut Vec<AssembledBlock>,
660 stop_reason: &mut StopReason,
661 usage: &mut Usage,
662 tx: &mpsc::Sender<StreamEvent>,
663 last_content_time: &mut std::time::Instant,
664) -> bool {
665 let Some(data_line) = extract_sse_data_line(frame) else {
666 return false;
667 };
668
669 let evt: Value = match serde_json::from_str(&data_line) {
670 Ok(v) => v,
671 Err(_) => return false,
672 };
673
674 handle_sse_event(evt, blocks, stop_reason, usage, tx, last_content_time).await
675}
676
677async fn handle_sse_event(
678 evt: Value,
679 blocks: &mut Vec<AssembledBlock>,
680 stop_reason: &mut StopReason,
681 usage: &mut Usage,
682 tx: &mpsc::Sender<StreamEvent>,
683 last_content_time: &mut std::time::Instant,
684) -> bool {
685 let evt_type = evt["type"].as_str().unwrap_or("");
686
687 let evt_json = serde_json::to_string(&evt).unwrap_or_default();
689 crate::debug::debug_log().stream_chunk(evt_type, &evt_json);
690
691 if evt_type == "content_block_delta" {
693 let delta_type = evt["delta"]["type"].as_str().unwrap_or("");
694 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
695 log::debug!(
696 "content_block_delta: type={}, idx={}, blocks_len={}, has_block={}",
697 delta_type,
698 idx,
699 blocks.len(),
700 idx < blocks.len()
701 );
702 }
703
704 if evt_type != "ping" {
706 *last_content_time = std::time::Instant::now();
707 }
708
709 match evt_type {
710 "message_start" => {
711 *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
716 debug!(
717 "message_start: usage_json={}",
718 serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
719 );
720 debug!(
721 "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
722 usage.input_tokens,
723 usage.output_tokens,
724 usage.cache_read_input_tokens,
725 usage.cache_creation_input_tokens
726 );
727 let _ = tx
729 .send(StreamEvent::Usage {
730 output_tokens: usage.output_tokens,
731 })
732 .await;
733 }
734 "content_block_start" => {
735 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
736 let block = &evt["content_block"];
737 let kind = block["type"].as_str().unwrap_or("");
738 while blocks.len() <= idx {
739 blocks.push(AssembledBlock::default());
740 }
741 match kind {
742 "text" => {
743 blocks[idx] = AssembledBlock::Text(String::new());
744 }
745 "thinking" => {
746 blocks[idx] = AssembledBlock::Thinking {
747 text: String::new(),
748 signature: None,
749 };
750 }
751 "tool_use" | "server_tool_use" => {
752 let id = block["id"].as_str().unwrap_or_default();
753 let name = block["name"].as_str().unwrap_or_default();
754 let is_server = kind == "server_tool_use";
755 blocks[idx] = if is_server {
756 AssembledBlock::ServerToolUse {
757 id: id.into(),
758 name: name.into(),
759 input_json: String::new(),
760 }
761 } else {
762 AssembledBlock::ToolUse {
763 id: id.into(),
764 name: name.into(),
765 input_json: String::new(),
766 }
767 };
768 let _ = tx
769 .send(StreamEvent::ToolUseStart {
770 id: id.into(),
771 name: name.into(),
772 })
773 .await;
774 }
775 "web_search_tool_result" => {
776 let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
777 blocks[idx] = AssembledBlock::WebSearchResult {
778 tool_use_id,
779 content_json: String::new(),
780 };
781 }
782 _ => {}
783 }
784 }
785 "content_block_delta" => {
786 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
787 let delta = &evt["delta"];
788 let dt = delta["type"].as_str().unwrap_or("");
789 if idx >= blocks.len() {
790 return false;
791 }
792 match (dt, &mut blocks[idx]) {
793 ("text_delta", AssembledBlock::Text(buf)) => {
794 if let Some(t) = delta["text"].as_str() {
795 buf.push_str(t);
796 let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
797 }
798 }
799 ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
800 if let Some(t) = delta["thinking"].as_str() {
801 text.push_str(t);
802 log::debug!("Received thinking_delta: {} chars", t.len());
803 let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
804 }
805 }
806 ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
807 if let Some(s) = delta["signature"].as_str() {
808 signature.get_or_insert_with(String::new).push_str(s);
809 }
810 }
811 ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
812 if let Some(p) = delta["partial_json"].as_str() {
813 input_json.push_str(p);
814 let _ = tx
815 .send(StreamEvent::ToolInputDelta {
816 bytes_so_far: input_json.len(),
817 })
818 .await;
819 }
820 }
821 ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
822 if let Some(p) = delta["partial_json"].as_str() {
823 input_json.push_str(p);
824 let _ = tx
825 .send(StreamEvent::ToolInputDelta {
826 bytes_so_far: input_json.len(),
827 })
828 .await;
829 }
830 }
831 _ => {}
832 }
833 }
834 "content_block_stop" => {
835 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
836 if let Some(AssembledBlock::ToolUse {
837 id,
838 name,
839 input_json,
840 }) = blocks.get(idx)
841 {
842 let input: Value = if input_json.is_empty() {
843 json!({})
844 } else {
845 serde_json::from_str(input_json).unwrap_or(json!({}))
846 };
847 let _ = tx
848 .send(StreamEvent::ToolInputComplete {
849 id: id.clone(),
850 name: name.clone(),
851 input,
852 })
853 .await;
854 }
855 }
856 "message_delta" => {
857 if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
858 *stop_reason = match sr {
859 "tool_use" => StopReason::ToolUse,
860 "max_tokens" => StopReason::MaxTokens,
861 _ => StopReason::EndTurn,
862 };
863 }
864 *usage = merge_usage(usage.clone(), &evt["usage"]);
868 debug!(
869 "message_delta: input={}, output={}, cache_read={}, cache_created={}",
870 usage.input_tokens,
871 usage.output_tokens,
872 usage.cache_read_input_tokens,
873 usage.cache_creation_input_tokens
874 );
875 let _ = tx
877 .send(StreamEvent::Usage {
878 output_tokens: usage.output_tokens,
879 })
880 .await;
881 }
882 "message_stop" => {
883 debug!(
884 "Message completed: stop_reason={}, usage={}",
885 match *stop_reason {
886 StopReason::EndTurn => "end_turn",
887 StopReason::ToolUse => "tool_use",
888 StopReason::MaxTokens => "max_tokens",
889 },
890 usage.output_tokens
891 );
892 debug!(
893 "message_stop final usage: cache_read={}, cache_created={}",
894 usage.cache_read_input_tokens, usage.cache_creation_input_tokens
895 );
896 let _ = tx
897 .send(StreamEvent::Done(finalize_incomplete_stream(
898 std::mem::take(blocks),
899 stop_reason.clone(),
900 usage.clone(),
901 )))
902 .await;
903 return true;
904 }
905 "error" => {
906 let msg = evt["error"]["message"]
907 .as_str()
908 .unwrap_or("unknown stream error")
909 .to_string();
910 let _ = tx.send(StreamEvent::Error(msg)).await;
911 return true;
912 }
913 _ => {}
914 }
915
916 false
917}
918
919fn build_chat_response(
920 blocks: Vec<AssembledBlock>,
921 stop_reason: StopReason,
922 usage: Usage,
923) -> ChatResponse {
924 let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
925 ChatResponse {
926 content,
927 stop_reason,
928 usage,
929 }
930}
931
932fn finalize_incomplete_stream(
933 blocks: Vec<AssembledBlock>,
934 stop_reason: StopReason,
935 usage: Usage,
936) -> ChatResponse {
937 build_chat_response(blocks, stop_reason, usage)
938}
939
940#[derive(Default)]
941enum AssembledBlock {
942 #[default]
943 Empty,
944 Text(String),
945 Thinking {
946 text: String,
947 signature: Option<String>,
948 },
949 ToolUse {
950 id: String,
951 name: String,
952 input_json: String,
953 },
954 ServerToolUse {
955 id: String,
956 name: String,
957 input_json: String,
958 },
959 WebSearchResult {
960 tool_use_id: String,
961 content_json: String,
962 },
963}
964
965impl AssembledBlock {
966 fn finish(self) -> Option<ContentBlock> {
967 match self {
968 AssembledBlock::Empty => None,
969 AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
970 AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
971 thinking: text,
972 signature,
973 }),
974 AssembledBlock::ToolUse {
975 id,
976 name,
977 input_json,
978 } => {
979 let input: Value = if input_json.is_empty() {
980 json!({})
981 } else {
982 serde_json::from_str(&input_json).unwrap_or(json!({}))
983 };
984 Some(ContentBlock::ToolUse { id, name, input })
985 }
986 AssembledBlock::ServerToolUse {
987 id,
988 name,
989 input_json,
990 } => {
991 let input: Value = if input_json.is_empty() {
992 json!({})
993 } else {
994 serde_json::from_str(&input_json).unwrap_or(json!({}))
995 };
996 Some(ContentBlock::ServerToolUse { id, name, input })
997 }
998 AssembledBlock::WebSearchResult {
999 tool_use_id,
1000 content_json,
1001 } => {
1002 let content: Value = if content_json.is_empty() {
1003 json!({"results": []})
1004 } else {
1005 serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
1006 };
1007 Some(ContentBlock::WebSearchResult {
1008 tool_use_id,
1009 content: parse_web_search_content(&content),
1010 })
1011 }
1012 }
1013 }
1014}
1015
1016fn parse_usage(u: &Value) -> Usage {
1019 Usage {
1020 input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
1021 output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
1022 cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
1023 cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
1024 }
1025}
1026
1027fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
1031 let fresh = parse_usage(u);
1032 if fresh.input_tokens > 0 {
1033 acc.input_tokens = fresh.input_tokens;
1034 }
1035 if fresh.output_tokens > 0 {
1036 acc.output_tokens = fresh.output_tokens;
1037 }
1038 if fresh.cache_creation_input_tokens > 0 {
1039 acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
1040 }
1041 if fresh.cache_read_input_tokens > 0 {
1042 acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
1043 }
1044 acc
1045}
1046
1047fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
1049 let results = value["results"]
1050 .as_array()
1051 .map(|arr| {
1052 arr.iter()
1053 .filter_map(|item| {
1054 Some(crate::providers::WebSearchResultItem {
1055 title: item["title"].as_str().map(String::from),
1056 url: item["url"].as_str()?.to_string(),
1057 encrypted_content: item["encrypted_content"].as_str().map(String::from),
1058 snippet: item["snippet"].as_str().map(String::from),
1059 })
1060 })
1061 .collect()
1062 })
1063 .unwrap_or_default();
1064
1065 crate::providers::WebSearchContent { results }
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070 use super::*;
1071
1072 #[test]
1073 fn take_next_sse_frame_supports_crlf_delimiter() {
1074 let mut buffer = concat!(
1075 "event: message_start\r\n",
1076 "data: {\"type\":\"message_start\"}\r\n\r\n",
1077 "data: {\"type\":\"message_stop\"}\r\n\r\n"
1078 )
1079 .to_string();
1080
1081 let first = take_next_sse_frame(&mut buffer).expect("first frame");
1082 assert!(first.contains("message_start"));
1083
1084 let second = take_next_sse_frame(&mut buffer).expect("second frame");
1085 assert!(second.contains("message_stop"));
1086 assert!(buffer.is_empty());
1087 }
1088
1089 #[test]
1090 fn take_trailing_sse_frame_returns_unterminated_event() {
1091 let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
1092 let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
1093 assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
1094 assert!(buffer.is_empty());
1095 }
1096
1097 #[test]
1098 fn extract_sse_data_line_supports_optional_space() {
1099 assert_eq!(
1100 extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
1101 Some("{\"k\":1}".to_string())
1102 );
1103 assert_eq!(
1104 extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
1105 Some("{\"k\":2}".to_string())
1106 );
1107 }
1108
1109 #[test]
1110 fn finalize_incomplete_stream_keeps_accumulated_content() {
1111 let response = finalize_incomplete_stream(
1112 vec![AssembledBlock::Text("partial reply".to_string())],
1113 StopReason::EndTurn,
1114 Usage::default(),
1115 );
1116
1117 assert_eq!(response.stop_reason, StopReason::EndTurn);
1118 assert_eq!(response.content.len(), 1);
1119 match &response.content[0] {
1120 ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
1121 other => panic!("unexpected block: {other:?}"),
1122 }
1123 }
1124}