1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use tokio::sync::mpsc;
7
8use crate::models::context_window_for;
9use crate::tools::ToolDefinition;
10
11use super::{
12 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
13 StreamEvent, Usage,
14};
15
16pub struct AnthropicProvider {
17 api_key: String,
18 model: String,
19 base_url: String,
20 client: reqwest::Client,
21 extra_headers: Vec<(String, String)>,
23}
24
25impl AnthropicProvider {
26 pub fn new(api_key: String, model: String, base_url: String) -> Self {
27 Self::with_headers(api_key, model, base_url, None)
28 }
29
30 pub fn with_headers(
31 api_key: String,
32 model: String,
33 base_url: String,
34 extra_headers: Option<std::collections::HashMap<String, String>>,
35 ) -> Self {
36 let client = reqwest::Client::builder()
41 .connect_timeout(std::time::Duration::from_secs(10))
42 .read_timeout(std::time::Duration::from_secs(60))
43 .timeout(std::time::Duration::from_secs(300)) .build()
45 .unwrap_or_else(|_| reqwest::Client::new());
46 let extra_headers: Vec<(String, String)> = extra_headers
47 .map(|h| h.into_iter().collect())
48 .unwrap_or_default();
49 Self {
50 api_key,
51 model,
52 base_url,
53 client,
54 extra_headers,
55 }
56 }
57
58 fn is_official_anthropic(&self) -> bool {
61 self.base_url.contains("api.anthropic.com")
62 }
63
64 fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
65 let filter_thinking = !self.is_official_anthropic();
69 log::debug!("convert_messages: filter_thinking={}, base_url={}", filter_thinking, self.base_url);
70
71 let mut thinking_count = 0;
73 for m in messages {
74 if let MessageContent::Blocks(blocks) = &m.content {
75 for b in blocks {
76 if matches!(b, ContentBlock::Thinking { .. }) {
77 thinking_count += 1;
78 }
79 }
80 }
81 }
82 if thinking_count > 0 {
83 log::debug!("convert_messages: Found {} thinking blocks in {} messages, filter_thinking={}", thinking_count, messages.len(), filter_thinking);
84 }
85
86 messages
87 .iter()
88 .filter(|m| m.role != Role::System)
89 .map(|m| {
90 let role = match m.role {
91 Role::User | Role::Tool => "user",
92 Role::Assistant => "assistant",
93 Role::System => unreachable!(),
94 };
95
96 let content = match &m.content {
97 MessageContent::Text(text) => json!(text),
98 MessageContent::Blocks(blocks) => {
99 let converted: Vec<Value> = blocks
100 .iter()
101 .filter(|b| {
102 if filter_thinking && matches!(b, ContentBlock::Thinking { .. }) {
104 return false;
105 }
106 true
107 })
108 .map(|b| match b {
109 ContentBlock::Text { text } => json!({"type": "text", "text": text}),
110 ContentBlock::ToolUse { id, name, input } => {
111 json!({"type": "tool_use", "id": id, "name": name, "input": input})
112 }
113 ContentBlock::ToolResult { tool_use_id, content } => {
114 json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
115 }
116 ContentBlock::Thinking { thinking, signature } => {
117 let mut obj = json!({"type": "thinking", "thinking": thinking});
118 if let Some(sig) = signature {
120 if !sig.is_empty() {
121 obj["signature"] = json!(sig);
122 }
123 }
124 obj
125 }
126 ContentBlock::ServerToolUse { id, name, input } => {
127 json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
128 }
129 ContentBlock::WebSearchResult { tool_use_id, content } => {
130 json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
131 }
132 })
133 .collect();
134
135 if converted.is_empty() {
136 json!([])
137 } else {
138 json!(converted)
139 }
140 }
141 };
142
143 if content.is_array() && content.as_array().map(|a| a.is_empty()).unwrap_or(false) {
145 return json!(null);
146 }
147
148 json!({"role": role, "content": content})
149 })
150 .filter(|v| !v.is_null())
151 .collect()
152 }
153
154 fn convert_tools_with_caching(
156 &self,
157 tools: &[ToolDefinition],
158 enable_caching: bool,
159 ) -> Vec<Value> {
160 let mut converted: Vec<Value> = tools
161 .iter()
162 .map(|t| {
163 json!({
164 "name": t.name,
165 "description": t.description,
166 "input_schema": t.parameters,
167 })
168 })
169 .collect();
170
171 if enable_caching && !converted.is_empty() {
173 let last_idx = converted.len() - 1;
174 if let Some(obj) = converted[last_idx].as_object_mut() {
175 obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
176 }
177 }
178
179 converted
180 }
181
182 fn build_body(&self, request: &ChatRequest) -> Value {
184 let mut body = json!({
185 "model": self.model,
186 "max_tokens": request.max_tokens,
187 "messages": self.convert_messages(&request.messages),
188 });
189
190 if request.enable_caching {
192 if let Some(system) = &request.system {
193 body["system"] = json!([
195 {
196 "type": "text",
197 "text": system,
198 "cache_control": {"type": "ephemeral"}
199 }
200 ]);
201 }
202 } else if let Some(system) = &request.system {
203 body["system"] = json!(system);
204 }
205
206 if !request.tools.is_empty() {
207 let tools = self.convert_tools_with_caching(
208 &request.tools,
209 request.enable_caching,
210 );
211 body["tools"] = json!(tools);
212 }
213
214 if !request.server_tools.is_empty() {
215 body["tools"] = json!(
216 body["tools"]
217 .as_array()
218 .map(|t| {
219 let mut tools = t.clone();
220 for st in &request.server_tools {
221 tools.push(serde_json::to_value(st).unwrap_or_default());
222 }
223 tools
224 })
225 .unwrap_or_else(|| request
226 .server_tools
227 .iter()
228 .map(|st| serde_json::to_value(st).unwrap_or_default())
229 .collect())
230 );
231 }
232
233 if request.think && self.is_official_anthropic() {
237 let config = thinking_config(&self.model);
238 log::debug!(
239 "Adding thinking config for model {}: {:?}",
240 self.model,
241 config
242 );
243 body["thinking"] = config;
244 } else if request.think {
245 log::debug!(
246 "Skipping thinking config for non-official API (model: {}, base_url: {})",
247 self.model,
248 self.base_url
249 );
250 }
251
252 body
253 }
254}
255
256fn thinking_config(model: &str) -> Value {
261 let m = model.to_lowercase();
262 let adaptive = m.contains("opus-4")
264 || m.contains("sonnet-4")
265 || m.contains("claude-4")
266 || m.contains("20250")
267 || m.contains("2025");
268 if adaptive {
269 json!({"type": "enabled", "budget_tokens": 10000})
270 } else {
271 json!({"type": "enabled", "budget_tokens": 5000})
273 }
274}
275
276#[async_trait]
277impl Provider for AnthropicProvider {
278 fn context_size(&self) -> Option<u32> {
279 context_window_for(&self.model)
280 }
281
282 fn model_name(&self) -> &str {
283 &self.model
284 }
285
286 fn clone_box(&self) -> Box<dyn Provider> {
287 Box::new(Self {
288 api_key: self.api_key.clone(),
289 model: self.model.clone(),
290 base_url: self.base_url.clone(),
291 client: reqwest::Client::new(),
292 extra_headers: self.extra_headers.clone(),
293 })
294 }
295
296 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
297 let body = self.build_body(&request);
298
299 let url = format!("{}/v1/messages", self.base_url);
300
301 crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
303
304 let mut req = self
305 .client
306 .post(&url)
307 .header("User-Agent", "curl/8.0")
308 .json(&body);
309
310 if self.is_official_anthropic() {
312 req = req
313 .header("x-api-key", &self.api_key)
314 .header("anthropic-version", "2025-04-15")
315 .header("anthropic-beta", "prompt-caching-2024-07-31");
316 } else {
317 req = req
318 .header("Authorization", format!("Bearer {}", self.api_key))
319 .header("anthropic-version", "2023-06-01")
321 .header("anthropic-beta", "prompt-caching-2024-07-31");
323 }
324
325 for (name, value) in &self.extra_headers {
327 req = req.header(name, value);
328 }
329
330 let response = req.send().await?;
331
332 let status = response.status();
333 let response_body: Value = response.json().await?;
334
335 crate::debug::debug_log().api_response(status.as_u16(), &serde_json::to_string(&response_body).unwrap_or_default());
337
338 if !status.is_success() {
339 let err_msg = response_body["error"]["message"]
340 .as_str()
341 .unwrap_or("unknown error");
342 anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
343 }
344
345 let stop_reason = match response_body["stop_reason"].as_str() {
346 Some("tool_use") => StopReason::ToolUse,
347 Some("max_tokens") => StopReason::MaxTokens,
348 _ => StopReason::EndTurn,
349 };
350
351 let content = response_body["content"]
352 .as_array()
353 .unwrap_or(&vec![])
354 .iter()
355 .filter_map(|block| match block["type"].as_str()? {
356 "text" => Some(ContentBlock::Text {
357 text: block["text"].as_str()?.to_string(),
358 }),
359 "tool_use" => Some(ContentBlock::ToolUse {
360 id: block["id"].as_str()?.to_string(),
361 name: block["name"].as_str()?.to_string(),
362 input: block["input"].clone(),
363 }),
364 "thinking" => Some(ContentBlock::Thinking {
365 thinking: block["thinking"].as_str()?.to_string(),
366 signature: block["signature"].as_str().map(String::from),
367 }),
368 "server_tool_use" => Some(ContentBlock::ServerToolUse {
369 id: block["id"].as_str()?.to_string(),
370 name: block["name"].as_str()?.to_string(),
371 input: block["input"].clone(),
372 }),
373 "web_search_tool_result" => {
374 let tool_use_id = block["tool_use_id"].as_str()?.to_string();
375 let content = parse_web_search_content(&block["content"]);
376 Some(ContentBlock::WebSearchResult {
377 tool_use_id,
378 content,
379 })
380 }
381 _ => None,
382 })
383 .collect();
384
385 Ok(ChatResponse {
386 content,
387 stop_reason,
388 usage: parse_usage(&response_body["usage"]),
389 })
390 }
391
392 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
393 let mut body = self.build_body(&request);
394 body["stream"] = json!(true);
395
396 let url = format!("{}/v1/messages", self.base_url);
397
398 crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
400
401 let mut req = self
402 .client
403 .post(&url)
404 .header("User-Agent", "curl/8.0")
405 .json(&body);
406
407 if self.is_official_anthropic() {
409 req = req
410 .header("x-api-key", &self.api_key)
411 .header("anthropic-version", "2025-04-15")
412 .header("anthropic-beta", "prompt-caching-2024-07-31");
413 } else {
414 req = req
415 .header("Authorization", format!("Bearer {}", self.api_key))
416 .header("anthropic-version", "2023-06-01")
418 .header("anthropic-beta", "prompt-caching-2024-07-31");
420 }
421
422 for (name, value) in &self.extra_headers {
424 req = req.header(name, value);
425 }
426
427 let response = req.send().await?;
428
429 if !response.status().is_success() {
430 let status = response.status();
431 let text = response.text().await.unwrap_or_default();
432 anyhow::bail!("Anthropic API error ({}): {}", status, text);
433 }
434
435 let (tx, rx) = mpsc::channel(64);
436 tokio::spawn(async move {
437 let mut stream = response.bytes_stream();
438 let mut buffer = String::new();
439 let mut sent_first_byte = false;
440
441 let mut blocks: Vec<AssembledBlock> = Vec::new();
443 let mut stop_reason = StopReason::EndTurn;
444 let mut usage = Usage::default();
445
446 let mut last_content_time = std::time::Instant::now();
448 const CONTENT_TIMEOUT_SECS: u64 = 300; while let Some(chunk) = stream.next().await {
451 let chunk = match chunk {
452 Ok(c) => c,
453 Err(e) => {
454 let error_msg = e.to_string();
456 let is_timeout = error_msg.contains("timeout") || error_msg.contains("timed out");
457 let is_decode = error_msg.contains("decode") || error_msg.contains("decoding");
458
459 let msg = if is_timeout {
460 format!("Stream timeout - the API took too long to respond: {}", error_msg)
461 } else if is_decode {
462 format!("Stream decode error - possibly interrupted or corrupted response: {}", error_msg)
463 } else {
464 format!("Stream read error: {}", error_msg)
465 };
466
467 if sent_first_byte && !blocks.is_empty() {
469 debug!("finalizing partial stream due to error");
470 let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
471 std::mem::take(&mut blocks),
472 stop_reason,
473 usage,
474 ))).await;
475 } else {
476 let _ = tx.send(StreamEvent::Error(msg)).await;
477 }
478 return;
479 }
480 };
481
482 if !sent_first_byte {
483 sent_first_byte = true;
484 let _ = tx.send(StreamEvent::FirstByte).await;
485 }
486
487 buffer.push_str(&String::from_utf8_lossy(&chunk));
488
489 let elapsed = last_content_time.elapsed().as_secs();
491 if elapsed > CONTENT_TIMEOUT_SECS && !blocks.is_empty() {
492 crate::debug::debug_log().stream_chunk("TIMEOUT_FORCE_FINALIZE",
493 &format!("elapsed={}s, blocks={}", elapsed, blocks.len()));
494 let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
495 std::mem::take(&mut blocks),
496 stop_reason,
497 usage,
498 ))).await;
499 return;
500 }
501
502 while let Some(frame) = take_next_sse_frame(&mut buffer) {
503 if handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx, &mut last_content_time)
504 .await
505 {
506 return;
507 }
508 }
509 }
510
511 if let Some(frame) = take_trailing_sse_frame(&mut buffer)
512 && handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx, &mut last_content_time).await
513 {
514 return;
515 }
516
517 if sent_first_byte {
518 debug!("stream ended without explicit message_stop; finalizing best-effort");
519 let _ = tx
520 .send(StreamEvent::Done(finalize_incomplete_stream(
521 std::mem::take(&mut blocks),
522 stop_reason,
523 usage,
524 )))
525 .await;
526 } else {
527 let _ = tx
528 .send(StreamEvent::Error(
529 "stream ended before any events were received".to_string(),
530 ))
531 .await;
532 }
533 });
534
535 Ok(rx)
536 }
537}
538
539fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
540 let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
541 let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
542 let (pos, delim_len) = match (lf, crlf) {
543 (Some(a), Some(b)) => {
544 if a.0 <= b.0 {
545 a
546 } else {
547 b
548 }
549 }
550 (Some(a), None) => a,
551 (None, Some(b)) => b,
552 (None, None) => return None,
553 };
554
555 let frame = buffer[..pos].to_string();
556 buffer.drain(..pos + delim_len);
557 Some(frame)
558}
559
560fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
561 let frame = buffer.trim().trim_end_matches('\r').to_string();
562 buffer.clear();
563 if frame.is_empty() { None } else { Some(frame) }
564}
565
566fn extract_sse_data_line(frame: &str) -> Option<String> {
567 for line in frame.lines() {
568 let line = line.trim_end_matches('\r');
569 if let Some(rest) = line.strip_prefix("data: ") {
571 return Some(rest.to_string());
572 }
573 if let Some(rest) = line.strip_prefix("data:") {
574 return Some(rest.to_string());
575 }
576 }
577 None
578}
579
580async fn handle_sse_frame(
581 frame: &str,
582 blocks: &mut Vec<AssembledBlock>,
583 stop_reason: &mut StopReason,
584 usage: &mut Usage,
585 tx: &mpsc::Sender<StreamEvent>,
586 last_content_time: &mut std::time::Instant,
587) -> bool {
588 let Some(data_line) = extract_sse_data_line(frame) else {
589 return false;
590 };
591
592 let evt: Value = match serde_json::from_str(&data_line) {
593 Ok(v) => v,
594 Err(_) => return false,
595 };
596
597 handle_sse_event(evt, blocks, stop_reason, usage, tx, last_content_time).await
598}
599
600async fn handle_sse_event(
601 evt: Value,
602 blocks: &mut Vec<AssembledBlock>,
603 stop_reason: &mut StopReason,
604 usage: &mut Usage,
605 tx: &mpsc::Sender<StreamEvent>,
606 last_content_time: &mut std::time::Instant,
607) -> bool {
608 let evt_type = evt["type"].as_str().unwrap_or("");
609
610 let evt_json = serde_json::to_string(&evt).unwrap_or_default();
612 crate::debug::debug_log().stream_chunk(evt_type, &evt_json);
613
614 if evt_type == "content_block_delta" {
616 let delta_type = evt["delta"]["type"].as_str().unwrap_or("");
617 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
618 log::debug!(
619 "content_block_delta: type={}, idx={}, blocks_len={}, has_block={}",
620 delta_type,
621 idx,
622 blocks.len(),
623 idx < blocks.len()
624 );
625 }
626
627 if evt_type != "ping" {
629 *last_content_time = std::time::Instant::now();
630 }
631
632 match evt_type {
633 "message_start" => {
634 *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
639 debug!(
640 "message_start: usage_json={}",
641 serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
642 );
643 debug!(
644 "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
645 usage.input_tokens,
646 usage.output_tokens,
647 usage.cache_read_input_tokens,
648 usage.cache_creation_input_tokens
649 );
650 let _ = tx
652 .send(StreamEvent::Usage {
653 output_tokens: usage.output_tokens,
654 })
655 .await;
656 }
657 "content_block_start" => {
658 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
659 let block = &evt["content_block"];
660 let kind = block["type"].as_str().unwrap_or("");
661 while blocks.len() <= idx {
662 blocks.push(AssembledBlock::default());
663 }
664 match kind {
665 "text" => {
666 blocks[idx] = AssembledBlock::Text(String::new());
667 }
668 "thinking" => {
669 blocks[idx] = AssembledBlock::Thinking {
670 text: String::new(),
671 signature: None,
672 };
673 }
674 "tool_use" | "server_tool_use" => {
675 let id = block["id"].as_str().unwrap_or_default();
676 let name = block["name"].as_str().unwrap_or_default();
677 let is_server = kind == "server_tool_use";
678 blocks[idx] = if is_server {
679 AssembledBlock::ServerToolUse {
680 id: id.into(),
681 name: name.into(),
682 input_json: String::new(),
683 }
684 } else {
685 AssembledBlock::ToolUse {
686 id: id.into(),
687 name: name.into(),
688 input_json: String::new(),
689 }
690 };
691 let _ = tx
692 .send(StreamEvent::ToolUseStart {
693 id: id.into(),
694 name: name.into(),
695 })
696 .await;
697 }
698 "web_search_tool_result" => {
699 let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
700 blocks[idx] = AssembledBlock::WebSearchResult {
701 tool_use_id,
702 content_json: String::new(),
703 };
704 }
705 _ => {}
706 }
707 }
708 "content_block_delta" => {
709 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
710 let delta = &evt["delta"];
711 let dt = delta["type"].as_str().unwrap_or("");
712 if idx >= blocks.len() {
713 return false;
714 }
715 match (dt, &mut blocks[idx]) {
716 ("text_delta", AssembledBlock::Text(buf)) => {
717 if let Some(t) = delta["text"].as_str() {
718 buf.push_str(t);
719 let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
720 }
721 }
722 ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
723 if let Some(t) = delta["thinking"].as_str() {
724 text.push_str(t);
725 log::debug!("Received thinking_delta: {} chars", t.len());
726 let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
727 }
728 }
729 ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
730 if let Some(s) = delta["signature"].as_str() {
731 signature.get_or_insert_with(String::new).push_str(s);
732 }
733 }
734 ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
735 if let Some(p) = delta["partial_json"].as_str() {
736 input_json.push_str(p);
737 let _ = tx
738 .send(StreamEvent::ToolInputDelta {
739 bytes_so_far: input_json.len(),
740 })
741 .await;
742 }
743 }
744 ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
745 if let Some(p) = delta["partial_json"].as_str() {
746 input_json.push_str(p);
747 let _ = tx
748 .send(StreamEvent::ToolInputDelta {
749 bytes_so_far: input_json.len(),
750 })
751 .await;
752 }
753 }
754 _ => {}
755 }
756 }
757 "message_delta" => {
758 if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
759 *stop_reason = match sr {
760 "tool_use" => StopReason::ToolUse,
761 "max_tokens" => StopReason::MaxTokens,
762 _ => StopReason::EndTurn,
763 };
764 }
765 *usage = merge_usage(usage.clone(), &evt["usage"]);
769 debug!(
770 "message_delta: input={}, output={}, cache_read={}, cache_created={}",
771 usage.input_tokens,
772 usage.output_tokens,
773 usage.cache_read_input_tokens,
774 usage.cache_creation_input_tokens
775 );
776 let _ = tx
778 .send(StreamEvent::Usage {
779 output_tokens: usage.output_tokens,
780 })
781 .await;
782 }
783 "message_stop" => {
784 debug!(
785 "Message completed: stop_reason={}, usage={}",
786 match *stop_reason {
787 StopReason::EndTurn => "end_turn",
788 StopReason::ToolUse => "tool_use",
789 StopReason::MaxTokens => "max_tokens",
790 },
791 usage.output_tokens
792 );
793 debug!(
794 "message_stop final usage: cache_read={}, cache_created={}",
795 usage.cache_read_input_tokens, usage.cache_creation_input_tokens
796 );
797 let _ = tx
798 .send(StreamEvent::Done(finalize_incomplete_stream(
799 std::mem::take(blocks),
800 stop_reason.clone(),
801 usage.clone(),
802 )))
803 .await;
804 return true;
805 }
806 "error" => {
807 let msg = evt["error"]["message"]
808 .as_str()
809 .unwrap_or("unknown stream error")
810 .to_string();
811 let _ = tx.send(StreamEvent::Error(msg)).await;
812 return true;
813 }
814 _ => {}
815 }
816
817 false
818}
819
820fn build_chat_response(
821 blocks: Vec<AssembledBlock>,
822 stop_reason: StopReason,
823 usage: Usage,
824) -> ChatResponse {
825 let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
826 ChatResponse {
827 content,
828 stop_reason,
829 usage,
830 }
831}
832
833fn finalize_incomplete_stream(
834 blocks: Vec<AssembledBlock>,
835 stop_reason: StopReason,
836 usage: Usage,
837) -> ChatResponse {
838 build_chat_response(blocks, stop_reason, usage)
839}
840
841#[derive(Default)]
842enum AssembledBlock {
843 #[default]
844 Empty,
845 Text(String),
846 Thinking {
847 text: String,
848 signature: Option<String>,
849 },
850 ToolUse {
851 id: String,
852 name: String,
853 input_json: String,
854 },
855 ServerToolUse {
856 id: String,
857 name: String,
858 input_json: String,
859 },
860 WebSearchResult {
861 tool_use_id: String,
862 content_json: String,
863 },
864}
865
866impl AssembledBlock {
867 fn finish(self) -> Option<ContentBlock> {
868 match self {
869 AssembledBlock::Empty => None,
870 AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
871 AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
872 thinking: text,
873 signature,
874 }),
875 AssembledBlock::ToolUse {
876 id,
877 name,
878 input_json,
879 } => {
880 let input: Value = if input_json.is_empty() {
881 json!({})
882 } else {
883 serde_json::from_str(&input_json).unwrap_or(json!({}))
884 };
885 Some(ContentBlock::ToolUse { id, name, input })
886 }
887 AssembledBlock::ServerToolUse {
888 id,
889 name,
890 input_json,
891 } => {
892 let input: Value = if input_json.is_empty() {
893 json!({})
894 } else {
895 serde_json::from_str(&input_json).unwrap_or(json!({}))
896 };
897 Some(ContentBlock::ServerToolUse { id, name, input })
898 }
899 AssembledBlock::WebSearchResult {
900 tool_use_id,
901 content_json,
902 } => {
903 let content: Value = if content_json.is_empty() {
904 json!({"results": []})
905 } else {
906 serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
907 };
908 Some(ContentBlock::WebSearchResult {
909 tool_use_id,
910 content: parse_web_search_content(&content),
911 })
912 }
913 }
914 }
915}
916
917fn parse_usage(u: &Value) -> Usage {
920 Usage {
921 input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
922 output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
923 cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
924 cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
925 }
926}
927
928fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
932 let fresh = parse_usage(u);
933 if fresh.input_tokens > 0 {
934 acc.input_tokens = fresh.input_tokens;
935 }
936 if fresh.output_tokens > 0 {
937 acc.output_tokens = fresh.output_tokens;
938 }
939 if fresh.cache_creation_input_tokens > 0 {
940 acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
941 }
942 if fresh.cache_read_input_tokens > 0 {
943 acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
944 }
945 acc
946}
947
948fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
950 let results = value["results"]
951 .as_array()
952 .map(|arr| {
953 arr.iter()
954 .filter_map(|item| {
955 Some(crate::providers::WebSearchResultItem {
956 title: item["title"].as_str().map(String::from),
957 url: item["url"].as_str()?.to_string(),
958 encrypted_content: item["encrypted_content"].as_str().map(String::from),
959 snippet: item["snippet"].as_str().map(String::from),
960 })
961 })
962 .collect()
963 })
964 .unwrap_or_default();
965
966 crate::providers::WebSearchContent { results }
967}
968
969#[cfg(test)]
970mod tests {
971 use super::*;
972
973 #[test]
974 fn take_next_sse_frame_supports_crlf_delimiter() {
975 let mut buffer = concat!(
976 "event: message_start\r\n",
977 "data: {\"type\":\"message_start\"}\r\n\r\n",
978 "data: {\"type\":\"message_stop\"}\r\n\r\n"
979 )
980 .to_string();
981
982 let first = take_next_sse_frame(&mut buffer).expect("first frame");
983 assert!(first.contains("message_start"));
984
985 let second = take_next_sse_frame(&mut buffer).expect("second frame");
986 assert!(second.contains("message_stop"));
987 assert!(buffer.is_empty());
988 }
989
990 #[test]
991 fn take_trailing_sse_frame_returns_unterminated_event() {
992 let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
993 let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
994 assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
995 assert!(buffer.is_empty());
996 }
997
998 #[test]
999 fn extract_sse_data_line_supports_optional_space() {
1000 assert_eq!(
1001 extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
1002 Some("{\"k\":1}".to_string())
1003 );
1004 assert_eq!(
1005 extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
1006 Some("{\"k\":2}".to_string())
1007 );
1008 }
1009
1010 #[test]
1011 fn finalize_incomplete_stream_keeps_accumulated_content() {
1012 let response = finalize_incomplete_stream(
1013 vec![AssembledBlock::Text("partial reply".to_string())],
1014 StopReason::EndTurn,
1015 Usage::default(),
1016 );
1017
1018 assert_eq!(response.stop_reason, StopReason::EndTurn);
1019 assert_eq!(response.content.len(), 1);
1020 match &response.content[0] {
1021 ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
1022 other => panic!("unexpected block: {other:?}"),
1023 }
1024 }
1025}