1use std::collections::VecDeque;
2use std::time::{SystemTime, UNIX_EPOCH};
3
4use runtime::{
5 load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
6 OAuthTokenExchangeRequest,
7};
8use serde::Deserialize;
9
10use crate::error::ApiError;
11use crate::providers::RetryPolicy;
12use crate::sse::SseParser;
13use crate::types::{MessageRequest, MessageResponse, StreamEvent};
14
15pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
16const ANTHROPIC_VERSION: &str = "2023-06-01";
17const REQUEST_ID_HEADER: &str = "request-id";
18const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
19
20#[derive(Clone, PartialEq, Eq)]
21pub enum AuthSource {
22 None,
23 ApiKey(String),
24 BearerToken(String),
25 ApiKeyAndBearer {
26 api_key: String,
27 bearer_token: String,
28 },
29}
30
31impl std::fmt::Debug for AuthSource {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 Self::None => write!(f, "AuthSource::None"),
35 Self::ApiKey(_) => write!(f, "AuthSource::ApiKey(***)"),
36 Self::BearerToken(_) => write!(f, "AuthSource::BearerToken(***)"),
37 Self::ApiKeyAndBearer { .. } => write!(f, "AuthSource::ApiKeyAndBearer(***)"),
38 }
39 }
40}
41
42impl AuthSource {
43 pub fn from_env() -> Result<Self, ApiError> {
44 let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
45 let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
46 match (api_key, auth_token) {
47 (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
48 api_key,
49 bearer_token,
50 }),
51 (Some(api_key), None) => Ok(Self::ApiKey(api_key)),
52 (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
53 (None, None) => Err(ApiError::missing_credentials(
54 "Anthropic",
55 &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
56 )),
57 }
58 }
59
60 #[must_use]
61 pub fn api_key(&self) -> Option<&str> {
62 match self {
63 Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
64 Self::None | Self::BearerToken(_) => None,
65 }
66 }
67
68 #[must_use]
69 pub fn bearer_token(&self) -> Option<&str> {
70 match self {
71 Self::BearerToken(token)
72 | Self::ApiKeyAndBearer {
73 bearer_token: token,
74 ..
75 } => Some(token),
76 Self::None | Self::ApiKey(_) => None,
77 }
78 }
79
80 #[must_use]
81 pub fn masked_authorization_header(&self) -> &'static str {
82 if self.bearer_token().is_some() {
83 "Bearer [REDACTED]"
84 } else {
85 "<absent>"
86 }
87 }
88
89 pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
90 if let Some(api_key) = self.api_key() {
91 request_builder = request_builder.header("x-api-key", api_key);
92 }
93 if let Some(token) = self.bearer_token() {
94 request_builder = request_builder.bearer_auth(token);
95 }
96 request_builder
97 }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
101pub struct OAuthTokenSet {
102 pub access_token: String,
103 pub refresh_token: Option<String>,
104 pub expires_at: Option<u64>,
105 #[serde(default)]
106 pub scopes: Vec<String>,
107}
108
109impl From<OAuthTokenSet> for AuthSource {
110 fn from(value: OAuthTokenSet) -> Self {
111 Self::BearerToken(value.access_token)
112 }
113}
114
115impl From<runtime::ResolvedCredential> for AuthSource {
116 fn from(value: runtime::ResolvedCredential) -> Self {
117 match value {
118 runtime::ResolvedCredential::ApiKey(key) => Self::ApiKey(key),
119 runtime::ResolvedCredential::BearerToken(token) => Self::BearerToken(token),
120 runtime::ResolvedCredential::ApiKeyAndBearer {
121 api_key,
122 bearer_token,
123 } => Self::ApiKeyAndBearer {
124 api_key,
125 bearer_token,
126 },
127 }
128 }
129}
130
131#[derive(Clone)]
132pub struct CodineerApiClient {
133 http: reqwest::Client,
134 auth: AuthSource,
135 base_url: String,
136 retry: RetryPolicy,
137}
138
139impl std::fmt::Debug for CodineerApiClient {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 f.debug_struct("CodineerApiClient")
142 .field("base_url", &self.base_url)
143 .field("auth", &self.auth)
144 .finish()
145 }
146}
147
148impl CodineerApiClient {
149 #[must_use]
150 pub fn new(api_key: impl Into<String>) -> Self {
151 Self {
152 http: crate::default_http_client(),
153 auth: AuthSource::ApiKey(api_key.into()),
154 base_url: DEFAULT_BASE_URL.to_string(),
155 retry: RetryPolicy::default(),
156 }
157 }
158
159 #[must_use]
160 pub fn from_auth(auth: AuthSource) -> Self {
161 Self {
162 http: crate::default_http_client(),
163 auth,
164 base_url: DEFAULT_BASE_URL.to_string(),
165 retry: RetryPolicy::default(),
166 }
167 }
168
169 pub fn from_env() -> Result<Self, ApiError> {
170 Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
171 }
172
173 #[must_use]
174 pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
175 self.auth = auth;
176 self
177 }
178
179 #[must_use]
180 pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
181 match (
182 self.auth.api_key().map(ToOwned::to_owned),
183 auth_token.filter(|token| !token.is_empty()),
184 ) {
185 (Some(api_key), Some(bearer_token)) => {
186 self.auth = AuthSource::ApiKeyAndBearer {
187 api_key,
188 bearer_token,
189 };
190 }
191 (Some(api_key), None) => {
192 self.auth = AuthSource::ApiKey(api_key);
193 }
194 (None, Some(bearer_token)) => {
195 self.auth = AuthSource::BearerToken(bearer_token);
196 }
197 (None, None) => {
198 self.auth = AuthSource::None;
199 }
200 }
201 self
202 }
203
204 #[must_use]
205 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
206 self.base_url = base_url.into();
207 self
208 }
209
210 #[must_use]
211 pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
212 self.retry = retry;
213 self
214 }
215
216 #[must_use]
217 pub fn auth_source(&self) -> &AuthSource {
218 &self.auth
219 }
220
221 pub async fn send_message(
222 &self,
223 request: &MessageRequest,
224 ) -> Result<MessageResponse, ApiError> {
225 let request = MessageRequest {
226 stream: false,
227 ..request.clone()
228 };
229 let response = self.send_with_retry(&request).await?;
230 let request_id = request_id_from_headers(response.headers());
231 let mut response = response
232 .json::<MessageResponse>()
233 .await
234 .map_err(ApiError::from)?;
235 if response.request_id.is_none() {
236 response.request_id = request_id;
237 }
238 Ok(response)
239 }
240
241 pub async fn stream_message(
242 &self,
243 request: &MessageRequest,
244 ) -> Result<MessageStream, ApiError> {
245 let response = self
246 .send_with_retry(&request.clone().with_streaming())
247 .await?;
248 Ok(MessageStream {
249 request_id: request_id_from_headers(response.headers()),
250 response,
251 parser: SseParser::new(),
252 pending: VecDeque::new(),
253 done: false,
254 })
255 }
256
257 pub async fn exchange_oauth_code(
258 &self,
259 config: &OAuthConfig,
260 request: &OAuthTokenExchangeRequest,
261 ) -> Result<OAuthTokenSet, ApiError> {
262 let response = self
263 .http
264 .post(&config.token_url)
265 .header("content-type", "application/x-www-form-urlencoded")
266 .form(&request.form_params())
267 .send()
268 .await
269 .map_err(ApiError::from)?;
270 let response = expect_success(response).await?;
271 response
272 .json::<OAuthTokenSet>()
273 .await
274 .map_err(ApiError::from)
275 }
276
277 pub async fn refresh_oauth_token(
278 &self,
279 config: &OAuthConfig,
280 request: &OAuthRefreshRequest,
281 ) -> Result<OAuthTokenSet, ApiError> {
282 let response = self
283 .http
284 .post(&config.token_url)
285 .header("content-type", "application/x-www-form-urlencoded")
286 .form(&request.form_params())
287 .send()
288 .await
289 .map_err(ApiError::from)?;
290 let response = expect_success(response).await?;
291 response
292 .json::<OAuthTokenSet>()
293 .await
294 .map_err(ApiError::from)
295 }
296
297 async fn send_with_retry(
298 &self,
299 request: &MessageRequest,
300 ) -> Result<reqwest::Response, ApiError> {
301 let mut attempts = 0;
302 let mut last_error: Option<ApiError>;
303
304 loop {
305 attempts += 1;
306 match self.send_raw_request(request).await {
307 Ok(response) => match expect_success(response).await {
308 Ok(response) => return Ok(response),
309 Err(error)
310 if error.is_retryable() && attempts <= self.retry.max_retries + 1 =>
311 {
312 last_error = Some(error);
313 }
314 Err(error) => return Err(error),
315 },
316 Err(error) if error.is_retryable() && attempts <= self.retry.max_retries + 1 => {
317 last_error = Some(error);
318 }
319 Err(error) => return Err(error),
320 }
321
322 if attempts > self.retry.max_retries {
323 break;
324 }
325
326 tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
327 }
328
329 Err(ApiError::RetriesExhausted {
330 attempts,
331 last_error: Box::new(last_error.unwrap_or(ApiError::Auth(
332 "retry loop exited without capturing an error".into(),
333 ))),
334 })
335 }
336
337 async fn send_raw_request(
338 &self,
339 request: &MessageRequest,
340 ) -> Result<reqwest::Response, ApiError> {
341 let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
342 let request_builder = self
343 .http
344 .post(&request_url)
345 .header("anthropic-version", ANTHROPIC_VERSION)
346 .header("content-type", "application/json");
347 let mut request_builder = self.auth.apply(request_builder);
348
349 request_builder = request_builder.json(request);
350 request_builder.send().await.map_err(ApiError::from)
351 }
352
353 fn backoff_for_attempt(&self, attempt: u32) -> Result<std::time::Duration, ApiError> {
354 let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
355 return Err(ApiError::BackoffOverflow {
356 attempt,
357 base_delay: self.retry.initial_backoff,
358 });
359 };
360 Ok(self
361 .retry
362 .initial_backoff
363 .checked_mul(multiplier)
364 .map_or(self.retry.max_backoff, |delay| {
365 delay.min(self.retry.max_backoff)
366 }))
367 }
368}
369
370impl AuthSource {
371 pub fn from_env_or_saved() -> Result<Self, ApiError> {
372 if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
373 return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
374 Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
375 api_key,
376 bearer_token,
377 }),
378 None => Ok(Self::ApiKey(api_key)),
379 };
380 }
381 if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
382 return Ok(Self::BearerToken(bearer_token));
383 }
384 match load_saved_oauth_token() {
385 Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
386 if token_set.refresh_token.is_some() {
387 Err(ApiError::Auth(
388 "saved OAuth token is expired; load runtime OAuth config to refresh it"
389 .to_string(),
390 ))
391 } else {
392 Err(ApiError::ExpiredOAuthToken)
393 }
394 }
395 Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
396 Ok(None) => Err(ApiError::missing_credentials(
397 "Anthropic",
398 &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
399 )),
400 Err(error) => Err(error),
401 }
402 }
403}
404
405#[must_use]
406pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
407 token_set
408 .expires_at
409 .is_some_and(|expires_at| expires_at <= now_unix_timestamp())
410}
411
412pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
413 let Some(token_set) = load_saved_oauth_token()? else {
414 return Ok(None);
415 };
416 resolve_saved_oauth_token_set(config, token_set).map(Some)
417}
418
419pub fn has_auth_from_env_or_saved() -> Result<bool, ApiError> {
420 Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some()
421 || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some()
422 || load_saved_oauth_token()?.is_some())
423}
424
425pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
426where
427 F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
428{
429 if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
430 return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
431 Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
432 api_key,
433 bearer_token,
434 }),
435 None => Ok(AuthSource::ApiKey(api_key)),
436 };
437 }
438 if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
439 return Ok(AuthSource::BearerToken(bearer_token));
440 }
441
442 let Some(token_set) = load_saved_oauth_token()? else {
443 return Err(ApiError::missing_credentials(
444 "Anthropic",
445 &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
446 ));
447 };
448 if !oauth_token_is_expired(&token_set) {
449 return Ok(AuthSource::BearerToken(token_set.access_token));
450 }
451 if token_set.refresh_token.is_none() {
452 return Err(ApiError::ExpiredOAuthToken);
453 }
454
455 let Some(config) = load_oauth_config()? else {
456 return Err(ApiError::Auth(
457 "saved OAuth token is expired; runtime OAuth config is missing".to_string(),
458 ));
459 };
460 Ok(AuthSource::from(resolve_saved_oauth_token_set(
461 &config, token_set,
462 )?))
463}
464
465fn resolve_saved_oauth_token_set(
466 config: &OAuthConfig,
467 token_set: OAuthTokenSet,
468) -> Result<OAuthTokenSet, ApiError> {
469 if !oauth_token_is_expired(&token_set) {
470 return Ok(token_set);
471 }
472 let Some(refresh_token) = token_set.refresh_token.clone() else {
473 return Err(ApiError::ExpiredOAuthToken);
474 };
475 let client = CodineerApiClient::from_auth(AuthSource::None).with_base_url(read_base_url());
476 let refreshed = client_runtime_block_on(async {
477 client
478 .refresh_oauth_token(
479 config,
480 &OAuthRefreshRequest::from_config(
481 config,
482 refresh_token,
483 Some(token_set.scopes.clone()),
484 ),
485 )
486 .await
487 })?;
488 let resolved = OAuthTokenSet {
489 access_token: refreshed.access_token,
490 refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
491 expires_at: refreshed.expires_at,
492 scopes: refreshed.scopes,
493 };
494 save_oauth_credentials(&runtime::OAuthTokenSet {
495 access_token: resolved.access_token.clone(),
496 refresh_token: resolved.refresh_token.clone(),
497 expires_at: resolved.expires_at,
498 scopes: resolved.scopes.clone(),
499 })
500 .map_err(ApiError::from)?;
501 Ok(resolved)
502}
503
504fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
505where
506 F: std::future::Future<Output = Result<T, ApiError>>,
507{
508 match tokio::runtime::Handle::try_current() {
509 Ok(handle) => tokio::task::block_in_place(|| handle.block_on(future)),
510 Err(_) => tokio::runtime::Runtime::new()
511 .map_err(ApiError::from)?
512 .block_on(future),
513 }
514}
515
516fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
517 let token_set = load_oauth_credentials().map_err(ApiError::from)?;
518 Ok(token_set.map(|token_set| OAuthTokenSet {
519 access_token: token_set.access_token,
520 refresh_token: token_set.refresh_token,
521 expires_at: token_set.expires_at,
522 scopes: token_set.scopes,
523 }))
524}
525
526fn now_unix_timestamp() -> u64 {
527 SystemTime::now()
528 .duration_since(UNIX_EPOCH)
529 .map_or(0, |duration| duration.as_secs())
530}
531
532fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
533 match std::env::var(key) {
534 Ok(value) if !value.is_empty() => Ok(Some(value)),
535 Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
536 Err(error) => Err(ApiError::from(error)),
537 }
538}
539
540#[cfg(test)]
541fn read_api_key() -> Result<String, ApiError> {
542 let auth = AuthSource::from_env_or_saved()?;
543 auth.api_key()
544 .or_else(|| auth.bearer_token())
545 .map(ToOwned::to_owned)
546 .ok_or(ApiError::missing_credentials(
547 "Anthropic",
548 &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
549 ))
550}
551
552#[cfg(test)]
553fn read_auth_token() -> Option<String> {
554 read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
555 .ok()
556 .and_then(std::convert::identity)
557}
558
559#[must_use]
560pub fn read_base_url() -> String {
561 std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
562}
563
564fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
565 headers
566 .get(REQUEST_ID_HEADER)
567 .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
568 .and_then(|value| value.to_str().ok())
569 .map(ToOwned::to_owned)
570}
571
572#[derive(Debug)]
573pub struct MessageStream {
574 request_id: Option<String>,
575 response: reqwest::Response,
576 parser: SseParser,
577 pending: VecDeque<StreamEvent>,
578 done: bool,
579}
580
581impl MessageStream {
582 #[must_use]
583 pub fn request_id(&self) -> Option<&str> {
584 self.request_id.as_deref()
585 }
586
587 pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
588 loop {
589 if let Some(event) = self.pending.pop_front() {
590 return Ok(Some(event));
591 }
592
593 if self.done {
594 let remaining = self.parser.finish()?;
595 self.pending.extend(remaining);
596 if let Some(event) = self.pending.pop_front() {
597 return Ok(Some(event));
598 }
599 return Ok(None);
600 }
601
602 match self.response.chunk().await? {
603 Some(chunk) => {
604 self.pending.extend(self.parser.push(&chunk)?);
605 }
606 None => {
607 self.done = true;
608 }
609 }
610 }
611 }
612}
613
614async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
615 let status = response.status();
616 if status.is_success() {
617 return Ok(response);
618 }
619
620 let body = response.text().await.unwrap_or_else(|_| String::new());
621 let parsed_error = serde_json::from_str::<ApiErrorEnvelope>(&body).ok();
622 let retryable = is_retryable_status(status);
623
624 Err(ApiError::Api {
625 status,
626 error_type: parsed_error
627 .as_ref()
628 .map(|error| error.error.error_type.clone()),
629 message: parsed_error
630 .as_ref()
631 .map(|error| error.error.message.clone()),
632 body,
633 retryable,
634 })
635}
636
637const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
638 matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
639}
640
641#[derive(Debug, Deserialize)]
642struct ApiErrorEnvelope {
643 error: ApiErrorBody,
644}
645
646#[derive(Debug, Deserialize)]
647struct ApiErrorBody {
648 #[serde(rename = "type")]
649 error_type: String,
650 message: String,
651}
652
653#[cfg(test)]
654#[path = "codineer_provider_tests.rs"]
655mod tests;