1use crate::auth::unmark_anthropic_oauth_bearer_token;
7use crate::error::{Error, Result};
8use crate::http::client::Client;
9use crate::model::{
10 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ThinkingContent,
11 ThinkingLevel, ToolCall, Usage, UserContent,
12};
13use crate::models::CompatConfig;
14use crate::provider::{CacheRetention, Context, Provider, StreamOptions, ToolDef};
15use crate::provider_metadata::canonical_provider_id;
16use crate::sse::SseStream;
17use async_trait::async_trait;
18use futures::StreamExt;
19use futures::stream::{self, Stream};
20use serde::{Deserialize, Serialize};
21use std::fs;
22use std::pin::Pin;
23
24const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
29const ANTHROPIC_API_VERSION: &str = "2023-06-01";
30const DEFAULT_MAX_TOKENS: u32 = 8192;
31const ANTHROPIC_OAUTH_TOKEN_PREFIX: &str = "sk-ant-oat";
32const ANTHROPIC_OAUTH_BETA_FLAGS: &str = "claude-code-20250219,oauth-2025-04-20";
33const KIMI_SHARE_DIR_ENV_KEY: &str = "KIMI_SHARE_DIR";
34
35#[inline]
36fn is_anthropic_oauth_token(token: &str) -> bool {
37 token.contains(ANTHROPIC_OAUTH_TOKEN_PREFIX)
38}
39
40#[inline]
41fn is_anthropic_provider(provider: &str) -> bool {
42 canonical_provider_id(provider).unwrap_or(provider) == "anthropic"
43}
44
45#[inline]
46fn is_anthropic_bearer_token(provider: &str, token: &str) -> bool {
47 if !is_anthropic_provider(provider) {
48 return false;
49 }
50 let token = token.trim();
51 if token.is_empty() {
52 return false;
53 }
54
55 if is_anthropic_oauth_token(token) {
57 return true;
58 }
59
60 !token.starts_with("sk-ant-")
62}
63
64#[inline]
65fn is_kimi_coding_provider(provider: &str) -> bool {
66 canonical_provider_id(provider).unwrap_or(provider) == "kimi-for-coding"
67}
68
69#[inline]
70fn is_kimi_oauth_token(provider: &str, token: &str) -> bool {
71 is_kimi_coding_provider(provider) && !token.starts_with("sk-")
72}
73
74fn sanitize_ascii_header_value(value: &str, fallback: &str) -> String {
75 if value.is_ascii() && !value.trim().is_empty() {
76 return value.to_string();
77 }
78 let sanitized = value
79 .chars()
80 .filter(char::is_ascii)
81 .collect::<String>()
82 .trim()
83 .to_string();
84 if sanitized.is_empty() {
85 fallback.to_string()
86 } else {
87 sanitized
88 }
89}
90
91fn home_dir_with_env_lookup<F>(env_lookup: F) -> Option<std::path::PathBuf>
92where
93 F: Fn(&str) -> Option<String>,
94{
95 env_lookup("HOME")
96 .map(|value| value.trim().to_string())
97 .filter(|value| !value.is_empty())
98 .map(std::path::PathBuf::from)
99 .or_else(|| {
100 env_lookup("USERPROFILE")
101 .map(|value| value.trim().to_string())
102 .filter(|value| !value.is_empty())
103 .map(std::path::PathBuf::from)
104 })
105 .or_else(|| {
106 let drive = env_lookup("HOMEDRIVE")
107 .map(|value| value.trim().to_string())
108 .filter(|value| !value.is_empty())?;
109 let path = env_lookup("HOMEPATH")
110 .map(|value| value.trim().to_string())
111 .filter(|value| !value.is_empty())?;
112 if path.starts_with('\\') || path.starts_with('/') {
113 Some(std::path::PathBuf::from(format!("{drive}{path}")))
114 } else {
115 let mut combined = std::path::PathBuf::from(drive);
116 combined.push(path);
117 Some(combined)
118 }
119 })
120}
121
122fn home_dir() -> Option<std::path::PathBuf> {
123 home_dir_with_env_lookup(|key| std::env::var(key).ok())
124}
125
126fn kimi_share_dir_with_env_lookup<F>(env_lookup: F) -> Option<std::path::PathBuf>
127where
128 F: Fn(&str) -> Option<String>,
129{
130 env_lookup(KIMI_SHARE_DIR_ENV_KEY)
131 .map(|value| value.trim().to_string())
132 .filter(|value| !value.is_empty())
133 .map(std::path::PathBuf::from)
134 .or_else(|| home_dir_with_env_lookup(env_lookup).map(|home| home.join(".kimi")))
135}
136
137fn kimi_share_dir() -> Option<std::path::PathBuf> {
138 kimi_share_dir_with_env_lookup(|key| std::env::var(key).ok())
139}
140
141fn kimi_device_id_paths() -> Option<(std::path::PathBuf, std::path::PathBuf)> {
142 let primary = kimi_share_dir()?.join("device_id");
143 let legacy = home_dir().map_or_else(
144 || primary.clone(),
145 |home| home.join(".pi").join("agent").join("kimi-device-id"),
146 );
147 Some((primary, legacy))
148}
149
150fn kimi_device_id() -> String {
151 let generated = uuid::Uuid::new_v4().simple().to_string();
152 let Some((primary, legacy)) = kimi_device_id_paths() else {
153 return generated;
154 };
155
156 for path in [&primary, &legacy] {
157 if let Ok(existing) = fs::read_to_string(path) {
158 let existing = existing.trim();
159 if !existing.is_empty() {
160 return existing.to_string();
161 }
162 }
163 }
164
165 if let Some(parent) = primary.parent() {
166 let _ = fs::create_dir_all(parent);
167 }
168
169 let mut options = fs::OpenOptions::new();
170 options.write(true).create(true).truncate(true);
171
172 #[cfg(unix)]
173 {
174 use std::os::unix::fs::OpenOptionsExt;
175 options.mode(0o600);
176 }
177
178 if let Ok(mut file) = options.open(&primary) {
179 use std::io::Write;
180 let _ = file.write_all(generated.as_bytes());
181 }
182
183 generated
184}
185
186fn kimi_common_headers() -> Vec<(String, String)> {
187 let device_name = std::env::var("HOSTNAME")
188 .ok()
189 .or_else(|| std::env::var("COMPUTERNAME").ok())
190 .unwrap_or_else(|| "unknown".to_string());
191 let device_model = format!("{} {}", std::env::consts::OS, std::env::consts::ARCH);
192 let os_version = std::env::consts::OS.to_string();
193
194 vec![
195 (
196 "X-Msh-Platform".to_string(),
197 sanitize_ascii_header_value("kimi_cli", "unknown"),
198 ),
199 (
200 "X-Msh-Version".to_string(),
201 sanitize_ascii_header_value(env!("CARGO_PKG_VERSION"), "unknown"),
202 ),
203 (
204 "X-Msh-Device-Name".to_string(),
205 sanitize_ascii_header_value(&device_name, "unknown"),
206 ),
207 (
208 "X-Msh-Device-Model".to_string(),
209 sanitize_ascii_header_value(&device_model, "unknown"),
210 ),
211 (
212 "X-Msh-Os-Version".to_string(),
213 sanitize_ascii_header_value(&os_version, "unknown"),
214 ),
215 (
216 "X-Msh-Device-Id".to_string(),
217 sanitize_ascii_header_value(&kimi_device_id(), "unknown"),
218 ),
219 ]
220}
221
222pub struct AnthropicProvider {
228 client: Client,
229 model: String,
230 base_url: String,
231 provider: String,
232 compat: Option<CompatConfig>,
233}
234
235impl AnthropicProvider {
236 pub fn new(model: impl Into<String>) -> Self {
238 Self {
239 client: Client::new(),
240 model: model.into(),
241 base_url: ANTHROPIC_API_URL.to_string(),
242 provider: "anthropic".to_string(),
243 compat: None,
244 }
245 }
246
247 #[must_use]
249 pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
250 self.provider = provider.into();
251 self
252 }
253
254 #[must_use]
256 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
257 self.base_url = base_url.into();
258 self
259 }
260
261 #[must_use]
263 pub fn with_client(mut self, client: Client) -> Self {
264 self.client = client;
265 self
266 }
267
268 #[must_use]
273 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
274 self.compat = compat;
275 self
276 }
277
278 pub fn build_request<'a>(
280 &'a self,
281 context: &'a Context<'_>,
282 options: &StreamOptions,
283 ) -> AnthropicRequest<'a> {
284 let messages = context
285 .messages
286 .iter()
287 .map(convert_message_to_anthropic)
288 .collect();
289
290 let tools: Option<Vec<AnthropicTool<'_>>> = if context.tools.is_empty() {
291 None
292 } else {
293 Some(
294 context
295 .tools
296 .iter()
297 .map(convert_tool_to_anthropic)
298 .collect(),
299 )
300 };
301
302 let thinking = options.thinking_level.and_then(|level| {
304 if level == ThinkingLevel::Off {
305 None
306 } else {
307 let budget = options.thinking_budgets.as_ref().map_or_else(
308 || level.default_budget(),
309 |b| match level {
310 ThinkingLevel::Off => 0,
311 ThinkingLevel::Minimal => b.minimal,
312 ThinkingLevel::Low => b.low,
313 ThinkingLevel::Medium => b.medium,
314 ThinkingLevel::High => b.high,
315 ThinkingLevel::XHigh => b.xhigh,
316 },
317 );
318 Some(AnthropicThinking {
319 r#type: "enabled",
320 budget_tokens: budget,
321 })
322 }
323 });
324
325 let mut max_tokens = options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
326 if let Some(t) = &thinking {
327 if max_tokens <= t.budget_tokens {
328 max_tokens = t.budget_tokens + 4096;
329 }
330 }
331
332 AnthropicRequest {
333 model: &self.model,
334 messages,
335 system: context.system_prompt.as_deref(),
336 max_tokens,
337 temperature: options.temperature,
338 tools,
339 stream: true,
340 thinking,
341 }
342 }
343}
344
345#[async_trait]
346impl Provider for AnthropicProvider {
347 fn name(&self) -> &str {
348 &self.provider
349 }
350
351 fn api(&self) -> &'static str {
352 "anthropic-messages"
353 }
354
355 fn model_id(&self) -> &str {
356 &self.model
357 }
358
359 #[allow(clippy::too_many_lines)]
360 async fn stream(
361 &self,
362 context: &Context<'_>,
363 options: &StreamOptions,
364 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
365 let raw_auth_value = options
366 .api_key
367 .clone()
368 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
369 .ok_or_else(|| {
370 Error::provider(
371 self.name(),
372 "Missing API key for provider. Configure credentials with /login <provider> or set the provider's API key env var.",
373 )
374 })?;
375 let forced_bearer_token = if is_anthropic_provider(&self.provider) {
376 unmark_anthropic_oauth_bearer_token(&raw_auth_value).map(ToString::to_string)
377 } else {
378 None
379 };
380 let force_bearer = forced_bearer_token.is_some();
381 let auth_value = forced_bearer_token.unwrap_or(raw_auth_value);
382
383 let request_body = self.build_request(context, options);
384 let anthropic_bearer_token =
385 force_bearer || is_anthropic_bearer_token(&self.provider, &auth_value);
386 let kimi_oauth_token = is_kimi_oauth_token(&self.provider, &auth_value);
387
388 let mut request = self
390 .client
391 .post(&self.base_url)
392 .header("Accept", "text/event-stream")
393 .header("anthropic-version", ANTHROPIC_API_VERSION);
394
395 if anthropic_bearer_token {
396 request = request
397 .header("Authorization", format!("Bearer {auth_value}"))
398 .header("anthropic-dangerous-direct-browser-access", "true")
399 .header("x-app", "cli")
400 .header(
401 "user-agent",
402 format!(
403 "pi_agent_rust/{} (external, cli)",
404 env!("CARGO_PKG_VERSION")
405 ),
406 );
407 } else if kimi_oauth_token {
408 request = request
409 .header("Authorization", format!("Bearer {auth_value}"))
410 .header(
411 "user-agent",
412 format!(
413 "pi_agent_rust/{} (kimi-oauth, cli)",
414 env!("CARGO_PKG_VERSION")
415 ),
416 );
417 for (name, value) in kimi_common_headers() {
418 request = request.header(name, value);
419 }
420 } else {
421 request = request.header("X-API-Key", &auth_value);
422 }
423
424 let mut beta_flags: Vec<&str> = Vec::new();
425 if anthropic_bearer_token {
426 beta_flags.push(ANTHROPIC_OAUTH_BETA_FLAGS);
427 }
428 if options.cache_retention != CacheRetention::None {
429 beta_flags.push("prompt-caching-2024-07-31");
430 }
431 if !beta_flags.is_empty() {
432 request = request.header("anthropic-beta", beta_flags.join(","));
433 }
434
435 if let Some(compat) = &self.compat {
437 if let Some(custom_headers) = &compat.custom_headers {
438 for (key, value) in custom_headers {
439 request = request.header(key, value);
440 }
441 }
442 }
443
444 for (key, value) in &options.headers {
446 request = request.header(key, value);
447 }
448
449 let request = request.json(&request_body)?;
450
451 let response = Box::pin(request.send()).await?;
452 let status = response.status();
453 if !(200..300).contains(&status) {
454 let body = response
455 .text()
456 .await
457 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
458 return Err(Error::provider(
459 self.name(),
460 format!("Anthropic API error (HTTP {status}): {body}"),
461 ));
462 }
463
464 let event_source = SseStream::new(response.bytes_stream());
466
467 let model = self.model.clone();
469 let api = self.api().to_string();
470 let provider = self.name().to_string();
471
472 let stream = stream::unfold(
473 StreamState::new(event_source, model, api, provider),
474 |mut state| async move {
475 if state.done {
476 return None;
477 }
478 loop {
479 match state.event_source.next().await {
480 Some(Ok(msg)) => {
481 if msg.event == "ping" {
482 } else {
484 match state.process_event(&msg.data) {
485 Ok(Some(event)) => {
486 if matches!(
487 &event,
488 StreamEvent::Done { .. } | StreamEvent::Error { .. }
489 ) {
490 state.done = true;
491 }
492 return Some((Ok(event), state));
493 }
494 Ok(None) => {}
495 Err(e) => {
496 state.done = true;
497 return Some((Err(e), state));
498 }
499 }
500 }
501 }
502 Some(Err(e)) => {
503 state.done = true;
504 let err = Error::api(format!("SSE error: {e}"));
505 return Some((Err(err), state));
506 }
507 None => {
511 state.done = true;
512 let reason = state.partial.stop_reason;
513 let message = std::mem::take(&mut state.partial);
514 return Some((Ok(StreamEvent::Done { reason, message }), state));
515 }
516 }
517 }
518 },
519 );
520
521 Ok(Box::pin(stream))
522 }
523}
524
525struct StreamState<S>
530where
531 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
532{
533 event_source: SseStream<S>,
534 partial: AssistantMessage,
535 current_tool_json: String,
536 current_tool_id: Option<String>,
537 current_tool_name: Option<String>,
538 done: bool,
539}
540
541impl<S> StreamState<S>
542where
543 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
544{
545 const fn recompute_total_tokens(&mut self) {
546 self.partial.usage.total_tokens = self
547 .partial
548 .usage
549 .input
550 .saturating_add(self.partial.usage.output)
551 .saturating_add(self.partial.usage.cache_read)
552 .saturating_add(self.partial.usage.cache_write);
553 }
554
555 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
556 Self {
557 event_source,
558 partial: AssistantMessage {
559 content: Vec::new(),
560 api,
561 provider,
562 model,
563 usage: Usage::default(),
564 stop_reason: StopReason::Stop,
565 error_message: None,
566 timestamp: chrono::Utc::now().timestamp_millis(),
567 },
568 current_tool_json: String::new(),
569 current_tool_id: None,
570 current_tool_name: None,
571 done: false,
572 }
573 }
574
575 #[allow(clippy::too_many_lines)]
576 fn process_event(&mut self, data: &str) -> Result<Option<StreamEvent>> {
577 let event: AnthropicStreamEvent =
578 serde_json::from_str(data).map_err(|e| Error::api(format!("JSON parse error: {e}")))?;
579
580 match event {
581 AnthropicStreamEvent::MessageStart { message } => {
582 Ok(Some(self.handle_message_start(message)))
583 }
584 AnthropicStreamEvent::ContentBlockStart {
585 index,
586 content_block,
587 } => Ok(Some(self.handle_content_block_start(index, content_block))),
588 AnthropicStreamEvent::ContentBlockDelta { index, delta } => {
589 Ok(self.handle_content_block_delta(index, delta))
590 }
591 AnthropicStreamEvent::ContentBlockStop { index } => {
592 Ok(self.handle_content_block_stop(index))
593 }
594 AnthropicStreamEvent::MessageDelta { delta, usage } => {
595 self.handle_message_delta(&delta, usage);
596 Ok(None)
597 }
598 AnthropicStreamEvent::MessageStop => {
599 let reason = self.partial.stop_reason;
600 Ok(Some(StreamEvent::Done {
601 reason,
602 message: std::mem::take(&mut self.partial),
603 }))
604 }
605 AnthropicStreamEvent::Error { error } => {
606 self.partial.stop_reason = StopReason::Error;
607 self.partial.error_message = Some(error.message);
608 Ok(Some(StreamEvent::Error {
609 reason: StopReason::Error,
610 error: std::mem::take(&mut self.partial),
611 }))
612 }
613 AnthropicStreamEvent::Ping => Ok(None),
614 }
615 }
616
617 fn handle_message_start(&mut self, message: AnthropicMessageStart) -> StreamEvent {
618 if let Some(usage) = message.usage {
619 self.partial.usage.input = usage.input;
620 self.partial.usage.cache_read = usage.cache_read.unwrap_or(0);
621 self.partial.usage.cache_write = usage.cache_write.unwrap_or(0);
622 self.recompute_total_tokens();
623 }
624 StreamEvent::Start {
625 partial: self.partial.clone(),
626 }
627 }
628
629 fn handle_content_block_start(
630 &mut self,
631 index: u32,
632 content_block: AnthropicContentBlock,
633 ) -> StreamEvent {
634 let content_index = index as usize;
635
636 match content_block {
637 AnthropicContentBlock::Text => {
638 self.partial
639 .content
640 .push(ContentBlock::Text(TextContent::new("")));
641 StreamEvent::TextStart { content_index }
642 }
643 AnthropicContentBlock::Thinking => {
644 self.partial
645 .content
646 .push(ContentBlock::Thinking(ThinkingContent {
647 thinking: String::new(),
648 thinking_signature: None,
649 }));
650 StreamEvent::ThinkingStart { content_index }
651 }
652 AnthropicContentBlock::ToolUse { id, name } => {
653 self.current_tool_json.clear();
654 self.current_tool_id = id;
655 self.current_tool_name = name;
656 self.partial.content.push(ContentBlock::ToolCall(ToolCall {
657 id: self.current_tool_id.clone().unwrap_or_default(),
658 name: self.current_tool_name.clone().unwrap_or_default(),
659 arguments: serde_json::Value::Null,
660 thought_signature: None,
661 }));
662 StreamEvent::ToolCallStart { content_index }
663 }
664 }
665 }
666
667 fn handle_content_block_delta(
668 &mut self,
669 index: u32,
670 delta: AnthropicDelta,
671 ) -> Option<StreamEvent> {
672 let idx = index as usize;
673
674 match delta {
675 AnthropicDelta::TextDelta { text } => {
676 if let Some(text) = text {
677 if let Some(ContentBlock::Text(t)) = self.partial.content.get_mut(idx) {
678 t.text.push_str(&text);
679 }
680 Some(StreamEvent::TextDelta {
681 content_index: idx,
682 delta: text,
683 })
684 } else {
685 None
686 }
687 }
688 AnthropicDelta::ThinkingDelta { thinking } => {
689 if let Some(thinking) = thinking {
690 if let Some(ContentBlock::Thinking(t)) = self.partial.content.get_mut(idx) {
691 t.thinking.push_str(&thinking);
692 }
693 Some(StreamEvent::ThinkingDelta {
694 content_index: idx,
695 delta: thinking,
696 })
697 } else {
698 None
699 }
700 }
701 AnthropicDelta::InputJsonDelta { partial_json } => {
702 if let Some(partial_json) = partial_json {
703 self.current_tool_json.push_str(&partial_json);
704 Some(StreamEvent::ToolCallDelta {
705 content_index: idx,
706 delta: partial_json,
707 })
708 } else {
709 None
710 }
711 }
712 AnthropicDelta::SignatureDelta { signature } => {
713 if let Some(sig) = signature {
717 if let Some(ContentBlock::Thinking(t)) = self.partial.content.get_mut(idx) {
718 t.thinking_signature = Some(sig);
719 }
720 }
721 None
722 }
723 }
724 }
725
726 fn handle_content_block_stop(&mut self, index: u32) -> Option<StreamEvent> {
727 let idx = index as usize;
728
729 match self.partial.content.get_mut(idx) {
730 Some(ContentBlock::Text(t)) => {
731 let content = t.text.clone();
735 Some(StreamEvent::TextEnd {
736 content_index: idx,
737 content,
738 })
739 }
740 Some(ContentBlock::Thinking(t)) => {
741 let content = t.thinking.clone();
743 Some(StreamEvent::ThinkingEnd {
744 content_index: idx,
745 content,
746 })
747 }
748 Some(ContentBlock::ToolCall(tc)) => {
749 let arguments: serde_json::Value =
750 match serde_json::from_str(&self.current_tool_json) {
751 Ok(args) => args,
752 Err(e) => {
753 tracing::warn!(
754 error = %e,
755 raw = %self.current_tool_json,
756 "Failed to parse tool arguments as JSON"
757 );
758 serde_json::Value::Null
759 }
760 };
761 let tool_call = ToolCall {
762 id: self.current_tool_id.take().unwrap_or_default(),
763 name: self.current_tool_name.take().unwrap_or_default(),
764 arguments: arguments.clone(),
765 thought_signature: None,
766 };
767 tc.arguments = arguments;
768 self.current_tool_json.clear();
769
770 Some(StreamEvent::ToolCallEnd {
771 content_index: idx,
772 tool_call,
773 })
774 }
775 _ => None,
776 }
777 }
778
779 #[allow(clippy::missing_const_for_fn)]
780 fn handle_message_delta(
781 &mut self,
782 delta: &AnthropicMessageDelta,
783 usage: Option<AnthropicDeltaUsage>,
784 ) {
785 if let Some(stop_reason) = delta.stop_reason {
786 self.partial.stop_reason = match stop_reason {
787 AnthropicStopReason::MaxTokens => StopReason::Length,
788 AnthropicStopReason::ToolUse => StopReason::ToolUse,
789 AnthropicStopReason::EndTurn | AnthropicStopReason::StopSequence => {
790 StopReason::Stop
791 }
792 };
793 }
794
795 if let Some(u) = usage {
796 self.partial.usage.output = u.output_tokens;
797 self.recompute_total_tokens();
798 }
799 }
800}
801
802#[derive(Debug, Serialize)]
807pub struct AnthropicRequest<'a> {
808 model: &'a str,
809 messages: Vec<AnthropicMessage<'a>>,
810 #[serde(skip_serializing_if = "Option::is_none")]
811 system: Option<&'a str>,
812 max_tokens: u32,
813 #[serde(skip_serializing_if = "Option::is_none")]
814 temperature: Option<f32>,
815 #[serde(skip_serializing_if = "Option::is_none")]
816 tools: Option<Vec<AnthropicTool<'a>>>,
817 stream: bool,
818 #[serde(skip_serializing_if = "Option::is_none")]
819 thinking: Option<AnthropicThinking>,
820}
821
822#[derive(Debug, Serialize)]
823struct AnthropicThinking {
824 r#type: &'static str,
825 budget_tokens: u32,
826}
827
828#[derive(Debug, Serialize)]
829struct AnthropicMessage<'a> {
830 role: &'static str,
831 content: Vec<AnthropicContent<'a>>,
832}
833
834#[derive(Debug, Serialize)]
835#[serde(tag = "type", rename_all = "snake_case")]
836enum AnthropicContent<'a> {
837 Text {
838 text: &'a str,
839 },
840 Thinking {
841 thinking: &'a str,
842 signature: &'a str,
843 },
844 Image {
845 source: AnthropicImageSource<'a>,
846 },
847 ToolUse {
848 id: &'a str,
849 name: &'a str,
850 input: &'a serde_json::Value,
851 },
852 ToolResult {
853 tool_use_id: &'a str,
854 content: Vec<AnthropicToolResultContent<'a>>,
855 #[serde(skip_serializing_if = "Option::is_none")]
856 is_error: Option<bool>,
857 },
858}
859
860#[derive(Debug, Serialize)]
861struct AnthropicImageSource<'a> {
862 r#type: &'static str,
863 media_type: &'a str,
864 data: &'a str,
865}
866
867#[derive(Debug, Serialize)]
868#[serde(tag = "type", rename_all = "snake_case")]
869enum AnthropicToolResultContent<'a> {
870 Text { text: &'a str },
871 Image { source: AnthropicImageSource<'a> },
872}
873
874#[derive(Debug, Serialize)]
875struct AnthropicTool<'a> {
876 name: &'a str,
877 description: &'a str,
878 input_schema: &'a serde_json::Value,
879}
880
881#[derive(Debug, Deserialize)]
886#[serde(tag = "type", rename_all = "snake_case")]
887enum AnthropicStreamEvent {
888 MessageStart {
889 message: AnthropicMessageStart,
890 },
891 ContentBlockStart {
892 index: u32,
893 content_block: AnthropicContentBlock,
894 },
895 ContentBlockDelta {
896 index: u32,
897 delta: AnthropicDelta,
898 },
899 ContentBlockStop {
900 index: u32,
901 },
902 MessageDelta {
903 delta: AnthropicMessageDelta,
904 #[serde(default)]
905 usage: Option<AnthropicDeltaUsage>,
906 },
907 MessageStop,
908 Error {
909 error: AnthropicError,
910 },
911 Ping,
912}
913
914#[derive(Debug, Deserialize)]
915struct AnthropicMessageStart {
916 #[serde(default)]
917 usage: Option<AnthropicUsage>,
918}
919
920#[derive(Debug, Deserialize)]
923#[allow(clippy::struct_field_names)]
924struct AnthropicUsage {
925 #[serde(rename = "input_tokens")]
926 input: u64,
927 #[serde(default, rename = "cache_read_input_tokens")]
928 cache_read: Option<u64>,
929 #[serde(default, rename = "cache_creation_input_tokens")]
930 cache_write: Option<u64>,
931}
932
933#[derive(Debug, Deserialize)]
934struct AnthropicDeltaUsage {
935 output_tokens: u64,
936}
937
938#[derive(Debug, Deserialize)]
942#[serde(tag = "type", rename_all = "snake_case")]
943enum AnthropicContentBlock {
944 Text,
945 Thinking,
946 ToolUse {
947 #[serde(default)]
948 id: Option<String>,
949 #[serde(default)]
950 name: Option<String>,
951 },
952}
953
954#[derive(Debug, Deserialize)]
960#[serde(tag = "type", rename_all = "snake_case")]
961#[allow(clippy::enum_variant_names)] enum AnthropicDelta {
963 TextDelta {
964 #[serde(default)]
965 text: Option<String>,
966 },
967 ThinkingDelta {
968 #[serde(default)]
969 thinking: Option<String>,
970 },
971 InputJsonDelta {
972 #[serde(default)]
973 partial_json: Option<String>,
974 },
975 SignatureDelta {
976 #[serde(default)]
977 signature: Option<String>,
978 },
979}
980
981#[derive(Debug, Clone, Copy, Deserialize)]
985#[serde(rename_all = "snake_case")]
986enum AnthropicStopReason {
987 EndTurn,
988 MaxTokens,
989 ToolUse,
990 StopSequence,
991}
992
993#[derive(Debug, Deserialize)]
994struct AnthropicMessageDelta {
995 #[serde(default)]
996 stop_reason: Option<AnthropicStopReason>,
997}
998
999#[derive(Debug, Deserialize)]
1000struct AnthropicError {
1001 message: String,
1002}
1003
1004fn convert_message_to_anthropic(message: &Message) -> AnthropicMessage<'_> {
1009 match message {
1010 Message::User(user) => AnthropicMessage {
1011 role: "user",
1012 content: convert_user_content(&user.content),
1013 },
1014 Message::Custom(custom) => AnthropicMessage {
1015 role: "user",
1016 content: vec![AnthropicContent::Text {
1017 text: &custom.content,
1018 }],
1019 },
1020 Message::Assistant(assistant) => AnthropicMessage {
1021 role: "assistant",
1022 content: assistant
1023 .content
1024 .iter()
1025 .filter_map(convert_content_block_to_anthropic)
1026 .collect(),
1027 },
1028 Message::ToolResult(result) => AnthropicMessage {
1029 role: "user",
1030 content: vec![AnthropicContent::ToolResult {
1031 tool_use_id: &result.tool_call_id,
1032 content: result
1033 .content
1034 .iter()
1035 .filter_map(|block| match block {
1036 ContentBlock::Text(t) => {
1037 Some(AnthropicToolResultContent::Text { text: &t.text })
1038 }
1039 ContentBlock::Image(img) => Some(AnthropicToolResultContent::Image {
1040 source: AnthropicImageSource {
1041 r#type: "base64",
1042 media_type: &img.mime_type,
1043 data: &img.data,
1044 },
1045 }),
1046 _ => None,
1047 })
1048 .collect(),
1049 is_error: if result.is_error { Some(true) } else { None },
1050 }],
1051 },
1052 }
1053}
1054
1055fn convert_user_content(content: &UserContent) -> Vec<AnthropicContent<'_>> {
1056 match content {
1057 UserContent::Text(text) => vec![AnthropicContent::Text { text }],
1058 UserContent::Blocks(blocks) => blocks
1059 .iter()
1060 .filter_map(|block| match block {
1061 ContentBlock::Text(t) => Some(AnthropicContent::Text { text: &t.text }),
1062 ContentBlock::Image(img) => Some(AnthropicContent::Image {
1063 source: AnthropicImageSource {
1064 r#type: "base64",
1065 media_type: &img.mime_type,
1066 data: &img.data,
1067 },
1068 }),
1069 _ => None,
1070 })
1071 .collect(),
1072 }
1073}
1074
1075fn convert_content_block_to_anthropic(block: &ContentBlock) -> Option<AnthropicContent<'_>> {
1076 match block {
1077 ContentBlock::Text(t) => Some(AnthropicContent::Text { text: &t.text }),
1078 ContentBlock::ToolCall(tc) => Some(AnthropicContent::ToolUse {
1079 id: &tc.id,
1080 name: &tc.name,
1081 input: &tc.arguments,
1082 }),
1083 ContentBlock::Thinking(t) => {
1087 t.thinking_signature
1088 .as_ref()
1089 .map(|sig| AnthropicContent::Thinking {
1090 thinking: &t.thinking,
1091 signature: sig,
1092 })
1093 }
1094 ContentBlock::Image(_) => None,
1095 }
1096}
1097
1098fn convert_tool_to_anthropic(tool: &ToolDef) -> AnthropicTool<'_> {
1099 AnthropicTool {
1100 name: &tool.name,
1101 description: &tool.description,
1102 input_schema: &tool.parameters,
1103 }
1104}
1105
1106#[cfg(test)]
1111mod tests {
1112 use super::*;
1113 use asupersync::runtime::RuntimeBuilder;
1114 use futures::{StreamExt, stream};
1115 use serde::{Deserialize, Serialize};
1116 use serde_json::Value;
1117 use serde_json::json;
1118 use std::collections::HashMap;
1119 use std::io::{Read, Write};
1120 use std::net::TcpListener;
1121 use std::path::PathBuf;
1122 use std::sync::mpsc;
1123 use std::time::Duration;
1124
1125 #[test]
1126 fn home_dir_lookup_falls_back_to_userprofile() {
1127 let home = home_dir_with_env_lookup(|key| match key {
1128 "USERPROFILE" => Some("C:\\Users\\Ada".to_string()),
1129 _ => None,
1130 });
1131
1132 assert_eq!(home, Some(PathBuf::from("C:\\Users\\Ada")));
1133 }
1134
1135 #[test]
1136 fn home_dir_lookup_falls_back_to_homedrive_homepath() {
1137 let home = home_dir_with_env_lookup(|key| match key {
1138 "HOMEDRIVE" => Some("D:".to_string()),
1139 "HOMEPATH" => Some("\\Users\\Grace".to_string()),
1140 _ => None,
1141 });
1142
1143 assert_eq!(home, Some(PathBuf::from("D:\\Users\\Grace")));
1144 }
1145
1146 #[test]
1147 fn test_convert_user_text_message() {
1148 let message = Message::User(crate::model::UserMessage {
1149 content: UserContent::Text("Hello".to_string()),
1150 timestamp: 0,
1151 });
1152
1153 let converted = convert_message_to_anthropic(&message);
1154 assert_eq!(converted.role, "user");
1155 assert_eq!(converted.content.len(), 1);
1156 }
1157
1158 #[test]
1159 fn test_thinking_budget() {
1160 assert_eq!(ThinkingLevel::Minimal.default_budget(), 1024);
1161 assert_eq!(ThinkingLevel::Low.default_budget(), 2048);
1162 assert_eq!(ThinkingLevel::Medium.default_budget(), 8192);
1163 assert_eq!(ThinkingLevel::High.default_budget(), 16384);
1164 }
1165
1166 #[test]
1167 fn test_build_request_includes_system_tools_and_thinking() {
1168 let provider = AnthropicProvider::new("claude-test");
1169 let context = Context {
1170 system_prompt: Some("System prompt".to_string().into()),
1171 messages: vec![Message::User(crate::model::UserMessage {
1172 content: UserContent::Text("Ping".to_string()),
1173 timestamp: 0,
1174 })]
1175 .into(),
1176 tools: vec![ToolDef {
1177 name: "echo".to_string(),
1178 description: "Echo a string.".to_string(),
1179 parameters: json!({
1180 "type": "object",
1181 "properties": {
1182 "text": { "type": "string" }
1183 },
1184 "required": ["text"]
1185 }),
1186 }]
1187 .into(),
1188 };
1189 let options = StreamOptions {
1190 max_tokens: Some(128),
1191 temperature: Some(0.2),
1192 thinking_level: Some(ThinkingLevel::Medium),
1193 thinking_budgets: Some(crate::provider::ThinkingBudgets {
1194 minimal: 1024,
1195 low: 2048,
1196 medium: 9000,
1197 high: 16384,
1198 xhigh: 32768,
1199 }),
1200 ..Default::default()
1201 };
1202
1203 let request = provider.build_request(&context, &options);
1204 assert_eq!(request.model, "claude-test");
1205 assert_eq!(request.system, Some("System prompt"));
1206 assert_eq!(request.temperature, Some(0.2));
1207 assert!(request.stream);
1208 assert_eq!(request.max_tokens, 13_096);
1209
1210 let thinking = request.thinking.expect("thinking config");
1211 assert_eq!(thinking.r#type, "enabled");
1212 assert_eq!(thinking.budget_tokens, 9000);
1213
1214 assert_eq!(request.messages.len(), 1);
1215 assert_eq!(request.messages[0].role, "user");
1216 assert_eq!(request.messages[0].content.len(), 1);
1217 match &request.messages[0].content[0] {
1218 AnthropicContent::Text { text } => assert_eq!(*text, "Ping"),
1219 other => panic!("expected text content, got {other:?}"),
1220 }
1221
1222 let tools = request.tools.expect("tools");
1223 assert_eq!(tools.len(), 1);
1224 assert_eq!(tools[0].name, "echo");
1225 assert_eq!(tools[0].description, "Echo a string.");
1226 assert_eq!(
1227 *tools[0].input_schema,
1228 json!({
1229 "type": "object",
1230 "properties": {
1231 "text": { "type": "string" }
1232 },
1233 "required": ["text"]
1234 })
1235 );
1236 }
1237
1238 #[test]
1239 fn test_build_request_omits_optional_fields_by_default() {
1240 let provider = AnthropicProvider::new("claude-test");
1241 let context = Context::default();
1242 let options = StreamOptions::default();
1243
1244 let request = provider.build_request(&context, &options);
1245 assert_eq!(request.model, "claude-test");
1246 assert_eq!(request.system, None);
1247 assert!(request.tools.is_none());
1248 assert!(request.thinking.is_none());
1249 assert_eq!(request.max_tokens, DEFAULT_MAX_TOKENS);
1250 assert!(request.stream);
1251 }
1252
1253 #[test]
1254 #[allow(clippy::too_many_lines)]
1255 fn test_stream_parses_thinking_and_tool_call_events() {
1256 let events = vec![
1257 json!({
1258 "type": "message_start",
1259 "message": { "usage": { "input_tokens": 3 } }
1260 }),
1261 json!({
1262 "type": "content_block_start",
1263 "index": 0,
1264 "content_block": { "type": "thinking" }
1265 }),
1266 json!({
1267 "type": "content_block_delta",
1268 "index": 0,
1269 "delta": { "type": "thinking_delta", "thinking": "step 1" }
1270 }),
1271 json!({
1272 "type": "content_block_stop",
1273 "index": 0
1274 }),
1275 json!({
1276 "type": "content_block_start",
1277 "index": 1,
1278 "content_block": { "type": "tool_use", "id": "tool_123", "name": "search" }
1279 }),
1280 json!({
1281 "type": "content_block_delta",
1282 "index": 1,
1283 "delta": { "type": "input_json_delta", "partial_json": "{\"q\":\"ru" }
1284 }),
1285 json!({
1286 "type": "content_block_delta",
1287 "index": 1,
1288 "delta": { "type": "input_json_delta", "partial_json": "st\"}" }
1289 }),
1290 json!({
1291 "type": "content_block_stop",
1292 "index": 1
1293 }),
1294 json!({
1295 "type": "content_block_start",
1296 "index": 2,
1297 "content_block": { "type": "text" }
1298 }),
1299 json!({
1300 "type": "content_block_delta",
1301 "index": 2,
1302 "delta": { "type": "text_delta", "text": "done" }
1303 }),
1304 json!({
1305 "type": "content_block_stop",
1306 "index": 2
1307 }),
1308 json!({
1309 "type": "message_delta",
1310 "delta": { "stop_reason": "tool_use" },
1311 "usage": { "output_tokens": 5 }
1312 }),
1313 json!({
1314 "type": "message_stop"
1315 }),
1316 ];
1317
1318 let out = collect_events(&events);
1319 assert_eq!(out.len(), 12, "expected full stream event sequence");
1320
1321 assert!(matches!(&out[0], StreamEvent::Start { .. }));
1322 assert!(matches!(
1323 &out[1],
1324 StreamEvent::ThinkingStart {
1325 content_index: 0,
1326 ..
1327 }
1328 ));
1329 assert!(matches!(
1330 &out[2],
1331 StreamEvent::ThinkingDelta {
1332 content_index: 0,
1333 delta,
1334 ..
1335 } if delta == "step 1"
1336 ));
1337 assert!(matches!(
1338 &out[3],
1339 StreamEvent::ThinkingEnd {
1340 content_index: 0,
1341 content,
1342 ..
1343 } if content == "step 1"
1344 ));
1345 assert!(matches!(
1346 &out[4],
1347 StreamEvent::ToolCallStart {
1348 content_index: 1,
1349 ..
1350 }
1351 ));
1352 assert!(matches!(
1353 &out[5],
1354 StreamEvent::ToolCallDelta {
1355 content_index: 1,
1356 delta,
1357 ..
1358 } if delta == "{\"q\":\"ru"
1359 ));
1360 assert!(matches!(
1361 &out[6],
1362 StreamEvent::ToolCallDelta {
1363 content_index: 1,
1364 delta,
1365 ..
1366 } if delta == "st\"}"
1367 ));
1368 if let StreamEvent::ToolCallEnd {
1369 content_index,
1370 tool_call,
1371 ..
1372 } = &out[7]
1373 {
1374 assert_eq!(*content_index, 1);
1375 assert_eq!(tool_call.id, "tool_123");
1376 assert_eq!(tool_call.name, "search");
1377 assert_eq!(tool_call.arguments, json!({ "q": "rust" }));
1378 } else {
1379 panic!("expected ToolCallEnd event, got {:?}", out[7]);
1380 }
1381 assert!(matches!(
1382 &out[8],
1383 StreamEvent::TextStart {
1384 content_index: 2,
1385 ..
1386 }
1387 ));
1388 assert!(matches!(
1389 &out[9],
1390 StreamEvent::TextDelta {
1391 content_index: 2,
1392 delta,
1393 ..
1394 } if delta == "done"
1395 ));
1396 assert!(matches!(
1397 &out[10],
1398 StreamEvent::TextEnd {
1399 content_index: 2,
1400 content,
1401 ..
1402 } if content == "done"
1403 ));
1404 if let StreamEvent::Done { reason, message } = &out[11] {
1405 assert_eq!(*reason, StopReason::ToolUse);
1406 assert_eq!(message.stop_reason, StopReason::ToolUse);
1407 } else {
1408 panic!("expected Done event, got {:?}", out[11]);
1409 }
1410 }
1411
1412 #[test]
1413 fn test_message_delta_sets_length_stop_reason_and_usage() {
1414 let events = vec![
1415 json!({
1416 "type": "message_start",
1417 "message": { "usage": { "input_tokens": 5 } }
1418 }),
1419 json!({
1420 "type": "message_delta",
1421 "delta": { "stop_reason": "max_tokens" },
1422 "usage": { "output_tokens": 7 }
1423 }),
1424 json!({
1425 "type": "message_stop"
1426 }),
1427 ];
1428
1429 let out = collect_events(&events);
1430 assert_eq!(out.len(), 2);
1431 if let StreamEvent::Done { reason, message } = &out[1] {
1432 assert_eq!(*reason, StopReason::Length);
1433 assert_eq!(message.stop_reason, StopReason::Length);
1434 assert_eq!(message.usage.input, 5);
1435 assert_eq!(message.usage.output, 7);
1436 assert_eq!(message.usage.total_tokens, 12);
1437 } else {
1438 panic!("expected Done event, got {:?}", out[1]);
1439 }
1440 }
1441
1442 #[test]
1443 fn test_usage_total_tokens_saturates_on_large_values() {
1444 let events = vec![
1445 json!({
1446 "type": "message_start",
1447 "message": {
1448 "usage": {
1449 "input_tokens": u64::MAX,
1450 "cache_read_input_tokens": 1,
1451 "cache_creation_input_tokens": 1
1452 }
1453 }
1454 }),
1455 json!({
1456 "type": "message_delta",
1457 "delta": { "stop_reason": "end_turn" },
1458 "usage": { "output_tokens": 1 }
1459 }),
1460 json!({
1461 "type": "message_stop"
1462 }),
1463 ];
1464
1465 let out = collect_events(&events);
1466 assert_eq!(out.len(), 2);
1467 if let StreamEvent::Done { message, .. } = &out[1] {
1468 assert_eq!(message.usage.total_tokens, u64::MAX);
1469 } else {
1470 panic!("expected Done event, got {:?}", out[1]);
1471 }
1472 }
1473
1474 #[derive(Debug, Deserialize)]
1475 struct ProviderFixture {
1476 cases: Vec<ProviderCase>,
1477 }
1478
1479 #[derive(Debug, Deserialize)]
1480 struct ProviderCase {
1481 name: String,
1482 events: Vec<Value>,
1483 expected: Vec<EventSummary>,
1484 }
1485
1486 #[derive(Debug, Deserialize, Serialize, PartialEq)]
1487 struct EventSummary {
1488 kind: String,
1489 #[serde(default)]
1490 content_index: Option<usize>,
1491 #[serde(default)]
1492 delta: Option<String>,
1493 #[serde(default)]
1494 content: Option<String>,
1495 #[serde(default)]
1496 reason: Option<String>,
1497 }
1498
1499 #[test]
1500 fn test_stream_fixtures() {
1501 let fixture = load_fixture("anthropic_stream.json");
1502 for case in fixture.cases {
1503 let events = collect_events(&case.events);
1504 let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1505 assert_eq!(summaries, case.expected, "case {}", case.name);
1506 }
1507 }
1508
1509 #[test]
1510 fn test_stream_error_event_maps_to_stop_reason_error() {
1511 let events = vec![json!({
1512 "type": "error",
1513 "error": { "message": "nope" }
1514 })];
1515
1516 let out = collect_events(&events);
1517 assert_eq!(out.len(), 1);
1518 assert!(
1519 matches!(&out[0], StreamEvent::Error { .. }),
1520 "expected StreamEvent::Error, got {:?}",
1521 out[0]
1522 );
1523 if let StreamEvent::Error { reason, error } = &out[0] {
1524 assert_eq!(*reason, StopReason::Error);
1525 assert_eq!(error.stop_reason, StopReason::Error);
1526 assert_eq!(error.error_message.as_deref(), Some("nope"));
1527 }
1528 }
1529
1530 #[test]
1531 fn test_stream_emits_single_done_when_transport_ends_after_message_stop() {
1532 let out = collect_stream_items_from_body(&success_sse_body());
1533 let done_count = out
1534 .iter()
1535 .filter(|item| matches!(item, Ok(StreamEvent::Done { .. })))
1536 .count();
1537 assert_eq!(done_count, 1, "expected exactly one terminal Done event");
1538 }
1539
1540 #[test]
1541 fn test_stream_error_event_is_terminal() {
1542 let body = [
1543 r#"data: {"type":"error","error":{"message":"boom"}}"#,
1544 "",
1545 r#"data: {"type":"message_stop"}"#,
1547 "",
1548 ]
1549 .join("\n");
1550
1551 let out = collect_stream_items_from_body(&body);
1552 assert_eq!(out.len(), 1, "Error should terminate the stream");
1553 assert!(matches!(out[0], Ok(StreamEvent::Error { .. })));
1554 }
1555
1556 #[test]
1557 fn test_stream_parse_error_is_terminal() {
1558 let body = [
1559 r#"data: {"type":"message_start","message":{"usage":{"input_tokens":1}}}"#,
1560 "",
1561 r"data: {invalid-json}",
1562 "",
1563 r#"data: {"type":"message_stop"}"#,
1565 "",
1566 ]
1567 .join("\n");
1568
1569 let out = collect_stream_items_from_body(&body);
1570 assert_eq!(out.len(), 2, "parse error should stop further events");
1571 assert!(matches!(out[0], Ok(StreamEvent::Start { .. })));
1572 match &out[1] {
1573 Ok(event) => panic!("expected parse error item, got event: {event:?}"),
1574 Err(err) => assert!(err.to_string().contains("JSON parse error")),
1575 }
1576 }
1577
1578 #[test]
1579 fn test_stream_sets_required_headers() {
1580 let captured = run_stream_and_capture_headers(CacheRetention::None)
1581 .expect("captured request for required headers");
1582 assert_eq!(
1583 captured.headers.get("x-api-key").map(String::as_str),
1584 Some("sk-ant-test-key")
1585 );
1586 assert_eq!(
1587 captured
1588 .headers
1589 .get("anthropic-version")
1590 .map(String::as_str),
1591 Some(ANTHROPIC_API_VERSION)
1592 );
1593 assert!(!captured.headers.contains_key("anthropic-beta"));
1594 assert!(captured.body.contains("\"stream\":true"));
1595 }
1596
1597 #[test]
1598 fn test_stream_adds_prompt_caching_beta_header_when_enabled() {
1599 let captured = run_stream_and_capture_headers(CacheRetention::Short)
1600 .expect("captured request for beta header");
1601 assert_eq!(
1602 captured.headers.get("anthropic-beta").map(String::as_str),
1603 Some("prompt-caching-2024-07-31")
1604 );
1605 }
1606
1607 #[test]
1608 fn test_stream_uses_oauth_bearer_auth_headers() {
1609 let captured =
1610 run_stream_and_capture_headers_with_api_key(CacheRetention::None, "sk-ant-oat-test")
1611 .expect("captured request for oauth headers");
1612 assert_eq!(
1613 captured.headers.get("authorization").map(String::as_str),
1614 Some("Bearer sk-ant-oat-test")
1615 );
1616 assert!(!captured.headers.contains_key("x-api-key"));
1617 assert_eq!(
1618 captured
1619 .headers
1620 .get("anthropic-dangerous-direct-browser-access")
1621 .map(String::as_str),
1622 Some("true")
1623 );
1624 assert_eq!(
1625 captured.headers.get("x-app").map(String::as_str),
1626 Some("cli")
1627 );
1628 assert!(
1629 captured
1630 .headers
1631 .get("anthropic-beta")
1632 .is_some_and(|value| value.contains("oauth-2025-04-20"))
1633 );
1634 assert!(
1635 captured
1636 .headers
1637 .get("user-agent")
1638 .is_some_and(|value| value.contains("pi_agent_rust/"))
1639 );
1640 }
1641
1642 #[test]
1643 fn test_stream_uses_bearer_headers_for_marked_anthropic_oauth_token() {
1644 let marked = "__pi_anthropic_oauth_bearer__:sk-ant-api-like-token";
1645 let captured = run_stream_and_capture_headers_with_api_key(CacheRetention::None, marked)
1646 .expect("captured request for marked oauth headers");
1647 assert_eq!(
1648 captured.headers.get("authorization").map(String::as_str),
1649 Some("Bearer sk-ant-api-like-token")
1650 );
1651 assert!(!captured.headers.contains_key("x-api-key"));
1652 assert!(
1653 captured
1654 .headers
1655 .get("anthropic-beta")
1656 .is_some_and(|value| value.contains("oauth-2025-04-20"))
1657 );
1658 }
1659
1660 #[test]
1661 fn test_stream_claude_style_non_sk_token_uses_bearer_auth_headers() {
1662 let captured =
1663 run_stream_and_capture_headers_with_api_key(CacheRetention::None, "claude-oauth-token")
1664 .expect("captured request for claude bearer headers");
1665 assert_eq!(
1666 captured.headers.get("authorization").map(String::as_str),
1667 Some("Bearer claude-oauth-token")
1668 );
1669 assert!(!captured.headers.contains_key("x-api-key"));
1670 }
1671
1672 #[test]
1673 fn test_stream_kimi_oauth_uses_bearer_and_kimi_headers() {
1674 let captured = run_stream_and_capture_headers_for_provider_with_api_key(
1675 CacheRetention::None,
1676 "kimi-for-coding",
1677 "kimi-oauth-token",
1678 )
1679 .expect("captured request for kimi oauth headers");
1680 assert_eq!(
1681 captured.headers.get("authorization").map(String::as_str),
1682 Some("Bearer kimi-oauth-token")
1683 );
1684 assert!(!captured.headers.contains_key("x-api-key"));
1685 assert!(
1686 !captured
1687 .headers
1688 .contains_key("anthropic-dangerous-direct-browser-access")
1689 );
1690 assert!(!captured.headers.contains_key("anthropic-beta"));
1691 assert_eq!(
1692 captured.headers.get("x-msh-platform").map(String::as_str),
1693 Some("kimi_cli")
1694 );
1695 assert!(captured.headers.contains_key("x-msh-version"));
1696 assert!(captured.headers.contains_key("x-msh-device-name"));
1697 assert!(captured.headers.contains_key("x-msh-device-model"));
1698 assert!(captured.headers.contains_key("x-msh-os-version"));
1699 assert!(captured.headers.contains_key("x-msh-device-id"));
1700 }
1701
1702 #[test]
1703 fn test_stream_kimi_api_key_uses_x_api_key_header() {
1704 let captured = run_stream_and_capture_headers_for_provider_with_api_key(
1705 CacheRetention::None,
1706 "kimi-for-coding",
1707 "sk-kimi-api-key",
1708 )
1709 .expect("captured request for kimi api-key headers");
1710 assert_eq!(
1711 captured.headers.get("x-api-key").map(String::as_str),
1712 Some("sk-kimi-api-key")
1713 );
1714 assert!(!captured.headers.contains_key("authorization"));
1715 assert!(!captured.headers.contains_key("x-msh-platform"));
1716 }
1717
1718 #[test]
1719 fn test_stream_oauth_beta_header_includes_prompt_caching_when_enabled() {
1720 let captured =
1721 run_stream_and_capture_headers_with_api_key(CacheRetention::Short, "sk-ant-oat-test")
1722 .expect("captured request for oauth + cache beta header");
1723 let beta = captured
1724 .headers
1725 .get("anthropic-beta")
1726 .expect("anthropic-beta header");
1727 assert!(beta.contains("oauth-2025-04-20"));
1728 assert!(beta.contains("prompt-caching-2024-07-31"));
1729 }
1730
1731 #[test]
1732 fn test_stream_http_error_includes_status_and_body_message() {
1733 let (base_url, _rx) = spawn_test_server(
1734 401,
1735 "application/json",
1736 r#"{"type":"error","error":{"type":"authentication_error","message":"Invalid API key"}}"#,
1737 );
1738 let provider = AnthropicProvider::new("claude-test").with_base_url(base_url);
1739 let context = Context {
1740 system_prompt: None,
1741 messages: vec![Message::User(crate::model::UserMessage {
1742 content: UserContent::Text("ping".to_string()),
1743 timestamp: 0,
1744 })]
1745 .into(),
1746 tools: Vec::new().into(),
1747 };
1748 let options = StreamOptions {
1749 api_key: Some("test-key".to_string()),
1750 ..Default::default()
1751 };
1752
1753 let runtime = RuntimeBuilder::current_thread()
1754 .build()
1755 .expect("runtime build");
1756 let result = runtime.block_on(async { provider.stream(&context, &options).await });
1757 let Err(err) = result else {
1758 panic!("expected HTTP error");
1759 };
1760 let message = err.to_string();
1761 assert!(message.contains("Anthropic API error (HTTP 401)"));
1762 assert!(message.contains("Invalid API key"));
1763 }
1764
1765 #[test]
1766 fn test_provider_name_reflects_override() {
1767 let provider = AnthropicProvider::new("claude-test").with_provider_name("kimi-for-coding");
1768 assert_eq!(provider.name(), "kimi-for-coding");
1769 }
1770
1771 #[derive(Debug)]
1772 struct CapturedRequest {
1773 headers: HashMap<String, String>,
1774 body: String,
1775 }
1776
1777 fn run_stream_and_capture_headers(cache_retention: CacheRetention) -> Option<CapturedRequest> {
1778 run_stream_and_capture_headers_with_api_key(cache_retention, "sk-ant-test-key")
1779 }
1780
1781 fn run_stream_and_capture_headers_with_api_key(
1782 cache_retention: CacheRetention,
1783 api_key: &str,
1784 ) -> Option<CapturedRequest> {
1785 run_stream_and_capture_headers_for_provider_with_api_key(
1786 cache_retention,
1787 "anthropic",
1788 api_key,
1789 )
1790 }
1791
1792 fn run_stream_and_capture_headers_for_provider_with_api_key(
1793 cache_retention: CacheRetention,
1794 provider_name: &str,
1795 api_key: &str,
1796 ) -> Option<CapturedRequest> {
1797 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1798 let provider = AnthropicProvider::new("claude-test")
1799 .with_provider_name(provider_name)
1800 .with_base_url(base_url);
1801 let context = Context {
1802 system_prompt: Some("test system".to_string().into()),
1803 messages: vec![Message::User(crate::model::UserMessage {
1804 content: UserContent::Text("ping".to_string()),
1805 timestamp: 0,
1806 })]
1807 .into(),
1808 tools: Vec::new().into(),
1809 };
1810 let options = StreamOptions {
1811 api_key: Some(api_key.to_string()),
1812 cache_retention,
1813 ..Default::default()
1814 };
1815
1816 let runtime = RuntimeBuilder::current_thread()
1817 .build()
1818 .expect("runtime build");
1819 runtime.block_on(async {
1820 let mut stream = provider.stream(&context, &options).await.expect("stream");
1821 while let Some(event) = stream.next().await {
1822 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1823 break;
1824 }
1825 }
1826 });
1827
1828 rx.recv_timeout(Duration::from_secs(2)).ok()
1829 }
1830
1831 fn collect_stream_items_from_body(body: &str) -> Vec<Result<StreamEvent>> {
1832 let (base_url, _rx) = spawn_test_server(200, "text/event-stream", body);
1833 let provider = AnthropicProvider::new("claude-test").with_base_url(base_url);
1834 let context = Context {
1835 system_prompt: Some("test system".to_string().into()),
1836 messages: vec![Message::User(crate::model::UserMessage {
1837 content: UserContent::Text("ping".to_string()),
1838 timestamp: 0,
1839 })]
1840 .into(),
1841 tools: Vec::new().into(),
1842 };
1843 let options = StreamOptions {
1844 api_key: Some("sk-ant-test-key".to_string()),
1845 ..Default::default()
1846 };
1847
1848 let runtime = RuntimeBuilder::current_thread()
1849 .build()
1850 .expect("runtime build");
1851 runtime.block_on(async {
1852 let mut stream = provider.stream(&context, &options).await.expect("stream");
1853 let mut items = Vec::new();
1854 while let Some(item) = stream.next().await {
1855 items.push(item);
1856 }
1857 items
1858 })
1859 }
1860
1861 fn success_sse_body() -> String {
1862 [
1863 r#"data: {"type":"message_start","message":{"usage":{"input_tokens":1}}}"#,
1864 "",
1865 r#"data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":1}}"#,
1866 "",
1867 r#"data: {"type":"message_stop"}"#,
1868 "",
1869 ]
1870 .join("\n")
1871 }
1872
1873 fn spawn_test_server(
1874 status_code: u16,
1875 content_type: &str,
1876 body: &str,
1877 ) -> (String, mpsc::Receiver<CapturedRequest>) {
1878 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1879 let addr = listener.local_addr().expect("local addr");
1880 let (tx, rx) = mpsc::channel();
1881 let body = body.to_string();
1882 let content_type = content_type.to_string();
1883
1884 std::thread::spawn(move || {
1885 let (mut socket, _) = listener.accept().expect("accept");
1886 socket
1887 .set_read_timeout(Some(Duration::from_secs(2)))
1888 .expect("set read timeout");
1889
1890 let mut bytes = Vec::new();
1891 let mut chunk = [0_u8; 4096];
1892 loop {
1893 match socket.read(&mut chunk) {
1894 Ok(0) => break,
1895 Ok(n) => {
1896 bytes.extend_from_slice(&chunk[..n]);
1897 if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1898 break;
1899 }
1900 }
1901 Err(err)
1902 if err.kind() == std::io::ErrorKind::WouldBlock
1903 || err.kind() == std::io::ErrorKind::TimedOut =>
1904 {
1905 break;
1906 }
1907 Err(err) => panic!("read request failed: {err}"),
1908 }
1909 }
1910
1911 let header_end = bytes
1912 .windows(4)
1913 .position(|window| window == b"\r\n\r\n")
1914 .expect("request header boundary");
1915 let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1916 let headers = parse_headers(&header_text);
1917 let mut request_body = bytes[header_end + 4..].to_vec();
1918
1919 let content_length = headers
1920 .get("content-length")
1921 .and_then(|value| value.parse::<usize>().ok())
1922 .unwrap_or(0);
1923 while request_body.len() < content_length {
1924 match socket.read(&mut chunk) {
1925 Ok(0) => break,
1926 Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1927 Err(err)
1928 if err.kind() == std::io::ErrorKind::WouldBlock
1929 || err.kind() == std::io::ErrorKind::TimedOut =>
1930 {
1931 break;
1932 }
1933 Err(err) => panic!("read request body failed: {err}"),
1934 }
1935 }
1936
1937 let captured = CapturedRequest {
1938 headers,
1939 body: String::from_utf8_lossy(&request_body).to_string(),
1940 };
1941 tx.send(captured).expect("send captured request");
1942
1943 let reason = match status_code {
1944 401 => "Unauthorized",
1945 500 => "Internal Server Error",
1946 _ => "OK",
1947 };
1948 let response = format!(
1949 "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1950 body.len()
1951 );
1952 socket
1953 .write_all(response.as_bytes())
1954 .expect("write response");
1955 socket.flush().expect("flush response");
1956 });
1957
1958 (format!("http://{addr}/messages"), rx)
1959 }
1960
1961 fn parse_headers(header_text: &str) -> HashMap<String, String> {
1962 let mut headers = HashMap::new();
1963 for line in header_text.lines().skip(1) {
1964 if let Some((name, value)) = line.split_once(':') {
1965 headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
1966 }
1967 }
1968 headers
1969 }
1970
1971 fn load_fixture(file_name: &str) -> ProviderFixture {
1972 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1973 .join("tests/fixtures/provider_responses")
1974 .join(file_name);
1975 let raw = std::fs::read_to_string(path).expect("fixture read");
1976 serde_json::from_str(&raw).expect("fixture parse")
1977 }
1978
1979 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1980 let runtime = RuntimeBuilder::current_thread()
1981 .build()
1982 .expect("runtime build");
1983 runtime.block_on(async move {
1984 let byte_stream = stream::iter(
1985 events
1986 .iter()
1987 .map(|event| {
1988 let data = match event {
1989 Value::String(text) => text.clone(),
1990 _ => serde_json::to_string(event).expect("serialize event"),
1991 };
1992 format!("data: {data}\n\n").into_bytes()
1993 })
1994 .map(Ok),
1995 );
1996 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1997 let mut state = StreamState::new(
1998 event_source,
1999 "claude-test".to_string(),
2000 "anthropic-messages".to_string(),
2001 "anthropic".to_string(),
2002 );
2003 let mut out = Vec::new();
2004
2005 while let Some(item) = state.event_source.next().await {
2006 let msg = item.expect("SSE event");
2007 if msg.event == "ping" {
2008 continue;
2009 }
2010 if let Some(event) = state.process_event(&msg.data).expect("process_event") {
2011 out.push(event);
2012 }
2013 }
2014
2015 out
2016 })
2017 }
2018
2019 fn summarize_event(event: &StreamEvent) -> EventSummary {
2020 match event {
2021 StreamEvent::Start { .. } => EventSummary {
2022 kind: "start".to_string(),
2023 content_index: None,
2024 delta: None,
2025 content: None,
2026 reason: None,
2027 },
2028 StreamEvent::TextStart { content_index, .. } => EventSummary {
2029 kind: "text_start".to_string(),
2030 content_index: Some(*content_index),
2031 delta: None,
2032 content: None,
2033 reason: None,
2034 },
2035 StreamEvent::TextDelta {
2036 content_index,
2037 delta,
2038 ..
2039 } => EventSummary {
2040 kind: "text_delta".to_string(),
2041 content_index: Some(*content_index),
2042 delta: Some(delta.clone()),
2043 content: None,
2044 reason: None,
2045 },
2046 StreamEvent::TextEnd {
2047 content_index,
2048 content,
2049 ..
2050 } => EventSummary {
2051 kind: "text_end".to_string(),
2052 content_index: Some(*content_index),
2053 delta: None,
2054 content: Some(content.clone()),
2055 reason: None,
2056 },
2057 StreamEvent::Done { reason, .. } => EventSummary {
2058 kind: "done".to_string(),
2059 content_index: None,
2060 delta: None,
2061 content: None,
2062 reason: Some(reason_to_string(*reason)),
2063 },
2064 StreamEvent::Error { reason, .. } => EventSummary {
2065 kind: "error".to_string(),
2066 content_index: None,
2067 delta: None,
2068 content: None,
2069 reason: Some(reason_to_string(*reason)),
2070 },
2071 _ => EventSummary {
2072 kind: "other".to_string(),
2073 content_index: None,
2074 delta: None,
2075 content: None,
2076 reason: None,
2077 },
2078 }
2079 }
2080
2081 fn reason_to_string(reason: StopReason) -> String {
2082 match reason {
2083 StopReason::Stop => "stop",
2084 StopReason::Length => "length",
2085 StopReason::ToolUse => "tool_use",
2086 StopReason::Error => "error",
2087 StopReason::Aborted => "aborted",
2088 }
2089 .to_string()
2090 }
2091
2092 #[test]
2095 fn test_compat_custom_headers_injected_into_request() {
2096 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
2097
2098 let mut custom = HashMap::new();
2099 custom.insert("X-Custom-Tag".to_string(), "anthropic-override".to_string());
2100 custom.insert("X-Routing-Hint".to_string(), "us-east-1".to_string());
2101 let compat = crate::models::CompatConfig {
2102 custom_headers: Some(custom),
2103 ..Default::default()
2104 };
2105
2106 let provider = AnthropicProvider::new("claude-test")
2107 .with_base_url(base_url)
2108 .with_compat(Some(compat));
2109
2110 let context = Context {
2111 system_prompt: Some("test".to_string().into()),
2112 messages: vec![Message::User(crate::model::UserMessage {
2113 content: UserContent::Text("hi".to_string()),
2114 timestamp: 0,
2115 })]
2116 .into(),
2117 tools: Vec::new().into(),
2118 };
2119 let options = StreamOptions {
2120 api_key: Some("sk-ant-test-key".to_string()),
2121 ..Default::default()
2122 };
2123
2124 let runtime = RuntimeBuilder::current_thread()
2125 .build()
2126 .expect("runtime build");
2127 runtime.block_on(async {
2128 let mut stream = provider.stream(&context, &options).await.expect("stream");
2129 while let Some(event) = stream.next().await {
2130 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
2131 break;
2132 }
2133 }
2134 });
2135
2136 let captured = rx
2137 .recv_timeout(Duration::from_secs(2))
2138 .expect("captured request");
2139 assert_eq!(
2140 captured.headers.get("x-custom-tag").map(String::as_str),
2141 Some("anthropic-override"),
2142 "compat custom header X-Custom-Tag missing"
2143 );
2144 assert_eq!(
2145 captured.headers.get("x-routing-hint").map(String::as_str),
2146 Some("us-east-1"),
2147 "compat custom header X-Routing-Hint missing"
2148 );
2149 assert_eq!(
2151 captured.headers.get("x-api-key").map(String::as_str),
2152 Some("sk-ant-test-key"),
2153 );
2154 }
2155
2156 #[test]
2157 fn test_compat_none_does_not_affect_headers() {
2158 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
2159
2160 let provider = AnthropicProvider::new("claude-test")
2161 .with_base_url(base_url)
2162 .with_compat(None);
2163
2164 let context = Context {
2165 system_prompt: Some("test".to_string().into()),
2166 messages: vec![Message::User(crate::model::UserMessage {
2167 content: UserContent::Text("hi".to_string()),
2168 timestamp: 0,
2169 })]
2170 .into(),
2171 tools: Vec::new().into(),
2172 };
2173 let options = StreamOptions {
2174 api_key: Some("sk-ant-test-key".to_string()),
2175 ..Default::default()
2176 };
2177
2178 let runtime = RuntimeBuilder::current_thread()
2179 .build()
2180 .expect("runtime build");
2181 runtime.block_on(async {
2182 let mut stream = provider.stream(&context, &options).await.expect("stream");
2183 while let Some(event) = stream.next().await {
2184 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
2185 break;
2186 }
2187 }
2188 });
2189
2190 let captured = rx
2191 .recv_timeout(Duration::from_secs(2))
2192 .expect("captured request");
2193 assert_eq!(
2195 captured.headers.get("x-api-key").map(String::as_str),
2196 Some("sk-ant-test-key"),
2197 );
2198 assert!(
2199 !captured.headers.contains_key("x-custom-tag"),
2200 "No custom headers should be present with compat=None"
2201 );
2202 }
2203
2204 mod proptest_process_event {
2209 use super::*;
2210 use proptest::prelude::*;
2211
2212 fn make_state()
2213 -> StreamState<impl Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin>
2214 {
2215 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
2216 let sse = crate::sse::SseStream::new(Box::pin(empty));
2217 StreamState::new(
2218 sse,
2219 "claude-test".into(),
2220 "anthropic-messages".into(),
2221 "anthropic".into(),
2222 )
2223 }
2224
2225 fn small_string() -> impl Strategy<Value = String> {
2226 prop_oneof![Just(String::new()), "[a-zA-Z0-9_]{1,16}", "[ -~]{0,32}",]
2227 }
2228
2229 fn optional_string() -> impl Strategy<Value = Option<String>> {
2230 prop_oneof![Just(None), small_string().prop_map(Some),]
2231 }
2232
2233 fn token_count() -> impl Strategy<Value = u64> {
2234 prop_oneof![
2235 5 => 0u64..10_000u64,
2236 2 => Just(0u64),
2237 1 => Just(u64::MAX),
2238 1 => (u64::MAX - 100)..=u64::MAX,
2239 ]
2240 }
2241
2242 fn block_type() -> impl Strategy<Value = String> {
2243 prop_oneof![
2244 Just("text".to_string()),
2245 Just("thinking".to_string()),
2246 Just("tool_use".to_string()),
2247 Just("unknown_block_type".to_string()),
2248 "[a-z_]{1,12}",
2249 ]
2250 }
2251
2252 fn delta_type() -> impl Strategy<Value = String> {
2253 prop_oneof![
2254 Just("text_delta".to_string()),
2255 Just("thinking_delta".to_string()),
2256 Just("input_json_delta".to_string()),
2257 Just("signature_delta".to_string()),
2258 Just("unknown_delta".to_string()),
2259 "[a-z_]{1,16}",
2260 ]
2261 }
2262
2263 fn content_index() -> impl Strategy<Value = u32> {
2264 prop_oneof![
2265 5 => 0u32..5u32,
2266 2 => Just(0u32),
2267 1 => Just(u32::MAX),
2268 1 => 1000u32..2000u32,
2269 ]
2270 }
2271
2272 fn stop_reason_str() -> impl Strategy<Value = String> {
2273 prop_oneof![
2274 Just("end_turn".to_string()),
2275 Just("max_tokens".to_string()),
2276 Just("tool_use".to_string()),
2277 Just("stop_sequence".to_string()),
2278 Just("unknown_reason".to_string()),
2279 "[a-z_]{1,12}",
2280 ]
2281 }
2282
2283 fn anthropic_event_json() -> impl Strategy<Value = String> {
2286 prop_oneof![
2287 3 => token_count().prop_flat_map(|input| {
2289 (Just(input), token_count(), token_count()).prop_map(
2290 move |(cache_read, cache_write, _)| {
2291 serde_json::json!({
2292 "type": "message_start",
2293 "message": {
2294 "usage": {
2295 "input_tokens": input,
2296 "cache_read_input_tokens": cache_read,
2297 "cache_creation_input_tokens": cache_write
2298 }
2299 }
2300 })
2301 .to_string()
2302 },
2303 )
2304 }),
2305 1 => Just(r#"{"type":"message_start","message":{}}"#.to_string()),
2307 3 => (content_index(), block_type(), optional_string(), optional_string())
2309 .prop_map(|(idx, bt, id, name)| {
2310 let mut block = serde_json::json!({"type": bt});
2311 if let Some(id) = id {
2312 block["id"] = serde_json::Value::String(id);
2313 }
2314 if let Some(name) = name {
2315 block["name"] = serde_json::Value::String(name);
2316 }
2317 serde_json::json!({
2318 "type": "content_block_start",
2319 "index": idx,
2320 "content_block": block
2321 })
2322 .to_string()
2323 }),
2324 3 => (content_index(), delta_type(), optional_string(), optional_string(), optional_string(), optional_string())
2326 .prop_map(|(idx, dt, text, thinking, partial_json, sig)| {
2327 let mut delta = serde_json::json!({"type": dt});
2328 if let Some(t) = text { delta["text"] = serde_json::Value::String(t); }
2329 if let Some(t) = thinking { delta["thinking"] = serde_json::Value::String(t); }
2330 if let Some(p) = partial_json { delta["partial_json"] = serde_json::Value::String(p); }
2331 if let Some(s) = sig { delta["signature"] = serde_json::Value::String(s); }
2332 serde_json::json!({
2333 "type": "content_block_delta",
2334 "index": idx,
2335 "delta": delta
2336 })
2337 .to_string()
2338 }),
2339 2 => content_index().prop_map(|idx| {
2341 serde_json::json!({"type": "content_block_stop", "index": idx}).to_string()
2342 }),
2343 2 => (stop_reason_str(), token_count()).prop_map(|(sr, out)| {
2345 serde_json::json!({
2346 "type": "message_delta",
2347 "delta": {"stop_reason": sr},
2348 "usage": {"output_tokens": out}
2349 })
2350 .to_string()
2351 }),
2352 1 => stop_reason_str().prop_map(|sr| {
2354 serde_json::json!({
2355 "type": "message_delta",
2356 "delta": {"stop_reason": sr}
2357 })
2358 .to_string()
2359 }),
2360 2 => Just(r#"{"type":"message_stop"}"#.to_string()),
2362 2 => small_string().prop_map(|msg| {
2364 serde_json::json!({"type": "error", "error": {"message": msg}}).to_string()
2365 }),
2366 2 => Just(r#"{"type":"ping"}"#.to_string()),
2368 ]
2369 }
2370
2371 fn chaos_json() -> impl Strategy<Value = String> {
2373 prop_oneof![
2374 Just(String::new()),
2376 Just("{}".to_string()),
2377 Just("[]".to_string()),
2378 Just("null".to_string()),
2379 Just("true".to_string()),
2380 Just("42".to_string()),
2381 Just("{".to_string()),
2383 Just(r#"{"type":}"#.to_string()),
2384 Just(r#"{"type":null}"#.to_string()),
2385 "[a-z_]{1,20}".prop_map(|t| format!(r#"{{"type":"{t}"}}"#)),
2387 "[ -~]{0,64}",
2389 Just(r#"{"type":"message_start"}"#.to_string()),
2391 Just(r#"{"type":"content_block_delta"}"#.to_string()),
2392 Just(r#"{"type":"error"}"#.to_string()),
2393 ]
2394 }
2395
2396 proptest! {
2397 #![proptest_config(ProptestConfig {
2398 cases: 256,
2399 max_shrink_iters: 100,
2400 .. ProptestConfig::default()
2401 })]
2402
2403 #[test]
2404 fn process_event_valid_never_panics(data in anthropic_event_json()) {
2405 let mut state = make_state();
2406 let _ = state.process_event(&data);
2407 }
2408
2409 #[test]
2410 fn process_event_chaos_never_panics(data in chaos_json()) {
2411 let mut state = make_state();
2412 let _ = state.process_event(&data);
2413 }
2414
2415 #[test]
2416 fn process_event_sequence_never_panics(
2417 events in prop::collection::vec(anthropic_event_json(), 1..8)
2418 ) {
2419 let mut state = make_state();
2420 for event in &events {
2421 let _ = state.process_event(event);
2422 }
2423 }
2424 }
2425 }
2426}
2427
2428#[cfg(feature = "fuzzing")]
2433pub mod fuzz {
2434 use super::*;
2435 use futures::stream;
2436 use std::pin::Pin;
2437
2438 type FuzzStream =
2439 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
2440
2441 pub struct Processor(StreamState<FuzzStream>);
2443
2444 impl Default for Processor {
2445 fn default() -> Self {
2446 Self::new()
2447 }
2448 }
2449
2450 impl Processor {
2451 pub fn new() -> Self {
2453 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
2454 Self(StreamState::new(
2455 crate::sse::SseStream::new(Box::pin(empty)),
2456 "claude-fuzz".into(),
2457 "anthropic-messages".into(),
2458 "anthropic".into(),
2459 ))
2460 }
2461
2462 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
2464 Ok(self
2465 .0
2466 .process_event(data)?
2467 .map_or_else(Vec::new, |event| vec![event]))
2468 }
2469 }
2470}