1use anyhow::Result;
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use log::debug;
5use serde_json::{Value, json};
6use tokio::sync::mpsc;
7
8use crate::tools::ToolDefinition;
9
10use super::{
11 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
12 StreamEvent, Usage,
13};
14
15pub struct AnthropicProvider {
16 api_key: String,
17 model: String,
18 base_url: String,
19 client: reqwest::Client,
20 is_dashscope: bool,
22}
23
24impl AnthropicProvider {
25 pub fn new(api_key: String, model: String, base_url: String) -> Self {
26 let is_dashscope = base_url.contains("dashscope.aliyuncs.com");
27 let client = reqwest::Client::builder()
28 .danger_accept_invalid_certs(true)
29 .build()
30 .unwrap_or_else(|_| reqwest::Client::new());
31 Self {
32 api_key,
33 model,
34 base_url,
35 client,
36 is_dashscope,
37 }
38 }
39
40 fn convert_messages(&self, messages: &[Message]) -> Vec<Value> {
41 messages
42 .iter()
43 .filter(|m| m.role != Role::System)
44 .map(|m| {
45 let role = match m.role {
46 Role::User | Role::Tool => "user",
47 Role::Assistant => "assistant",
48 Role::System => unreachable!(),
49 };
50
51 let content = match &m.content {
52 MessageContent::Text(text) => json!(text),
53 MessageContent::Blocks(blocks) => {
54 let converted: Vec<Value> = blocks
55 .iter()
56 .map(|b| match b {
57 ContentBlock::Text { text } => json!({"type": "text", "text": text}),
58 ContentBlock::ToolUse { id, name, input } => {
59 json!({"type": "tool_use", "id": id, "name": name, "input": input})
60 }
61 ContentBlock::ToolResult { tool_use_id, content } => {
62 json!({"type": "tool_result", "tool_use_id": tool_use_id, "content": content})
63 }
64 ContentBlock::Thinking { thinking, signature } => {
65 let mut obj = json!({"type": "thinking", "thinking": thinking});
66 if let Some(sig) = signature {
67 obj["signature"] = json!(sig);
68 }
69 obj
70 }
71 ContentBlock::ServerToolUse { id, name, input } => {
72 json!({"type": "server_tool_use", "id": id, "name": name, "input": input})
73 }
74 ContentBlock::WebSearchResult { tool_use_id, content } => {
75 json!({"type": "web_search_tool_result", "tool_use_id": tool_use_id, "content": content})
76 }
77 })
78 .collect();
79 json!(converted)
80 }
81 };
82
83 json!({"role": role, "content": content})
84 })
85 .collect()
86 }
87
88 fn convert_tools_with_caching(&self, tools: &[ToolDefinition], enable_caching: bool) -> Vec<Value> {
90 let mut converted: Vec<Value> = tools
91 .iter()
92 .map(|t| {
93 json!({
94 "name": t.name,
95 "description": t.description,
96 "input_schema": t.parameters,
97 })
98 })
99 .collect();
100
101 if enable_caching && !converted.is_empty() {
103 let last_idx = converted.len() - 1;
104 if let Some(obj) = converted[last_idx].as_object_mut() {
105 obj.insert("cache_control".to_string(), json!({"type": "ephemeral"}));
106 }
107 }
108
109 converted
110 }
111
112 fn build_body(&self, request: &ChatRequest) -> Value {
114 let mut body = json!({
115 "model": self.model,
116 "max_tokens": request.max_tokens,
117 "messages": self.convert_messages(&request.messages),
118 });
119
120 if request.enable_caching && !self.is_dashscope {
122 if let Some(system) = &request.system {
123 body["system"] = json!([
125 {
126 "type": "text",
127 "text": system,
128 "cache_control": {"type": "ephemeral"}
129 }
130 ]);
131 }
132 } else if let Some(system) = &request.system {
133 body["system"] = json!(system);
134 }
135
136 if !request.tools.is_empty() {
137 let tools = self.convert_tools_with_caching(&request.tools, request.enable_caching && !self.is_dashscope);
138 body["tools"] = json!(tools);
139 }
140
141 if !request.server_tools.is_empty() {
142 body["tools"] = json!(body["tools"]
143 .as_array()
144 .map(|t| {
145 let mut tools = t.clone();
146 for st in &request.server_tools {
147 tools.push(serde_json::to_value(st).unwrap_or_default());
148 }
149 tools
150 })
151 .unwrap_or_else(|| request.server_tools.iter().map(|st| serde_json::to_value(st).unwrap_or_default()).collect()));
152 }
153
154 if request.think && !self.is_dashscope {
156 let config = thinking_config(&self.model);
157 log::debug!("Adding thinking config for model {}: {:?}", self.model, config);
158 body["thinking"] = config;
159 } else if !request.think {
160 log::debug!("Thinking disabled by request.think=false");
161 } else if self.is_dashscope {
162 log::debug!("Thinking disabled for DashScope");
163 }
164
165 body
166 }
167}
168
169fn thinking_config(model: &str) -> Value {
174 let m = model.to_lowercase();
175 let adaptive = m.contains("opus-4") || m.contains("sonnet-4") || m.contains("claude-4")
177 || m.contains("20250") || m.contains("2025");
178 if adaptive {
179 json!({"type": "enabled", "budget_tokens": 10000})
180 } else {
181 json!({"type": "enabled", "budget_tokens": 5000})
182 }
183}
184
185#[async_trait]
186impl Provider for AnthropicProvider {
187 fn context_size(&self) -> Option<u32> {
188 context_window_for(&self.model)
189 }
190
191 fn clone_box(&self) -> Box<dyn Provider> {
192 Box::new(Self {
193 api_key: self.api_key.clone(),
194 model: self.model.clone(),
195 base_url: self.base_url.clone(),
196 client: reqwest::Client::new(),
197 is_dashscope: self.is_dashscope,
198 })
199 }
200
201 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
202 let body = self.build_body(&request);
203
204 let url = format!("{}/v1/messages", self.base_url);
205 let mut req = self
206 .client
207 .post(&url)
208 .header("User-Agent", "curl/8.0")
209 .json(&body);
210
211 if self.is_dashscope {
213 req = req.header("Authorization", format!("Bearer {}", self.api_key));
214 } else {
215 req = req.header("x-api-key", &self.api_key)
216 .header("anthropic-version", "2025-04-15")
217 .header("anthropic-beta", "prompt-caching-2024-07-31");
218 }
219
220 let response = req.send().await?;
221
222 let status = response.status();
223 let response_body: Value = response.json().await?;
224
225 if !status.is_success() {
226 let err_msg = response_body["error"]["message"]
227 .as_str()
228 .unwrap_or("unknown error");
229 anyhow::bail!("Anthropic API error ({}): {}", status, err_msg);
230 }
231
232 let stop_reason = match response_body["stop_reason"].as_str() {
233 Some("tool_use") => StopReason::ToolUse,
234 Some("max_tokens") => StopReason::MaxTokens,
235 _ => StopReason::EndTurn,
236 };
237
238 let content = response_body["content"]
239 .as_array()
240 .unwrap_or(&vec![])
241 .iter()
242 .filter_map(|block| match block["type"].as_str()? {
243 "text" => Some(ContentBlock::Text {
244 text: block["text"].as_str()?.to_string(),
245 }),
246 "tool_use" => Some(ContentBlock::ToolUse {
247 id: block["id"].as_str()?.to_string(),
248 name: block["name"].as_str()?.to_string(),
249 input: block["input"].clone(),
250 }),
251 "thinking" => Some(ContentBlock::Thinking {
252 thinking: block["thinking"].as_str()?.to_string(),
253 signature: block["signature"].as_str().map(String::from),
254 }),
255 "server_tool_use" => Some(ContentBlock::ServerToolUse {
256 id: block["id"].as_str()?.to_string(),
257 name: block["name"].as_str()?.to_string(),
258 input: block["input"].clone(),
259 }),
260 "web_search_tool_result" => {
261 let tool_use_id = block["tool_use_id"].as_str()?.to_string();
262 let content = parse_web_search_content(&block["content"]);
263 Some(ContentBlock::WebSearchResult {
264 tool_use_id,
265 content,
266 })
267 }
268 _ => None,
269 })
270 .collect();
271
272 Ok(ChatResponse {
273 content,
274 stop_reason,
275 usage: parse_usage(&response_body["usage"]),
276 })
277 }
278
279 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
280 let mut body = self.build_body(&request);
281 body["stream"] = json!(true);
282
283 let url = format!("{}/v1/messages", self.base_url);
284 let mut req = self
285 .client
286 .post(&url)
287 .header("User-Agent", "curl/8.0")
288 .json(&body);
289
290 if self.is_dashscope {
292 req = req
293 .header("Authorization", format!("Bearer {}", self.api_key))
294 .header("X-DashScope-SSE", "enable");
295 } else {
296 req = req.header("x-api-key", &self.api_key)
297 .header("anthropic-version", "2025-04-15")
298 .header("anthropic-beta", "prompt-caching-2024-07-31");
299 }
300
301 let response = req.send().await?;
302
303 if !response.status().is_success() {
304 let status = response.status();
305 let text = response.text().await.unwrap_or_default();
306 anyhow::bail!("Anthropic API error ({}): {}", status, text);
307 }
308
309 let (tx, rx) = mpsc::channel(64);
310 tokio::spawn(async move {
311 let mut stream = response.bytes_stream();
312 let mut buffer = String::new();
313 let mut sent_first_byte = false;
314
315 let mut blocks: Vec<AssembledBlock> = Vec::new();
317 let mut stop_reason = StopReason::EndTurn;
318 let mut usage = Usage::default();
319
320 while let Some(chunk) = stream.next().await {
321 let chunk = match chunk {
322 Ok(c) => c,
323 Err(e) => {
324 let _ = tx
325 .send(StreamEvent::Error(format!("stream read error: {}", e)))
326 .await;
327 return;
328 }
329 };
330
331 if !sent_first_byte {
332 sent_first_byte = true;
333 let _ = tx.send(StreamEvent::FirstByte).await;
334 }
335
336 buffer.push_str(&String::from_utf8_lossy(&chunk));
337
338 while let Some(frame) = take_next_sse_frame(&mut buffer) {
339 if handle_sse_frame(
340 &frame,
341 &mut blocks,
342 &mut stop_reason,
343 &mut usage,
344 &tx,
345 )
346 .await
347 {
348 return;
349 }
350 }
351 }
352
353 if let Some(frame) = take_trailing_sse_frame(&mut buffer)
354 && handle_sse_frame(&frame, &mut blocks, &mut stop_reason, &mut usage, &tx).await {
355 return;
356 }
357
358 if sent_first_byte {
359 debug!(
360 "stream ended without explicit message_stop; finalizing best-effort"
361 );
362 let _ = tx
363 .send(StreamEvent::Done(finalize_incomplete_stream(
364 std::mem::take(&mut blocks),
365 stop_reason,
366 usage,
367 )))
368 .await;
369 } else {
370 let _ = tx
371 .send(StreamEvent::Error(
372 "stream ended before any events were received".to_string(),
373 ))
374 .await;
375 }
376 });
377
378 Ok(rx)
379 }
380}
381
382fn take_next_sse_frame(buffer: &mut String) -> Option<String> {
383 let lf = buffer.find("\n\n").map(|pos| (pos, 2usize));
384 let crlf = buffer.find("\r\n\r\n").map(|pos| (pos, 4usize));
385 let (pos, delim_len) = match (lf, crlf) {
386 (Some(a), Some(b)) => {
387 if a.0 <= b.0 {
388 a
389 } else {
390 b
391 }
392 }
393 (Some(a), None) => a,
394 (None, Some(b)) => b,
395 (None, None) => return None,
396 };
397
398 let frame = buffer[..pos].to_string();
399 buffer.drain(..pos + delim_len);
400 Some(frame)
401}
402
403fn take_trailing_sse_frame(buffer: &mut String) -> Option<String> {
404 let frame = buffer.trim().trim_end_matches('\r').to_string();
405 buffer.clear();
406 if frame.is_empty() {
407 None
408 } else {
409 Some(frame)
410 }
411}
412
413fn extract_sse_data_line(frame: &str) -> Option<String> {
414 for line in frame.lines() {
415 let line = line.trim_end_matches('\r');
416 if let Some(rest) = line.strip_prefix("data: ") {
418 return Some(rest.to_string());
419 }
420 if let Some(rest) = line.strip_prefix("data:") {
421 return Some(rest.to_string());
422 }
423 }
424 None
425}
426
427async fn handle_sse_frame(
428 frame: &str,
429 blocks: &mut Vec<AssembledBlock>,
430 stop_reason: &mut StopReason,
431 usage: &mut Usage,
432 tx: &mpsc::Sender<StreamEvent>,
433) -> bool {
434 let Some(data_line) = extract_sse_data_line(frame) else {
435 return false;
436 };
437
438 let evt: Value = match serde_json::from_str(&data_line) {
439 Ok(v) => v,
440 Err(_) => return false,
441 };
442
443 handle_sse_event(evt, blocks, stop_reason, usage, tx).await
444}
445
446async fn handle_sse_event(
447 evt: Value,
448 blocks: &mut Vec<AssembledBlock>,
449 stop_reason: &mut StopReason,
450 usage: &mut Usage,
451 tx: &mpsc::Sender<StreamEvent>,
452) -> bool {
453 match evt["type"].as_str().unwrap_or("") {
454 "message_start" => {
455 *usage = merge_usage(usage.clone(), &evt["message"]["usage"]);
460 debug!(
461 "message_start: usage_json={}",
462 serde_json::to_string(&evt["message"]["usage"]).unwrap_or_default()
463 );
464 debug!(
465 "message_start parsed: input={}, output={}, cache_read={}, cache_created={}",
466 usage.input_tokens, usage.output_tokens,
467 usage.cache_read_input_tokens, usage.cache_creation_input_tokens
468 );
469 let _ = tx.send(StreamEvent::Usage { output_tokens: usage.output_tokens }).await;
471 }
472 "content_block_start" => {
473 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
474 let block = &evt["content_block"];
475 let kind = block["type"].as_str().unwrap_or("");
476 while blocks.len() <= idx {
477 blocks.push(AssembledBlock::default());
478 }
479 match kind {
480 "text" => {
481 blocks[idx] = AssembledBlock::Text(String::new());
482 }
483 "thinking" => {
484 blocks[idx] = AssembledBlock::Thinking {
485 text: String::new(),
486 signature: None,
487 };
488 }
489 "tool_use" => {
490 let id = block["id"].as_str().unwrap_or("").to_string();
491 let name = block["name"].as_str().unwrap_or("").to_string();
492 blocks[idx] = AssembledBlock::ToolUse {
493 id: id.clone(),
494 name: name.clone(),
495 input_json: String::new(),
496 };
497 let _ = tx.send(StreamEvent::ToolUseStart { id, name }).await;
498 }
499 "server_tool_use" => {
500 let id = block["id"].as_str().unwrap_or("").to_string();
501 let name = block["name"].as_str().unwrap_or("").to_string();
502 blocks[idx] = AssembledBlock::ServerToolUse {
503 id: id.clone(),
504 name: name.clone(),
505 input_json: String::new(),
506 };
507 let _ = tx.send(StreamEvent::ToolUseStart { id, name }).await;
508 }
509 "web_search_tool_result" => {
510 let tool_use_id = block["tool_use_id"].as_str().unwrap_or("").to_string();
511 blocks[idx] = AssembledBlock::WebSearchResult {
512 tool_use_id,
513 content_json: String::new(),
514 };
515 }
516 _ => {}
517 }
518 }
519 "content_block_delta" => {
520 let idx = evt["index"].as_u64().unwrap_or(0) as usize;
521 let delta = &evt["delta"];
522 let dt = delta["type"].as_str().unwrap_or("");
523 if idx >= blocks.len() {
524 return false;
525 }
526 match (dt, &mut blocks[idx]) {
527 ("text_delta", AssembledBlock::Text(buf)) => {
528 if let Some(t) = delta["text"].as_str() {
529 buf.push_str(t);
530 let _ = tx.send(StreamEvent::TextDelta(t.to_string())).await;
531 }
532 }
533 ("thinking_delta", AssembledBlock::Thinking { text, .. }) => {
534 if let Some(t) = delta["thinking"].as_str() {
535 text.push_str(t);
536 log::debug!("Received thinking_delta: {} chars", t.len());
537 let _ = tx.send(StreamEvent::ThinkingDelta(t.to_string())).await;
538 }
539 }
540 ("signature_delta", AssembledBlock::Thinking { signature, .. }) => {
541 if let Some(s) = delta["signature"].as_str() {
542 signature.get_or_insert_with(String::new).push_str(s);
543 }
544 }
545 ("input_json_delta", AssembledBlock::ToolUse { input_json, .. }) => {
546 if let Some(p) = delta["partial_json"].as_str() {
547 input_json.push_str(p);
548 let _ = tx
549 .send(StreamEvent::ToolInputDelta {
550 bytes_so_far: input_json.len(),
551 })
552 .await;
553 }
554 }
555 ("input_json_delta", AssembledBlock::ServerToolUse { input_json, .. }) => {
556 if let Some(p) = delta["partial_json"].as_str() {
557 input_json.push_str(p);
558 let _ = tx
559 .send(StreamEvent::ToolInputDelta {
560 bytes_so_far: input_json.len(),
561 })
562 .await;
563 }
564 }
565 _ => {}
566 }
567 }
568 "message_delta" => {
569 if let Some(sr) = evt["delta"]["stop_reason"].as_str() {
570 *stop_reason = match sr {
571 "tool_use" => StopReason::ToolUse,
572 "max_tokens" => StopReason::MaxTokens,
573 _ => StopReason::EndTurn,
574 };
575 }
576 *usage = merge_usage(usage.clone(), &evt["usage"]);
580 debug!(
581 "message_delta: input={}, output={}, cache_read={}, cache_created={}",
582 usage.input_tokens, usage.output_tokens,
583 usage.cache_read_input_tokens, usage.cache_creation_input_tokens
584 );
585 let _ = tx.send(StreamEvent::Usage { output_tokens: usage.output_tokens }).await;
587 }
588 "message_stop" => {
589 debug!(
590 "Message completed: stop_reason={}, usage={}",
591 match *stop_reason {
592 StopReason::EndTurn => "end_turn",
593 StopReason::ToolUse => "tool_use",
594 StopReason::MaxTokens => "max_tokens",
595 },
596 usage.output_tokens
597 );
598 debug!(
599 "message_stop final usage: cache_read={}, cache_created={}",
600 usage.cache_read_input_tokens, usage.cache_creation_input_tokens
601 );
602 let _ = tx
603 .send(StreamEvent::Done(finalize_incomplete_stream(
604 std::mem::take(blocks),
605 stop_reason.clone(),
606 usage.clone(),
607 )))
608 .await;
609 return true;
610 }
611 "error" => {
612 let msg = evt["error"]["message"]
613 .as_str()
614 .unwrap_or("unknown stream error")
615 .to_string();
616 let _ = tx.send(StreamEvent::Error(msg)).await;
617 return true;
618 }
619 _ => {}
620 }
621
622 false
623}
624
625fn build_chat_response(
626 blocks: Vec<AssembledBlock>,
627 stop_reason: StopReason,
628 usage: Usage,
629) -> ChatResponse {
630 let content: Vec<ContentBlock> = blocks.into_iter().filter_map(|b| b.finish()).collect();
631 ChatResponse {
632 content,
633 stop_reason,
634 usage,
635 }
636}
637
638fn finalize_incomplete_stream(
639 blocks: Vec<AssembledBlock>,
640 stop_reason: StopReason,
641 usage: Usage,
642) -> ChatResponse {
643 build_chat_response(blocks, stop_reason, usage)
644}
645
646#[derive(Default)]
647enum AssembledBlock {
648 #[default]
649 Empty,
650 Text(String),
651 Thinking {
652 text: String,
653 signature: Option<String>,
654 },
655 ToolUse {
656 id: String,
657 name: String,
658 input_json: String,
659 },
660 ServerToolUse {
661 id: String,
662 name: String,
663 input_json: String,
664 },
665 WebSearchResult {
666 tool_use_id: String,
667 content_json: String,
668 },
669}
670
671impl AssembledBlock {
672 fn finish(self) -> Option<ContentBlock> {
673 match self {
674 AssembledBlock::Empty => None,
675 AssembledBlock::Text(text) => Some(ContentBlock::Text { text }),
676 AssembledBlock::Thinking { text, signature } => Some(ContentBlock::Thinking {
677 thinking: text,
678 signature,
679 }),
680 AssembledBlock::ToolUse {
681 id,
682 name,
683 input_json,
684 } => {
685 let input: Value = if input_json.is_empty() {
686 json!({})
687 } else {
688 serde_json::from_str(&input_json).unwrap_or(json!({}))
689 };
690 Some(ContentBlock::ToolUse { id, name, input })
691 }
692 AssembledBlock::ServerToolUse {
693 id,
694 name,
695 input_json,
696 } => {
697 let input: Value = if input_json.is_empty() {
698 json!({})
699 } else {
700 serde_json::from_str(&input_json).unwrap_or(json!({}))
701 };
702 Some(ContentBlock::ServerToolUse { id, name, input })
703 }
704 AssembledBlock::WebSearchResult {
705 tool_use_id,
706 content_json,
707 } => {
708 let content: Value = if content_json.is_empty() {
709 json!({"results": []})
710 } else {
711 serde_json::from_str(&content_json).unwrap_or(json!({"results": []}))
712 };
713 Some(ContentBlock::WebSearchResult {
714 tool_use_id,
715 content: parse_web_search_content(&content),
716 })
717 }
718 }
719 }
720}
721
722fn parse_usage(u: &Value) -> Usage {
725 Usage {
726 input_tokens: u["input_tokens"].as_u64().unwrap_or(0) as u32,
727 output_tokens: u["output_tokens"].as_u64().unwrap_or(0) as u32,
728 cache_creation_input_tokens: u["cache_creation_input_tokens"].as_u64().unwrap_or(0) as u32,
729 cache_read_input_tokens: u["cache_read_input_tokens"].as_u64().unwrap_or(0) as u32,
730 }
731}
732
733fn merge_usage(mut acc: Usage, u: &Value) -> Usage {
737 let fresh = parse_usage(u);
738 if fresh.input_tokens > 0 {
739 acc.input_tokens = fresh.input_tokens;
740 }
741 if fresh.output_tokens > 0 {
742 acc.output_tokens = fresh.output_tokens;
743 }
744 if fresh.cache_creation_input_tokens > 0 {
745 acc.cache_creation_input_tokens = fresh.cache_creation_input_tokens;
746 }
747 if fresh.cache_read_input_tokens > 0 {
748 acc.cache_read_input_tokens = fresh.cache_read_input_tokens;
749 }
750 acc
751}
752
753fn parse_web_search_content(value: &serde_json::Value) -> crate::providers::WebSearchContent {
755 let results = value["results"]
756 .as_array()
757 .map(|arr| {
758 arr.iter()
759 .filter_map(|item| {
760 Some(crate::providers::WebSearchResultItem {
761 title: item["title"].as_str().map(String::from),
762 url: item["url"].as_str()?.to_string(),
763 encrypted_content: item["encrypted_content"].as_str().map(String::from),
764 snippet: item["snippet"].as_str().map(String::from),
765 })
766 })
767 .collect()
768 })
769 .unwrap_or_default();
770
771 crate::providers::WebSearchContent { results }
772}
773
774fn context_window_for(model: &str) -> Option<u32> {
780 if let Ok(raw) = std::env::var("CONTEXT_SIZE")
781 && let Ok(n) = raw.trim().parse::<u32>()
782 && n > 0 {
783 return Some(n);
784 }
785 let m = model.to_ascii_lowercase();
786
787 if m.contains("[1m]") || m.contains("opus-4-7") || m.contains("opus-4.7") {
789 return Some(1_000_000);
790 }
791 if m.contains("claude-3") || m.contains("claude-4") || m.contains("claude-opus") || m.contains("claude-sonnet") {
793 return Some(200_000);
794 }
795 if m.contains("claude-2") {
797 return Some(100_000);
798 }
799 if m.contains("claude-instant") {
801 return Some(100_000);
802 }
803 if m.contains("kimi") {
805 return Some(128_000);
806 }
807 if m.contains("deepseek") {
809 return Some(128_000);
810 }
811 if m.contains("glm") {
814 return Some(128_000);
815 }
816 if m.contains("qwen") {
818 if m.contains("qwen-max") || m.contains("qwen2.5") {
819 return Some(128_000);
820 }
821 return Some(32_000);
822 }
823 Some(128_000)
826}
827
828#[cfg(test)]
829mod tests {
830 use super::*;
831
832 #[test]
833 fn take_next_sse_frame_supports_crlf_delimiter() {
834 let mut buffer = concat!(
835 "event: message_start\r\n",
836 "data: {\"type\":\"message_start\"}\r\n\r\n",
837 "data: {\"type\":\"message_stop\"}\r\n\r\n"
838 )
839 .to_string();
840
841 let first = take_next_sse_frame(&mut buffer).expect("first frame");
842 assert!(first.contains("message_start"));
843
844 let second = take_next_sse_frame(&mut buffer).expect("second frame");
845 assert!(second.contains("message_stop"));
846 assert!(buffer.is_empty());
847 }
848
849 #[test]
850 fn take_trailing_sse_frame_returns_unterminated_event() {
851 let mut buffer = "data: {\"type\":\"message_stop\"}\r\n".to_string();
852 let frame = take_trailing_sse_frame(&mut buffer).expect("trailing frame");
853 assert_eq!(frame, "data: {\"type\":\"message_stop\"}");
854 assert!(buffer.is_empty());
855 }
856
857 #[test]
858 fn extract_sse_data_line_supports_optional_space() {
859 assert_eq!(
860 extract_sse_data_line("event: x\r\ndata: {\"k\":1}\r"),
861 Some("{\"k\":1}".to_string())
862 );
863 assert_eq!(
864 extract_sse_data_line("event: x\r\ndata:{\"k\":2}\r"),
865 Some("{\"k\":2}".to_string())
866 );
867 }
868
869 #[test]
870 fn finalize_incomplete_stream_keeps_accumulated_content() {
871 let response = finalize_incomplete_stream(
872 vec![AssembledBlock::Text("partial reply".to_string())],
873 StopReason::EndTurn,
874 Usage::default(),
875 );
876
877 assert_eq!(response.stop_reason, StopReason::EndTurn);
878 assert_eq!(response.content.len(), 1);
879 match &response.content[0] {
880 ContentBlock::Text { text } => assert_eq!(text, "partial reply"),
881 other => panic!("unexpected block: {other:?}"),
882 }
883 }
884}
885