1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use std::sync::Arc;
7use tokio::sync::mpsc;
8
9use crate::constants::{
10 ANTHROPIC_API_VERSION, DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_CONTENT_TIMEOUT_SECS,
11 DEFAULT_READ_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS, THINKING_BUDGET_NEW_MODELS,
12 THINKING_BUDGET_OLD_MODELS,
13};
14use crate::models::context_window_for;
15use crate::tools::ToolDefinition;
16
17use super::{
18 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
19 StreamEvent, Usage,
20};
21
22pub struct AnthropicProvider {
23 api_key: String,
24 model: String,
25 base_url: String,
26 client: reqwest::Client,
27 extra_headers: Vec<(String, String)>,
29}
30
31impl AnthropicProvider {
32 pub fn new(api_key: String, model: String, base_url: String) -> Self {
33 Self::with_headers(api_key, model, base_url, None)
34 }
35
36 fn load_proxy_from_env() -> Option<String> {
38 std::env::var("HTTPS_PROXY")
39 .ok()
40 .or_else(|| std::env::var("https_proxy").ok())
41 .or_else(|| std::env::var("HTTP_PROXY").ok())
42 .or_else(|| std::env::var("http_proxy").ok())
43 }
44
45 pub fn with_headers(
46 api_key: String,
47 model: String,
48 base_url: String,
49 extra_headers: Option<std::collections::HashMap<String, String>>,
50 ) -> Self {
51 let mut client_builder = reqwest::Client::builder()
56 .connect_timeout(std::time::Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
57 .read_timeout(std::time::Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
58 .timeout(std::time::Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS)); if let Some(proxy_url) = Self::load_proxy_from_env()
62 && let Ok(proxy) = reqwest::Proxy::all(&proxy_url)
63 {
64 log::info!("AnthropicProvider using proxy: {}", proxy_url);
65 client_builder = client_builder.proxy(proxy);
66 }
67
68 let client = client_builder
69 .build()
70 .unwrap_or_else(|_| reqwest::Client::new());
71 let extra_headers: Vec<(String, String)> = extra_headers
72 .map(|h| h.into_iter().collect())
73 .unwrap_or_default();
74 Self {
75 api_key,
76 model,
77 base_url,
78 client,
79 extra_headers,
80 }
81 }
82
83 fn is_official_anthropic(&self) -> bool {
86 self.base_url.contains("api.anthropic.com")
87 }
88
89 fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
90 let filter_thinking = !self.is_official_anthropic();
94 log::debug!(
95 "convert_messages: filter_thinking={}, base_url={}",
96 filter_thinking,
97 self.base_url
98 );
99
100 let mut thinking_count = 0;
102 for m in messages {
103 if let MessageContent::Blocks(blocks) = &m.content {
104 for b in blocks {
105 if matches!(b, ContentBlock::Thinking { .. }) {
106 thinking_count += 1;
107 }
108 }
109 }
110 }
111 if thinking_count > 0 {
112 log::debug!(
113 "convert_messages: Found {} thinking blocks in {} messages, filter_thinking={}",
114 thinking_count,
115 messages.len(),
116 filter_thinking
117 );
118 }
119
120 messages
121 .iter()
122 .filter(|m| m.role != Role::System)
123 .map(|m| {
124 let role = match m.role {
125 Role::User | Role::Tool => "user",
126 Role::Assistant => "assistant",
127 Role::System => unreachable!(),
128 };
129
130 let content = match &m.content {
131 MessageContent::Text(text) => json!(text),
132 MessageContent::Blocks(blocks) => {
133 let converted: Vec<Value> = blocks
134 .iter()
135 .filter(|b| {
136 if filter_thinking && matches!(b, ContentBlock::Thinking { .. }) {
138 return false;
139 }
140 true
141 })
142 .map(|b| match b {
143 ContentBlock::Text { text } => json!({"type": "text", "text": text}),
144 ContentBlock::ToolUse { id, name, input } => {
145 json!({"type": "tool_use", "id": id, "name": name, "input": input})
146 }
147 ContentBlock::ToolResult { tool_use_id, content } => {
148 json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
149 }
150 ContentBlock::Thinking { thinking, signature } => {
151 let mut obj = json!({"type": "thinking", "thinking": thinking});
152 if let Some(sig) = signature
154 && !sig.is_empty() {
155 obj["signature"] = json!(sig);
156 }
157 obj
158 }
159 ContentBlock::ServerToolUse { id, name, input } => {
160 json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
161 }
162 ContentBlock::WebSearchResult { tool_use_id, content } => {
163 json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
164 }
165 })
166 .collect();
167
168 if converted.is_empty() {
169 json!([])
170 } else {
171 json!(converted)
172 }
173 }
174 };
175
176 if content.is_array() && content.as_array().map(|a| a.is_empty()).unwrap_or(false) {
178 return json!(null);
179 }
180
181 json!({"role": role, "content": content})
182 })
183 .filter(|v| !v.is_null())
184 .collect()
185 }
186
187 fn convert_tools_with_caching(
189 &self,
190 tools: &[ToolDefinition],
191 enable_caching: bool,
192 ) -> Vec<Value> {
193 let mut converted: Vec<Value> = tools
194 .iter()
195 .map(|t| {
196 json!({
197 "name": t.name,
198 "description": t.description,
199 "input_schema": t.parameters,
200 })
201 })
202 .collect();
203
204 if enable_caching && !converted.is_empty() {
206 let last_idx = converted.len() - 1;
207 if let Some(obj) = converted[last_idx].as_object_mut() {
208 obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
209 }
210 }
211
212 converted
213 }
214
215 fn build_body(&self, request: &ChatRequest) -> Value {
217 let mut body = json!({
218 "model": self.model,
219 "max_tokens": request.max_tokens,
220 "messages": self.convert_messages(&request.messages),
221 });
222
223 if request.enable_caching {
225 if let Some(system) = &request.system {
226 body["system"] = json!([
228 {
229 "type": "text",
230 "text": system,
231 "cache_control": {"type": "ephemeral"}
232 }
233 ]);
234 }
235 } else if let Some(system) = &request.system {
236 body["system"] = json!(system);
237 }
238
239 if !request.tools.is_empty() {
240 let tools = self.convert_tools_with_caching(&request.tools, request.enable_caching);
241 body["tools"] = json!(tools);
242 }
243
244 if !request.server_tools.is_empty() {
245 body["tools"] = json!(
246 body["tools"]
247 .as_array()
248 .map(|t| {
249 let mut tools = t.clone();
250 for st in &request.server_tools {
251 tools.push(serde_json::to_value(st).unwrap_or_default());
252 }
253 tools
254 })
255 .unwrap_or_else(|| request
256 .server_tools
257 .iter()
258 .map(|st| serde_json::to_value(st).unwrap_or_default())
259 .collect())
260 );
261 }
262
263 if request.think && self.is_official_anthropic() {
267 let config = thinking_config(&self.model);
268 log::debug!(
269 "Adding thinking config for model {}: {:?}",
270 self.model,
271 config
272 );
273 body["thinking"] = config;
274 } else if request.think {
275 log::debug!(
276 "Skipping thinking config for non-official API (model: {}, base_url: {})",
277 self.model,
278 self.base_url
279 );
280 }
281
282 body
283 }
284}
285
286fn thinking_config(model: &str) -> Value {
291 let m = model.to_lowercase();
292 let adaptive = m.contains("opus-4")
294 || m.contains("sonnet-4")
295 || m.contains("claude-4")
296 || m.contains("20250")
297 || m.contains("2025");
298 if adaptive {
299 json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_NEW_MODELS})
300 } else {
301 json!({"type": "enabled", "budget_tokens": THINKING_BUDGET_OLD_MODELS})
303 }
304}
305
306#[async_trait]
307impl Provider for AnthropicProvider {
308 fn context_size(&self) -> Option<u32> {
309 context_window_for(&self.model)
310 }
311
312 fn model_name(&self) -> &str {
313 &self.model
314 }
315
316 fn clone_box(&self) -> Box<dyn Provider> {
317 Box::new(Self {
318 api_key: self.api_key.clone(),
319 model: self.model.clone(),
320 base_url: self.base_url.clone(),
321 client: reqwest::Client::new(),
322 extra_headers: self.extra_headers.clone(),
323 })
324 }
325
326 fn clone_arc(&self) -> Arc<dyn Provider> {
327 Arc::new(Self {
328 api_key: self.api_key.clone(),
329 model: self.model.clone(),
330 base_url: self.base_url.clone(),
331 client: reqwest::Client::new(),
332 extra_headers: self.extra_headers.clone(),
333 })
334 }
335
336 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
337 let body = self.build_body(&request);
338
339 let url = format!("{}/v1/messages", self.base_url);
340
341 crate::debug::debug_log()
343 .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
344
345 let mut req = self
346 .client
347 .post(&url)
348 .header("User-Agent", "curl/8.0")
349 .json(&body);
350
351 if self.is_official_anthropic() {
353 req = req
354 .header("x-api-key", &self.api_key)
355 .header("anthropic-version", ANTHROPIC_API_VERSION)
356 .header("anthropic-beta", "prompt-caching-2024-07-31");
357 } else {
358 req = req
359 .header("Authorization", format!("Bearer {}", self.api_key))
360 .header("anthropic-version", "2023-06-01")
362 .header("anthropic-beta", "prompt-caching-2024-07-31");
364 }
365
366 for (name, value) in &self.extra_headers {
368 req = req.header(name, value);
369 }
370
371 let response = req.send().await?;
372
373 let status = response.status();
374 let response_body: Value = response.json().await?;
375
376 crate::debug::debug_log().api_response(
378 status.as_u16(),
379 &serde_json::to_string(&response_body).unwrap_or_default(),
380 );
381
382 if !status.is_success() {
383 let err_msg = response_body["error"]["message"]
384 .as_str()
385 .unwrap_or("unknown error");
386 anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
387 }
388
389 let stop_reason = match response_body["stop_reason"].as_str() {
390 Some("tool_use") => StopReason::ToolUse,
391 Some("max_tokens") => StopReason::MaxTokens,
392 _ => StopReason::EndTurn,
393 };
394
395 let content = response_body["content"]
396 .as_array()
397 .unwrap_or(&vec![])
398 .iter()
399 .filter_map(|block| match block["type"].as_str()? {
400 "text" => Some(ContentBlock::Text {
401 text: block["text"].as_str()?.to_string(),
402 }),
403 "tool_use" => Some(ContentBlock::ToolUse {
404 id: block["id"].as_str()?.to_string(),
405 name: block["name"].as_str()?.to_string(),
406 input: block["input"].clone(),
407 }),
408 "thinking" => Some(ContentBlock::Thinking {
409 thinking: block["thinking"].as_str()?.to_string(),
410 signature: block["signature"].as_str().map(String::from),
411 }),
412 "server_tool_use" => Some(ContentBlock::ServerToolUse {
413 id: block["id"].as_str()?.to_string(),
414 name: block["name"].as_str()?.to_string(),
415 input: block["input"].clone(),
416 }),
417 "web_search_tool_result" => {
418 let tool_use_id = block["tool_use_id"].as_str()?.to_string();
419 let content = parse_web_search_content(&block["content"]);
420 Some(ContentBlock::WebSearchResult {
421 tool_use_id,
422 content,
423 })
424 }
425 _ => None,
426 })
427 .collect();
428
429 Ok(ChatResponse {
430 content,
431 stop_reason,
432 usage: parse_usage(&response_body["usage"]),
433 })
434 }
435
436 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
437 let mut body = self.build_body(&request);
438 body["stream"] = json!(true);
439
440 let url = format!("{}/v1/messages", self.base_url);
441
442 crate::debug::debug_log()
444 .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
445
446 let mut req = self
447 .client
448 .post(&url)
449 .header("User-Agent", "curl/8.0")
450 .json(&body);
451
452 if self.is_official_anthropic() {
454 req = req
455 .header("x-api-key", &self.api_key)
456 .header("anthropic-version", ANTHROPIC_API_VERSION)
457 .header("anthropic-beta", "prompt-caching-2024-07-31");
458 } else {
459 req = req
460 .header("Authorization", format!("Bearer {}", self.api_key))
461 .header("anthropic-version", "2023-06-01")
463 .header("anthropic-beta", "prompt-caching-2024-07-31");
465 }
466
467 for (name, value) in &self.extra_headers {
469 req = req.header(name, value);
470 }
471
472 let response = req.send().await?;
473
474 if !response.status().is_success() {
475 let status = response.status();
476 let text = response.text().await.unwrap_or_default();
477 anyhow::bail!("Anthropic API error ({}): {}", status, text);
478 }
479
480 let (tx, rx) = mpsc::channel(64);
481 tokio::spawn(async move {
482 let mut stream = response.bytes_stream();
483 let mut buffer = String::new();
484 let mut sent_first_byte = false;
485
486 let mut blocks: Vec<AssembledBlock> = Vec::new();
488 let mut stop_reason = StopReason::EndTurn;
489 let mut usage = Usage::default();
490
491 let mut last_content_time = std::time::Instant::now();
493 const CONTENT_TIMEOUT_SECS: u64 = DEFAULT_CONTENT_TIMEOUT_SECS; while let Some(chunk) = stream.next().await {
496 let chunk = match chunk {
497 Ok(c) => c,
498 Err(e) => {
499 let error_msg = e.to_string();
501 let is_timeout =
502 error_msg.contains("timeout") || error_msg.contains("timed out");
503 let is_decode =
504 error_msg.contains("decode") || error_msg.contains("decoding");
505
506 let msg = if is_timeout {
507 format!(
508 "Stream timeout - the API took too long to respond: {}",
509 error_msg
510 )
511 } else if is_decode {
512 format!(
513 "Stream decode error - possibly interrupted or corrupted response: {}",
514 error_msg
515 )
516 } else {
517 format!("Stream read error: {}", error_msg)
518 };
519
520 if sent_first_byte && !blocks.is_empty() {
522 debug!("finalizing partial stream due to error");
523 let _ = tx
524 .send(StreamEvent::Done(finalize_incomplete_stream(
525 std::mem::take(&mut blocks),
526 stop_reason,
527 usage,
528 )))
529 .await;
530 } else {
531 let _ = tx.send(StreamEvent::Error(msg)).await;
532 }
533 return;
534 }
535 };
536
537 if !sent_first_byte {
538 sent_first_byte = true;
539 let _ = tx.send(StreamEvent::FirstByte).await;
540 }
541
542 buffer.push_str(&String::from_utf8_lossy(&chunk));
543
544 let elapsed = last_content_time.elapsed().as_secs();
546 if elapsed > CONTENT_TIMEOUT_SECS && !blocks.is_empty() {
547 crate::debug::debug_log().stream_chunk(
548 "TIMEOUT_FORCE_FINALIZE",
549 &format!("elapsed={}s, blocks={}", elapsed, blocks.len()),
550 );
551 let _ = tx
552 .send(StreamEvent::Done(finalize_incomplete_stream(
553 std::mem::take(&mut blocks),
554 stop_reason,
555 usage,
556 )))
557 .await;
558 return;
559 }
560
561 while let Some(frame) = take_next_sse_frame(&mut buffer) {
562 if handle_sse_frame(
563 &frame,
564 &mut blocks,
565 &mut stop_reason,
566 &mut usage,
567 &tx,
568 &mut last_content_time,
569 )
570 .await
571 {
572 return;
573 }
574 }
575 }
576
577 if let Some(frame) = take_trailing_sse_frame(&mut buffer)
578 && handle_sse_frame(
579 &frame,
580 &mut blocks,
581 &mut stop_reason,
582 &mut usage,
583 &tx,
584 &mut last_content_time,
585 )
586 .await
587 {
588 return;
589 }
590
591 if sent_first_byte {
592 debug!("stream ended without explicit message_stop; finalizing best-effort");
593 let _ = tx
594 .send(StreamEvent::Done(finalize_incomplete_stream(
595 std::mem::take(&mut blocks),
596 stop_reason,
597 usage,
598 )))
599 .await;
600 } else {
601 let _ = tx
602 .send(StreamEvent::Error(
603 "stream ended before any events were received".to_string(),
604 ))
605 .await;
606 }
607 });
608
609 Ok(rx)
610 }
611}
612
613fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
614 let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
615 let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
616 let (pos, delim_len) = match (lf, crlf) {
617 (Some(a), Some(b)) => {
618 if a.0 <= b.0 {
619 a
620 } else {
621 b
622 }
623 }
624 (Some(a), None) => a,
625 (None, Some(b)) => b,
626 (None, None) => return None,
627 };
628
629 let frame = buffer[..pos].to_string();
630 buffer.drain(..pos + delim_len);
631 Some(frame)
632}
633
634fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
635 let frame = buffer.trim().trim_end_matches('\r').to_string();
636 buffer.clear();
637 if frame.is_empty() { None } else { Some(frame) }
638}
639
640fn extract_sse_data_line(frame: &str) -> Option<String> {
641 for line in frame.lines() {
642 let line = line.trim_end_matches('\r');
643 if let Some(rest) = line.strip_prefix("data: ") {
645 return Some(rest.to_string());
646 }
647 if let Some(rest) = line.strip_prefix("data:") {
648 return Some(rest.to_string());
649 }
650 }
651 None
652}
653
654async fn handle_sse_frame(
655 frame: &str,
656 blocks: &mut Vec<AssembledBlock>,
657 stop_reason: &mut StopReason,
658 usage: &mut Usage,
659 tx: &mpsc::Sender<StreamEvent>,
660 last_content_time: &mut std::time::Instant,
661) -> bool {
662 let Some(data_line) = extract_sse_data_line(frame) else {
663 return false;
664 };
665
666 let evt: Value = match serde_json::from_str(&data_line) {
667 Ok(v) => v,
668 Err(_) => return false,
669 };
670
671 handle_sse_event(evt, blocks, stop_reason, usage, tx, last_content_time).await
672}
673
674async fn handle_sse_event(
675 evt: Value,
676 blocks: &mut Vec<AssembledBlock>,
677 stop_reason: &mut StopReason,
678 usage: &mut Usage,
679 tx: &mpsc::Sender<StreamEvent>,
680 last_content_time: &mut std::time::Instant,
681) -> bool {
682 let evt_type = evt["type"].as_str().unwrap_or("");
683
684 let evt_json = serde_json::to_string(&evt).unwrap_or_default();
686 crate::debug::debug_log().stream_chunk(evt_type, &evt_json);
687
688 if evt_type == "content_block_delta" {
690 let delta_type = evt["delta"]["type"].as_str().unwrap_or("");
691 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
692 log::debug!(
693 "content_block_delta: type={}, idx={}, blocks_len={}, has_block={}",
694 delta_type,
695 idx,
696 blocks.len(),
697 idx < blocks.len()
698 );
699 }
700
701 if evt_type != "ping" {
703 *last_content_time = std::time::Instant::now();
704 }
705
706 match evt_type {
707 "message_start" => {
708 *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
713 debug!(
714 "message_start: usage_json={}",
715 serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
716 );
717 debug!(
718 "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
719 usage.input_tokens,
720 usage.output_tokens,
721 usage.cache_read_input_tokens,
722 usage.cache_creation_input_tokens
723 );
724 let _ = tx
726 .send(StreamEvent::Usage {
727 output_tokens: usage.output_tokens,
728 })
729 .await;
730 }
731 "content_block_start" => {
732 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
733 let block = &evt["content_block"];
734 let kind = block["type"].as_str().unwrap_or("");
735 while blocks.len() <= idx {
736 blocks.push(AssembledBlock::default());
737 }
738 match kind {
739 "text" => {
740 blocks[idx] = AssembledBlock::Text(String::new());
741 }
742 "thinking" => {
743 blocks[idx] = AssembledBlock::Thinking {
744 text: String::new(),
745 signature: None,
746 };
747 }
748 "tool_use" | "server_tool_use" => {
749 let id = block["id"].as_str().unwrap_or_default();
750 let name = block["name"].as_str().unwrap_or_default();
751 let is_server = kind == "server_tool_use";
752 blocks[idx] = if is_server {
753 AssembledBlock::ServerToolUse {
754 id: id.into(),
755 name: name.into(),
756 input_json: String::new(),
757 }
758 } else {
759 AssembledBlock::ToolUse {
760 id: id.into(),
761 name: name.into(),
762 input_json: String::new(),
763 }
764 };
765 let _ = tx
766 .send(StreamEvent::ToolUseStart {
767 id: id.into(),
768 name: name.into(),
769 })
770 .await;
771 }
772 "web_search_tool_result" => {
773 let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
774 blocks[idx] = AssembledBlock::WebSearchResult {
775 tool_use_id,
776 content_json: String::new(),
777 };
778 }
779 _ => {}
780 }
781 }
782 "content_block_delta" => {
783 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
784 let delta = &evt["delta"];
785 let dt = delta["type"].as_str().unwrap_or("");
786 if idx >= blocks.len() {
787 return false;
788 }
789 match (dt, &mut blocks[idx]) {
790 ("text_delta", AssembledBlock::Text(buf)) => {
791 if let Some(t) = delta["text"].as_str() {
792 buf.push_str(t);
793 let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
794 }
795 }
796 ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
797 if let Some(t) = delta["thinking"].as_str() {
798 text.push_str(t);
799 log::debug!("Received thinking_delta: {} chars", t.len());
800 let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
801 }
802 }
803 ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
804 if let Some(s) = delta["signature"].as_str() {
805 signature.get_or_insert_with(String::new).push_str(s);
806 }
807 }
808 ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
809 if let Some(p) = delta["partial_json"].as_str() {
810 input_json.push_str(p);
811 let _ = tx
812 .send(StreamEvent::ToolInputDelta {
813 bytes_so_far: input_json.len(),
814 })
815 .await;
816 }
817 }
818 ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
819 if let Some(p) = delta["partial_json"].as_str() {
820 input_json.push_str(p);
821 let _ = tx
822 .send(StreamEvent::ToolInputDelta {
823 bytes_so_far: input_json.len(),
824 })
825 .await;
826 }
827 }
828 _ => {}
829 }
830 }
831 "content_block_stop" => {
832 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
833 if let Some(AssembledBlock::ToolUse {
834 id,
835 name,
836 input_json,
837 }) = blocks.get(idx)
838 {
839 let input: Value = if input_json.is_empty() {
840 json!({})
841 } else {
842 serde_json::from_str(input_json).unwrap_or(json!({}))
843 };
844 let _ = tx
845 .send(StreamEvent::ToolInputComplete {
846 id: id.clone(),
847 name: name.clone(),
848 input,
849 })
850 .await;
851 }
852 }
853 "message_delta" => {
854 if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
855 *stop_reason = match sr {
856 "tool_use" => StopReason::ToolUse,
857 "max_tokens" => StopReason::MaxTokens,
858 _ => StopReason::EndTurn,
859 };
860 }
861 *usage = merge_usage(usage.clone(), &evt["usage"]);
865 debug!(
866 "message_delta: input={}, output={}, cache_read={}, cache_created={}",
867 usage.input_tokens,
868 usage.output_tokens,
869 usage.cache_read_input_tokens,
870 usage.cache_creation_input_tokens
871 );
872 let _ = tx
874 .send(StreamEvent::Usage {
875 output_tokens: usage.output_tokens,
876 })
877 .await;
878 }
879 "message_stop" => {
880 debug!(
881 "Message completed: stop_reason={}, usage={}",
882 match *stop_reason {
883 StopReason::EndTurn => "end_turn",
884 StopReason::ToolUse => "tool_use",
885 StopReason::MaxTokens => "max_tokens",
886 },
887 usage.output_tokens
888 );
889 debug!(
890 "message_stop final usage: cache_read={}, cache_created={}",
891 usage.cache_read_input_tokens, usage.cache_creation_input_tokens
892 );
893 let _ = tx
894 .send(StreamEvent::Done(finalize_incomplete_stream(
895 std::mem::take(blocks),
896 stop_reason.clone(),
897 usage.clone(),
898 )))
899 .await;
900 return true;
901 }
902 "error" => {
903 let msg = evt["error"]["message"]
904 .as_str()
905 .unwrap_or("unknown stream error")
906 .to_string();
907 let _ = tx.send(StreamEvent::Error(msg)).await;
908 return true;
909 }
910 _ => {}
911 }
912
913 false
914}
915
916fn build_chat_response(
917 blocks: Vec<AssembledBlock>,
918 stop_reason: StopReason,
919 usage: Usage,
920) -> ChatResponse {
921 let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
922 ChatResponse {
923 content,
924 stop_reason,
925 usage,
926 }
927}
928
929fn finalize_incomplete_stream(
930 blocks: Vec<AssembledBlock>,
931 stop_reason: StopReason,
932 usage: Usage,
933) -> ChatResponse {
934 build_chat_response(blocks, stop_reason, usage)
935}
936
937#[derive(Default)]
938enum AssembledBlock {
939 #[default]
940 Empty,
941 Text(String),
942 Thinking {
943 text: String,
944 signature: Option<String>,
945 },
946 ToolUse {
947 id: String,
948 name: String,
949 input_json: String,
950 },
951 ServerToolUse {
952 id: String,
953 name: String,
954 input_json: String,
955 },
956 WebSearchResult {
957 tool_use_id: String,
958 content_json: String,
959 },
960}
961
962impl AssembledBlock {
963 fn finish(self) -> Option<ContentBlock> {
964 match self {
965 AssembledBlock::Empty => None,
966 AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
967 AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
968 thinking: text,
969 signature,
970 }),
971 AssembledBlock::ToolUse {
972 id,
973 name,
974 input_json,
975 } => {
976 let input: Value = if input_json.is_empty() {
977 json!({})
978 } else {
979 serde_json::from_str(&input_json).unwrap_or(json!({}))
980 };
981 Some(ContentBlock::ToolUse { id, name, input })
982 }
983 AssembledBlock::ServerToolUse {
984 id,
985 name,
986 input_json,
987 } => {
988 let input: Value = if input_json.is_empty() {
989 json!({})
990 } else {
991 serde_json::from_str(&input_json).unwrap_or(json!({}))
992 };
993 Some(ContentBlock::ServerToolUse { id, name, input })
994 }
995 AssembledBlock::WebSearchResult {
996 tool_use_id,
997 content_json,
998 } => {
999 let content: Value = if content_json.is_empty() {
1000 json!({"results": []})
1001 } else {
1002 serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
1003 };
1004 Some(ContentBlock::WebSearchResult {
1005 tool_use_id,
1006 content: parse_web_search_content(&content),
1007 })
1008 }
1009 }
1010 }
1011}
1012
1013fn parse_usage(u: &Value) -> Usage {
1016 Usage {
1017 input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
1018 output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
1019 cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
1020 cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
1021 }
1022}
1023
1024fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
1028 let fresh = parse_usage(u);
1029 if fresh.input_tokens > 0 {
1030 acc.input_tokens = fresh.input_tokens;
1031 }
1032 if fresh.output_tokens > 0 {
1033 acc.output_tokens = fresh.output_tokens;
1034 }
1035 if fresh.cache_creation_input_tokens > 0 {
1036 acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
1037 }
1038 if fresh.cache_read_input_tokens > 0 {
1039 acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
1040 }
1041 acc
1042}
1043
1044fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
1046 let results = value["results"]
1047 .as_array()
1048 .map(|arr| {
1049 arr.iter()
1050 .filter_map(|item| {
1051 Some(crate::providers::WebSearchResultItem {
1052 title: item["title"].as_str().map(String::from),
1053 url: item["url"].as_str()?.to_string(),
1054 encrypted_content: item["encrypted_content"].as_str().map(String::from),
1055 snippet: item["snippet"].as_str().map(String::from),
1056 })
1057 })
1058 .collect()
1059 })
1060 .unwrap_or_default();
1061
1062 crate::providers::WebSearchContent { results }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067 use super::*;
1068
1069 #[test]
1070 fn take_next_sse_frame_supports_crlf_delimiter() {
1071 let mut buffer = concat!(
1072 "event: message_start\r\n",
1073 "data: {\"type\":\"message_start\"}\r\n\r\n",
1074 "data: {\"type\":\"message_stop\"}\r\n\r\n"
1075 )
1076 .to_string();
1077
1078 let first = take_next_sse_frame(&mut buffer).expect("first frame");
1079 assert!(first.contains("message_start"));
1080
1081 let second = take_next_sse_frame(&mut buffer).expect("second frame");
1082 assert!(second.contains("message_stop"));
1083 assert!(buffer.is_empty());
1084 }
1085
1086 #[test]
1087 fn take_trailing_sse_frame_returns_unterminated_event() {
1088 let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
1089 let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
1090 assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
1091 assert!(buffer.is_empty());
1092 }
1093
1094 #[test]
1095 fn extract_sse_data_line_supports_optional_space() {
1096 assert_eq!(
1097 extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
1098 Some("{\"k\":1}".to_string())
1099 );
1100 assert_eq!(
1101 extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
1102 Some("{\"k\":2}".to_string())
1103 );
1104 }
1105
1106 #[test]
1107 fn finalize_incomplete_stream_keeps_accumulated_content() {
1108 let response = finalize_incomplete_stream(
1109 vec![AssembledBlock::Text("partial reply".to_string())],
1110 StopReason::EndTurn,
1111 Usage::default(),
1112 );
1113
1114 assert_eq!(response.stop_reason, StopReason::EndTurn);
1115 assert_eq!(response.content.len(), 1);
1116 match &response.content[0] {
1117 ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
1118 other => panic!("unexpected block: {other:?}"),
1119 }
1120 }
1121}