1use std::collections::VecDeque;
2use std::time::{SystemTime, UNIX_EPOCH};
3
4use runtime::{
5 load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
6 OAuthTokenExchangeRequest,
7};
8use serde::Deserialize;
9
10use crate::error::ApiError;
11use crate::providers::RetryPolicy;
12use crate::sse::SseParser;
13use crate::types::{MessageRequest, MessageResponse, StreamEvent};
14
15pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
16const ANTHROPIC_VERSION: &str = "2023-06-01";
17const REQUEST_ID_HEADER: &str = "request-id";
18const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum AuthSource {
22 None,
23 ApiKey(String),
24 BearerToken(String),
25 ApiKeyAndBearer {
26 api_key: String,
27 bearer_token: String,
28 },
29}
30
31impl AuthSource {
32 pub fn from_env() -> Result<Self, ApiError> {
33 let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
34 let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
35 match (api_key, auth_token) {
36 (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
37 api_key,
38 bearer_token,
39 }),
40 (Some(api_key), None) => Ok(Self::ApiKey(api_key)),
41 (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
42 (None, None) => Err(ApiError::missing_credentials(
43 "Anthropic",
44 &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
45 )),
46 }
47 }
48
49 #[must_use]
50 pub fn api_key(&self) -> Option<&str> {
51 match self {
52 Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
53 Self::None | Self::BearerToken(_) => None,
54 }
55 }
56
57 #[must_use]
58 pub fn bearer_token(&self) -> Option<&str> {
59 match self {
60 Self::BearerToken(token)
61 | Self::ApiKeyAndBearer {
62 bearer_token: token,
63 ..
64 } => Some(token),
65 Self::None | Self::ApiKey(_) => None,
66 }
67 }
68
69 #[must_use]
70 pub fn masked_authorization_header(&self) -> &'static str {
71 if self.bearer_token().is_some() {
72 "Bearer [REDACTED]"
73 } else {
74 "<absent>"
75 }
76 }
77
78 pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
79 if let Some(api_key) = self.api_key() {
80 request_builder = request_builder.header("x-api-key", api_key);
81 }
82 if let Some(token) = self.bearer_token() {
83 request_builder = request_builder.bearer_auth(token);
84 }
85 request_builder
86 }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
90pub struct OAuthTokenSet {
91 pub access_token: String,
92 pub refresh_token: Option<String>,
93 pub expires_at: Option<u64>,
94 #[serde(default)]
95 pub scopes: Vec<String>,
96}
97
98impl From<OAuthTokenSet> for AuthSource {
99 fn from(value: OAuthTokenSet) -> Self {
100 Self::BearerToken(value.access_token)
101 }
102}
103
104#[derive(Debug, Clone)]
105pub struct CodineerApiClient {
106 http: reqwest::Client,
107 auth: AuthSource,
108 base_url: String,
109 retry: RetryPolicy,
110}
111
112impl CodineerApiClient {
113 #[must_use]
114 pub fn new(api_key: impl Into<String>) -> Self {
115 Self {
116 http: reqwest::Client::new(),
117 auth: AuthSource::ApiKey(api_key.into()),
118 base_url: DEFAULT_BASE_URL.to_string(),
119 retry: RetryPolicy::default(),
120 }
121 }
122
123 #[must_use]
124 pub fn from_auth(auth: AuthSource) -> Self {
125 Self {
126 http: reqwest::Client::new(),
127 auth,
128 base_url: DEFAULT_BASE_URL.to_string(),
129 retry: RetryPolicy::default(),
130 }
131 }
132
133 pub fn from_env() -> Result<Self, ApiError> {
134 Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
135 }
136
137 #[must_use]
138 pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
139 self.auth = auth;
140 self
141 }
142
143 #[must_use]
144 pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
145 match (
146 self.auth.api_key().map(ToOwned::to_owned),
147 auth_token.filter(|token| !token.is_empty()),
148 ) {
149 (Some(api_key), Some(bearer_token)) => {
150 self.auth = AuthSource::ApiKeyAndBearer {
151 api_key,
152 bearer_token,
153 };
154 }
155 (Some(api_key), None) => {
156 self.auth = AuthSource::ApiKey(api_key);
157 }
158 (None, Some(bearer_token)) => {
159 self.auth = AuthSource::BearerToken(bearer_token);
160 }
161 (None, None) => {
162 self.auth = AuthSource::None;
163 }
164 }
165 self
166 }
167
168 #[must_use]
169 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
170 self.base_url = base_url.into();
171 self
172 }
173
174 #[must_use]
175 pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
176 self.retry = retry;
177 self
178 }
179
180 #[must_use]
181 pub fn auth_source(&self) -> &AuthSource {
182 &self.auth
183 }
184
185 pub async fn send_message(
186 &self,
187 request: &MessageRequest,
188 ) -> Result<MessageResponse, ApiError> {
189 let request = MessageRequest {
190 stream: false,
191 ..request.clone()
192 };
193 let response = self.send_with_retry(&request).await?;
194 let request_id = request_id_from_headers(response.headers());
195 let mut response = response
196 .json::<MessageResponse>()
197 .await
198 .map_err(ApiError::from)?;
199 if response.request_id.is_none() {
200 response.request_id = request_id;
201 }
202 Ok(response)
203 }
204
205 pub async fn stream_message(
206 &self,
207 request: &MessageRequest,
208 ) -> Result<MessageStream, ApiError> {
209 let response = self
210 .send_with_retry(&request.clone().with_streaming())
211 .await?;
212 Ok(MessageStream {
213 request_id: request_id_from_headers(response.headers()),
214 response,
215 parser: SseParser::new(),
216 pending: VecDeque::new(),
217 done: false,
218 })
219 }
220
221 pub async fn exchange_oauth_code(
222 &self,
223 config: &OAuthConfig,
224 request: &OAuthTokenExchangeRequest,
225 ) -> Result<OAuthTokenSet, ApiError> {
226 let response = self
227 .http
228 .post(&config.token_url)
229 .header("content-type", "application/x-www-form-urlencoded")
230 .form(&request.form_params())
231 .send()
232 .await
233 .map_err(ApiError::from)?;
234 let response = expect_success(response).await?;
235 response
236 .json::<OAuthTokenSet>()
237 .await
238 .map_err(ApiError::from)
239 }
240
241 pub async fn refresh_oauth_token(
242 &self,
243 config: &OAuthConfig,
244 request: &OAuthRefreshRequest,
245 ) -> Result<OAuthTokenSet, ApiError> {
246 let response = self
247 .http
248 .post(&config.token_url)
249 .header("content-type", "application/x-www-form-urlencoded")
250 .form(&request.form_params())
251 .send()
252 .await
253 .map_err(ApiError::from)?;
254 let response = expect_success(response).await?;
255 response
256 .json::<OAuthTokenSet>()
257 .await
258 .map_err(ApiError::from)
259 }
260
261 async fn send_with_retry(
262 &self,
263 request: &MessageRequest,
264 ) -> Result<reqwest::Response, ApiError> {
265 let mut attempts = 0;
266 let mut last_error: Option<ApiError>;
267
268 loop {
269 attempts += 1;
270 match self.send_raw_request(request).await {
271 Ok(response) => match expect_success(response).await {
272 Ok(response) => return Ok(response),
273 Err(error)
274 if error.is_retryable() && attempts <= self.retry.max_retries + 1 =>
275 {
276 last_error = Some(error);
277 }
278 Err(error) => return Err(error),
279 },
280 Err(error) if error.is_retryable() && attempts <= self.retry.max_retries + 1 => {
281 last_error = Some(error);
282 }
283 Err(error) => return Err(error),
284 }
285
286 if attempts > self.retry.max_retries {
287 break;
288 }
289
290 tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
291 }
292
293 Err(ApiError::RetriesExhausted {
294 attempts,
295 last_error: Box::new(last_error.unwrap_or(ApiError::Auth(
296 "retry loop exited without capturing an error".into(),
297 ))),
298 })
299 }
300
301 async fn send_raw_request(
302 &self,
303 request: &MessageRequest,
304 ) -> Result<reqwest::Response, ApiError> {
305 let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
306 let request_builder = self
307 .http
308 .post(&request_url)
309 .header("anthropic-version", ANTHROPIC_VERSION)
310 .header("content-type", "application/json");
311 let mut request_builder = self.auth.apply(request_builder);
312
313 request_builder = request_builder.json(request);
314 request_builder.send().await.map_err(ApiError::from)
315 }
316
317 fn backoff_for_attempt(&self, attempt: u32) -> Result<std::time::Duration, ApiError> {
318 let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
319 return Err(ApiError::BackoffOverflow {
320 attempt,
321 base_delay: self.retry.initial_backoff,
322 });
323 };
324 Ok(self
325 .retry
326 .initial_backoff
327 .checked_mul(multiplier)
328 .map_or(self.retry.max_backoff, |delay| {
329 delay.min(self.retry.max_backoff)
330 }))
331 }
332}
333
334impl AuthSource {
335 pub fn from_env_or_saved() -> Result<Self, ApiError> {
336 if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
337 return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
338 Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
339 api_key,
340 bearer_token,
341 }),
342 None => Ok(Self::ApiKey(api_key)),
343 };
344 }
345 if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
346 return Ok(Self::BearerToken(bearer_token));
347 }
348 match load_saved_oauth_token() {
349 Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
350 if token_set.refresh_token.is_some() {
351 Err(ApiError::Auth(
352 "saved OAuth token is expired; load runtime OAuth config to refresh it"
353 .to_string(),
354 ))
355 } else {
356 Err(ApiError::ExpiredOAuthToken)
357 }
358 }
359 Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
360 Ok(None) => Err(ApiError::missing_credentials(
361 "Anthropic",
362 &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
363 )),
364 Err(error) => Err(error),
365 }
366 }
367}
368
369#[must_use]
370pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
371 token_set
372 .expires_at
373 .is_some_and(|expires_at| expires_at <= now_unix_timestamp())
374}
375
376pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
377 let Some(token_set) = load_saved_oauth_token()? else {
378 return Ok(None);
379 };
380 resolve_saved_oauth_token_set(config, token_set).map(Some)
381}
382
383pub fn has_auth_from_env_or_saved() -> Result<bool, ApiError> {
384 Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some()
385 || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some()
386 || load_saved_oauth_token()?.is_some())
387}
388
389pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
390where
391 F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
392{
393 if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
394 return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
395 Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
396 api_key,
397 bearer_token,
398 }),
399 None => Ok(AuthSource::ApiKey(api_key)),
400 };
401 }
402 if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
403 return Ok(AuthSource::BearerToken(bearer_token));
404 }
405
406 let Some(token_set) = load_saved_oauth_token()? else {
407 return Err(ApiError::missing_credentials(
408 "Anthropic",
409 &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
410 ));
411 };
412 if !oauth_token_is_expired(&token_set) {
413 return Ok(AuthSource::BearerToken(token_set.access_token));
414 }
415 if token_set.refresh_token.is_none() {
416 return Err(ApiError::ExpiredOAuthToken);
417 }
418
419 let Some(config) = load_oauth_config()? else {
420 return Err(ApiError::Auth(
421 "saved OAuth token is expired; runtime OAuth config is missing".to_string(),
422 ));
423 };
424 Ok(AuthSource::from(resolve_saved_oauth_token_set(
425 &config, token_set,
426 )?))
427}
428
429fn resolve_saved_oauth_token_set(
430 config: &OAuthConfig,
431 token_set: OAuthTokenSet,
432) -> Result<OAuthTokenSet, ApiError> {
433 if !oauth_token_is_expired(&token_set) {
434 return Ok(token_set);
435 }
436 let Some(refresh_token) = token_set.refresh_token.clone() else {
437 return Err(ApiError::ExpiredOAuthToken);
438 };
439 let client = CodineerApiClient::from_auth(AuthSource::None).with_base_url(read_base_url());
440 let refreshed = client_runtime_block_on(async {
441 client
442 .refresh_oauth_token(
443 config,
444 &OAuthRefreshRequest::from_config(
445 config,
446 refresh_token,
447 Some(token_set.scopes.clone()),
448 ),
449 )
450 .await
451 })?;
452 let resolved = OAuthTokenSet {
453 access_token: refreshed.access_token,
454 refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
455 expires_at: refreshed.expires_at,
456 scopes: refreshed.scopes,
457 };
458 save_oauth_credentials(&runtime::OAuthTokenSet {
459 access_token: resolved.access_token.clone(),
460 refresh_token: resolved.refresh_token.clone(),
461 expires_at: resolved.expires_at,
462 scopes: resolved.scopes.clone(),
463 })
464 .map_err(ApiError::from)?;
465 Ok(resolved)
466}
467
468fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
469where
470 F: std::future::Future<Output = Result<T, ApiError>>,
471{
472 match tokio::runtime::Handle::try_current() {
473 Ok(handle) => tokio::task::block_in_place(|| handle.block_on(future)),
474 Err(_) => tokio::runtime::Runtime::new()
475 .map_err(ApiError::from)?
476 .block_on(future),
477 }
478}
479
480fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
481 let token_set = load_oauth_credentials().map_err(ApiError::from)?;
482 Ok(token_set.map(|token_set| OAuthTokenSet {
483 access_token: token_set.access_token,
484 refresh_token: token_set.refresh_token,
485 expires_at: token_set.expires_at,
486 scopes: token_set.scopes,
487 }))
488}
489
490fn now_unix_timestamp() -> u64 {
491 SystemTime::now()
492 .duration_since(UNIX_EPOCH)
493 .map_or(0, |duration| duration.as_secs())
494}
495
496fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
497 match std::env::var(key) {
498 Ok(value) if !value.is_empty() => Ok(Some(value)),
499 Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
500 Err(error) => Err(ApiError::from(error)),
501 }
502}
503
504#[cfg(test)]
505fn read_api_key() -> Result<String, ApiError> {
506 let auth = AuthSource::from_env_or_saved()?;
507 auth.api_key()
508 .or_else(|| auth.bearer_token())
509 .map(ToOwned::to_owned)
510 .ok_or(ApiError::missing_credentials(
511 "Anthropic",
512 &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
513 ))
514}
515
516#[cfg(test)]
517fn read_auth_token() -> Option<String> {
518 read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
519 .ok()
520 .and_then(std::convert::identity)
521}
522
523#[must_use]
524pub fn read_base_url() -> String {
525 std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
526}
527
528fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
529 headers
530 .get(REQUEST_ID_HEADER)
531 .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
532 .and_then(|value| value.to_str().ok())
533 .map(ToOwned::to_owned)
534}
535
536#[derive(Debug)]
537pub struct MessageStream {
538 request_id: Option<String>,
539 response: reqwest::Response,
540 parser: SseParser,
541 pending: VecDeque<StreamEvent>,
542 done: bool,
543}
544
545impl MessageStream {
546 #[must_use]
547 pub fn request_id(&self) -> Option<&str> {
548 self.request_id.as_deref()
549 }
550
551 pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
552 loop {
553 if let Some(event) = self.pending.pop_front() {
554 return Ok(Some(event));
555 }
556
557 if self.done {
558 let remaining = self.parser.finish()?;
559 self.pending.extend(remaining);
560 if let Some(event) = self.pending.pop_front() {
561 return Ok(Some(event));
562 }
563 return Ok(None);
564 }
565
566 match self.response.chunk().await? {
567 Some(chunk) => {
568 self.pending.extend(self.parser.push(&chunk)?);
569 }
570 None => {
571 self.done = true;
572 }
573 }
574 }
575 }
576}
577
578async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
579 let status = response.status();
580 if status.is_success() {
581 return Ok(response);
582 }
583
584 let body = response.text().await.unwrap_or_else(|_| String::new());
585 let parsed_error = serde_json::from_str::<ApiErrorEnvelope>(&body).ok();
586 let retryable = is_retryable_status(status);
587
588 Err(ApiError::Api {
589 status,
590 error_type: parsed_error
591 .as_ref()
592 .map(|error| error.error.error_type.clone()),
593 message: parsed_error
594 .as_ref()
595 .map(|error| error.error.message.clone()),
596 body,
597 retryable,
598 })
599}
600
601const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
602 matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
603}
604
605#[derive(Debug, Deserialize)]
606struct ApiErrorEnvelope {
607 error: ApiErrorBody,
608}
609
610#[derive(Debug, Deserialize)]
611struct ApiErrorBody {
612 #[serde(rename = "type")]
613 error_type: String,
614 message: String,
615}
616
617#[cfg(test)]
618#[path = "codineer_provider_tests.rs"]
619mod tests;