1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use tokio::sync::mpsc;
7
8use crate::models::context_window_for;
9use crate::tools::ToolDefinition;
10
11use super::{
12 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
13 StreamEvent, Usage,
14};
15
16pub struct AnthropicProvider {
17 api_key: String,
18 model: String,
19 base_url: String,
20 client: reqwest::Client,
21 extra_headers: Vec<(String, String)>,
23}
24
25impl AnthropicProvider {
26 pub fn new(api_key: String, model: String, base_url: String) -> Self {
27 Self::with_headers(api_key, model, base_url, None)
28 }
29
30 pub fn with_headers(
31 api_key: String,
32 model: String,
33 base_url: String,
34 extra_headers: Option<std::collections::HashMap<String, String>>,
35 ) -> Self {
36 let client = reqwest::Client::builder()
41 .connect_timeout(std::time::Duration::from_secs(10))
42 .read_timeout(std::time::Duration::from_secs(60))
43 .timeout(std::time::Duration::from_secs(300)) .build()
45 .unwrap_or_else(|_| reqwest::Client::new());
46 let extra_headers: Vec<(String, String)> = extra_headers
47 .map(|h| h.into_iter().collect())
48 .unwrap_or_default();
49 Self {
50 api_key,
51 model,
52 base_url,
53 client,
54 extra_headers,
55 }
56 }
57
58 fn is_official_anthropic(&self) -> bool {
61 self.base_url.contains("api.anthropic.com")
62 }
63
64 fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
65 messages
66 .iter()
67 .filter(|m| m.role != Role::System)
68 .map(|m| {
69 let role = match m.role {
70 Role::User | Role::Tool => "user",
71 Role::Assistant => "assistant",
72 Role::System => unreachable!(),
73 };
74
75 let content = match &m.content {
76 MessageContent::Text(text) => json!(text),
77 MessageContent::Blocks(blocks) => {
78 let converted: Vec<Value> = blocks
79 .iter()
80 .map(|b| match b {
81 ContentBlock::Text { text } => json!({"type": "text", "text": text}),
82 ContentBlock::ToolUse { id, name, input } => {
83 json!({"type": "tool_use", "id": id, "name": name, "input": input})
84 }
85 ContentBlock::ToolResult { tool_use_id, content } => {
86 json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
87 }
88 ContentBlock::Thinking { thinking, signature } => {
89 let mut obj = json!({"type": "thinking", "thinking": thinking});
90 if let Some(sig) = signature {
91 obj["signature"] = json!(sig);
92 }
93 obj
94 }
95 ContentBlock::ServerToolUse { id, name, input } => {
96 json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
97 }
98 ContentBlock::WebSearchResult { tool_use_id, content } => {
99 json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
100 }
101 })
102 .collect();
103 json!(converted)
104 }
105 };
106
107 json!({"role": role, "content": content})
108 })
109 .collect()
110 }
111
112 fn convert_tools_with_caching(
114 &self,
115 tools: &[ToolDefinition],
116 enable_caching: bool,
117 ) -> Vec<Value> {
118 let mut converted: Vec<Value> = tools
119 .iter()
120 .map(|t| {
121 json!({
122 "name": t.name,
123 "description": t.description,
124 "input_schema": t.parameters,
125 })
126 })
127 .collect();
128
129 if enable_caching && !converted.is_empty() {
131 let last_idx = converted.len() - 1;
132 if let Some(obj) = converted[last_idx].as_object_mut() {
133 obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
134 }
135 }
136
137 converted
138 }
139
140 fn build_body(&self, request: &ChatRequest) -> Value {
142 let mut body = json!({
143 "model": self.model,
144 "max_tokens": request.max_tokens,
145 "messages": self.convert_messages(&request.messages),
146 });
147
148 if request.enable_caching {
150 if let Some(system) = &request.system {
151 body["system"] = json!([
153 {
154 "type": "text",
155 "text": system,
156 "cache_control": {"type": "ephemeral"}
157 }
158 ]);
159 }
160 } else if let Some(system) = &request.system {
161 body["system"] = json!(system);
162 }
163
164 if !request.tools.is_empty() {
165 let tools = self.convert_tools_with_caching(
166 &request.tools,
167 request.enable_caching,
168 );
169 body["tools"] = json!(tools);
170 }
171
172 if !request.server_tools.is_empty() {
173 body["tools"] = json!(
174 body["tools"]
175 .as_array()
176 .map(|t| {
177 let mut tools = t.clone();
178 for st in &request.server_tools {
179 tools.push(serde_json::to_value(st).unwrap_or_default());
180 }
181 tools
182 })
183 .unwrap_or_else(|| request
184 .server_tools
185 .iter()
186 .map(|st| serde_json::to_value(st).unwrap_or_default())
187 .collect())
188 );
189 }
190
191 if request.think {
193 let config = thinking_config(&self.model);
194 log::debug!(
195 "Adding thinking config for model {}: {:?}",
196 self.model,
197 config
198 );
199 body["thinking"] = config;
200 }
201
202 body
203 }
204}
205
206fn thinking_config(model: &str) -> Value {
211 let m = model.to_lowercase();
212 let adaptive = m.contains("opus-4")
214 || m.contains("sonnet-4")
215 || m.contains("claude-4")
216 || m.contains("20250")
217 || m.contains("2025");
218 if adaptive {
219 json!({"type": "enabled", "budget_tokens": 10000})
220 } else {
221 json!({"type": "enabled", "budget_tokens": 5000})
222 }
223}
224
225#[async_trait]
226impl Provider for AnthropicProvider {
227 fn context_size(&self) -> Option<u32> {
228 context_window_for(&self.model)
229 }
230
231 fn clone_box(&self) -> Box<dyn Provider> {
232 Box::new(Self {
233 api_key: self.api_key.clone(),
234 model: self.model.clone(),
235 base_url: self.base_url.clone(),
236 client: reqwest::Client::new(),
237 extra_headers: self.extra_headers.clone(),
238 })
239 }
240
241 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
242 let body = self.build_body(&request);
243
244 let url = format!("{}/v1/messages", self.base_url);
245 let mut req = self
246 .client
247 .post(&url)
248 .header("User-Agent", "curl/8.0")
249 .json(&body);
250
251 if self.is_official_anthropic() {
253 req = req
254 .header("x-api-key", &self.api_key)
255 .header("anthropic-version", "2025-04-15")
256 .header("anthropic-beta", "prompt-caching-2024-07-31");
257 } else {
258 req = req.header("Authorization", format!("Bearer {}", self.api_key));
259 }
260
261 for (name, value) in &self.extra_headers {
263 req = req.header(name, value);
264 }
265
266 let response = req.send().await?;
267
268 let status = response.status();
269 let response_body: Value = response.json().await?;
270
271 if !status.is_success() {
272 let err_msg = response_body["error"]["message"]
273 .as_str()
274 .unwrap_or("unknown error");
275 anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
276 }
277
278 let stop_reason = match response_body["stop_reason"].as_str() {
279 Some("tool_use") => StopReason::ToolUse,
280 Some("max_tokens") => StopReason::MaxTokens,
281 _ => StopReason::EndTurn,
282 };
283
284 let content = response_body["content"]
285 .as_array()
286 .unwrap_or(&vec![])
287 .iter()
288 .filter_map(|block| match block["type"].as_str()? {
289 "text" => Some(ContentBlock::Text {
290 text: block["text"].as_str()?.to_string(),
291 }),
292 "tool_use" => Some(ContentBlock::ToolUse {
293 id: block["id"].as_str()?.to_string(),
294 name: block["name"].as_str()?.to_string(),
295 input: block["input"].clone(),
296 }),
297 "thinking" => Some(ContentBlock::Thinking {
298 thinking: block["thinking"].as_str()?.to_string(),
299 signature: block["signature"].as_str().map(String::from),
300 }),
301 "server_tool_use" => Some(ContentBlock::ServerToolUse {
302 id: block["id"].as_str()?.to_string(),
303 name: block["name"].as_str()?.to_string(),
304 input: block["input"].clone(),
305 }),
306 "web_search_tool_result" => {
307 let tool_use_id = block["tool_use_id"].as_str()?.to_string();
308 let content = parse_web_search_content(&block["content"]);
309 Some(ContentBlock::WebSearchResult {
310 tool_use_id,
311 content,
312 })
313 }
314 _ => None,
315 })
316 .collect();
317
318 Ok(ChatResponse {
319 content,
320 stop_reason,
321 usage: parse_usage(&response_body["usage"]),
322 })
323 }
324
325 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
326 let mut body = self.build_body(&request);
327 body["stream"] = json!(true);
328
329 let url = format!("{}/v1/messages", self.base_url);
330 let mut req = self
331 .client
332 .post(&url)
333 .header("User-Agent", "curl/8.0")
334 .json(&body);
335
336 if self.is_official_anthropic() {
338 req = req
339 .header("x-api-key", &self.api_key)
340 .header("anthropic-version", "2025-04-15")
341 .header("anthropic-beta", "prompt-caching-2024-07-31");
342 } else {
343 req = req.header("Authorization", format!("Bearer {}", self.api_key));
344 }
345
346 for (name, value) in &self.extra_headers {
348 req = req.header(name, value);
349 }
350
351 let response = req.send().await?;
352
353 if !response.status().is_success() {
354 let status = response.status();
355 let text = response.text().await.unwrap_or_default();
356 anyhow::bail!("Anthropic API error ({}): {}", status, text);
357 }
358
359 let (tx, rx) = mpsc::channel(64);
360 tokio::spawn(async move {
361 let mut stream = response.bytes_stream();
362 let mut buffer = String::new();
363 let mut sent_first_byte = false;
364
365 let mut blocks: Vec<AssembledBlock> = Vec::new();
367 let mut stop_reason = StopReason::EndTurn;
368 let mut usage = Usage::default();
369
370 while let Some(chunk) = stream.next().await {
371 let chunk = match chunk {
372 Ok(c) => c,
373 Err(e) => {
374 let error_msg = e.to_string();
376 let is_timeout = error_msg.contains("timeout") || error_msg.contains("timed out");
377 let is_decode = error_msg.contains("decode") || error_msg.contains("decoding");
378
379 let msg = if is_timeout {
380 format!("Stream timeout - the API took too long to respond: {}", error_msg)
381 } else if is_decode {
382 format!("Stream decode error - possibly interrupted or corrupted response: {}", error_msg)
383 } else {
384 format!("Stream read error: {}", error_msg)
385 };
386
387 if sent_first_byte && !blocks.is_empty() {
389 debug!("finalizing partial stream due to error");
390 let _ = tx.send(StreamEvent::Done(finalize_incomplete_stream(
391 std::mem::take(&mut blocks),
392 stop_reason,
393 usage,
394 ))).await;
395 } else {
396 let _ = tx.send(StreamEvent::Error(msg)).await;
397 }
398 return;
399 }
400 };
401
402 if !sent_first_byte {
403 sent_first_byte = true;
404 let _ = tx.send(StreamEvent::FirstByte).await;
405 }
406
407 buffer.push_str(&String::from_utf8_lossy(&chunk));
408
409 while let Some(frame) = take_next_sse_frame(&mut buffer) {
410 if handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx)
411 .await
412 {
413 return;
414 }
415 }
416 }
417
418 if let Some(frame) = take_trailing_sse_frame(&mut buffer)
419 && handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx).await
420 {
421 return;
422 }
423
424 if sent_first_byte {
425 debug!("stream ended without explicit message_stop; finalizing best-effort");
426 let _ = tx
427 .send(StreamEvent::Done(finalize_incomplete_stream(
428 std::mem::take(&mut blocks),
429 stop_reason,
430 usage,
431 )))
432 .await;
433 } else {
434 let _ = tx
435 .send(StreamEvent::Error(
436 "stream ended before any events were received".to_string(),
437 ))
438 .await;
439 }
440 });
441
442 Ok(rx)
443 }
444}
445
446fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
447 let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
448 let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
449 let (pos, delim_len) = match (lf, crlf) {
450 (Some(a), Some(b)) => {
451 if a.0 <= b.0 {
452 a
453 } else {
454 b
455 }
456 }
457 (Some(a), None) => a,
458 (None, Some(b)) => b,
459 (None, None) => return None,
460 };
461
462 let frame = buffer[..pos].to_string();
463 buffer.drain(..pos + delim_len);
464 Some(frame)
465}
466
467fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
468 let frame = buffer.trim().trim_end_matches('\r').to_string();
469 buffer.clear();
470 if frame.is_empty() { None } else { Some(frame) }
471}
472
473fn extract_sse_data_line(frame: &str) -> Option<String> {
474 for line in frame.lines() {
475 let line = line.trim_end_matches('\r');
476 if let Some(rest) = line.strip_prefix("data: ") {
478 return Some(rest.to_string());
479 }
480 if let Some(rest) = line.strip_prefix("data:") {
481 return Some(rest.to_string());
482 }
483 }
484 None
485}
486
487async fn handle_sse_frame(
488 frame: &str,
489 blocks: &mut Vec<AssembledBlock>,
490 stop_reason: &mut StopReason,
491 usage: &mut Usage,
492 tx: &mpsc::Sender<StreamEvent>,
493) -> bool {
494 let Some(data_line) = extract_sse_data_line(frame) else {
495 return false;
496 };
497
498 let evt: Value = match serde_json::from_str(&data_line) {
499 Ok(v) => v,
500 Err(_) => return false,
501 };
502
503 handle_sse_event(evt, blocks, stop_reason, usage, tx).await
504}
505
506async fn handle_sse_event(
507 evt: Value,
508 blocks: &mut Vec<AssembledBlock>,
509 stop_reason: &mut StopReason,
510 usage: &mut Usage,
511 tx: &mpsc::Sender<StreamEvent>,
512) -> bool {
513 match evt["type"].as_str().unwrap_or("") {
514 "message_start" => {
515 *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
520 debug!(
521 "message_start: usage_json={}",
522 serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
523 );
524 debug!(
525 "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
526 usage.input_tokens,
527 usage.output_tokens,
528 usage.cache_read_input_tokens,
529 usage.cache_creation_input_tokens
530 );
531 let _ = tx
533 .send(StreamEvent::Usage {
534 output_tokens: usage.output_tokens,
535 })
536 .await;
537 }
538 "content_block_start" => {
539 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
540 let block = &evt["content_block"];
541 let kind = block["type"].as_str().unwrap_or("");
542 while blocks.len() <= idx {
543 blocks.push(AssembledBlock::default());
544 }
545 match kind {
546 "text" => {
547 blocks[idx] = AssembledBlock::Text(String::new());
548 }
549 "thinking" => {
550 blocks[idx] = AssembledBlock::Thinking {
551 text: String::new(),
552 signature: None,
553 };
554 }
555 "tool_use" | "server_tool_use" => {
556 let id = block["id"].as_str().unwrap_or_default();
557 let name = block["name"].as_str().unwrap_or_default();
558 let is_server = kind == "server_tool_use";
559 blocks[idx] = if is_server {
560 AssembledBlock::ServerToolUse {
561 id: id.into(),
562 name: name.into(),
563 input_json: String::new(),
564 }
565 } else {
566 AssembledBlock::ToolUse {
567 id: id.into(),
568 name: name.into(),
569 input_json: String::new(),
570 }
571 };
572 let _ = tx
573 .send(StreamEvent::ToolUseStart {
574 id: id.into(),
575 name: name.into(),
576 })
577 .await;
578 }
579 "web_search_tool_result" => {
580 let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
581 blocks[idx] = AssembledBlock::WebSearchResult {
582 tool_use_id,
583 content_json: String::new(),
584 };
585 }
586 _ => {}
587 }
588 }
589 "content_block_delta" => {
590 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
591 let delta = &evt["delta"];
592 let dt = delta["type"].as_str().unwrap_or("");
593 if idx >= blocks.len() {
594 return false;
595 }
596 match (dt, &mut blocks[idx]) {
597 ("text_delta", AssembledBlock::Text(buf)) => {
598 if let Some(t) = delta["text"].as_str() {
599 buf.push_str(t);
600 let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
601 }
602 }
603 ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
604 if let Some(t) = delta["thinking"].as_str() {
605 text.push_str(t);
606 log::debug!("Received thinking_delta: {} chars", t.len());
607 let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
608 }
609 }
610 ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
611 if let Some(s) = delta["signature"].as_str() {
612 signature.get_or_insert_with(String::new).push_str(s);
613 }
614 }
615 ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
616 if let Some(p) = delta["partial_json"].as_str() {
617 input_json.push_str(p);
618 let _ = tx
619 .send(StreamEvent::ToolInputDelta {
620 bytes_so_far: input_json.len(),
621 })
622 .await;
623 }
624 }
625 ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
626 if let Some(p) = delta["partial_json"].as_str() {
627 input_json.push_str(p);
628 let _ = tx
629 .send(StreamEvent::ToolInputDelta {
630 bytes_so_far: input_json.len(),
631 })
632 .await;
633 }
634 }
635 _ => {}
636 }
637 }
638 "message_delta" => {
639 if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
640 *stop_reason = match sr {
641 "tool_use" => StopReason::ToolUse,
642 "max_tokens" => StopReason::MaxTokens,
643 _ => StopReason::EndTurn,
644 };
645 }
646 *usage = merge_usage(usage.clone(), &evt["usage"]);
650 debug!(
651 "message_delta: input={}, output={}, cache_read={}, cache_created={}",
652 usage.input_tokens,
653 usage.output_tokens,
654 usage.cache_read_input_tokens,
655 usage.cache_creation_input_tokens
656 );
657 let _ = tx
659 .send(StreamEvent::Usage {
660 output_tokens: usage.output_tokens,
661 })
662 .await;
663 }
664 "message_stop" => {
665 debug!(
666 "Message completed: stop_reason={}, usage={}",
667 match *stop_reason {
668 StopReason::EndTurn => "end_turn",
669 StopReason::ToolUse => "tool_use",
670 StopReason::MaxTokens => "max_tokens",
671 },
672 usage.output_tokens
673 );
674 debug!(
675 "message_stop final usage: cache_read={}, cache_created={}",
676 usage.cache_read_input_tokens, usage.cache_creation_input_tokens
677 );
678 let _ = tx
679 .send(StreamEvent::Done(finalize_incomplete_stream(
680 std::mem::take(blocks),
681 stop_reason.clone(),
682 usage.clone(),
683 )))
684 .await;
685 return true;
686 }
687 "error" => {
688 let msg = evt["error"]["message"]
689 .as_str()
690 .unwrap_or("unknown stream error")
691 .to_string();
692 let _ = tx.send(StreamEvent::Error(msg)).await;
693 return true;
694 }
695 _ => {}
696 }
697
698 false
699}
700
701fn build_chat_response(
702 blocks: Vec<AssembledBlock>,
703 stop_reason: StopReason,
704 usage: Usage,
705) -> ChatResponse {
706 let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
707 ChatResponse {
708 content,
709 stop_reason,
710 usage,
711 }
712}
713
714fn finalize_incomplete_stream(
715 blocks: Vec<AssembledBlock>,
716 stop_reason: StopReason,
717 usage: Usage,
718) -> ChatResponse {
719 build_chat_response(blocks, stop_reason, usage)
720}
721
722#[derive(Default)]
723enum AssembledBlock {
724 #[default]
725 Empty,
726 Text(String),
727 Thinking {
728 text: String,
729 signature: Option<String>,
730 },
731 ToolUse {
732 id: String,
733 name: String,
734 input_json: String,
735 },
736 ServerToolUse {
737 id: String,
738 name: String,
739 input_json: String,
740 },
741 WebSearchResult {
742 tool_use_id: String,
743 content_json: String,
744 },
745}
746
747impl AssembledBlock {
748 fn finish(self) -> Option<ContentBlock> {
749 match self {
750 AssembledBlock::Empty => None,
751 AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
752 AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
753 thinking: text,
754 signature,
755 }),
756 AssembledBlock::ToolUse {
757 id,
758 name,
759 input_json,
760 } => {
761 let input: Value = if input_json.is_empty() {
762 json!({})
763 } else {
764 serde_json::from_str(&input_json).unwrap_or(json!({}))
765 };
766 Some(ContentBlock::ToolUse { id, name, input })
767 }
768 AssembledBlock::ServerToolUse {
769 id,
770 name,
771 input_json,
772 } => {
773 let input: Value = if input_json.is_empty() {
774 json!({})
775 } else {
776 serde_json::from_str(&input_json).unwrap_or(json!({}))
777 };
778 Some(ContentBlock::ServerToolUse { id, name, input })
779 }
780 AssembledBlock::WebSearchResult {
781 tool_use_id,
782 content_json,
783 } => {
784 let content: Value = if content_json.is_empty() {
785 json!({"results": []})
786 } else {
787 serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
788 };
789 Some(ContentBlock::WebSearchResult {
790 tool_use_id,
791 content: parse_web_search_content(&content),
792 })
793 }
794 }
795 }
796}
797
798fn parse_usage(u: &Value) -> Usage {
801 Usage {
802 input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
803 output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
804 cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
805 cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
806 }
807}
808
809fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
813 let fresh = parse_usage(u);
814 if fresh.input_tokens > 0 {
815 acc.input_tokens = fresh.input_tokens;
816 }
817 if fresh.output_tokens > 0 {
818 acc.output_tokens = fresh.output_tokens;
819 }
820 if fresh.cache_creation_input_tokens > 0 {
821 acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
822 }
823 if fresh.cache_read_input_tokens > 0 {
824 acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
825 }
826 acc
827}
828
829fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
831 let results = value["results"]
832 .as_array()
833 .map(|arr| {
834 arr.iter()
835 .filter_map(|item| {
836 Some(crate::providers::WebSearchResultItem {
837 title: item["title"].as_str().map(String::from),
838 url: item["url"].as_str()?.to_string(),
839 encrypted_content: item["encrypted_content"].as_str().map(String::from),
840 snippet: item["snippet"].as_str().map(String::from),
841 })
842 })
843 .collect()
844 })
845 .unwrap_or_default();
846
847 crate::providers::WebSearchContent { results }
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853
854 #[test]
855 fn take_next_sse_frame_supports_crlf_delimiter() {
856 let mut buffer = concat!(
857 "event: message_start\r\n",
858 "data: {\"type\":\"message_start\"}\r\n\r\n",
859 "data: {\"type\":\"message_stop\"}\r\n\r\n"
860 )
861 .to_string();
862
863 let first = take_next_sse_frame(&mut buffer).expect("first frame");
864 assert!(first.contains("message_start"));
865
866 let second = take_next_sse_frame(&mut buffer).expect("second frame");
867 assert!(second.contains("message_stop"));
868 assert!(buffer.is_empty());
869 }
870
871 #[test]
872 fn take_trailing_sse_frame_returns_unterminated_event() {
873 let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
874 let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
875 assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
876 assert!(buffer.is_empty());
877 }
878
879 #[test]
880 fn extract_sse_data_line_supports_optional_space() {
881 assert_eq!(
882 extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
883 Some("{\"k\":1}".to_string())
884 );
885 assert_eq!(
886 extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
887 Some("{\"k\":2}".to_string())
888 );
889 }
890
891 #[test]
892 fn finalize_incomplete_stream_keeps_accumulated_content() {
893 let response = finalize_incomplete_stream(
894 vec![AssembledBlock::Text("partial reply".to_string())],
895 StopReason::EndTurn,
896 Usage::default(),
897 );
898
899 assert_eq!(response.stop_reason, StopReason::EndTurn);
900 assert_eq!(response.content.len(), 1);
901 match &response.content[0] {
902 ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
903 other => panic!("unexpected block: {other:?}"),
904 }
905 }
906}