1use crate::auth::unmark_anthropic_oauth_bearer_token;
7use crate::error::{Error, Result};
8use crate::http::client::Client;
9use crate::model::{
10 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ThinkingContent,
11 ThinkingLevel, ToolCall, Usage, UserContent,
12};
13use crate::models::CompatConfig;
14use crate::provider::{CacheRetention, Context, Provider, StreamOptions, ToolDef};
15use crate::provider_metadata::canonical_provider_id;
16use crate::sse::SseStream;
17use async_trait::async_trait;
18use futures::StreamExt;
19use futures::stream::{self, Stream};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use std::fs;
23use std::pin::Pin;
24
25const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
30const ANTHROPIC_API_VERSION: &str = "2023-06-01";
31const DEFAULT_MAX_TOKENS: u32 = 8192;
32const ANTHROPIC_OAUTH_TOKEN_PREFIX: &str = "sk-ant-oat";
33const ANTHROPIC_OAUTH_BETA_FLAGS: &str = "claude-code-20250219,oauth-2025-04-20";
36const ANTHROPIC_CACHE_BETA_FLAG: &str = "prompt-caching-2024-07-31";
39const KIMI_SHARE_DIR_ENV_KEY: &str = "KIMI_SHARE_DIR";
40
41fn anthropic_oauth_beta_flags() -> String {
42 std::env::var("PI_ANTHROPIC_BETA_FLAGS")
43 .ok()
44 .filter(|v| !v.is_empty())
45 .unwrap_or_else(|| ANTHROPIC_OAUTH_BETA_FLAGS.to_string())
46}
47
48fn anthropic_cache_beta_flag() -> String {
49 std::env::var("PI_ANTHROPIC_CACHE_BETA_FLAG")
50 .ok()
51 .filter(|v| !v.is_empty())
52 .unwrap_or_else(|| ANTHROPIC_CACHE_BETA_FLAG.to_string())
53}
54
55#[inline]
56fn is_anthropic_oauth_token(token: &str) -> bool {
57 token.contains(ANTHROPIC_OAUTH_TOKEN_PREFIX)
58}
59
60#[inline]
61fn is_anthropic_provider(provider: &str) -> bool {
62 canonical_provider_id(provider).unwrap_or(provider) == "anthropic"
63}
64
65#[inline]
66fn is_anthropic_bearer_token(provider: &str, token: &str) -> bool {
67 if !is_anthropic_provider(provider) {
68 return false;
69 }
70 let token = token.trim();
71 if token.is_empty() {
72 return false;
73 }
74
75 if is_anthropic_oauth_token(token) {
77 return true;
78 }
79
80 !token.starts_with("sk-ant-")
82}
83
84#[inline]
85fn is_kimi_coding_provider(provider: &str) -> bool {
86 canonical_provider_id(provider).unwrap_or(provider) == "kimi-for-coding"
87}
88
89#[inline]
90fn is_kimi_oauth_token(provider: &str, token: &str) -> bool {
91 is_kimi_coding_provider(provider) && !token.starts_with("sk-")
92}
93
94fn bearer_token_from_authorization_header(value: &str) -> Option<String> {
95 let mut parts = value.split_whitespace();
96 let scheme = parts.next()?;
97 let bearer_value = parts.next()?;
98 if parts.next().is_some() {
99 return None;
100 }
101 if scheme.eq_ignore_ascii_case("bearer") && !bearer_value.trim().is_empty() {
102 Some(bearer_value.trim().to_string())
103 } else {
104 None
105 }
106}
107
108fn authorization_override(
109 options: &StreamOptions,
110 compat: Option<&CompatConfig>,
111) -> Option<String> {
112 super::first_non_empty_header_value_case_insensitive(&options.headers, &["authorization"])
113 .or_else(|| {
114 compat
115 .and_then(|compat| compat.custom_headers.as_ref())
116 .and_then(|headers| {
117 super::first_non_empty_header_value_case_insensitive(
118 headers,
119 &["authorization"],
120 )
121 })
122 })
123}
124
125fn x_api_key_override(options: &StreamOptions, compat: Option<&CompatConfig>) -> Option<String> {
126 super::first_non_empty_header_value_case_insensitive(&options.headers, &["x-api-key"]).or_else(
127 || {
128 compat
129 .and_then(|compat| compat.custom_headers.as_ref())
130 .and_then(|headers| {
131 super::first_non_empty_header_value_case_insensitive(headers, &["x-api-key"])
132 })
133 },
134 )
135}
136
137fn sanitize_ascii_header_value(value: &str, fallback: &str) -> String {
138 if value.is_ascii() && !value.trim().is_empty() {
139 return value.to_string();
140 }
141 let sanitized = value
142 .chars()
143 .filter(char::is_ascii)
144 .collect::<String>()
145 .trim()
146 .to_string();
147 if sanitized.is_empty() {
148 fallback.to_string()
149 } else {
150 sanitized
151 }
152}
153
154fn home_dir_with_env_lookup<F>(env_lookup: F) -> Option<std::path::PathBuf>
155where
156 F: Fn(&str) -> Option<String>,
157{
158 env_lookup("HOME")
159 .map(|value| value.trim().to_string())
160 .filter(|value| !value.is_empty())
161 .map(std::path::PathBuf::from)
162 .or_else(|| {
163 env_lookup("USERPROFILE")
164 .map(|value| value.trim().to_string())
165 .filter(|value| !value.is_empty())
166 .map(std::path::PathBuf::from)
167 })
168 .or_else(|| {
169 let drive = env_lookup("HOMEDRIVE")
170 .map(|value| value.trim().to_string())
171 .filter(|value| !value.is_empty())?;
172 let path = env_lookup("HOMEPATH")
173 .map(|value| value.trim().to_string())
174 .filter(|value| !value.is_empty())?;
175 if path.starts_with('\\') || path.starts_with('/') {
176 Some(std::path::PathBuf::from(format!("{drive}{path}")))
177 } else {
178 let mut combined = std::path::PathBuf::from(drive);
179 combined.push(path);
180 Some(combined)
181 }
182 })
183}
184
185fn home_dir() -> Option<std::path::PathBuf> {
186 home_dir_with_env_lookup(|key| std::env::var(key).ok())
187}
188
189fn kimi_share_dir_with_env_lookup<F>(env_lookup: F) -> Option<std::path::PathBuf>
190where
191 F: Fn(&str) -> Option<String>,
192{
193 env_lookup(KIMI_SHARE_DIR_ENV_KEY)
194 .map(|value| value.trim().to_string())
195 .filter(|value| !value.is_empty())
196 .map(std::path::PathBuf::from)
197 .or_else(|| home_dir_with_env_lookup(env_lookup).map(|home| home.join(".kimi")))
198}
199
200fn kimi_share_dir() -> Option<std::path::PathBuf> {
201 kimi_share_dir_with_env_lookup(|key| std::env::var(key).ok())
202}
203
204fn kimi_device_id_paths() -> Option<(std::path::PathBuf, std::path::PathBuf)> {
205 let primary = kimi_share_dir()?.join("device_id");
206 let legacy = home_dir().map_or_else(
207 || primary.clone(),
208 |home| home.join(".pi").join("agent").join("kimi-device-id"),
209 );
210 Some((primary, legacy))
211}
212
213fn kimi_device_id() -> String {
214 static DEVICE_ID: std::sync::OnceLock<String> = std::sync::OnceLock::new();
215 DEVICE_ID
216 .get_or_init(|| {
217 let generated = uuid::Uuid::new_v4().simple().to_string();
218 let Some((primary, legacy)) = kimi_device_id_paths() else {
219 return generated;
220 };
221
222 for path in [&primary, &legacy] {
223 if let Ok(existing) = fs::read_to_string(path) {
224 let existing = existing.trim();
225 if !existing.is_empty() {
226 return existing.to_string();
227 }
228 }
229 }
230
231 if let Some(parent) = primary.parent() {
232 let _ = fs::create_dir_all(parent);
233 }
234
235 let mut options = fs::OpenOptions::new();
236 options.write(true).create_new(true);
237
238 #[cfg(unix)]
239 {
240 use std::os::unix::fs::OpenOptionsExt;
241 options.mode(0o600);
242 }
243
244 if let Ok(mut file) = options.open(&primary) {
245 use std::io::Write;
246 let _ = file.write_all(generated.as_bytes());
247 }
248
249 generated
250 })
251 .clone()
252}
253
254fn kimi_common_headers() -> Vec<(String, String)> {
255 let device_name = std::env::var("HOSTNAME")
256 .ok()
257 .or_else(|| std::env::var("COMPUTERNAME").ok())
258 .unwrap_or_else(|| "unknown".to_string());
259 let device_model = format!("{} {}", std::env::consts::OS, std::env::consts::ARCH);
260 let os_version = std::env::consts::OS.to_string();
261
262 vec![
263 (
264 "X-Msh-Platform".to_string(),
265 sanitize_ascii_header_value("kimi_cli", "unknown"),
266 ),
267 (
268 "X-Msh-Version".to_string(),
269 sanitize_ascii_header_value(env!("CARGO_PKG_VERSION"), "unknown"),
270 ),
271 (
272 "X-Msh-Device-Name".to_string(),
273 sanitize_ascii_header_value(&device_name, "unknown"),
274 ),
275 (
276 "X-Msh-Device-Model".to_string(),
277 sanitize_ascii_header_value(&device_model, "unknown"),
278 ),
279 (
280 "X-Msh-Os-Version".to_string(),
281 sanitize_ascii_header_value(&os_version, "unknown"),
282 ),
283 (
284 "X-Msh-Device-Id".to_string(),
285 sanitize_ascii_header_value(&kimi_device_id(), "unknown"),
286 ),
287 ]
288}
289
290pub struct AnthropicProvider {
296 client: Client,
297 model: String,
298 base_url: String,
299 provider: String,
300 compat: Option<CompatConfig>,
301}
302
303impl AnthropicProvider {
304 pub fn new(model: impl Into<String>) -> Self {
306 Self {
307 client: Client::new(),
308 model: model.into(),
309 base_url: ANTHROPIC_API_URL.to_string(),
310 provider: "anthropic".to_string(),
311 compat: None,
312 }
313 }
314
315 #[must_use]
317 pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
318 self.provider = provider.into();
319 self
320 }
321
322 #[must_use]
324 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
325 self.base_url = base_url.into();
326 self
327 }
328
329 #[must_use]
331 pub fn with_client(mut self, client: Client) -> Self {
332 self.client = client;
333 self
334 }
335
336 #[must_use]
341 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
342 self.compat = compat;
343 self
344 }
345
346 pub fn build_request<'a>(
348 &'a self,
349 context: &'a Context<'_>,
350 options: &StreamOptions,
351 ) -> AnthropicRequest<'a> {
352 let messages = context
353 .messages
354 .iter()
355 .map(convert_message_to_anthropic)
356 .collect();
357
358 let tools: Option<Vec<AnthropicTool<'_>>> = if context.tools.is_empty() {
359 None
360 } else {
361 Some(
362 context
363 .tools
364 .iter()
365 .map(convert_tool_to_anthropic)
366 .collect(),
367 )
368 };
369
370 let thinking = options.thinking_level.and_then(|level| {
372 if level == ThinkingLevel::Off {
373 None
374 } else {
375 let budget = options.thinking_budgets.as_ref().map_or_else(
376 || level.default_budget(),
377 |b| match level {
378 ThinkingLevel::Off => 0,
379 ThinkingLevel::Minimal => b.minimal,
380 ThinkingLevel::Low => b.low,
381 ThinkingLevel::Medium => b.medium,
382 ThinkingLevel::High => b.high,
383 ThinkingLevel::XHigh => b.xhigh,
384 },
385 );
386 Some(AnthropicThinking {
387 r#type: "enabled",
388 budget_tokens: budget,
389 })
390 }
391 });
392
393 let mut max_tokens = options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
394 if let Some(t) = &thinking {
395 if max_tokens <= t.budget_tokens {
396 max_tokens = t.budget_tokens + 4096;
397 }
398 }
399
400 let temperature = if thinking.is_some() {
401 Some(1.0)
402 } else {
403 options.temperature
404 };
405
406 AnthropicRequest {
407 model: &self.model,
408 messages,
409 system: context.system_prompt.as_deref(),
410 max_tokens,
411 temperature,
412 tools,
413 stream: true,
414 thinking,
415 }
416 }
417}
418
419#[async_trait]
420impl Provider for AnthropicProvider {
421 fn name(&self) -> &str {
422 &self.provider
423 }
424
425 fn api(&self) -> &'static str {
426 "anthropic-messages"
427 }
428
429 fn model_id(&self) -> &str {
430 &self.model
431 }
432
433 #[allow(clippy::too_many_lines)]
434 async fn stream(
435 &self,
436 context: &Context<'_>,
437 options: &StreamOptions,
438 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
439 let request_body = self.build_request(context, options);
440 let authorization_override = authorization_override(options, self.compat.as_ref());
441 let x_api_key_override = x_api_key_override(options, self.compat.as_ref());
442 let mut anthropic_bearer_token = false;
443 let mut kimi_oauth_token = false;
444
445 let mut request = self
447 .client
448 .post(&self.base_url)
449 .header("Accept", "text/event-stream")
450 .header("anthropic-version", ANTHROPIC_API_VERSION);
451
452 if let Some(authorization_override) = authorization_override {
453 if let Some(bearer_token) =
454 bearer_token_from_authorization_header(&authorization_override)
455 {
456 anthropic_bearer_token = is_anthropic_bearer_token(&self.provider, &bearer_token);
457 kimi_oauth_token = is_kimi_oauth_token(&self.provider, &bearer_token);
458 }
459 } else if x_api_key_override.is_none() {
460 let raw_auth_value = options
461 .api_key
462 .clone()
463 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
464 .ok_or_else(|| {
465 Error::provider(
466 self.name(),
467 "Missing API key for provider. Configure credentials with /login <provider> or set the provider's API key env var.",
468 )
469 })?;
470 let forced_bearer_token = if is_anthropic_provider(&self.provider) {
471 unmark_anthropic_oauth_bearer_token(&raw_auth_value).map(ToString::to_string)
472 } else {
473 None
474 };
475 let force_bearer = forced_bearer_token.is_some();
476 let auth_value = forced_bearer_token.unwrap_or(raw_auth_value);
477
478 anthropic_bearer_token =
479 force_bearer || is_anthropic_bearer_token(&self.provider, &auth_value);
480 kimi_oauth_token = is_kimi_oauth_token(&self.provider, &auth_value);
481
482 if anthropic_bearer_token || kimi_oauth_token {
483 request = request.header("Authorization", format!("Bearer {auth_value}"));
484 } else {
485 request = request.header("X-API-Key", &auth_value);
486 }
487 }
488
489 if anthropic_bearer_token {
490 request = request
491 .header("anthropic-dangerous-direct-browser-access", "true")
492 .header("x-app", "cli")
493 .header(
494 "user-agent",
495 format!(
496 "pi_agent_rust/{} (external, cli)",
497 env!("CARGO_PKG_VERSION")
498 ),
499 );
500 } else if kimi_oauth_token {
501 request = request.header(
502 "user-agent",
503 format!(
504 "pi_agent_rust/{} (kimi-oauth, cli)",
505 env!("CARGO_PKG_VERSION")
506 ),
507 );
508 for (name, value) in kimi_common_headers() {
509 request = request.header(name, value);
510 }
511 }
512
513 let mut beta_flags: Vec<String> = Vec::new();
514 if anthropic_bearer_token {
515 beta_flags.push(anthropic_oauth_beta_flags());
516 }
517 if options.cache_retention != CacheRetention::None {
518 beta_flags.push(anthropic_cache_beta_flag());
519 }
520 if !beta_flags.is_empty() {
521 request = request.header("anthropic-beta", beta_flags.join(","));
522 }
523
524 if let Some(compat) = &self.compat {
526 if let Some(custom_headers) = &compat.custom_headers {
527 request = super::apply_headers_ignoring_blank_auth_overrides(
528 request,
529 custom_headers,
530 &["authorization", "x-api-key"],
531 );
532 }
533 }
534
535 request = super::apply_headers_ignoring_blank_auth_overrides(
537 request,
538 &options.headers,
539 &["authorization", "x-api-key"],
540 );
541
542 let request = request.json(&request_body)?;
543
544 let response = Box::pin(request.send()).await?;
545 let status = response.status();
546 if !(200..300).contains(&status) {
547 let body = response
548 .text()
549 .await
550 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
551 return Err(Error::provider(
552 self.name(),
553 format!("Anthropic API error (HTTP {status}): {body}"),
554 ));
555 }
556
557 let event_source = SseStream::new(response.bytes_stream());
559
560 let model = self.model.clone();
562 let api = self.api().to_string();
563 let provider = self.name().to_string();
564
565 let stream = stream::unfold(
566 StreamState::new(event_source, model, api, provider),
567 |mut state| async move {
568 if state.done {
569 return None;
570 }
571 loop {
572 match state.event_source.next().await {
573 Some(Ok(msg)) => {
574 state.transient_error_count = 0;
575 if msg.event == "ping" {
576 } else {
578 match state.process_event(&msg.data) {
579 Ok(Some(event)) => {
580 if matches!(
581 &event,
582 StreamEvent::Done { .. } | StreamEvent::Error { .. }
583 ) {
584 state.done = true;
585 }
586 return Some((Ok(event), state));
587 }
588 Ok(None) => {}
589 Err(e) => {
590 state.done = true;
591 return Some((Err(e), state));
592 }
593 }
594 }
595 }
596 Some(Err(e)) => {
597 const MAX_CONSECUTIVE_TRANSIENT_ERRORS: usize = 5;
602 if e.kind() == std::io::ErrorKind::WriteZero
603 || e.kind() == std::io::ErrorKind::WouldBlock
604 || e.kind() == std::io::ErrorKind::TimedOut
605 {
606 state.transient_error_count += 1;
607 if state.transient_error_count <= MAX_CONSECUTIVE_TRANSIENT_ERRORS {
608 tracing::warn!(
609 kind = ?e.kind(),
610 count = state.transient_error_count,
611 "Transient error in SSE stream, continuing"
612 );
613 continue;
614 }
615 tracing::warn!(
616 kind = ?e.kind(),
617 "Error persisted after {MAX_CONSECUTIVE_TRANSIENT_ERRORS} \
618 consecutive attempts, treating as fatal"
619 );
620 }
621 state.done = true;
622 let err = Error::api(format!("SSE error: {e}"));
623 return Some((Err(err), state));
624 }
625 None => {
629 state.done = true;
630 let reason = state.partial.stop_reason;
631 let message = std::mem::take(&mut state.partial);
632 return Some((Ok(StreamEvent::Done { reason, message }), state));
633 }
634 }
635 }
636 },
637 );
638
639 Ok(Box::pin(stream))
640 }
641}
642
643struct ToolAccum {
648 id: String,
649 name: String,
650 json: String,
651}
652
653struct StreamState<S>
654where
655 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
656{
657 event_source: SseStream<S>,
658 partial: AssistantMessage,
659 tool_accums: HashMap<u32, ToolAccum>,
660 done: bool,
661 transient_error_count: usize,
663}
664
665impl<S> StreamState<S>
666where
667 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
668{
669 const fn recompute_total_tokens(&mut self) {
670 self.partial.usage.total_tokens = self
671 .partial
672 .usage
673 .input
674 .saturating_add(self.partial.usage.output)
675 .saturating_add(self.partial.usage.cache_read)
676 .saturating_add(self.partial.usage.cache_write);
677 }
678
679 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
680 Self {
681 event_source,
682 partial: AssistantMessage {
683 content: Vec::new(),
684 api,
685 provider,
686 model,
687 usage: Usage::default(),
688 stop_reason: StopReason::Stop,
689 error_message: None,
690 timestamp: chrono::Utc::now().timestamp_millis(),
691 },
692 tool_accums: HashMap::new(),
693 done: false,
694 transient_error_count: 0,
695 }
696 }
697
698 #[allow(clippy::too_many_lines)]
699 fn process_event(&mut self, data: &str) -> Result<Option<StreamEvent>> {
700 let event: AnthropicStreamEvent = serde_json::from_str(data)
701 .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
702
703 match event {
704 AnthropicStreamEvent::MessageStart { message } => {
705 Ok(Some(self.handle_message_start(message)))
706 }
707 AnthropicStreamEvent::ContentBlockStart {
708 index,
709 content_block,
710 } => Ok(Some(self.handle_content_block_start(index, content_block))),
711 AnthropicStreamEvent::ContentBlockDelta { index, delta } => {
712 Ok(self.handle_content_block_delta(index, delta))
713 }
714 AnthropicStreamEvent::ContentBlockStop { index } => {
715 Ok(self.handle_content_block_stop(index))
716 }
717 AnthropicStreamEvent::MessageDelta { delta, usage } => {
718 self.handle_message_delta(&delta, usage);
719 Ok(None)
720 }
721 AnthropicStreamEvent::MessageStop => {
722 let reason = self.partial.stop_reason;
723 Ok(Some(StreamEvent::Done {
724 reason,
725 message: std::mem::take(&mut self.partial),
726 }))
727 }
728 AnthropicStreamEvent::Error { error } => {
729 self.partial.stop_reason = StopReason::Error;
730 self.partial.error_message = Some(error.message);
731 Ok(Some(StreamEvent::Error {
732 reason: StopReason::Error,
733 error: std::mem::take(&mut self.partial),
734 }))
735 }
736 AnthropicStreamEvent::Ping => Ok(None),
737 }
738 }
739
740 fn handle_message_start(&mut self, message: AnthropicMessageStart) -> StreamEvent {
741 if let Some(usage) = message.usage {
742 self.partial.usage.input = usage.input;
743 self.partial.usage.cache_read = usage.cache_read.unwrap_or(0);
744 self.partial.usage.cache_write = usage.cache_write.unwrap_or(0);
745 self.recompute_total_tokens();
746 }
747 StreamEvent::Start {
748 partial: self.partial.clone(),
749 }
750 }
751
752 fn handle_content_block_start(
753 &mut self,
754 index: u32,
755 content_block: AnthropicContentBlock,
756 ) -> StreamEvent {
757 let content_index = index as usize;
758
759 match content_block {
760 AnthropicContentBlock::Text => {
761 self.partial
762 .content
763 .push(ContentBlock::Text(TextContent::new("")));
764 StreamEvent::TextStart { content_index }
765 }
766 AnthropicContentBlock::Thinking => {
767 self.partial
768 .content
769 .push(ContentBlock::Thinking(ThinkingContent {
770 thinking: String::new(),
771 thinking_signature: None,
772 }));
773 StreamEvent::ThinkingStart { content_index }
774 }
775 AnthropicContentBlock::ToolUse { id, name } => {
776 let id = id.unwrap_or_default();
777 let name = name.unwrap_or_default();
778 self.tool_accums.insert(
779 index,
780 ToolAccum {
781 id: id.clone(),
782 name: name.clone(),
783 json: String::new(),
784 },
785 );
786 self.partial.content.push(ContentBlock::ToolCall(ToolCall {
787 id,
788 name,
789 arguments: serde_json::Value::Null,
790 thought_signature: None,
791 }));
792 StreamEvent::ToolCallStart { content_index }
793 }
794 }
795 }
796
797 fn handle_content_block_delta(
798 &mut self,
799 index: u32,
800 delta: AnthropicDelta,
801 ) -> Option<StreamEvent> {
802 let idx = index as usize;
803
804 match delta {
805 AnthropicDelta::TextDelta { text } => {
806 if let Some(text) = text {
807 if let Some(ContentBlock::Text(t)) = self.partial.content.get_mut(idx) {
808 t.text.push_str(&text);
809 }
810 Some(StreamEvent::TextDelta {
811 content_index: idx,
812 delta: text,
813 })
814 } else {
815 None
816 }
817 }
818 AnthropicDelta::ThinkingDelta { thinking } => {
819 if let Some(thinking) = thinking {
820 if let Some(ContentBlock::Thinking(t)) = self.partial.content.get_mut(idx) {
821 t.thinking.push_str(&thinking);
822 }
823 Some(StreamEvent::ThinkingDelta {
824 content_index: idx,
825 delta: thinking,
826 })
827 } else {
828 None
829 }
830 }
831 AnthropicDelta::InputJsonDelta { partial_json } => {
832 if let Some(partial_json) = partial_json {
833 if let Some(accum) = self.tool_accums.get_mut(&index) {
834 accum.json.push_str(&partial_json);
835 }
836 Some(StreamEvent::ToolCallDelta {
837 content_index: idx,
838 delta: partial_json,
839 })
840 } else {
841 None
842 }
843 }
844 AnthropicDelta::SignatureDelta { signature } => {
845 if let Some(sig) = signature {
849 if let Some(ContentBlock::Thinking(t)) = self.partial.content.get_mut(idx) {
850 t.thinking_signature = Some(sig);
851 }
852 }
853 None
854 }
855 }
856 }
857
858 fn handle_content_block_stop(&mut self, index: u32) -> Option<StreamEvent> {
859 let idx = index as usize;
860
861 match self.partial.content.get_mut(idx) {
862 Some(ContentBlock::Text(t)) => {
863 let content = t.text.clone();
867 Some(StreamEvent::TextEnd {
868 content_index: idx,
869 content,
870 })
871 }
872 Some(ContentBlock::Thinking(t)) => {
873 let content = t.thinking.clone();
875 Some(StreamEvent::ThinkingEnd {
876 content_index: idx,
877 content,
878 })
879 }
880 Some(ContentBlock::ToolCall(tc)) => {
881 if let Some(accum) = self.tool_accums.remove(&index) {
882 let arguments: serde_json::Value = match serde_json::from_str(&accum.json) {
883 Ok(args) => args,
884 Err(e) => {
885 tracing::warn!(
886 error = %e,
887 raw = %accum.json,
888 "Failed to parse tool arguments as JSON"
889 );
890 serde_json::Value::Null
891 }
892 };
893 let tool_call = ToolCall {
894 id: accum.id,
895 name: accum.name,
896 arguments: arguments.clone(),
897 thought_signature: None,
898 };
899 tc.arguments = arguments;
900
901 Some(StreamEvent::ToolCallEnd {
902 content_index: idx,
903 tool_call,
904 })
905 } else {
906 None
907 }
908 }
909 _ => None,
910 }
911 }
912
913 #[allow(clippy::missing_const_for_fn)]
914 fn handle_message_delta(
915 &mut self,
916 delta: &AnthropicMessageDelta,
917 usage: Option<AnthropicDeltaUsage>,
918 ) {
919 if let Some(stop_reason) = delta.stop_reason {
920 self.partial.stop_reason = match stop_reason {
921 AnthropicStopReason::MaxTokens => StopReason::Length,
922 AnthropicStopReason::ToolUse => StopReason::ToolUse,
923 AnthropicStopReason::EndTurn | AnthropicStopReason::StopSequence => {
924 StopReason::Stop
925 }
926 };
927 }
928
929 if let Some(u) = usage {
930 self.partial.usage.output = u.output_tokens;
931 self.recompute_total_tokens();
932 }
933 }
934}
935
936#[derive(Debug, Serialize)]
941pub struct AnthropicRequest<'a> {
942 model: &'a str,
943 messages: Vec<AnthropicMessage<'a>>,
944 #[serde(skip_serializing_if = "Option::is_none")]
945 system: Option<&'a str>,
946 max_tokens: u32,
947 #[serde(skip_serializing_if = "Option::is_none")]
948 temperature: Option<f32>,
949 #[serde(skip_serializing_if = "Option::is_none")]
950 tools: Option<Vec<AnthropicTool<'a>>>,
951 stream: bool,
952 #[serde(skip_serializing_if = "Option::is_none")]
953 thinking: Option<AnthropicThinking>,
954}
955
956#[derive(Debug, Serialize)]
957struct AnthropicThinking {
958 r#type: &'static str,
959 budget_tokens: u32,
960}
961
962#[derive(Debug, Serialize)]
963struct AnthropicMessage<'a> {
964 role: &'static str,
965 content: Vec<AnthropicContent<'a>>,
966}
967
968#[derive(Debug, Serialize)]
969#[serde(tag = "type", rename_all = "snake_case")]
970enum AnthropicContent<'a> {
971 Text {
972 text: &'a str,
973 },
974 Thinking {
975 thinking: &'a str,
976 signature: &'a str,
977 },
978 Image {
979 source: AnthropicImageSource<'a>,
980 },
981 ToolUse {
982 id: &'a str,
983 name: &'a str,
984 input: &'a serde_json::Value,
985 },
986 ToolResult {
987 tool_use_id: &'a str,
988 content: Vec<AnthropicToolResultContent<'a>>,
989 #[serde(skip_serializing_if = "Option::is_none")]
990 is_error: Option<bool>,
991 },
992}
993
994#[derive(Debug, Serialize)]
995struct AnthropicImageSource<'a> {
996 r#type: &'static str,
997 media_type: &'a str,
998 data: &'a str,
999}
1000
1001#[derive(Debug, Serialize)]
1002#[serde(tag = "type", rename_all = "snake_case")]
1003enum AnthropicToolResultContent<'a> {
1004 Text { text: &'a str },
1005 Image { source: AnthropicImageSource<'a> },
1006}
1007
1008#[derive(Debug, Serialize)]
1009struct AnthropicTool<'a> {
1010 name: &'a str,
1011 description: &'a str,
1012 input_schema: &'a serde_json::Value,
1013}
1014
1015#[derive(Debug, Deserialize)]
1020#[serde(tag = "type", rename_all = "snake_case")]
1021enum AnthropicStreamEvent {
1022 MessageStart {
1023 message: AnthropicMessageStart,
1024 },
1025 ContentBlockStart {
1026 index: u32,
1027 content_block: AnthropicContentBlock,
1028 },
1029 ContentBlockDelta {
1030 index: u32,
1031 delta: AnthropicDelta,
1032 },
1033 ContentBlockStop {
1034 index: u32,
1035 },
1036 MessageDelta {
1037 delta: AnthropicMessageDelta,
1038 #[serde(default)]
1039 usage: Option<AnthropicDeltaUsage>,
1040 },
1041 MessageStop,
1042 Error {
1043 error: AnthropicError,
1044 },
1045 Ping,
1046}
1047
1048#[derive(Debug, Deserialize)]
1049struct AnthropicMessageStart {
1050 #[serde(default)]
1051 usage: Option<AnthropicUsage>,
1052}
1053
1054#[derive(Debug, Deserialize)]
1057#[allow(clippy::struct_field_names)]
1058struct AnthropicUsage {
1059 #[serde(rename = "input_tokens")]
1060 input: u64,
1061 #[serde(default, rename = "cache_read_input_tokens")]
1062 cache_read: Option<u64>,
1063 #[serde(default, rename = "cache_creation_input_tokens")]
1064 cache_write: Option<u64>,
1065}
1066
1067#[derive(Debug, Deserialize)]
1068struct AnthropicDeltaUsage {
1069 output_tokens: u64,
1070}
1071
1072#[derive(Debug, Deserialize)]
1076#[serde(tag = "type", rename_all = "snake_case")]
1077enum AnthropicContentBlock {
1078 Text,
1079 Thinking,
1080 ToolUse {
1081 #[serde(default)]
1082 id: Option<String>,
1083 #[serde(default)]
1084 name: Option<String>,
1085 },
1086}
1087
1088#[derive(Debug, Deserialize)]
1094#[serde(tag = "type", rename_all = "snake_case")]
1095#[allow(clippy::enum_variant_names)] enum AnthropicDelta {
1097 TextDelta {
1098 #[serde(default)]
1099 text: Option<String>,
1100 },
1101 ThinkingDelta {
1102 #[serde(default)]
1103 thinking: Option<String>,
1104 },
1105 InputJsonDelta {
1106 #[serde(default)]
1107 partial_json: Option<String>,
1108 },
1109 SignatureDelta {
1110 #[serde(default)]
1111 signature: Option<String>,
1112 },
1113}
1114
1115#[derive(Debug, Clone, Copy, Deserialize)]
1119#[serde(rename_all = "snake_case")]
1120enum AnthropicStopReason {
1121 EndTurn,
1122 MaxTokens,
1123 ToolUse,
1124 StopSequence,
1125}
1126
1127#[derive(Debug, Deserialize)]
1128struct AnthropicMessageDelta {
1129 #[serde(default)]
1130 stop_reason: Option<AnthropicStopReason>,
1131}
1132
1133#[derive(Debug, Deserialize)]
1134struct AnthropicError {
1135 message: String,
1136}
1137
1138fn convert_message_to_anthropic(message: &Message) -> AnthropicMessage<'_> {
1143 match message {
1144 Message::User(user) => AnthropicMessage {
1145 role: "user",
1146 content: convert_user_content(&user.content),
1147 },
1148 Message::Custom(custom) => AnthropicMessage {
1149 role: "user",
1150 content: vec![AnthropicContent::Text {
1151 text: &custom.content,
1152 }],
1153 },
1154 Message::Assistant(assistant) => AnthropicMessage {
1155 role: "assistant",
1156 content: assistant
1157 .content
1158 .iter()
1159 .filter_map(convert_content_block_to_anthropic)
1160 .collect(),
1161 },
1162 Message::ToolResult(result) => AnthropicMessage {
1163 role: "user",
1164 content: vec![AnthropicContent::ToolResult {
1165 tool_use_id: &result.tool_call_id,
1166 content: result
1167 .content
1168 .iter()
1169 .filter_map(|block| match block {
1170 ContentBlock::Text(t) => {
1171 Some(AnthropicToolResultContent::Text { text: &t.text })
1172 }
1173 ContentBlock::Image(img) => Some(AnthropicToolResultContent::Image {
1174 source: AnthropicImageSource {
1175 r#type: "base64",
1176 media_type: &img.mime_type,
1177 data: &img.data,
1178 },
1179 }),
1180 _ => None,
1181 })
1182 .collect(),
1183 is_error: if result.is_error { Some(true) } else { None },
1184 }],
1185 },
1186 }
1187}
1188
1189fn convert_user_content(content: &UserContent) -> Vec<AnthropicContent<'_>> {
1190 match content {
1191 UserContent::Text(text) => vec![AnthropicContent::Text { text }],
1192 UserContent::Blocks(blocks) => blocks
1193 .iter()
1194 .filter_map(|block| match block {
1195 ContentBlock::Text(t) => Some(AnthropicContent::Text { text: &t.text }),
1196 ContentBlock::Image(img) => Some(AnthropicContent::Image {
1197 source: AnthropicImageSource {
1198 r#type: "base64",
1199 media_type: &img.mime_type,
1200 data: &img.data,
1201 },
1202 }),
1203 _ => None,
1204 })
1205 .collect(),
1206 }
1207}
1208
1209fn convert_content_block_to_anthropic(block: &ContentBlock) -> Option<AnthropicContent<'_>> {
1210 match block {
1211 ContentBlock::Text(t) => Some(AnthropicContent::Text { text: &t.text }),
1212 ContentBlock::ToolCall(tc) => Some(AnthropicContent::ToolUse {
1213 id: &tc.id,
1214 name: &tc.name,
1215 input: &tc.arguments,
1216 }),
1217 ContentBlock::Thinking(t) => {
1221 t.thinking_signature
1222 .as_ref()
1223 .map(|sig| AnthropicContent::Thinking {
1224 thinking: &t.thinking,
1225 signature: sig,
1226 })
1227 }
1228 ContentBlock::Image(_) => None,
1229 }
1230}
1231
1232fn convert_tool_to_anthropic(tool: &ToolDef) -> AnthropicTool<'_> {
1233 AnthropicTool {
1234 name: &tool.name,
1235 description: &tool.description,
1236 input_schema: &tool.parameters,
1237 }
1238}
1239
1240#[cfg(test)]
1245mod tests {
1246 use super::*;
1247 use asupersync::runtime::RuntimeBuilder;
1248 use futures::{StreamExt, stream};
1249 use serde::{Deserialize, Serialize};
1250 use serde_json::Value;
1251 use serde_json::json;
1252 use std::collections::HashMap;
1253 use std::io::{Read, Write};
1254 use std::net::TcpListener;
1255 use std::path::PathBuf;
1256 use std::sync::mpsc;
1257 use std::time::Duration;
1258
1259 #[test]
1260 fn home_dir_lookup_falls_back_to_userprofile() {
1261 let home = home_dir_with_env_lookup(|key| match key {
1262 "USERPROFILE" => Some("C:\\Users\\Ada".to_string()),
1263 _ => None,
1264 });
1265
1266 assert_eq!(home, Some(PathBuf::from("C:\\Users\\Ada")));
1267 }
1268
1269 #[test]
1270 fn home_dir_lookup_falls_back_to_homedrive_homepath() {
1271 let home = home_dir_with_env_lookup(|key| match key {
1272 "HOMEDRIVE" => Some("D:".to_string()),
1273 "HOMEPATH" => Some("\\Users\\Grace".to_string()),
1274 _ => None,
1275 });
1276
1277 assert_eq!(home, Some(PathBuf::from("D:\\Users\\Grace")));
1278 }
1279
1280 #[test]
1281 fn test_convert_user_text_message() {
1282 let message = Message::User(crate::model::UserMessage {
1283 content: UserContent::Text("Hello".to_string()),
1284 timestamp: 0,
1285 });
1286
1287 let converted = convert_message_to_anthropic(&message);
1288 assert_eq!(converted.role, "user");
1289 assert_eq!(converted.content.len(), 1);
1290 }
1291
1292 #[test]
1293 fn test_thinking_budget() {
1294 assert_eq!(ThinkingLevel::Minimal.default_budget(), 1024);
1295 assert_eq!(ThinkingLevel::Low.default_budget(), 2048);
1296 assert_eq!(ThinkingLevel::Medium.default_budget(), 8192);
1297 assert_eq!(ThinkingLevel::High.default_budget(), 16384);
1298 }
1299
1300 #[test]
1301 fn test_build_request_includes_system_tools_and_thinking() {
1302 let provider = AnthropicProvider::new("claude-test");
1303 let context = Context {
1304 system_prompt: Some("System prompt".to_string().into()),
1305 messages: vec![Message::User(crate::model::UserMessage {
1306 content: UserContent::Text("Ping".to_string()),
1307 timestamp: 0,
1308 })]
1309 .into(),
1310 tools: vec![ToolDef {
1311 name: "echo".to_string(),
1312 description: "Echo a string.".to_string(),
1313 parameters: json!({
1314 "type": "object",
1315 "properties": {
1316 "text": { "type": "string" }
1317 },
1318 "required": ["text"]
1319 }),
1320 }]
1321 .into(),
1322 };
1323 let options = StreamOptions {
1324 max_tokens: Some(128),
1325 temperature: Some(0.2),
1326 thinking_level: Some(ThinkingLevel::Medium),
1327 thinking_budgets: Some(crate::provider::ThinkingBudgets {
1328 minimal: 1024,
1329 low: 2048,
1330 medium: 9000,
1331 high: 16384,
1332 xhigh: 32768,
1333 }),
1334 ..Default::default()
1335 };
1336
1337 let request = provider.build_request(&context, &options);
1338 assert_eq!(request.model, "claude-test");
1339 assert_eq!(request.system, Some("System prompt"));
1340 assert_eq!(request.temperature, Some(1.0)); assert!(request.stream);
1342 assert_eq!(request.max_tokens, 13_096);
1343
1344 let thinking = request.thinking.expect("thinking config");
1345 assert_eq!(thinking.r#type, "enabled");
1346 assert_eq!(thinking.budget_tokens, 9000);
1347
1348 assert_eq!(request.messages.len(), 1);
1349 assert_eq!(request.messages[0].role, "user");
1350 assert_eq!(request.messages[0].content.len(), 1);
1351 match &request.messages[0].content[0] {
1352 AnthropicContent::Text { text } => assert_eq!(*text, "Ping"),
1353 other => panic!(),
1354 }
1355
1356 let tools = request.tools.expect("tools");
1357 assert_eq!(tools.len(), 1);
1358 assert_eq!(tools[0].name, "echo");
1359 assert_eq!(tools[0].description, "Echo a string.");
1360 assert_eq!(
1361 *tools[0].input_schema,
1362 json!({
1363 "type": "object",
1364 "properties": {
1365 "text": { "type": "string" }
1366 },
1367 "required": ["text"]
1368 })
1369 );
1370 }
1371
1372 #[test]
1373 fn test_build_request_omits_optional_fields_by_default() {
1374 let provider = AnthropicProvider::new("claude-test");
1375 let context = Context::default();
1376 let options = StreamOptions::default();
1377
1378 let request = provider.build_request(&context, &options);
1379 assert_eq!(request.model, "claude-test");
1380 assert_eq!(request.system, None);
1381 assert!(request.tools.is_none());
1382 assert!(request.thinking.is_none());
1383 assert_eq!(request.max_tokens, DEFAULT_MAX_TOKENS);
1384 assert!(request.stream);
1385 }
1386
1387 #[test]
1388 #[allow(clippy::too_many_lines)]
1389 fn test_stream_parses_thinking_and_tool_call_events() {
1390 let events = vec![
1391 json!({
1392 "type": "message_start",
1393 "message": { "usage": { "input_tokens": 3 } }
1394 }),
1395 json!({
1396 "type": "content_block_start",
1397 "index": 0,
1398 "content_block": { "type": "thinking" }
1399 }),
1400 json!({
1401 "type": "content_block_delta",
1402 "index": 0,
1403 "delta": { "type": "thinking_delta", "thinking": "step 1" }
1404 }),
1405 json!({
1406 "type": "content_block_stop",
1407 "index": 0
1408 }),
1409 json!({
1410 "type": "content_block_start",
1411 "index": 1,
1412 "content_block": { "type": "tool_use", "id": "tool_123", "name": "search" }
1413 }),
1414 json!({
1415 "type": "content_block_delta",
1416 "index": 1,
1417 "delta": { "type": "input_json_delta", "partial_json": "{\"q\":\"ru" }
1418 }),
1419 json!({
1420 "type": "content_block_delta",
1421 "index": 1,
1422 "delta": { "type": "input_json_delta", "partial_json": "st\"}" }
1423 }),
1424 json!({
1425 "type": "content_block_stop",
1426 "index": 1
1427 }),
1428 json!({
1429 "type": "content_block_start",
1430 "index": 2,
1431 "content_block": { "type": "text" }
1432 }),
1433 json!({
1434 "type": "content_block_delta",
1435 "index": 2,
1436 "delta": { "type": "text_delta", "text": "done" }
1437 }),
1438 json!({
1439 "type": "content_block_stop",
1440 "index": 2
1441 }),
1442 json!({
1443 "type": "message_delta",
1444 "delta": { "stop_reason": "tool_use" },
1445 "usage": { "output_tokens": 5 }
1446 }),
1447 json!({
1448 "type": "message_stop"
1449 }),
1450 ];
1451
1452 let out = collect_events(&events);
1453 assert_eq!(out.len(), 12, "expected full stream event sequence");
1454
1455 assert!(matches!(&out[0], StreamEvent::Start { .. }));
1456 assert!(matches!(
1457 &out[1],
1458 StreamEvent::ThinkingStart {
1459 content_index: 0,
1460 ..
1461 }
1462 ));
1463 assert!(matches!(
1464 &out[2],
1465 StreamEvent::ThinkingDelta {
1466 content_index: 0,
1467 delta,
1468 ..
1469 } if delta == "step 1"
1470 ));
1471 assert!(matches!(
1472 &out[3],
1473 StreamEvent::ThinkingEnd {
1474 content_index: 0,
1475 content,
1476 ..
1477 } if content == "step 1"
1478 ));
1479 assert!(matches!(
1480 &out[4],
1481 StreamEvent::ToolCallStart {
1482 content_index: 1,
1483 ..
1484 }
1485 ));
1486 assert!(matches!(
1487 &out[5],
1488 StreamEvent::ToolCallDelta {
1489 content_index: 1,
1490 delta,
1491 ..
1492 } if delta == "{\"q\":\"ru"
1493 ));
1494 assert!(matches!(
1495 &out[6],
1496 StreamEvent::ToolCallDelta {
1497 content_index: 1,
1498 delta,
1499 ..
1500 } if delta == "st\"}"
1501 ));
1502 if let StreamEvent::ToolCallEnd {
1503 content_index,
1504 tool_call,
1505 ..
1506 } = &out[7]
1507 {
1508 assert_eq!(*content_index, 1);
1509 assert_eq!(tool_call.id, "tool_123");
1510 assert_eq!(tool_call.name, "search");
1511 assert_eq!(tool_call.arguments, json!({ "q": "rust" }));
1512 } else {
1513 panic!();
1514 }
1515 assert!(matches!(
1516 &out[8],
1517 StreamEvent::TextStart {
1518 content_index: 2,
1519 ..
1520 }
1521 ));
1522 assert!(matches!(
1523 &out[9],
1524 StreamEvent::TextDelta {
1525 content_index: 2,
1526 delta,
1527 ..
1528 } if delta == "done"
1529 ));
1530 assert!(matches!(
1531 &out[10],
1532 StreamEvent::TextEnd {
1533 content_index: 2,
1534 content,
1535 ..
1536 } if content == "done"
1537 ));
1538 if let StreamEvent::Done { reason, message } = &out[11] {
1539 assert_eq!(*reason, StopReason::ToolUse);
1540 assert_eq!(message.stop_reason, StopReason::ToolUse);
1541 } else {
1542 panic!();
1543 }
1544 }
1545
1546 #[test]
1547 fn test_message_delta_sets_length_stop_reason_and_usage() {
1548 let events = vec![
1549 json!({
1550 "type": "message_start",
1551 "message": { "usage": { "input_tokens": 5 } }
1552 }),
1553 json!({
1554 "type": "message_delta",
1555 "delta": { "stop_reason": "max_tokens" },
1556 "usage": { "output_tokens": 7 }
1557 }),
1558 json!({
1559 "type": "message_stop"
1560 }),
1561 ];
1562
1563 let out = collect_events(&events);
1564 assert_eq!(out.len(), 2);
1565 if let StreamEvent::Done { reason, message } = &out[1] {
1566 assert_eq!(*reason, StopReason::Length);
1567 assert_eq!(message.stop_reason, StopReason::Length);
1568 assert_eq!(message.usage.input, 5);
1569 assert_eq!(message.usage.output, 7);
1570 assert_eq!(message.usage.total_tokens, 12);
1571 } else {
1572 panic!();
1573 }
1574 }
1575
1576 #[test]
1577 fn test_usage_total_tokens_saturates_on_large_values() {
1578 let events = vec![
1579 json!({
1580 "type": "message_start",
1581 "message": {
1582 "usage": {
1583 "input_tokens": u64::MAX,
1584 "cache_read_input_tokens": 1,
1585 "cache_creation_input_tokens": 1
1586 }
1587 }
1588 }),
1589 json!({
1590 "type": "message_delta",
1591 "delta": { "stop_reason": "end_turn" },
1592 "usage": { "output_tokens": 1 }
1593 }),
1594 json!({
1595 "type": "message_stop"
1596 }),
1597 ];
1598
1599 let out = collect_events(&events);
1600 assert_eq!(out.len(), 2);
1601 if let StreamEvent::Done { message, .. } = &out[1] {
1602 assert_eq!(message.usage.total_tokens, u64::MAX);
1603 } else {
1604 panic!();
1605 }
1606 }
1607
1608 #[derive(Debug, Deserialize)]
1609 struct ProviderFixture {
1610 cases: Vec<ProviderCase>,
1611 }
1612
1613 #[derive(Debug, Deserialize)]
1614 struct ProviderCase {
1615 name: String,
1616 events: Vec<Value>,
1617 expected: Vec<EventSummary>,
1618 }
1619
1620 #[derive(Debug, Deserialize, Serialize, PartialEq)]
1621 struct EventSummary {
1622 kind: String,
1623 #[serde(default)]
1624 content_index: Option<usize>,
1625 #[serde(default)]
1626 delta: Option<String>,
1627 #[serde(default)]
1628 content: Option<String>,
1629 #[serde(default)]
1630 reason: Option<String>,
1631 }
1632
1633 #[test]
1634 fn test_stream_fixtures() {
1635 let fixture = load_fixture("anthropic_stream.json");
1636 for case in fixture.cases {
1637 let events = collect_events(&case.events);
1638 let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1639 assert_eq!(summaries, case.expected, "case {}", case.name);
1640 }
1641 }
1642
1643 #[test]
1644 fn test_stream_error_event_maps_to_stop_reason_error() {
1645 let events = vec![json!({
1646 "type": "error",
1647 "error": { "message": "nope" }
1648 })];
1649
1650 let out = collect_events(&events);
1651 assert_eq!(out.len(), 1);
1652 assert!(
1653 matches!(&out[0], StreamEvent::Error { .. }),
1654 "expected StreamEvent::Error, got {:?}",
1655 out[0]
1656 );
1657 if let StreamEvent::Error { reason, error } = &out[0] {
1658 assert_eq!(*reason, StopReason::Error);
1659 assert_eq!(error.stop_reason, StopReason::Error);
1660 assert_eq!(error.error_message.as_deref(), Some("nope"));
1661 }
1662 }
1663
1664 #[test]
1665 fn test_stream_emits_single_done_when_transport_ends_after_message_stop() {
1666 let out = collect_stream_items_from_body(&success_sse_body());
1667 let done_count = out
1668 .iter()
1669 .filter(|item| matches!(item, Ok(StreamEvent::Done { .. })))
1670 .count();
1671 assert_eq!(done_count, 1, "expected exactly one terminal Done event");
1672 }
1673
1674 #[test]
1675 fn test_stream_error_event_is_terminal() {
1676 let body = [
1677 r#"data: {"type":"error","error":{"message":"boom"}}"#,
1678 "",
1679 r#"data: {"type":"message_stop"}"#,
1681 "",
1682 ]
1683 .join("\n");
1684
1685 let out = collect_stream_items_from_body(&body);
1686 assert_eq!(out.len(), 1, "Error should terminate the stream");
1687 assert!(matches!(out[0], Ok(StreamEvent::Error { .. })));
1688 }
1689
1690 #[test]
1691 fn test_stream_parse_error_is_terminal() {
1692 let body = [
1693 r#"data: {"type":"message_start","message":{"usage":{"input_tokens":1}}}"#,
1694 "",
1695 r"data: {invalid-json}",
1696 "",
1697 r#"data: {"type":"message_stop"}"#,
1699 "",
1700 ]
1701 .join("\n");
1702
1703 let out = collect_stream_items_from_body(&body);
1704 assert_eq!(out.len(), 2, "parse error should stop further events");
1705 assert!(matches!(out[0], Ok(StreamEvent::Start { .. })));
1706 match &out[1] {
1707 Ok(event) => panic!(),
1708 Err(err) => assert!(err.to_string().contains("JSON parse error")),
1709 }
1710 }
1711
1712 #[test]
1713 fn test_stream_fragmented_sse_transport_preserves_text_delta_order() {
1714 let response_parts = vec![
1715 "seg-00|".to_string(),
1716 "seg-01|".to_string(),
1717 "seg-02|".to_string(),
1718 "seg-03|".to_string(),
1719 "seg-04|".to_string(),
1720 "seg-05|".to_string(),
1721 "seg-06|".to_string(),
1722 "seg-07|".to_string(),
1723 "seg-08|".to_string(),
1724 "seg-09|".to_string(),
1725 "seg-10|".to_string(),
1726 "seg-11|".to_string(),
1727 ];
1728 let expected_text = response_parts.concat();
1729 let part_refs = response_parts
1730 .iter()
1731 .map(String::as_str)
1732 .collect::<Vec<_>>();
1733 let frames = build_text_stream_sse_frames(&part_refs);
1734 let chunks = split_ascii_stream_bytes(&frames, &[1, 2, 5, 3, 8, 13, 21]);
1735 let out = collect_events_from_byte_chunks(chunks);
1736
1737 assert!(matches!(out.first(), Some(StreamEvent::Start { .. })));
1738 assert!(matches!(
1739 out.get(1),
1740 Some(StreamEvent::TextStart {
1741 content_index: 0,
1742 ..
1743 })
1744 ));
1745
1746 let deltas = collect_text_deltas(&out);
1747 assert_eq!(deltas, response_parts);
1748 assert_eq!(deltas.concat(), expected_text);
1749
1750 let final_text = out
1751 .iter()
1752 .find_map(|event| match event {
1753 StreamEvent::TextEnd { content, .. } => Some(content.clone()),
1754 _ => None,
1755 })
1756 .expect("text_end event");
1757 assert_eq!(final_text, expected_text);
1758
1759 let done_count = out
1760 .iter()
1761 .filter(|event| matches!(event, StreamEvent::Done { .. }))
1762 .count();
1763 assert_eq!(done_count, 1, "expected exactly one Done event");
1764
1765 match out.last() {
1766 Some(StreamEvent::Done { reason, message }) => {
1767 assert_eq!(*reason, StopReason::Stop);
1768 assert_eq!(message.stop_reason, StopReason::Stop);
1769 }
1770 other => panic!("expected final Done event, got {other:?}"),
1771 }
1772 }
1773
1774 #[test]
1775 fn test_stream_high_volume_fragmented_sse_preserves_delta_count_and_content() {
1776 let response_parts = (0..128)
1777 .map(|idx| format!("chunk-{idx:03}|"))
1778 .collect::<Vec<_>>();
1779 let expected_text = response_parts.concat();
1780 let part_refs = response_parts
1781 .iter()
1782 .map(String::as_str)
1783 .collect::<Vec<_>>();
1784 let frames = build_text_stream_sse_frames(&part_refs);
1785 let chunks = split_ascii_stream_bytes(&frames, &[1, 1, 2, 3, 5, 8, 13, 21, 34]);
1786 let out = collect_events_from_byte_chunks(chunks);
1787 let deltas = collect_text_deltas(&out);
1788
1789 assert_eq!(
1790 deltas.len(),
1791 response_parts.len(),
1792 "expected one TextDelta per text fragment"
1793 );
1794 assert_eq!(deltas, response_parts);
1795 assert_eq!(deltas.concat(), expected_text);
1796
1797 let final_text = out
1798 .iter()
1799 .find_map(|event| match event {
1800 StreamEvent::TextEnd { content, .. } => Some(content.clone()),
1801 _ => None,
1802 })
1803 .expect("text_end event");
1804 assert_eq!(final_text, expected_text);
1805 }
1806
1807 #[test]
1808 fn test_stream_sets_required_headers() {
1809 let captured = run_stream_and_capture_headers(CacheRetention::None)
1810 .expect("captured request for required headers");
1811 assert_eq!(
1812 captured.headers.get("x-api-key").map(String::as_str),
1813 Some("sk-ant-test-key")
1814 );
1815 assert_eq!(
1816 captured
1817 .headers
1818 .get("anthropic-version")
1819 .map(String::as_str),
1820 Some(ANTHROPIC_API_VERSION)
1821 );
1822 assert!(!captured.headers.contains_key("anthropic-beta"));
1823 assert!(captured.body.contains("\"stream\":true"));
1824 }
1825
1826 #[test]
1827 fn test_stream_adds_prompt_caching_beta_header_when_enabled() {
1828 let captured = run_stream_and_capture_headers(CacheRetention::Short)
1829 .expect("captured request for beta header");
1830 assert_eq!(
1831 captured.headers.get("anthropic-beta").map(String::as_str),
1832 Some("prompt-caching-2024-07-31")
1833 );
1834 }
1835
1836 #[test]
1837 fn test_stream_uses_oauth_bearer_auth_headers() {
1838 let captured =
1839 run_stream_and_capture_headers_with_api_key(CacheRetention::None, "sk-ant-oat-test")
1840 .expect("captured request for oauth headers");
1841 assert_eq!(
1842 captured.headers.get("authorization").map(String::as_str),
1843 Some("Bearer sk-ant-oat-test")
1844 );
1845 assert!(!captured.headers.contains_key("x-api-key"));
1846 assert_eq!(
1847 captured
1848 .headers
1849 .get("anthropic-dangerous-direct-browser-access")
1850 .map(String::as_str),
1851 Some("true")
1852 );
1853 assert_eq!(
1854 captured.headers.get("x-app").map(String::as_str),
1855 Some("cli")
1856 );
1857 assert!(
1858 captured
1859 .headers
1860 .get("anthropic-beta")
1861 .is_some_and(|value| value.contains("oauth-2025-04-20"))
1862 );
1863 assert!(
1864 captured
1865 .headers
1866 .get("user-agent")
1867 .is_some_and(|value| value.contains("pi_agent_rust/"))
1868 );
1869 }
1870
1871 #[test]
1872 fn test_stream_uses_bearer_headers_for_marked_anthropic_oauth_token() {
1873 let marked = "__pi_anthropic_oauth_bearer__:sk-ant-api-like-token";
1874 let captured = run_stream_and_capture_headers_with_api_key(CacheRetention::None, marked)
1875 .expect("captured request for marked oauth headers");
1876 assert_eq!(
1877 captured.headers.get("authorization").map(String::as_str),
1878 Some("Bearer sk-ant-api-like-token")
1879 );
1880 assert!(!captured.headers.contains_key("x-api-key"));
1881 assert!(
1882 captured
1883 .headers
1884 .get("anthropic-beta")
1885 .is_some_and(|value| value.contains("oauth-2025-04-20"))
1886 );
1887 }
1888
1889 #[test]
1890 fn test_stream_claude_style_non_sk_token_uses_bearer_auth_headers() {
1891 let captured =
1892 run_stream_and_capture_headers_with_api_key(CacheRetention::None, "claude-oauth-token")
1893 .expect("captured request for claude bearer headers");
1894 assert_eq!(
1895 captured.headers.get("authorization").map(String::as_str),
1896 Some("Bearer claude-oauth-token")
1897 );
1898 assert!(!captured.headers.contains_key("x-api-key"));
1899 }
1900
1901 #[test]
1902 fn test_stream_kimi_oauth_uses_bearer_and_kimi_headers() {
1903 let captured = run_stream_and_capture_headers_for_provider_with_api_key(
1904 CacheRetention::None,
1905 "kimi-for-coding",
1906 "kimi-oauth-token",
1907 )
1908 .expect("captured request for kimi oauth headers");
1909 assert_eq!(
1910 captured.headers.get("authorization").map(String::as_str),
1911 Some("Bearer kimi-oauth-token")
1912 );
1913 assert!(!captured.headers.contains_key("x-api-key"));
1914 assert!(
1915 !captured
1916 .headers
1917 .contains_key("anthropic-dangerous-direct-browser-access")
1918 );
1919 assert!(!captured.headers.contains_key("anthropic-beta"));
1920 assert_eq!(
1921 captured.headers.get("x-msh-platform").map(String::as_str),
1922 Some("kimi_cli")
1923 );
1924 assert!(captured.headers.contains_key("x-msh-version"));
1925 assert!(captured.headers.contains_key("x-msh-device-name"));
1926 assert!(captured.headers.contains_key("x-msh-device-model"));
1927 assert!(captured.headers.contains_key("x-msh-os-version"));
1928 assert!(captured.headers.contains_key("x-msh-device-id"));
1929 }
1930
1931 #[test]
1932 fn test_stream_kimi_api_key_uses_x_api_key_header() {
1933 let captured = run_stream_and_capture_headers_for_provider_with_api_key(
1934 CacheRetention::None,
1935 "kimi-for-coding",
1936 "sk-kimi-api-key",
1937 )
1938 .expect("captured request for kimi api-key headers");
1939 assert_eq!(
1940 captured.headers.get("x-api-key").map(String::as_str),
1941 Some("sk-kimi-api-key")
1942 );
1943 assert!(!captured.headers.contains_key("authorization"));
1944 assert!(!captured.headers.contains_key("x-msh-platform"));
1945 }
1946
1947 #[test]
1948 fn test_stream_oauth_beta_header_includes_prompt_caching_when_enabled() {
1949 let captured =
1950 run_stream_and_capture_headers_with_api_key(CacheRetention::Short, "sk-ant-oat-test")
1951 .expect("captured request for oauth + cache beta header");
1952 let beta = captured
1953 .headers
1954 .get("anthropic-beta")
1955 .expect("anthropic-beta header");
1956 assert!(beta.contains("oauth-2025-04-20"));
1957 assert!(beta.contains("prompt-caching-2024-07-31"));
1958 }
1959
1960 #[test]
1961 fn test_stream_http_error_includes_status_and_body_message() {
1962 let (base_url, _rx) = spawn_test_server(
1963 401,
1964 "application/json",
1965 r#"{"type":"error","error":{"type":"authentication_error","message":"Invalid API key"}}"#,
1966 );
1967 let provider = AnthropicProvider::new("claude-test").with_base_url(base_url);
1968 let context = Context {
1969 system_prompt: None,
1970 messages: vec![Message::User(crate::model::UserMessage {
1971 content: UserContent::Text("ping".to_string()),
1972 timestamp: 0,
1973 })]
1974 .into(),
1975 tools: Vec::new().into(),
1976 };
1977 let options = StreamOptions {
1978 api_key: Some("test-key".to_string()),
1979 ..Default::default()
1980 };
1981
1982 let runtime = RuntimeBuilder::current_thread()
1983 .build()
1984 .expect("runtime build");
1985 let result = runtime.block_on(async { provider.stream(&context, &options).await });
1986 let Err(err) = result else {
1987 panic!();
1988 };
1989 let message = err.to_string();
1990 assert!(message.contains("Anthropic API error (HTTP 401)"));
1991 assert!(message.contains("Invalid API key"));
1992 }
1993
1994 #[test]
1995 fn test_provider_name_reflects_override() {
1996 let provider = AnthropicProvider::new("claude-test").with_provider_name("kimi-for-coding");
1997 assert_eq!(provider.name(), "kimi-for-coding");
1998 }
1999
2000 #[derive(Debug)]
2001 struct CapturedRequest {
2002 headers: HashMap<String, String>,
2003 body: String,
2004 }
2005
2006 fn run_stream_and_capture_headers(cache_retention: CacheRetention) -> Option<CapturedRequest> {
2007 run_stream_and_capture_headers_with_api_key(cache_retention, "sk-ant-test-key")
2008 }
2009
2010 fn run_stream_and_capture_headers_with_api_key(
2011 cache_retention: CacheRetention,
2012 api_key: &str,
2013 ) -> Option<CapturedRequest> {
2014 run_stream_and_capture_headers_for_provider_with_api_key(
2015 cache_retention,
2016 "anthropic",
2017 api_key,
2018 )
2019 }
2020
2021 fn run_stream_and_capture_headers_for_provider_with_api_key(
2022 cache_retention: CacheRetention,
2023 provider_name: &str,
2024 api_key: &str,
2025 ) -> Option<CapturedRequest> {
2026 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
2027 let provider = AnthropicProvider::new("claude-test")
2028 .with_provider_name(provider_name)
2029 .with_base_url(base_url);
2030 let context = Context {
2031 system_prompt: Some("test system".to_string().into()),
2032 messages: vec![Message::User(crate::model::UserMessage {
2033 content: UserContent::Text("ping".to_string()),
2034 timestamp: 0,
2035 })]
2036 .into(),
2037 tools: Vec::new().into(),
2038 };
2039 let options = StreamOptions {
2040 api_key: Some(api_key.to_string()),
2041 cache_retention,
2042 ..Default::default()
2043 };
2044
2045 let runtime = RuntimeBuilder::current_thread()
2046 .build()
2047 .expect("runtime build");
2048 runtime.block_on(async {
2049 let mut stream = provider.stream(&context, &options).await.expect("stream");
2050 while let Some(event) = stream.next().await {
2051 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
2052 break;
2053 }
2054 }
2055 });
2056
2057 rx.recv_timeout(Duration::from_secs(2)).ok()
2058 }
2059
2060 fn collect_stream_items_from_body(body: &str) -> Vec<Result<StreamEvent>> {
2061 let (base_url, _rx) = spawn_test_server(200, "text/event-stream", body);
2062 let provider = AnthropicProvider::new("claude-test").with_base_url(base_url);
2063 let context = Context {
2064 system_prompt: Some("test system".to_string().into()),
2065 messages: vec![Message::User(crate::model::UserMessage {
2066 content: UserContent::Text("ping".to_string()),
2067 timestamp: 0,
2068 })]
2069 .into(),
2070 tools: Vec::new().into(),
2071 };
2072 let options = StreamOptions {
2073 api_key: Some("sk-ant-test-key".to_string()),
2074 ..Default::default()
2075 };
2076
2077 let runtime = RuntimeBuilder::current_thread()
2078 .build()
2079 .expect("runtime build");
2080 runtime.block_on(async {
2081 let mut stream = provider.stream(&context, &options).await.expect("stream");
2082 let mut items = Vec::new();
2083 while let Some(item) = stream.next().await {
2084 items.push(item);
2085 }
2086 items
2087 })
2088 }
2089
2090 fn success_sse_body() -> String {
2091 [
2092 r#"data: {"type":"message_start","message":{"usage":{"input_tokens":1}}}"#,
2093 "",
2094 r#"data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":1}}"#,
2095 "",
2096 r#"data: {"type":"message_stop"}"#,
2097 "",
2098 ]
2099 .join("\n")
2100 }
2101
2102 fn spawn_test_server(
2103 status_code: u16,
2104 content_type: &str,
2105 body: &str,
2106 ) -> (String, mpsc::Receiver<CapturedRequest>) {
2107 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
2108 let addr = listener.local_addr().expect("local addr");
2109 let (tx, rx) = mpsc::channel();
2110 let body = body.to_string();
2111 let content_type = content_type.to_string();
2112
2113 std::thread::spawn(move || {
2114 let (mut socket, _) = listener.accept().expect("accept");
2115 socket
2116 .set_read_timeout(Some(Duration::from_secs(2)))
2117 .expect("set read timeout");
2118
2119 let mut bytes = Vec::new();
2120 let mut chunk = [0_u8; 4096];
2121 loop {
2122 match socket.read(&mut chunk) {
2123 Ok(0) => break,
2124 Ok(n) => {
2125 bytes.extend_from_slice(&chunk[..n]);
2126 if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
2127 break;
2128 }
2129 }
2130 Err(err)
2131 if err.kind() == std::io::ErrorKind::WouldBlock
2132 || err.kind() == std::io::ErrorKind::TimedOut =>
2133 {
2134 break;
2135 }
2136 Err(err) => panic!(),
2137 }
2138 }
2139
2140 let header_end = bytes
2141 .windows(4)
2142 .position(|window| window == b"\r\n\r\n")
2143 .expect("request header boundary");
2144 let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
2145 let headers = parse_headers(&header_text);
2146 let mut request_body = bytes[header_end + 4..].to_vec();
2147
2148 let content_length = headers
2149 .get("content-length")
2150 .and_then(|value| value.parse::<usize>().ok())
2151 .unwrap_or(0);
2152 while request_body.len() < content_length {
2153 match socket.read(&mut chunk) {
2154 Ok(0) => break,
2155 Ok(n) => request_body.extend_from_slice(&chunk[..n]),
2156 Err(err)
2157 if err.kind() == std::io::ErrorKind::WouldBlock
2158 || err.kind() == std::io::ErrorKind::TimedOut =>
2159 {
2160 break;
2161 }
2162 Err(err) => panic!(),
2163 }
2164 }
2165
2166 let captured = CapturedRequest {
2167 headers,
2168 body: String::from_utf8_lossy(&request_body).to_string(),
2169 };
2170 tx.send(captured).expect("send captured request");
2171
2172 let reason = match status_code {
2173 401 => "Unauthorized",
2174 500 => "Internal Server Error",
2175 _ => "OK",
2176 };
2177 let response = format!(
2178 "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
2179 body.len()
2180 );
2181 socket
2182 .write_all(response.as_bytes())
2183 .expect("write response");
2184 socket.flush().expect("flush response");
2185 });
2186
2187 (format!("http://{addr}/messages"), rx)
2188 }
2189
2190 fn parse_headers(header_text: &str) -> HashMap<String, String> {
2191 let mut headers = HashMap::new();
2192 for line in header_text.lines().skip(1) {
2193 if let Some((name, value)) = line.split_once(':') {
2194 headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
2195 }
2196 }
2197 headers
2198 }
2199
2200 fn load_fixture(file_name: &str) -> ProviderFixture {
2201 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
2202 .join("tests/fixtures/provider_responses")
2203 .join(file_name);
2204 let raw = std::fs::read_to_string(path).expect("fixture read");
2205 serde_json::from_str(&raw).expect("fixture parse")
2206 }
2207
2208 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
2209 let runtime = RuntimeBuilder::current_thread()
2210 .build()
2211 .expect("runtime build");
2212 runtime.block_on(async move {
2213 let byte_stream = stream::iter(
2214 events
2215 .iter()
2216 .map(|event| {
2217 let data = match event {
2218 Value::String(text) => text.clone(),
2219 _ => serde_json::to_string(event).expect("serialize event"),
2220 };
2221 format!("data: {data}\n\n").into_bytes()
2222 })
2223 .map(Ok),
2224 );
2225 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
2226 let mut state = StreamState::new(
2227 event_source,
2228 "claude-test".to_string(),
2229 "anthropic-messages".to_string(),
2230 "anthropic".to_string(),
2231 );
2232 let mut out = Vec::new();
2233
2234 while let Some(item) = state.event_source.next().await {
2235 let msg = item.expect("SSE event");
2236 if msg.event == "ping" {
2237 continue;
2238 }
2239 if let Some(event) = state.process_event(&msg.data).expect("process_event") {
2240 out.push(event);
2241 }
2242 }
2243
2244 out
2245 })
2246 }
2247
2248 fn collect_events_from_byte_chunks(chunks: Vec<Vec<u8>>) -> Vec<StreamEvent> {
2249 let runtime = RuntimeBuilder::current_thread()
2250 .build()
2251 .expect("runtime build");
2252 runtime.block_on(async move {
2253 let byte_stream = stream::iter(chunks.into_iter().map(Ok));
2254 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
2255 let mut state = StreamState::new(
2256 event_source,
2257 "claude-test".to_string(),
2258 "anthropic-messages".to_string(),
2259 "anthropic".to_string(),
2260 );
2261 let mut out = Vec::new();
2262
2263 while let Some(item) = state.event_source.next().await {
2264 let msg = item.expect("SSE event");
2265 if msg.event == "ping" {
2266 continue;
2267 }
2268 if let Some(event) = state.process_event(&msg.data).expect("process_event") {
2269 out.push(event);
2270 }
2271 }
2272
2273 out
2274 })
2275 }
2276
2277 fn build_text_stream_sse_frames(text_parts: &[&str]) -> Vec<String> {
2278 let message_start = json!({
2279 "type": "message_start",
2280 "message": {
2281 "usage": {
2282 "input_tokens": 10,
2283 "cache_creation_input_tokens": 0,
2284 "cache_read_input_tokens": 0,
2285 "output_tokens": 1
2286 }
2287 }
2288 });
2289 let content_start = json!({
2290 "type": "content_block_start",
2291 "index": 0,
2292 "content_block": { "type": "text" }
2293 });
2294 let content_stop = json!({
2295 "type": "content_block_stop",
2296 "index": 0
2297 });
2298 let message_delta = json!({
2299 "type": "message_delta",
2300 "delta": { "stop_reason": "end_turn" },
2301 "usage": { "output_tokens": text_parts.len().max(1) }
2302 });
2303
2304 let mut frames = vec![
2305 format!("event: message_start\ndata: {message_start}\n\n"),
2306 format!("event: content_block_start\ndata: {content_start}\n\n"),
2307 ];
2308
2309 for (idx, text) in text_parts.iter().enumerate() {
2310 if idx % 4 == 1 {
2311 frames.push("event: ping\ndata: {\"type\":\"ping\"}\n\n".to_string());
2312 }
2313 let content_delta = json!({
2314 "type": "content_block_delta",
2315 "index": 0,
2316 "delta": { "type": "text_delta", "text": text }
2317 });
2318 frames.push(format!(
2319 "event: content_block_delta\ndata: {content_delta}\n\n"
2320 ));
2321 }
2322
2323 frames.push(format!(
2324 "event: content_block_stop\ndata: {content_stop}\n\n"
2325 ));
2326 frames.push(format!("event: message_delta\ndata: {message_delta}\n\n"));
2327 frames.push("event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n".to_string());
2328 frames
2329 }
2330
2331 fn split_ascii_stream_bytes(frames: &[String], fragment_sizes: &[usize]) -> Vec<Vec<u8>> {
2332 assert!(
2333 !fragment_sizes.is_empty(),
2334 "fragment_sizes must contain at least one size"
2335 );
2336 assert!(
2337 fragment_sizes.iter().all(|size| *size > 0),
2338 "fragment_sizes must be positive"
2339 );
2340
2341 let joined = frames.concat();
2342 assert!(
2343 joined.is_ascii(),
2344 "test-only chunk fragmentation expects ASCII SSE fixtures"
2345 );
2346
2347 let bytes = joined.as_bytes();
2348 let mut offset = 0usize;
2349 let mut idx = 0usize;
2350 let mut chunks = Vec::new();
2351 while offset < bytes.len() {
2352 let size = fragment_sizes[idx % fragment_sizes.len()];
2353 let end = (offset + size).min(bytes.len());
2354 chunks.push(bytes[offset..end].to_vec());
2355 offset = end;
2356 idx += 1;
2357 }
2358 chunks
2359 }
2360
2361 fn collect_text_deltas(events: &[StreamEvent]) -> Vec<String> {
2362 events
2363 .iter()
2364 .filter_map(|event| match event {
2365 StreamEvent::TextDelta { delta, .. } => Some(delta.clone()),
2366 _ => None,
2367 })
2368 .collect()
2369 }
2370
2371 fn summarize_event(event: &StreamEvent) -> EventSummary {
2372 match event {
2373 StreamEvent::Start { .. } => EventSummary {
2374 kind: "start".to_string(),
2375 content_index: None,
2376 delta: None,
2377 content: None,
2378 reason: None,
2379 },
2380 StreamEvent::TextStart { content_index, .. } => EventSummary {
2381 kind: "text_start".to_string(),
2382 content_index: Some(*content_index),
2383 delta: None,
2384 content: None,
2385 reason: None,
2386 },
2387 StreamEvent::TextDelta {
2388 content_index,
2389 delta,
2390 ..
2391 } => EventSummary {
2392 kind: "text_delta".to_string(),
2393 content_index: Some(*content_index),
2394 delta: Some(delta.clone()),
2395 content: None,
2396 reason: None,
2397 },
2398 StreamEvent::TextEnd {
2399 content_index,
2400 content,
2401 ..
2402 } => EventSummary {
2403 kind: "text_end".to_string(),
2404 content_index: Some(*content_index),
2405 delta: None,
2406 content: Some(content.clone()),
2407 reason: None,
2408 },
2409 StreamEvent::Done { reason, .. } => EventSummary {
2410 kind: "done".to_string(),
2411 content_index: None,
2412 delta: None,
2413 content: None,
2414 reason: Some(reason_to_string(*reason)),
2415 },
2416 StreamEvent::Error { reason, .. } => EventSummary {
2417 kind: "error".to_string(),
2418 content_index: None,
2419 delta: None,
2420 content: None,
2421 reason: Some(reason_to_string(*reason)),
2422 },
2423 _ => EventSummary {
2424 kind: "other".to_string(),
2425 content_index: None,
2426 delta: None,
2427 content: None,
2428 reason: None,
2429 },
2430 }
2431 }
2432
2433 fn reason_to_string(reason: StopReason) -> String {
2434 match reason {
2435 StopReason::Stop => "stop",
2436 StopReason::Length => "length",
2437 StopReason::ToolUse => "tool_use",
2438 StopReason::Error => "error",
2439 StopReason::Aborted => "aborted",
2440 }
2441 .to_string()
2442 }
2443
2444 #[test]
2447 fn test_compat_custom_headers_injected_into_request() {
2448 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
2449
2450 let mut custom = HashMap::new();
2451 custom.insert("X-Custom-Tag".to_string(), "anthropic-override".to_string());
2452 custom.insert("X-Routing-Hint".to_string(), "us-east-1".to_string());
2453 let compat = crate::models::CompatConfig {
2454 custom_headers: Some(custom),
2455 ..Default::default()
2456 };
2457
2458 let provider = AnthropicProvider::new("claude-test")
2459 .with_base_url(base_url)
2460 .with_compat(Some(compat));
2461
2462 let context = Context {
2463 system_prompt: Some("test".to_string().into()),
2464 messages: vec![Message::User(crate::model::UserMessage {
2465 content: UserContent::Text("hi".to_string()),
2466 timestamp: 0,
2467 })]
2468 .into(),
2469 tools: Vec::new().into(),
2470 };
2471 let options = StreamOptions {
2472 api_key: Some("sk-ant-test-key".to_string()),
2473 ..Default::default()
2474 };
2475
2476 let runtime = RuntimeBuilder::current_thread()
2477 .build()
2478 .expect("runtime build");
2479 runtime.block_on(async {
2480 let mut stream = provider.stream(&context, &options).await.expect("stream");
2481 while let Some(event) = stream.next().await {
2482 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
2483 break;
2484 }
2485 }
2486 });
2487
2488 let captured = rx
2489 .recv_timeout(Duration::from_secs(2))
2490 .expect("captured request");
2491 assert_eq!(
2492 captured.headers.get("x-custom-tag").map(String::as_str),
2493 Some("anthropic-override"),
2494 "compat custom header X-Custom-Tag missing"
2495 );
2496 assert_eq!(
2497 captured.headers.get("x-routing-hint").map(String::as_str),
2498 Some("us-east-1"),
2499 "compat custom header X-Routing-Hint missing"
2500 );
2501 assert_eq!(
2503 captured.headers.get("x-api-key").map(String::as_str),
2504 Some("sk-ant-test-key"),
2505 );
2506 }
2507
2508 #[test]
2509 fn test_compat_authorization_header_works_without_api_key() {
2510 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
2511
2512 let mut custom = HashMap::new();
2513 custom.insert(
2514 "Authorization".to_string(),
2515 "Bearer sk-ant-oat-compat".to_string(),
2516 );
2517 let provider = AnthropicProvider::new("claude-test")
2518 .with_base_url(base_url)
2519 .with_compat(Some(crate::models::CompatConfig {
2520 custom_headers: Some(custom),
2521 ..Default::default()
2522 }));
2523
2524 let context = Context {
2525 system_prompt: Some("test".to_string().into()),
2526 messages: vec![Message::User(crate::model::UserMessage {
2527 content: UserContent::Text("hi".to_string()),
2528 timestamp: 0,
2529 })]
2530 .into(),
2531 tools: Vec::new().into(),
2532 };
2533
2534 let runtime = RuntimeBuilder::current_thread()
2535 .build()
2536 .expect("runtime build");
2537 runtime.block_on(async {
2538 let mut stream = provider
2539 .stream(&context, &StreamOptions::default())
2540 .await
2541 .expect("stream");
2542 while let Some(event) = stream.next().await {
2543 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
2544 break;
2545 }
2546 }
2547 });
2548
2549 let captured = rx
2550 .recv_timeout(Duration::from_secs(2))
2551 .expect("captured request");
2552 assert_eq!(
2553 captured.headers.get("authorization").map(String::as_str),
2554 Some("Bearer sk-ant-oat-compat")
2555 );
2556 assert!(!captured.headers.contains_key("x-api-key"));
2557 assert_eq!(
2558 captured
2559 .headers
2560 .get("anthropic-dangerous-direct-browser-access")
2561 .map(String::as_str),
2562 Some("true")
2563 );
2564 assert_eq!(
2565 captured.headers.get("x-app").map(String::as_str),
2566 Some("cli")
2567 );
2568 assert!(
2569 captured
2570 .headers
2571 .get("anthropic-beta")
2572 .is_some_and(|value| value.contains("oauth-2025-04-20"))
2573 );
2574 }
2575
2576 #[test]
2577 fn test_authorization_override_wins_side_effects_over_x_api_key_override() {
2578 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
2579
2580 let mut custom = HashMap::new();
2581 custom.insert(
2582 "Authorization".to_string(),
2583 "Bearer sk-ant-oat-compat".to_string(),
2584 );
2585 let provider = AnthropicProvider::new("claude-test")
2586 .with_base_url(base_url)
2587 .with_compat(Some(crate::models::CompatConfig {
2588 custom_headers: Some(custom),
2589 ..Default::default()
2590 }));
2591
2592 let context = Context {
2593 system_prompt: Some("test".to_string().into()),
2594 messages: vec![Message::User(crate::model::UserMessage {
2595 content: UserContent::Text("hi".to_string()),
2596 timestamp: 0,
2597 })]
2598 .into(),
2599 tools: Vec::new().into(),
2600 };
2601 let mut headers = HashMap::new();
2602 headers.insert("X-API-Key".to_string(), "header-ant-key".to_string());
2603
2604 let runtime = RuntimeBuilder::current_thread()
2605 .build()
2606 .expect("runtime build");
2607 runtime.block_on(async {
2608 let mut stream = provider
2609 .stream(
2610 &context,
2611 &StreamOptions {
2612 headers,
2613 ..Default::default()
2614 },
2615 )
2616 .await
2617 .expect("stream");
2618 while let Some(event) = stream.next().await {
2619 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
2620 break;
2621 }
2622 }
2623 });
2624
2625 let captured = rx
2626 .recv_timeout(Duration::from_secs(2))
2627 .expect("captured request");
2628 assert_eq!(
2629 captured.headers.get("authorization").map(String::as_str),
2630 Some("Bearer sk-ant-oat-compat")
2631 );
2632 assert_eq!(
2633 captured.headers.get("x-api-key").map(String::as_str),
2634 Some("header-ant-key")
2635 );
2636 assert_eq!(
2637 captured
2638 .headers
2639 .get("anthropic-dangerous-direct-browser-access")
2640 .map(String::as_str),
2641 Some("true")
2642 );
2643 assert_eq!(
2644 captured.headers.get("x-app").map(String::as_str),
2645 Some("cli")
2646 );
2647 assert!(
2648 captured
2649 .headers
2650 .get("anthropic-beta")
2651 .is_some_and(|value| value.contains("oauth-2025-04-20"))
2652 );
2653 }
2654
2655 #[test]
2656 fn test_stream_option_x_api_key_header_works_without_api_key() {
2657 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
2658
2659 let provider = AnthropicProvider::new("claude-test").with_base_url(base_url);
2660 let context = Context {
2661 system_prompt: Some("test".to_string().into()),
2662 messages: vec![Message::User(crate::model::UserMessage {
2663 content: UserContent::Text("hi".to_string()),
2664 timestamp: 0,
2665 })]
2666 .into(),
2667 tools: Vec::new().into(),
2668 };
2669 let mut headers = HashMap::new();
2670 headers.insert("X-API-Key".to_string(), "header-ant-key".to_string());
2671
2672 let runtime = RuntimeBuilder::current_thread()
2673 .build()
2674 .expect("runtime build");
2675 runtime.block_on(async {
2676 let mut stream = provider
2677 .stream(
2678 &context,
2679 &StreamOptions {
2680 headers,
2681 ..Default::default()
2682 },
2683 )
2684 .await
2685 .expect("stream");
2686 while let Some(event) = stream.next().await {
2687 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
2688 break;
2689 }
2690 }
2691 });
2692
2693 let captured = rx
2694 .recv_timeout(Duration::from_secs(2))
2695 .expect("captured request");
2696 assert_eq!(
2697 captured.headers.get("x-api-key").map(String::as_str),
2698 Some("header-ant-key")
2699 );
2700 assert!(!captured.headers.contains_key("authorization"));
2701 }
2702
2703 #[test]
2704 fn test_compat_none_does_not_affect_headers() {
2705 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
2706
2707 let provider = AnthropicProvider::new("claude-test")
2708 .with_base_url(base_url)
2709 .with_compat(None);
2710
2711 let context = Context {
2712 system_prompt: Some("test".to_string().into()),
2713 messages: vec![Message::User(crate::model::UserMessage {
2714 content: UserContent::Text("hi".to_string()),
2715 timestamp: 0,
2716 })]
2717 .into(),
2718 tools: Vec::new().into(),
2719 };
2720 let options = StreamOptions {
2721 api_key: Some("sk-ant-test-key".to_string()),
2722 ..Default::default()
2723 };
2724
2725 let runtime = RuntimeBuilder::current_thread()
2726 .build()
2727 .expect("runtime build");
2728 runtime.block_on(async {
2729 let mut stream = provider.stream(&context, &options).await.expect("stream");
2730 while let Some(event) = stream.next().await {
2731 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
2732 break;
2733 }
2734 }
2735 });
2736
2737 let captured = rx
2738 .recv_timeout(Duration::from_secs(2))
2739 .expect("captured request");
2740 assert_eq!(
2742 captured.headers.get("x-api-key").map(String::as_str),
2743 Some("sk-ant-test-key"),
2744 );
2745 assert!(
2746 !captured.headers.contains_key("x-custom-tag"),
2747 "No custom headers should be present with compat=None"
2748 );
2749 }
2750
2751 mod proptest_process_event {
2756 use super::*;
2757 use proptest::prelude::*;
2758
2759 fn make_state()
2760 -> StreamState<impl Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin>
2761 {
2762 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
2763 let sse = crate::sse::SseStream::new(Box::pin(empty));
2764 StreamState::new(
2765 sse,
2766 "claude-test".into(),
2767 "anthropic-messages".into(),
2768 "anthropic".into(),
2769 )
2770 }
2771
2772 fn small_string() -> impl Strategy<Value = String> {
2773 prop_oneof![Just(String::new()), "[a-zA-Z0-9_]{1,16}", "[ -~]{0,32}",]
2774 }
2775
2776 fn optional_string() -> impl Strategy<Value = Option<String>> {
2777 prop_oneof![Just(None), small_string().prop_map(Some),]
2778 }
2779
2780 fn token_count() -> impl Strategy<Value = u64> {
2781 prop_oneof![
2782 5 => 0u64..10_000u64,
2783 2 => Just(0u64),
2784 1 => Just(u64::MAX),
2785 1 => (u64::MAX - 100)..=u64::MAX,
2786 ]
2787 }
2788
2789 fn block_type() -> impl Strategy<Value = String> {
2790 prop_oneof![
2791 Just("text".to_string()),
2792 Just("thinking".to_string()),
2793 Just("tool_use".to_string()),
2794 Just("unknown_block_type".to_string()),
2795 "[a-z_]{1,12}",
2796 ]
2797 }
2798
2799 fn delta_type() -> impl Strategy<Value = String> {
2800 prop_oneof![
2801 Just("text_delta".to_string()),
2802 Just("thinking_delta".to_string()),
2803 Just("input_json_delta".to_string()),
2804 Just("signature_delta".to_string()),
2805 Just("unknown_delta".to_string()),
2806 "[a-z_]{1,16}",
2807 ]
2808 }
2809
2810 fn content_index() -> impl Strategy<Value = u32> {
2811 prop_oneof![
2812 5 => 0u32..5u32,
2813 2 => Just(0u32),
2814 1 => Just(u32::MAX),
2815 1 => 1000u32..2000u32,
2816 ]
2817 }
2818
2819 fn stop_reason_str() -> impl Strategy<Value = String> {
2820 prop_oneof![
2821 Just("end_turn".to_string()),
2822 Just("max_tokens".to_string()),
2823 Just("tool_use".to_string()),
2824 Just("stop_sequence".to_string()),
2825 Just("unknown_reason".to_string()),
2826 "[a-z_]{1,12}",
2827 ]
2828 }
2829
2830 fn anthropic_event_json() -> impl Strategy<Value = String> {
2833 prop_oneof![
2834 3 => token_count().prop_flat_map(|input| {
2836 (Just(input), token_count(), token_count()).prop_map(
2837 move |(cache_read, cache_write, _)| {
2838 serde_json::json!({
2839 "type": "message_start",
2840 "message": {
2841 "usage": {
2842 "input_tokens": input,
2843 "cache_read_input_tokens": cache_read,
2844 "cache_creation_input_tokens": cache_write
2845 }
2846 }
2847 })
2848 .to_string()
2849 },
2850 )
2851 }),
2852 1 => Just(r#"{"type":"message_start","message":{}}"#.to_string()),
2854 3 => (content_index(), block_type(), optional_string(), optional_string())
2856 .prop_map(|(idx, bt, id, name)| {
2857 let mut block = serde_json::json!({"type": bt});
2858 if let Some(id) = id {
2859 block["id"] = serde_json::Value::String(id);
2860 }
2861 if let Some(name) = name {
2862 block["name"] = serde_json::Value::String(name);
2863 }
2864 serde_json::json!({
2865 "type": "content_block_start",
2866 "index": idx,
2867 "content_block": block
2868 })
2869 .to_string()
2870 }),
2871 3 => (content_index(), delta_type(), optional_string(), optional_string(), optional_string(), optional_string())
2873 .prop_map(|(idx, dt, text, thinking, partial_json, sig)| {
2874 let mut delta = serde_json::json!({"type": dt});
2875 if let Some(t) = text { delta["text"] = serde_json::Value::String(t); }
2876 if let Some(t) = thinking { delta["thinking"] = serde_json::Value::String(t); }
2877 if let Some(p) = partial_json { delta["partial_json"] = serde_json::Value::String(p); }
2878 if let Some(s) = sig { delta["signature"] = serde_json::Value::String(s); }
2879 serde_json::json!({
2880 "type": "content_block_delta",
2881 "index": idx,
2882 "delta": delta
2883 })
2884 .to_string()
2885 }),
2886 2 => content_index().prop_map(|idx| {
2888 serde_json::json!({"type": "content_block_stop", "index": idx}).to_string()
2889 }),
2890 2 => (stop_reason_str(), token_count()).prop_map(|(sr, out)| {
2892 serde_json::json!({
2893 "type": "message_delta",
2894 "delta": {"stop_reason": sr},
2895 "usage": {"output_tokens": out}
2896 })
2897 .to_string()
2898 }),
2899 1 => stop_reason_str().prop_map(|sr| {
2901 serde_json::json!({
2902 "type": "message_delta",
2903 "delta": {"stop_reason": sr}
2904 })
2905 .to_string()
2906 }),
2907 2 => Just(r#"{"type":"message_stop"}"#.to_string()),
2909 2 => small_string().prop_map(|msg| {
2911 serde_json::json!({"type": "error", "error": {"message": msg}}).to_string()
2912 }),
2913 2 => Just(r#"{"type":"ping"}"#.to_string()),
2915 ]
2916 }
2917
2918 fn chaos_json() -> impl Strategy<Value = String> {
2920 prop_oneof![
2921 Just(String::new()),
2923 Just("{}".to_string()),
2924 Just("[]".to_string()),
2925 Just("null".to_string()),
2926 Just("true".to_string()),
2927 Just("42".to_string()),
2928 Just("{".to_string()),
2930 Just(r#"{"type":}"#.to_string()),
2931 Just(r#"{"type":null}"#.to_string()),
2932 "[a-z_]{1,20}".prop_map(|t| format!(r#"{{"type":"{t}"}}"#)),
2934 "[ -~]{0,64}",
2936 Just(r#"{"type":"message_start"}"#.to_string()),
2938 Just(r#"{"type":"content_block_delta"}"#.to_string()),
2939 Just(r#"{"type":"error"}"#.to_string()),
2940 ]
2941 }
2942
2943 proptest! {
2944 #![proptest_config(ProptestConfig {
2945 cases: 256,
2946 max_shrink_iters: 100,
2947 .. ProptestConfig::default()
2948 })]
2949
2950 #[test]
2951 fn process_event_valid_never_panics(data in anthropic_event_json()) {
2952 let mut state = make_state();
2953 let _ = state.process_event(&data);
2954 }
2955
2956 #[test]
2957 fn process_event_chaos_never_panics(data in chaos_json()) {
2958 let mut state = make_state();
2959 let _ = state.process_event(&data);
2960 }
2961
2962 #[test]
2963 fn process_event_sequence_never_panics(
2964 events in prop::collection::vec(anthropic_event_json(), 1..8)
2965 ) {
2966 let mut state = make_state();
2967 for event in &events {
2968 let _ = state.process_event(event);
2969 }
2970 }
2971 }
2972 }
2973}
2974
2975#[cfg(feature = "fuzzing")]
2980pub mod fuzz {
2981 use super::*;
2982 use futures::stream;
2983 use std::pin::Pin;
2984
2985 type FuzzStream =
2986 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
2987
2988 pub struct Processor(StreamState<FuzzStream>);
2990
2991 impl Default for Processor {
2992 fn default() -> Self {
2993 Self::new()
2994 }
2995 }
2996
2997 impl Processor {
2998 pub fn new() -> Self {
3000 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
3001 Self(StreamState::new(
3002 crate::sse::SseStream::new(Box::pin(empty)),
3003 "claude-fuzz".into(),
3004 "anthropic-messages".into(),
3005 "anthropic".into(),
3006 ))
3007 }
3008
3009 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
3011 Ok(self
3012 .0
3013 .process_event(data)?
3014 .map_or_else(Vec::new, |event| vec![event]))
3015 }
3016 }
3017}