1use std::collections::HashMap;
9use std::sync::{
10 Arc, Mutex as StdMutex,
11 atomic::{AtomicU64, Ordering},
12};
13use std::time::{Duration, Instant};
14
15use base64::Engine as _;
16use meerkat_core::{AuthBindingRef, AuthProfile, BackendProfile, CredentialSourceSpec, Provider};
17use parking_lot::Mutex;
18use serde::{Deserialize, Serialize};
19
20use crate::auth_oauth::{OAuthEndpoints, OAuthTokenRequestFormat};
21use crate::auth_store::{
22 PersistedAuthMode, credential_source_uses_persisted_store, persisted_auth_mode_for_auth_method,
23};
24
25const DEFAULT_MAX_OUTSTANDING_FLOWS: usize = 1024;
26
27const ANTHROPIC_CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
28const ANTHROPIC_AUTHORIZE_URL: &str = "https://claude.com/cai/oauth/authorize";
29const ANTHROPIC_CONSOLE_AUTHORIZE_URL: &str = "https://platform.claude.com/oauth/authorize";
30const ANTHROPIC_TOKEN_URL: &str = "https://platform.claude.com/v1/oauth/token";
31const ANTHROPIC_CONSOLE_SCOPES: &[&str] = &["org:create_api_key", "user:profile"];
32const ANTHROPIC_CLAUDE_AI_SCOPES: &[&str] = &[
33 "user:profile",
34 "user:inference",
35 "user:sessions:claude_code",
36 "user:mcp_servers",
37 "user:file_upload",
38];
39
40const OPENAI_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
41const OPENAI_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
42const OPENAI_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
43const OPENAI_SCOPES: &[&str] = &[
44 "openid",
45 "profile",
46 "email",
47 "offline_access",
48 "api.connectors.read",
49 "api.connectors.invoke",
50];
51const OPENAI_ORIGINATOR: &str = "codex_cli_rs";
52
53const GOOGLE_CLIENT_ID: &str = concat!(
54 "6812558",
55 "09395-oo8ft2oprdrnp9e3aqf6av3hmdib135j",
56 ".apps.googleusercontent.com",
57);
58const GOOGLE_CLIENT_SECRET: &str = concat!("GOCSP", "X-4uHgMPm", "-1o7Sk-geV6Cu5clXFsxl");
59const GOOGLE_AUTHORIZE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth";
60const GOOGLE_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
61const GOOGLE_DEVICE_CODE_URL: &str = "https://oauth2.googleapis.com/device/code";
62const GOOGLE_SCOPES: &[&str] = &[
63 "https://www.googleapis.com/auth/cloud-platform",
64 "https://www.googleapis.com/auth/userinfo.email",
65 "https://www.googleapis.com/auth/userinfo.profile",
66];
67
68const TEST_OAUTH_ENDPOINT_OVERRIDE_ENV: &str = "MEERKAT_TEST_OAUTH_ENDPOINT_OVERRIDE";
69const TEST_OAUTH_BASE_URL_ENV: &str = "MEERKAT_TEST_OAUTH_BASE_URL";
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
72pub enum OAuthProviderIdentity {
73 AnthropicClaudeAi,
74 AnthropicConsoleApiKey,
75 OpenAiChatGpt,
76 GoogleCodeAssist,
77}
78
79impl OAuthProviderIdentity {
80 pub fn from_alias(alias: &str) -> Option<Self> {
81 match alias {
82 "anthropic" | "claude" | "claude.ai" => Some(Self::AnthropicClaudeAi),
83 "anthropic_console_api_key" => Some(Self::AnthropicConsoleApiKey),
84 "openai" | "chatgpt" => Some(Self::OpenAiChatGpt),
85 "google" | "gemini" | "code_assist" => Some(Self::GoogleCodeAssist),
86 _ => None,
87 }
88 }
89
90 pub fn canonical_alias(self) -> &'static str {
91 match self {
92 Self::AnthropicClaudeAi => "anthropic",
93 Self::AnthropicConsoleApiKey => "anthropic_console_api_key",
94 Self::OpenAiChatGpt => "openai",
95 Self::GoogleCodeAssist => "google",
96 }
97 }
98
99 pub fn provider(self) -> Provider {
100 match self {
101 Self::AnthropicClaudeAi | Self::AnthropicConsoleApiKey => Provider::Anthropic,
102 Self::OpenAiChatGpt => Provider::OpenAI,
103 Self::GoogleCodeAssist => Provider::Gemini,
104 }
105 }
106
107 pub fn auth_mode(self) -> PersistedAuthMode {
108 match self {
109 Self::AnthropicClaudeAi => PersistedAuthMode::ClaudeAiOauth,
110 Self::AnthropicConsoleApiKey => PersistedAuthMode::OauthToApiKey,
111 Self::OpenAiChatGpt => PersistedAuthMode::ChatgptOauth,
112 Self::GoogleCodeAssist => PersistedAuthMode::GoogleOauth,
113 }
114 }
115
116 pub fn backend_kind(self) -> &'static str {
117 match self {
118 Self::AnthropicClaudeAi | Self::AnthropicConsoleApiKey => "anthropic_api",
119 Self::OpenAiChatGpt => "chatgpt_backend",
120 Self::GoogleCodeAssist => "google_code_assist",
121 }
122 }
123
124 pub fn client_secret(self) -> Option<&'static str> {
125 match self {
126 Self::AnthropicClaudeAi | Self::AnthropicConsoleApiKey | Self::OpenAiChatGpt => None,
127 Self::GoogleCodeAssist => Some(GOOGLE_CLIENT_SECRET),
128 }
129 }
130
131 pub fn endpoints(self, redirect_uri: impl Into<String>) -> OAuthEndpoints {
132 let endpoints = match self {
133 Self::AnthropicClaudeAi => OAuthEndpoints {
134 client_id: ANTHROPIC_CLIENT_ID.into(),
135 authorize_url: ANTHROPIC_AUTHORIZE_URL.into(),
136 token_url: ANTHROPIC_TOKEN_URL.into(),
137 device_code_url: None,
138 redirect_uri: redirect_uri.into(),
139 scopes: strings(ANTHROPIC_CLAUDE_AI_SCOPES),
140 extra_authorize_params: vec![("code".into(), "true".into())],
141 token_request_format: OAuthTokenRequestFormat::Json,
142 include_state_in_token_exchange: true,
143 refresh_scopes: strings(ANTHROPIC_CLAUDE_AI_SCOPES),
144 extra_headers: Vec::new(),
145 },
146 Self::AnthropicConsoleApiKey => OAuthEndpoints {
147 client_id: ANTHROPIC_CLIENT_ID.into(),
148 authorize_url: ANTHROPIC_CONSOLE_AUTHORIZE_URL.into(),
149 token_url: ANTHROPIC_TOKEN_URL.into(),
150 device_code_url: None,
151 redirect_uri: redirect_uri.into(),
152 scopes: strings(ANTHROPIC_CONSOLE_SCOPES),
153 extra_authorize_params: Vec::new(),
154 token_request_format: OAuthTokenRequestFormat::Json,
155 include_state_in_token_exchange: true,
156 refresh_scopes: strings(ANTHROPIC_CONSOLE_SCOPES),
157 extra_headers: Vec::new(),
158 },
159 Self::OpenAiChatGpt => OAuthEndpoints {
160 client_id: OPENAI_CLIENT_ID.into(),
161 authorize_url: OPENAI_AUTHORIZE_URL.into(),
162 token_url: OPENAI_TOKEN_URL.into(),
163 device_code_url: None,
164 redirect_uri: redirect_uri.into(),
165 scopes: strings(OPENAI_SCOPES),
166 extra_authorize_params: vec![
167 ("id_token_add_organizations".into(), "true".into()),
168 ("codex_cli_simplified_flow".into(), "true".into()),
169 ("originator".into(), OPENAI_ORIGINATOR.into()),
170 ],
171 token_request_format: OAuthTokenRequestFormat::FormUrlEncoded,
172 include_state_in_token_exchange: false,
173 refresh_scopes: Vec::new(),
174 extra_headers: Vec::new(),
175 },
176 Self::GoogleCodeAssist => OAuthEndpoints {
177 client_id: GOOGLE_CLIENT_ID.into(),
178 authorize_url: GOOGLE_AUTHORIZE_URL.into(),
179 token_url: GOOGLE_TOKEN_URL.into(),
180 device_code_url: Some(GOOGLE_DEVICE_CODE_URL.into()),
181 redirect_uri: redirect_uri.into(),
182 scopes: strings(GOOGLE_SCOPES),
183 extra_authorize_params: Vec::new(),
184 token_request_format: OAuthTokenRequestFormat::FormUrlEncoded,
185 include_state_in_token_exchange: false,
186 refresh_scopes: Vec::new(),
187 extra_headers: Vec::new(),
188 },
189 };
190 apply_test_oauth_endpoint_override(self, endpoints)
191 }
192}
193
194impl std::fmt::Display for OAuthProviderIdentity {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 f.write_str(self.canonical_alias())
197 }
198}
199
200#[derive(Debug, Clone)]
201pub struct OAuthProviderResolution {
202 pub identity: OAuthProviderIdentity,
203 pub provider: Provider,
204 pub endpoints: OAuthEndpoints,
205 pub auth_mode: PersistedAuthMode,
206 pub client_secret: Option<&'static str>,
207}
208
209#[derive(Debug, thiserror::Error, PartialEq, Eq)]
210#[error("Unknown provider '{provider}'. Supported: anthropic, openai, google.")]
211pub struct OAuthProviderResolutionError {
212 pub provider: String,
213}
214
215pub fn resolve_oauth_provider(
216 provider: &str,
217 redirect_uri: impl Into<String>,
218) -> Result<OAuthProviderResolution, OAuthProviderResolutionError> {
219 let identity = OAuthProviderIdentity::from_alias(provider).ok_or_else(|| {
220 OAuthProviderResolutionError {
221 provider: provider.to_string(),
222 }
223 })?;
224 Ok(OAuthProviderResolution {
225 identity,
226 provider: identity.provider(),
227 endpoints: identity.endpoints(redirect_uri),
228 auth_mode: identity.auth_mode(),
229 client_secret: identity.client_secret(),
230 })
231}
232
233#[doc(hidden)]
237pub fn apply_test_oauth_endpoint_override(
238 identity: OAuthProviderIdentity,
239 mut endpoints: OAuthEndpoints,
240) -> OAuthEndpoints {
241 #[cfg(not(target_arch = "wasm32"))]
242 {
243 let enabled = std::env::var(TEST_OAUTH_ENDPOINT_OVERRIDE_ENV)
244 .map(|value| {
245 matches!(
246 value.as_str(),
247 "1" | "true" | "TRUE" | "yes" | "YES" | "on" | "ON"
248 )
249 })
250 .unwrap_or(false);
251 if !enabled {
252 return endpoints;
253 }
254 let Ok(base_url) = std::env::var(TEST_OAUTH_BASE_URL_ENV) else {
255 return endpoints;
256 };
257 let base_url = base_url.trim_end_matches('/');
258 let provider = identity.canonical_alias();
259 endpoints.authorize_url = format!("{base_url}/{provider}/authorize");
260 endpoints.token_url = format!("{base_url}/{provider}/token");
261 if endpoints.device_code_url.is_some() {
262 endpoints.device_code_url = Some(format!("{base_url}/{provider}/device/code"));
263 }
264 }
265 endpoints
266}
267
268#[derive(Debug, thiserror::Error, PartialEq, Eq)]
269pub enum OAuthTargetValidationError {
270 #[error("OAuth target backend provider mismatch: expected {expected:?}, got {actual:?}")]
271 BackendProviderMismatch {
272 expected: Provider,
273 actual: Provider,
274 },
275 #[error(
276 "OAuth target backend_kind '{backend_kind}' cannot store credential mode {expected_mode:?}; expected backend_kind '{expected_backend_kind}'"
277 )]
278 BackendKindMismatch {
279 backend_kind: String,
280 expected_backend_kind: &'static str,
281 expected_mode: PersistedAuthMode,
282 },
283 #[error("OAuth target provider mismatch: expected {expected:?}, got {actual:?}")]
284 ProviderMismatch {
285 expected: Provider,
286 actual: Provider,
287 },
288 #[error(
289 "OAuth target auth_method '{auth_method}' cannot store credential mode {expected_mode:?}"
290 )]
291 AuthMethodMismatch {
292 auth_method: String,
293 expected_mode: PersistedAuthMode,
294 },
295 #[error(
296 "OAuth target source '{source_kind}' cannot store OAuth credentials; expected source.kind = 'managed_store' or 'platform_default'"
297 )]
298 SourceMismatch { source_kind: &'static str },
299}
300
301pub fn validate_oauth_login_target(
302 auth_profile: &AuthProfile,
303 identity: OAuthProviderIdentity,
304) -> Result<(), OAuthTargetValidationError> {
305 validate_oauth_target_for_auth_mode(auth_profile, identity.provider(), identity.auth_mode())
306}
307
308pub fn validate_oauth_login_binding(
309 backend_profile: &BackendProfile,
310 auth_profile: &AuthProfile,
311 identity: OAuthProviderIdentity,
312) -> Result<(), OAuthTargetValidationError> {
313 validate_oauth_target_binding_for_auth_mode(
314 backend_profile,
315 auth_profile,
316 identity.provider(),
317 identity.auth_mode(),
318 identity.backend_kind(),
319 )
320}
321
322pub fn validate_oauth_target_for_auth_mode(
323 auth_profile: &AuthProfile,
324 expected_provider: Provider,
325 expected_mode: PersistedAuthMode,
326) -> Result<(), OAuthTargetValidationError> {
327 if auth_profile.provider != expected_provider {
328 return Err(OAuthTargetValidationError::ProviderMismatch {
329 expected: expected_provider,
330 actual: auth_profile.provider,
331 });
332 }
333 match persisted_auth_mode_for_auth_method(&auth_profile.auth_method) {
334 Some(actual_mode) if actual_mode == expected_mode => {}
335 _ => {
336 return Err(OAuthTargetValidationError::AuthMethodMismatch {
337 auth_method: auth_profile.auth_method.clone(),
338 expected_mode,
339 });
340 }
341 }
342 if !oauth_source_can_store_flow_credentials(&auth_profile.source) {
343 return Err(OAuthTargetValidationError::SourceMismatch {
344 source_kind: source_kind_label(&auth_profile.source),
345 });
346 }
347 Ok(())
348}
349
350pub fn validate_oauth_target_binding_for_auth_mode(
351 backend_profile: &BackendProfile,
352 auth_profile: &AuthProfile,
353 expected_provider: Provider,
354 expected_mode: PersistedAuthMode,
355 expected_backend_kind: &'static str,
356) -> Result<(), OAuthTargetValidationError> {
357 validate_oauth_target_for_auth_mode(auth_profile, expected_provider, expected_mode)?;
358 if backend_profile.provider != expected_provider {
359 return Err(OAuthTargetValidationError::BackendProviderMismatch {
360 expected: expected_provider,
361 actual: backend_profile.provider,
362 });
363 }
364 if backend_profile.backend_kind != expected_backend_kind {
365 return Err(OAuthTargetValidationError::BackendKindMismatch {
366 backend_kind: backend_profile.backend_kind.clone(),
367 expected_backend_kind,
368 expected_mode,
369 });
370 }
371 Ok(())
372}
373
374fn oauth_source_can_store_flow_credentials(source: &CredentialSourceSpec) -> bool {
375 credential_source_uses_persisted_store(source)
376}
377
378fn source_kind_label(source: &CredentialSourceSpec) -> &'static str {
379 match source {
380 CredentialSourceSpec::InlineSecret { .. } => "inline_secret",
381 CredentialSourceSpec::ManagedStore => "managed_store",
382 CredentialSourceSpec::Env { .. } => "env",
383 CredentialSourceSpec::ExternalResolver { .. } => "external_resolver",
384 CredentialSourceSpec::PlatformDefault => "platform_default",
385 CredentialSourceSpec::Command { .. } => "command",
386 CredentialSourceSpec::FileDescriptor { .. } => "file_descriptor",
387 }
388}
389
390#[derive(Debug, Clone, PartialEq, Eq)]
391pub struct OAuthFlowRecord {
392 pub target: AuthBindingRef,
393 pub provider: OAuthProviderIdentity,
394 pub redirect_uri: String,
395 pub pkce_verifier: String,
396 pub created_at: Instant,
397}
398
399#[derive(Debug, Clone, PartialEq, Eq)]
400pub struct OAuthDeviceFlowRecord {
401 pub target: AuthBindingRef,
402 pub provider: OAuthProviderIdentity,
403 pub device_code: String,
404 pub created_at: Instant,
405 pub expires_at: Instant,
406}
407
408#[derive(Debug, Clone, Default, PartialEq, Eq)]
409pub struct OAuthPrunedFlows {
410 pub browser: Vec<(String, AuthBindingRef)>,
411 pub device: Vec<(String, AuthBindingRef)>,
412}
413
414impl OAuthPrunedFlows {
415 fn from_expired(
416 browser: Vec<(String, OAuthFlowRecord)>,
417 device: Vec<OAuthDeviceFlowRecord>,
418 ) -> Self {
419 Self {
420 browser: browser
421 .into_iter()
422 .map(|(flow_id, record)| (flow_id, record.target))
423 .collect(),
424 device: device
425 .into_iter()
426 .map(|record| (record.device_code, record.target))
427 .collect(),
428 }
429 }
430}
431
432#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
433pub struct OAuthFlowRegistrySnapshot {
434 #[serde(default)]
435 pub browser: Vec<PersistedOAuthBrowserFlow>,
436 #[serde(default)]
437 pub device: Vec<PersistedOAuthDeviceFlow>,
438}
439
440#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
441pub struct PersistedOAuthBrowserFlow {
442 pub state: String,
443 pub target: AuthBindingRef,
444 pub provider: String,
445 pub redirect_uri: String,
446 pub pkce_verifier: String,
447 pub created_at_millis: u64,
448 pub expires_at_millis: u64,
449}
450
451#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
452pub struct PersistedOAuthDeviceFlow {
453 pub target: AuthBindingRef,
454 pub provider: String,
455 pub device_code: String,
456 pub created_at_millis: u64,
457 pub expires_at_millis: u64,
458}
459
460#[derive(Debug, Clone, PartialEq, Eq)]
461struct OAuthDeviceFlowState {
462 record: OAuthDeviceFlowRecord,
463 poll_lease: Option<OAuthDevicePollLeaseState>,
464}
465
466#[derive(Debug, Clone, Copy, PartialEq, Eq)]
467struct OAuthDevicePollLeaseState {
468 id: u64,
469}
470
471#[derive(Debug, thiserror::Error, PartialEq, Eq)]
472pub enum OAuthFlowError {
473 #[error("oauth state is missing or expired")]
474 Missing,
475 #[error("oauth registry projection payload missing after AuthMachine accepted {operation}")]
476 RegistryProjectionMissing { operation: &'static str },
477 #[error("oauth state provider mismatch: expected {expected}, got {actual}")]
478 ProviderMismatch {
479 expected: OAuthProviderIdentity,
480 actual: OAuthProviderIdentity,
481 },
482 #[error("oauth state redirect_uri mismatch")]
483 RedirectUriMismatch,
484 #[error("oauth state target mismatch: expected {expected:?}, got {actual:?}")]
485 TargetMismatch {
486 expected: Box<AuthBindingRef>,
487 actual: Box<AuthBindingRef>,
488 },
489 #[error("failed to generate oauth state token")]
490 StateGenerationFailed,
491 #[error("oauth state registry is at capacity ({max_outstanding} outstanding flows)")]
492 CapacityExceeded { max_outstanding: usize },
493 #[error("oauth device code poll is already in progress")]
494 DevicePollInProgress,
495 #[error("oauth device code is already admitted")]
496 DeviceCodeAlreadyAdmitted,
497 #[error("oauth device code expiry is out of range")]
498 DeviceExpiryOutOfRange,
499 #[error("oauth flow lifecycle transition rejected during {operation}: {detail}")]
500 LifecycleRejected {
501 operation: &'static str,
502 detail: String,
503 },
504 #[error("oauth flow durable persistence failed during {operation}: {detail}")]
505 PersistenceFailed {
506 operation: &'static str,
507 detail: String,
508 },
509}
510
511pub trait OAuthDevicePollLifecycle: Send + Sync {
512 fn device_flow_state_is_authmachine_owned(&self) -> bool {
513 false
514 }
515
516 fn finish_device_poll(
517 &self,
518 target: &AuthBindingRef,
519 device_code: &str,
520 ) -> Result<(), OAuthFlowError>;
521
522 fn consume_device_flow(
523 &self,
524 target: &AuthBindingRef,
525 device_code: &str,
526 provider: OAuthProviderIdentity,
527 ) -> Result<(), OAuthFlowError>;
528
529 fn expire_device_flow(
530 &self,
531 target: &AuthBindingRef,
532 device_code: &str,
533 ) -> Result<(), OAuthFlowError>;
534
535 fn restore_device_flow(&self, _record: &OAuthDeviceFlowRecord) -> Result<(), OAuthFlowError> {
536 Ok(())
537 }
538
539 fn device_flow_payloads_changed(&self) -> Result<(), OAuthFlowError> {
540 Ok(())
541 }
542
543 fn device_flow_payload_removed(
544 &self,
545 _record: &OAuthDeviceFlowRecord,
546 ) -> Result<(), OAuthFlowError> {
547 self.device_flow_payloads_changed()
548 }
549}
550
551pub struct OAuthDevicePollLease {
552 device_flows: Arc<Mutex<HashMap<String, OAuthDeviceFlowState>>>,
553 target: AuthBindingRef,
554 device_code: String,
555 provider: OAuthProviderIdentity,
556 lease_id: u64,
557 lifecycle: Option<Arc<dyn OAuthDevicePollLifecycle>>,
558 operation_lock: Option<Arc<StdMutex<()>>>,
559 active: bool,
560}
561
562impl std::fmt::Debug for OAuthDevicePollLease {
563 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
564 f.debug_struct("OAuthDevicePollLease")
565 .field("target", &self.target)
566 .field("device_code", &self.device_code)
567 .field("provider", &self.provider)
568 .field("lease_id", &self.lease_id)
569 .field("has_lifecycle", &self.lifecycle.is_some())
570 .field("has_operation_lock", &self.operation_lock.is_some())
571 .field("active", &self.active)
572 .finish()
573 }
574}
575
576impl OAuthDevicePollLease {
577 fn new(
578 device_flows: Arc<Mutex<HashMap<String, OAuthDeviceFlowState>>>,
579 target: AuthBindingRef,
580 device_code: String,
581 provider: OAuthProviderIdentity,
582 lease_id: u64,
583 ) -> Self {
584 Self {
585 device_flows,
586 target,
587 device_code,
588 provider,
589 lease_id,
590 lifecycle: None,
591 operation_lock: None,
592 active: true,
593 }
594 }
595
596 pub fn terminal_flow_state_is_authmachine_owned(&self) -> bool {
597 self.lifecycle
598 .as_ref()
599 .map(|lifecycle| lifecycle.device_flow_state_is_authmachine_owned())
600 .unwrap_or(false)
601 }
602
603 fn local_missing_error(&self, operation: &'static str) -> OAuthFlowError {
604 if self.terminal_flow_state_is_authmachine_owned() {
605 OAuthFlowError::RegistryProjectionMissing { operation }
606 } else {
607 OAuthFlowError::Missing
608 }
609 }
610
611 pub fn with_lifecycle(mut self, lifecycle: Arc<dyn OAuthDevicePollLifecycle>) -> Self {
612 self.lifecycle = Some(lifecycle);
613 self
614 }
615
616 pub fn with_operation_lock(mut self, operation_lock: Arc<StdMutex<()>>) -> Self {
617 self.operation_lock = Some(operation_lock);
618 self
619 }
620
621 pub fn finish(mut self) -> Result<(), OAuthFlowError> {
622 let _operation_guard = self.operation_lock.as_ref().map(|lock| {
623 lock.lock()
624 .unwrap_or_else(std::sync::PoisonError::into_inner)
625 });
626 let verify_result = {
627 let mut flows = self.device_flows.lock();
628 prune_expired_device_locked(&mut flows);
629 verify_device_poll_lease_locked(
630 &mut flows,
631 &self.device_code,
632 self.provider,
633 self.lease_id,
634 )
635 .map(|_| ())
636 };
637 if matches!(verify_result, Err(OAuthFlowError::Missing)) {
638 if self.terminal_flow_state_is_authmachine_owned()
639 && let Some(lifecycle) = &self.lifecycle
640 {
641 let _ = lifecycle.finish_device_poll(&self.target, &self.device_code);
642 }
643 return Err(self.local_missing_error("finish_oauth_device_poll"));
644 }
645 verify_result?;
646
647 if let Some(lifecycle) = &self.lifecycle {
648 lifecycle.finish_device_poll(&self.target, &self.device_code)?;
649 }
650
651 let result = {
652 let mut flows = self.device_flows.lock();
653 let result = release_device_poll_lease_locked(
654 &mut flows,
655 &self.device_code,
656 self.provider,
657 self.lease_id,
658 );
659 prune_expired_device_locked(&mut flows);
660 result
661 };
662 match result {
663 Ok(()) => {
664 self.active = false;
665 if let Some(lifecycle) = &self.lifecycle {
666 lifecycle.device_flow_payloads_changed()?;
667 }
668 Ok(())
669 }
670 Err(OAuthFlowError::Missing) => {
671 Err(self.local_missing_error("finish_oauth_device_poll"))
672 }
673 Err(err) => Err(err),
674 }
675 }
676
677 pub fn verify(&self) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
678 let mut flows = self.device_flows.lock();
679 prune_expired_device_locked(&mut flows);
680 let result = verify_device_poll_lease_locked(
681 &mut flows,
682 &self.device_code,
683 self.provider,
684 self.lease_id,
685 );
686 prune_expired_device_locked(&mut flows);
687 if matches!(result, Err(OAuthFlowError::Missing)) {
688 return Err(self.local_missing_error("verify_oauth_device_poll"));
689 }
690 result
691 }
692
693 pub fn consume(mut self) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
694 let _operation_guard = self.operation_lock.as_ref().map(|lock| {
695 lock.lock()
696 .unwrap_or_else(std::sync::PoisonError::into_inner)
697 });
698 let verified = {
699 let mut flows = self.device_flows.lock();
700 prune_expired_device_locked(&mut flows);
701 verify_device_poll_lease_locked(
702 &mut flows,
703 &self.device_code,
704 self.provider,
705 self.lease_id,
706 )
707 };
708 let verified = match verified {
709 Ok(verified) => verified,
710 Err(OAuthFlowError::Missing) => {
711 if self.terminal_flow_state_is_authmachine_owned()
712 && let Some(lifecycle) = &self.lifecycle
713 {
714 let _ = lifecycle.finish_device_poll(&self.target, &self.device_code);
715 }
716 return Err(self.local_missing_error("consume_oauth_device_flow"));
717 }
718 Err(err) => return Err(err),
719 };
720
721 if let Some(lifecycle) = &self.lifecycle {
722 lifecycle.consume_device_flow(&self.target, &self.device_code, self.provider)?;
723 }
724
725 let result = {
726 let mut flows = self.device_flows.lock();
727 let result = consume_device_poll_lease_locked(
728 &mut flows,
729 &self.device_code,
730 self.provider,
731 self.lease_id,
732 );
733 prune_expired_device_locked(&mut flows);
734 result
735 };
736 match result {
737 Ok(record) => {
738 self.active = false;
739 if let Some(lifecycle) = &self.lifecycle
740 && let Err(err) = lifecycle.device_flow_payload_removed(&record)
741 {
742 if matches!(
743 err,
744 OAuthFlowError::Missing | OAuthFlowError::RegistryProjectionMissing { .. }
745 ) {
746 return Err(err);
747 }
748 let _ = lifecycle.restore_device_flow(&verified);
749 let mut flows = self.device_flows.lock();
750 flows.insert(
751 verified.device_code.clone(),
752 OAuthDeviceFlowState {
753 record: verified,
754 poll_lease: None,
755 },
756 );
757 return Err(err);
758 }
759 Ok(record)
760 }
761 Err(err) => {
762 if matches!(err, OAuthFlowError::Missing)
763 && self.terminal_flow_state_is_authmachine_owned()
764 {
765 return Err(self.local_missing_error("consume_oauth_device_flow"));
766 }
767 if let Some(lifecycle) = &self.lifecycle {
768 let _ = lifecycle.restore_device_flow(&verified);
769 if matches!(err, OAuthFlowError::Missing) {
770 let _ = lifecycle.expire_device_flow(&self.target, &self.device_code);
771 }
772 }
773 Err(err)
774 }
775 }
776 }
777}
778
779impl Drop for OAuthDevicePollLease {
780 fn drop(&mut self) {
781 if !self.active {
782 return;
783 }
784 let mut flows = self.device_flows.lock();
785 prune_expired_device_locked(&mut flows);
786 let result = release_device_poll_lease_locked(
787 &mut flows,
788 &self.device_code,
789 self.provider,
790 self.lease_id,
791 );
792 prune_expired_device_locked(&mut flows);
793 if let Some(lifecycle) = &self.lifecycle {
794 if matches!(result, Err(OAuthFlowError::Missing)) {
795 if lifecycle.device_flow_state_is_authmachine_owned() {
796 let _ = lifecycle.finish_device_poll(&self.target, &self.device_code);
797 } else {
798 let _ = lifecycle.expire_device_flow(&self.target, &self.device_code);
799 }
800 } else {
801 let _ = lifecycle.finish_device_poll(&self.target, &self.device_code);
802 }
803 }
804 }
805}
806
807pub trait OAuthFlowAuthority: Send + Sync {
808 fn terminal_flow_state_is_authmachine_owned(&self) -> bool {
809 false
810 }
811
812 fn start(
813 &self,
814 target: AuthBindingRef,
815 provider: OAuthProviderIdentity,
816 redirect_uri: String,
817 pkce_verifier: String,
818 ) -> Result<String, OAuthFlowError>;
819
820 fn verify(
821 &self,
822 state: &str,
823 target: &AuthBindingRef,
824 provider: OAuthProviderIdentity,
825 redirect_uri: &str,
826 ) -> Result<OAuthFlowRecord, OAuthFlowError>;
827
828 fn consume(
829 &self,
830 state: &str,
831 target: &AuthBindingRef,
832 provider: OAuthProviderIdentity,
833 redirect_uri: &str,
834 ) -> Result<OAuthFlowRecord, OAuthFlowError>;
835
836 fn admit_device_code(
837 &self,
838 target: AuthBindingRef,
839 provider: OAuthProviderIdentity,
840 device_code: String,
841 expires_in: Duration,
842 ) -> Result<(), OAuthFlowError>;
843
844 fn verify_device_code(
845 &self,
846 device_code: &str,
847 target: &AuthBindingRef,
848 provider: OAuthProviderIdentity,
849 ) -> Result<OAuthDeviceFlowRecord, OAuthFlowError>;
850
851 fn begin_device_code_poll(
852 &self,
853 device_code: &str,
854 target: &AuthBindingRef,
855 provider: OAuthProviderIdentity,
856 ) -> Result<OAuthDevicePollLease, OAuthFlowError>;
857}
858
859#[derive(Debug)]
860pub struct OAuthFlowRegistry {
861 ttl: Duration,
862 max_outstanding: usize,
863 flows: Mutex<HashMap<String, OAuthFlowRecord>>,
864 device_flows: Arc<Mutex<HashMap<String, OAuthDeviceFlowState>>>,
865 next_device_poll_lease_id: AtomicU64,
866}
867
868impl OAuthFlowRegistry {
869 pub fn new(ttl: Duration) -> Self {
870 Self::new_with_capacity(ttl, DEFAULT_MAX_OUTSTANDING_FLOWS)
871 }
872
873 pub fn new_with_capacity(ttl: Duration, max_outstanding: usize) -> Self {
874 Self {
875 ttl,
876 max_outstanding: max_outstanding.max(1),
877 flows: Mutex::new(HashMap::new()),
878 device_flows: Arc::new(Mutex::new(HashMap::new())),
879 next_device_poll_lease_id: AtomicU64::new(1),
880 }
881 }
882
883 pub fn max_outstanding(&self) -> usize {
884 self.max_outstanding
885 }
886
887 pub fn ttl(&self) -> Duration {
888 self.ttl
889 }
890
891 pub fn new_state() -> Result<String, OAuthFlowError> {
892 new_state_token()
893 }
894
895 pub fn start(
896 &self,
897 target: AuthBindingRef,
898 provider: OAuthProviderIdentity,
899 redirect_uri: impl Into<String>,
900 pkce_verifier: impl Into<String>,
901 ) -> Result<String, OAuthFlowError> {
902 <Self as OAuthFlowAuthority>::start(
903 self,
904 target,
905 provider,
906 redirect_uri.into(),
907 pkce_verifier.into(),
908 )
909 }
910
911 pub fn verify(
912 &self,
913 state: &str,
914 target: &AuthBindingRef,
915 provider: OAuthProviderIdentity,
916 redirect_uri: &str,
917 ) -> Result<OAuthFlowRecord, OAuthFlowError> {
918 <Self as OAuthFlowAuthority>::verify(self, state, target, provider, redirect_uri)
919 }
920
921 pub fn consume(
922 &self,
923 state: &str,
924 target: &AuthBindingRef,
925 provider: OAuthProviderIdentity,
926 redirect_uri: &str,
927 ) -> Result<OAuthFlowRecord, OAuthFlowError> {
928 <Self as OAuthFlowAuthority>::consume(self, state, target, provider, redirect_uri)
929 }
930
931 pub fn admit_device_code(
932 &self,
933 target: AuthBindingRef,
934 provider: OAuthProviderIdentity,
935 device_code: impl Into<String>,
936 expires_in: Duration,
937 ) -> Result<(), OAuthFlowError> {
938 <Self as OAuthFlowAuthority>::admit_device_code(
939 self,
940 target,
941 provider,
942 device_code.into(),
943 expires_in,
944 )
945 }
946
947 pub fn verify_device_code(
948 &self,
949 device_code: &str,
950 target: &AuthBindingRef,
951 provider: OAuthProviderIdentity,
952 ) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
953 <Self as OAuthFlowAuthority>::verify_device_code(self, device_code, target, provider)
954 }
955
956 pub fn begin_device_code_poll(
957 &self,
958 device_code: &str,
959 target: &AuthBindingRef,
960 provider: OAuthProviderIdentity,
961 ) -> Result<OAuthDevicePollLease, OAuthFlowError> {
962 <Self as OAuthFlowAuthority>::begin_device_code_poll(self, device_code, target, provider)
963 }
964
965 pub fn expire_device_code(
966 &self,
967 device_code: &str,
968 target: &AuthBindingRef,
969 provider: OAuthProviderIdentity,
970 ) -> Result<(), OAuthFlowError> {
971 let mut flows = self.device_flows.lock();
972 prune_expired_device_locked(&mut flows);
973 let Some(state) = flows.get(device_code) else {
974 return Err(OAuthFlowError::Missing);
975 };
976 verify_device_record(&state.record, target, provider)?;
977 flows.remove(device_code);
978 Ok(())
979 }
980
981 pub fn prune_expired_browser_flows(&self) -> Vec<(String, AuthBindingRef)> {
982 let mut flows = self.flows.lock();
983 take_expired_locked(&mut flows, self.ttl)
984 .into_iter()
985 .map(|(flow_id, record)| (flow_id, record.target))
986 .collect()
987 }
988
989 pub fn prune_expired_device_flows(&self) -> Vec<(String, AuthBindingRef)> {
990 let mut flows = self.device_flows.lock();
991 take_expired_device_locked(&mut flows)
992 .into_iter()
993 .map(|record| (record.device_code, record.target))
994 .collect()
995 }
996
997 pub fn retain_flows_with_lifecycle(
998 &self,
999 mut browser_active: impl FnMut(&AuthBindingRef, &str) -> bool,
1000 mut device_active: impl FnMut(&AuthBindingRef, &str) -> bool,
1001 ) -> OAuthPrunedFlows {
1002 let mut flows = self.flows.lock();
1003 let mut browser = Vec::new();
1004 flows.retain(|flow_id, record| {
1005 let keep = browser_active(&record.target, flow_id);
1006 if !keep {
1007 browser.push((flow_id.clone(), record.target.clone()));
1008 }
1009 keep
1010 });
1011
1012 let mut device_flows = self.device_flows.lock();
1013 let mut device = Vec::new();
1014 device_flows.retain(|device_code, state| {
1015 let keep = device_active(&state.record.target, device_code);
1016 if !keep {
1017 device.push((device_code.clone(), state.record.target.clone()));
1018 }
1019 keep
1020 });
1021 OAuthPrunedFlows { browser, device }
1022 }
1023
1024 pub fn snapshot_for_persistence(&self, now_millis: u64) -> OAuthFlowRegistrySnapshot {
1025 let flows = self.flows.lock();
1026 let browser = flows
1027 .iter()
1028 .filter_map(|(state, record)| {
1029 let elapsed = record.created_at.elapsed();
1030 if elapsed > self.ttl {
1031 return None;
1032 }
1033 let elapsed_millis = duration_millis_u64(elapsed);
1034 let ttl_millis = duration_millis_u64(self.ttl);
1035 Some(PersistedOAuthBrowserFlow {
1036 state: state.clone(),
1037 target: record.target.clone(),
1038 provider: record.provider.canonical_alias().to_string(),
1039 redirect_uri: record.redirect_uri.clone(),
1040 pkce_verifier: record.pkce_verifier.clone(),
1041 created_at_millis: now_millis.saturating_sub(elapsed_millis),
1042 expires_at_millis: now_millis
1043 .saturating_add(ttl_millis.saturating_sub(elapsed_millis)),
1044 })
1045 })
1046 .collect::<Vec<_>>();
1047 drop(flows);
1048
1049 let device_flows = self.device_flows.lock();
1050 let now = Instant::now();
1051 let device = device_flows
1052 .values()
1053 .filter_map(|state| {
1054 let remaining = state.record.expires_at.checked_duration_since(now)?;
1055 let created_elapsed = state.record.created_at.elapsed();
1056 Some(PersistedOAuthDeviceFlow {
1057 target: state.record.target.clone(),
1058 provider: state.record.provider.canonical_alias().to_string(),
1059 device_code: state.record.device_code.clone(),
1060 created_at_millis: now_millis
1061 .saturating_sub(duration_millis_u64(created_elapsed)),
1062 expires_at_millis: now_millis.saturating_add(duration_millis_u64(remaining)),
1063 })
1064 })
1065 .collect();
1066
1067 OAuthFlowRegistrySnapshot { browser, device }
1068 }
1069
1070 pub fn insert_restored_browser_flow(
1071 &self,
1072 state: String,
1073 target: AuthBindingRef,
1074 provider: OAuthProviderIdentity,
1075 redirect_uri: String,
1076 pkce_verifier: String,
1077 created_at: Instant,
1078 ) -> Result<(), OAuthFlowError> {
1079 let mut flows = self.flows.lock();
1080 let device_flows = self.device_flows.lock();
1081 if flows.len() + device_flows.len() >= self.max_outstanding {
1082 return Err(OAuthFlowError::CapacityExceeded {
1083 max_outstanding: self.max_outstanding,
1084 });
1085 }
1086 flows.insert(
1087 state,
1088 OAuthFlowRecord {
1089 target,
1090 provider,
1091 redirect_uri,
1092 pkce_verifier,
1093 created_at,
1094 },
1095 );
1096 Ok(())
1097 }
1098
1099 pub fn insert_restored_device_flow(
1100 &self,
1101 target: AuthBindingRef,
1102 provider: OAuthProviderIdentity,
1103 device_code: String,
1104 created_at: Instant,
1105 expires_at: Instant,
1106 ) -> Result<(), OAuthFlowError> {
1107 let flows = self.flows.lock();
1108 let mut device_flows = self.device_flows.lock();
1109 if device_flows.contains_key(&device_code) {
1110 return Err(OAuthFlowError::DeviceCodeAlreadyAdmitted);
1111 }
1112 if flows.len() + device_flows.len() >= self.max_outstanding {
1113 return Err(OAuthFlowError::CapacityExceeded {
1114 max_outstanding: self.max_outstanding,
1115 });
1116 }
1117 let record = OAuthDeviceFlowRecord {
1118 target,
1119 provider,
1120 device_code: device_code.clone(),
1121 created_at,
1122 expires_at,
1123 };
1124 device_flows.insert(
1125 device_code,
1126 OAuthDeviceFlowState {
1127 record,
1128 poll_lease: None,
1129 },
1130 );
1131 Ok(())
1132 }
1133
1134 pub fn start_with_pruned(
1135 &self,
1136 target: AuthBindingRef,
1137 provider: OAuthProviderIdentity,
1138 redirect_uri: String,
1139 pkce_verifier: String,
1140 ) -> Result<(String, OAuthPrunedFlows), OAuthFlowError> {
1141 let state = new_state_token()?;
1142 let record = OAuthFlowRecord {
1143 target,
1144 provider,
1145 redirect_uri,
1146 pkce_verifier,
1147 created_at: Instant::now(),
1148 };
1149 let mut flows = self.flows.lock();
1150 let mut device_flows = self.device_flows.lock();
1151 let expired_browser = take_expired_locked(&mut flows, self.ttl);
1152 let expired_device = take_expired_device_locked(&mut device_flows);
1153 if flows.len() + device_flows.len() >= self.max_outstanding {
1154 return Err(OAuthFlowError::CapacityExceeded {
1155 max_outstanding: self.max_outstanding,
1156 });
1157 }
1158 flows.insert(state.clone(), record);
1159 Ok((
1160 state,
1161 OAuthPrunedFlows::from_expired(expired_browser, expired_device),
1162 ))
1163 }
1164
1165 pub fn insert_browser_flow_with_pruned(
1166 &self,
1167 state: String,
1168 target: AuthBindingRef,
1169 provider: OAuthProviderIdentity,
1170 redirect_uri: String,
1171 pkce_verifier: String,
1172 ) -> Result<OAuthPrunedFlows, OAuthFlowError> {
1173 let record = OAuthFlowRecord {
1174 target,
1175 provider,
1176 redirect_uri,
1177 pkce_verifier,
1178 created_at: Instant::now(),
1179 };
1180 let mut flows = self.flows.lock();
1181 let mut device_flows = self.device_flows.lock();
1182 let expired_browser = take_expired_locked(&mut flows, self.ttl);
1183 let expired_device = take_expired_device_locked(&mut device_flows);
1184 if flows.len() + device_flows.len() >= self.max_outstanding {
1185 return Err(OAuthFlowError::CapacityExceeded {
1186 max_outstanding: self.max_outstanding,
1187 });
1188 }
1189 flows.insert(state, record);
1190 Ok(OAuthPrunedFlows::from_expired(
1191 expired_browser,
1192 expired_device,
1193 ))
1194 }
1195
1196 pub fn admit_device_code_with_pruned(
1197 &self,
1198 target: AuthBindingRef,
1199 provider: OAuthProviderIdentity,
1200 device_code: String,
1201 expires_in: Duration,
1202 ) -> Result<OAuthPrunedFlows, OAuthFlowError> {
1203 let mut flows = self.flows.lock();
1204 let mut device_flows = self.device_flows.lock();
1205 let expired_browser = take_expired_locked(&mut flows, self.ttl);
1206 let expired_device = take_expired_device_locked(&mut device_flows);
1207 if device_flows.contains_key(&device_code) {
1208 return Err(OAuthFlowError::DeviceCodeAlreadyAdmitted);
1209 }
1210 if flows.len() + device_flows.len() >= self.max_outstanding {
1211 return Err(OAuthFlowError::CapacityExceeded {
1212 max_outstanding: self.max_outstanding,
1213 });
1214 }
1215 let now = Instant::now();
1216 let expires_at = now
1217 .checked_add(expires_in)
1218 .ok_or(OAuthFlowError::DeviceExpiryOutOfRange)?;
1219 let record = OAuthDeviceFlowRecord {
1220 target,
1221 provider,
1222 device_code: device_code.clone(),
1223 created_at: now,
1224 expires_at,
1225 };
1226 device_flows.insert(
1227 device_code,
1228 OAuthDeviceFlowState {
1229 record,
1230 poll_lease: None,
1231 },
1232 );
1233 Ok(OAuthPrunedFlows::from_expired(
1234 expired_browser,
1235 expired_device,
1236 ))
1237 }
1238
1239 pub fn admit_device_code_with_pruned_without_capacity(
1240 &self,
1241 target: AuthBindingRef,
1242 provider: OAuthProviderIdentity,
1243 device_code: String,
1244 expires_in: Duration,
1245 ) -> Result<OAuthPrunedFlows, OAuthFlowError> {
1246 let mut flows = self.flows.lock();
1247 let mut device_flows = self.device_flows.lock();
1248 let expired_browser = take_expired_locked(&mut flows, self.ttl);
1249 let expired_device = take_expired_device_locked(&mut device_flows);
1250 if device_flows.contains_key(&device_code) {
1251 return Err(OAuthFlowError::DeviceCodeAlreadyAdmitted);
1252 }
1253 let now = Instant::now();
1254 let expires_at = now
1255 .checked_add(expires_in)
1256 .ok_or(OAuthFlowError::DeviceExpiryOutOfRange)?;
1257 let record = OAuthDeviceFlowRecord {
1258 target,
1259 provider,
1260 device_code: device_code.clone(),
1261 created_at: now,
1262 expires_at,
1263 };
1264 device_flows.insert(
1265 device_code,
1266 OAuthDeviceFlowState {
1267 record,
1268 poll_lease: None,
1269 },
1270 );
1271 Ok(OAuthPrunedFlows::from_expired(
1272 expired_browser,
1273 expired_device,
1274 ))
1275 }
1276}
1277
1278impl Default for OAuthFlowRegistry {
1279 fn default() -> Self {
1280 Self::new(Duration::from_secs(10 * 60))
1281 }
1282}
1283
1284impl OAuthFlowAuthority for OAuthFlowRegistry {
1285 fn start(
1286 &self,
1287 target: AuthBindingRef,
1288 provider: OAuthProviderIdentity,
1289 redirect_uri: String,
1290 pkce_verifier: String,
1291 ) -> Result<String, OAuthFlowError> {
1292 self.start_with_pruned(target, provider, redirect_uri, pkce_verifier)
1293 .map(|(state, _)| state)
1294 }
1295
1296 fn verify(
1297 &self,
1298 state: &str,
1299 target: &AuthBindingRef,
1300 provider: OAuthProviderIdentity,
1301 redirect_uri: &str,
1302 ) -> Result<OAuthFlowRecord, OAuthFlowError> {
1303 let mut flows = self.flows.lock();
1304 prune_expired_locked(&mut flows, self.ttl);
1305 let Some(record) = flows.get(state) else {
1306 return Err(OAuthFlowError::Missing);
1307 };
1308 verify_browser_record(record, target, provider, redirect_uri)?;
1309 Ok(record.clone())
1310 }
1311
1312 fn consume(
1313 &self,
1314 state: &str,
1315 target: &AuthBindingRef,
1316 provider: OAuthProviderIdentity,
1317 redirect_uri: &str,
1318 ) -> Result<OAuthFlowRecord, OAuthFlowError> {
1319 let mut flows = self.flows.lock();
1320 prune_expired_locked(&mut flows, self.ttl);
1321 let Some(record) = flows.get(state) else {
1322 return Err(OAuthFlowError::Missing);
1323 };
1324 verify_browser_record(record, target, provider, redirect_uri)?;
1325 flows.remove(state).ok_or(OAuthFlowError::Missing)
1326 }
1327
1328 fn admit_device_code(
1329 &self,
1330 target: AuthBindingRef,
1331 provider: OAuthProviderIdentity,
1332 device_code: String,
1333 expires_in: Duration,
1334 ) -> Result<(), OAuthFlowError> {
1335 self.admit_device_code_with_pruned(target, provider, device_code, expires_in)
1336 .map(|_| ())
1337 }
1338
1339 fn verify_device_code(
1340 &self,
1341 device_code: &str,
1342 target: &AuthBindingRef,
1343 provider: OAuthProviderIdentity,
1344 ) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
1345 let mut flows = self.device_flows.lock();
1346 prune_expired_device_locked(&mut flows);
1347 let Some(record) = flows.get(device_code) else {
1348 return Err(OAuthFlowError::Missing);
1349 };
1350 verify_device_record(&record.record, target, provider)?;
1351 Ok(record.record.clone())
1352 }
1353
1354 fn begin_device_code_poll(
1355 &self,
1356 device_code: &str,
1357 target: &AuthBindingRef,
1358 provider: OAuthProviderIdentity,
1359 ) -> Result<OAuthDevicePollLease, OAuthFlowError> {
1360 let mut flows = self.device_flows.lock();
1361 prune_expired_device_locked(&mut flows);
1362 let Some(state) = flows.get_mut(device_code) else {
1363 return Err(OAuthFlowError::Missing);
1364 };
1365 verify_device_record(&state.record, target, provider)?;
1366 if state.poll_lease.is_some() {
1367 return Err(OAuthFlowError::DevicePollInProgress);
1368 }
1369 let lease_id = self
1370 .next_device_poll_lease_id
1371 .fetch_add(1, Ordering::Relaxed);
1372 state.poll_lease = Some(OAuthDevicePollLeaseState { id: lease_id });
1373 Ok(OAuthDevicePollLease::new(
1374 Arc::clone(&self.device_flows),
1375 target.clone(),
1376 device_code.to_string(),
1377 provider,
1378 lease_id,
1379 ))
1380 }
1381}
1382
1383fn new_state_token() -> Result<String, OAuthFlowError> {
1384 let mut bytes = [0_u8; 32];
1385 getrandom::fill(&mut bytes).map_err(|_| OAuthFlowError::StateGenerationFailed)?;
1386 Ok(format!(
1387 "st-{}",
1388 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
1389 ))
1390}
1391
1392fn duration_millis_u64(duration: Duration) -> u64 {
1393 u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
1394}
1395
1396fn prune_expired_locked(flows: &mut HashMap<String, OAuthFlowRecord>, ttl: Duration) {
1397 let _ = take_expired_locked(flows, ttl);
1398}
1399
1400fn take_expired_locked(
1401 flows: &mut HashMap<String, OAuthFlowRecord>,
1402 ttl: Duration,
1403) -> Vec<(String, OAuthFlowRecord)> {
1404 let expired = flows
1405 .iter()
1406 .filter(|(_, record)| record.created_at.elapsed() > ttl)
1407 .map(|(flow_id, _)| flow_id.clone())
1408 .collect::<Vec<_>>();
1409 expired
1410 .into_iter()
1411 .filter_map(|flow_id| flows.remove(&flow_id).map(|record| (flow_id, record)))
1412 .collect()
1413}
1414
1415fn release_device_poll_lease_locked(
1416 flows: &mut HashMap<String, OAuthDeviceFlowState>,
1417 device_code: &str,
1418 provider: OAuthProviderIdentity,
1419 lease_id: u64,
1420) -> Result<(), OAuthFlowError> {
1421 let Some(state) = flows.get_mut(device_code) else {
1422 return Err(OAuthFlowError::Missing);
1423 };
1424 if state.record.provider != provider {
1425 return Err(OAuthFlowError::ProviderMismatch {
1426 expected: state.record.provider,
1427 actual: provider,
1428 });
1429 }
1430 match state.poll_lease {
1431 Some(lease) if lease.id == lease_id => {
1432 state.poll_lease = None;
1433 Ok(())
1434 }
1435 Some(_) => Err(OAuthFlowError::DevicePollInProgress),
1436 None => Err(OAuthFlowError::Missing),
1437 }
1438}
1439
1440fn verify_browser_record(
1441 record: &OAuthFlowRecord,
1442 target: &AuthBindingRef,
1443 provider: OAuthProviderIdentity,
1444 redirect_uri: &str,
1445) -> Result<(), OAuthFlowError> {
1446 if &record.target != target {
1447 return Err(OAuthFlowError::TargetMismatch {
1448 expected: Box::new(record.target.clone()),
1449 actual: Box::new(target.clone()),
1450 });
1451 }
1452 if record.provider != provider {
1453 return Err(OAuthFlowError::ProviderMismatch {
1454 expected: record.provider,
1455 actual: provider,
1456 });
1457 }
1458 if record.redirect_uri != redirect_uri {
1459 return Err(OAuthFlowError::RedirectUriMismatch);
1460 }
1461 Ok(())
1462}
1463
1464fn verify_device_record(
1465 record: &OAuthDeviceFlowRecord,
1466 target: &AuthBindingRef,
1467 provider: OAuthProviderIdentity,
1468) -> Result<(), OAuthFlowError> {
1469 if &record.target != target {
1470 return Err(OAuthFlowError::TargetMismatch {
1471 expected: Box::new(record.target.clone()),
1472 actual: Box::new(target.clone()),
1473 });
1474 }
1475 if record.provider != provider {
1476 return Err(OAuthFlowError::ProviderMismatch {
1477 expected: record.provider,
1478 actual: provider,
1479 });
1480 }
1481 Ok(())
1482}
1483
1484fn verify_device_poll_lease_locked(
1485 flows: &mut HashMap<String, OAuthDeviceFlowState>,
1486 device_code: &str,
1487 provider: OAuthProviderIdentity,
1488 lease_id: u64,
1489) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
1490 let Some(state) = flows.get(device_code) else {
1491 return Err(OAuthFlowError::Missing);
1492 };
1493 if state.record.provider != provider {
1494 return Err(OAuthFlowError::ProviderMismatch {
1495 expected: state.record.provider,
1496 actual: provider,
1497 });
1498 }
1499 match state.poll_lease {
1500 Some(lease) if lease.id == lease_id => {}
1501 Some(_) => return Err(OAuthFlowError::DevicePollInProgress),
1502 None => return Err(OAuthFlowError::Missing),
1503 }
1504 Ok(state.record.clone())
1505}
1506
1507fn consume_device_poll_lease_locked(
1508 flows: &mut HashMap<String, OAuthDeviceFlowState>,
1509 device_code: &str,
1510 provider: OAuthProviderIdentity,
1511 lease_id: u64,
1512) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
1513 verify_device_poll_lease_locked(flows, device_code, provider, lease_id)?;
1514 flows
1515 .remove(device_code)
1516 .map(|state| state.record)
1517 .ok_or(OAuthFlowError::Missing)
1518}
1519
1520fn prune_expired_device_locked(flows: &mut HashMap<String, OAuthDeviceFlowState>) {
1521 let _ = take_expired_device_locked(flows);
1522}
1523
1524fn take_expired_device_locked(
1525 flows: &mut HashMap<String, OAuthDeviceFlowState>,
1526) -> Vec<OAuthDeviceFlowRecord> {
1527 let now = Instant::now();
1528 let expired = flows
1529 .iter()
1530 .filter(|(_, state)| state.record.expires_at < now)
1531 .map(|(device_code, _)| device_code.clone())
1532 .collect::<Vec<_>>();
1533 expired
1534 .into_iter()
1535 .filter_map(|device_code| flows.remove(&device_code).map(|state| state.record))
1536 .collect()
1537}
1538
1539fn strings(values: &[&str]) -> Vec<String> {
1540 values.iter().map(|value| (*value).to_string()).collect()
1541}
1542
1543#[cfg(test)]
1544#[allow(clippy::expect_used)]
1545mod tests {
1546 use super::*;
1547
1548 fn target() -> AuthBindingRef {
1549 AuthBindingRef {
1550 realm: meerkat_core::RealmId::parse("dev").expect("valid realm"),
1551 binding: meerkat_core::BindingId::parse("default_openai").expect("valid binding"),
1552 profile: None,
1553 }
1554 }
1555
1556 fn alternate_target() -> AuthBindingRef {
1557 AuthBindingRef {
1558 realm: meerkat_core::RealmId::parse("dev").expect("valid realm"),
1559 binding: meerkat_core::BindingId::parse("alternate_openai").expect("valid binding"),
1560 profile: None,
1561 }
1562 }
1563
1564 #[test]
1565 fn oauth_state_pkce_round_trip() {
1566 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1567 let state = registry
1568 .start(
1569 target(),
1570 OAuthProviderIdentity::OpenAiChatGpt,
1571 "http://127.0.0.1/callback",
1572 "verifier",
1573 )
1574 .expect("state generation succeeds");
1575 let record = registry.consume(
1576 &state,
1577 &target(),
1578 OAuthProviderIdentity::OpenAiChatGpt,
1579 "http://127.0.0.1/callback",
1580 );
1581 assert!(
1582 record.is_ok(),
1583 "state should resolve once: {:?}",
1584 record.err()
1585 );
1586 if let Ok(record) = record {
1587 assert_eq!(record.pkce_verifier, "verifier");
1588 }
1589 assert!(matches!(
1590 registry.consume(
1591 &state,
1592 &target(),
1593 OAuthProviderIdentity::OpenAiChatGpt,
1594 "http://127.0.0.1/callback"
1595 ),
1596 Err(OAuthFlowError::Missing)
1597 ));
1598 }
1599
1600 #[test]
1601 fn oauth_state_rejects_mismatch() {
1602 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1603 let state = registry
1604 .start(
1605 target(),
1606 OAuthProviderIdentity::OpenAiChatGpt,
1607 "http://127.0.0.1/callback",
1608 "verifier",
1609 )
1610 .expect("state generation succeeds");
1611 assert!(matches!(
1612 registry.consume(
1613 &state,
1614 &target(),
1615 OAuthProviderIdentity::AnthropicClaudeAi,
1616 "http://127.0.0.1/callback"
1617 ),
1618 Err(OAuthFlowError::ProviderMismatch { .. })
1619 ));
1620 }
1621
1622 #[test]
1623 fn oauth_state_provider_mismatch_does_not_consume_state() {
1624 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1625 let state = registry
1626 .start(
1627 target(),
1628 OAuthProviderIdentity::OpenAiChatGpt,
1629 "http://127.0.0.1/callback",
1630 "verifier",
1631 )
1632 .expect("state generation succeeds");
1633 assert!(matches!(
1634 registry.consume(
1635 &state,
1636 &target(),
1637 OAuthProviderIdentity::AnthropicClaudeAi,
1638 "http://127.0.0.1/callback"
1639 ),
1640 Err(OAuthFlowError::ProviderMismatch { .. })
1641 ));
1642
1643 let record = registry
1644 .consume(
1645 &state,
1646 &target(),
1647 OAuthProviderIdentity::OpenAiChatGpt,
1648 "http://127.0.0.1/callback",
1649 )
1650 .expect("provider mismatch must leave state retryable");
1651 assert_eq!(record.pkce_verifier, "verifier");
1652 }
1653
1654 #[test]
1655 fn oauth_state_rejects_redirect_uri_mismatch() {
1656 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1657 let state = registry
1658 .start(
1659 target(),
1660 OAuthProviderIdentity::OpenAiChatGpt,
1661 "http://127.0.0.1/callback",
1662 "verifier",
1663 )
1664 .expect("state generation succeeds");
1665 assert!(matches!(
1666 registry.consume(
1667 &state,
1668 &target(),
1669 OAuthProviderIdentity::OpenAiChatGpt,
1670 "http://127.0.0.1/other"
1671 ),
1672 Err(OAuthFlowError::RedirectUriMismatch)
1673 ));
1674 }
1675
1676 #[test]
1677 fn oauth_state_redirect_uri_mismatch_does_not_consume_state() {
1678 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1679 let state = registry
1680 .start(
1681 target(),
1682 OAuthProviderIdentity::OpenAiChatGpt,
1683 "http://127.0.0.1/callback",
1684 "verifier",
1685 )
1686 .expect("state generation succeeds");
1687 assert!(matches!(
1688 registry.consume(
1689 &state,
1690 &target(),
1691 OAuthProviderIdentity::OpenAiChatGpt,
1692 "http://127.0.0.1/other"
1693 ),
1694 Err(OAuthFlowError::RedirectUriMismatch)
1695 ));
1696
1697 let record = registry
1698 .consume(
1699 &state,
1700 &target(),
1701 OAuthProviderIdentity::OpenAiChatGpt,
1702 "http://127.0.0.1/callback",
1703 )
1704 .expect("redirect mismatch must leave state retryable");
1705 assert_eq!(record.pkce_verifier, "verifier");
1706 }
1707
1708 #[test]
1709 fn oauth_state_target_mismatch_does_not_consume_state() {
1710 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1711 let state = registry
1712 .start(
1713 target(),
1714 OAuthProviderIdentity::OpenAiChatGpt,
1715 "http://127.0.0.1/callback",
1716 "verifier",
1717 )
1718 .expect("state generation succeeds");
1719 assert!(matches!(
1720 registry.consume(
1721 &state,
1722 &alternate_target(),
1723 OAuthProviderIdentity::OpenAiChatGpt,
1724 "http://127.0.0.1/callback"
1725 ),
1726 Err(OAuthFlowError::TargetMismatch { .. })
1727 ));
1728
1729 let record = registry
1730 .consume(
1731 &state,
1732 &target(),
1733 OAuthProviderIdentity::OpenAiChatGpt,
1734 "http://127.0.0.1/callback",
1735 )
1736 .expect("target mismatch must leave state retryable");
1737 assert_eq!(record.pkce_verifier, "verifier");
1738 }
1739
1740 #[test]
1741 fn oauth_state_verify_does_not_consume_state() {
1742 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1743 let state = registry
1744 .start(
1745 target(),
1746 OAuthProviderIdentity::OpenAiChatGpt,
1747 "http://127.0.0.1/callback",
1748 "verifier",
1749 )
1750 .expect("state generation succeeds");
1751
1752 let verified = registry
1753 .verify(
1754 &state,
1755 &target(),
1756 OAuthProviderIdentity::OpenAiChatGpt,
1757 "http://127.0.0.1/callback",
1758 )
1759 .expect("verify should read state before terminal commit");
1760
1761 assert_eq!(verified.pkce_verifier, "verifier");
1762 registry
1763 .consume(
1764 &state,
1765 &target(),
1766 OAuthProviderIdentity::OpenAiChatGpt,
1767 "http://127.0.0.1/callback",
1768 )
1769 .expect("verified browser state remains available for terminal consume");
1770 }
1771
1772 #[test]
1773 fn oauth_state_random_tokens_are_urlsafe_and_unique() {
1774 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1775 let first = registry
1776 .start(
1777 target(),
1778 OAuthProviderIdentity::OpenAiChatGpt,
1779 "http://127.0.0.1/callback",
1780 "verifier-a",
1781 )
1782 .expect("state generation succeeds");
1783 let second = registry
1784 .start(
1785 target(),
1786 OAuthProviderIdentity::OpenAiChatGpt,
1787 "http://127.0.0.1/callback",
1788 "verifier-b",
1789 )
1790 .expect("state generation succeeds");
1791 assert_ne!(first, second);
1792 assert!(first.starts_with("st-"));
1793 assert!(
1794 first[3..]
1795 .chars()
1796 .all(|ch| { ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' })
1797 );
1798 }
1799
1800 #[test]
1801 fn oauth_state_expired_records_are_pruned() {
1802 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1803 registry.flows.lock().insert(
1804 "st-old".to_string(),
1805 OAuthFlowRecord {
1806 target: target(),
1807 provider: OAuthProviderIdentity::OpenAiChatGpt,
1808 redirect_uri: "http://127.0.0.1/callback".to_string(),
1809 pkce_verifier: "verifier".to_string(),
1810 created_at: Instant::now()
1811 .checked_sub(Duration::from_secs(61))
1812 .expect("test duration is representable"),
1813 },
1814 );
1815 assert!(matches!(
1816 registry.consume(
1817 "st-old",
1818 &target(),
1819 OAuthProviderIdentity::OpenAiChatGpt,
1820 "http://127.0.0.1/callback"
1821 ),
1822 Err(OAuthFlowError::Missing)
1823 ));
1824 }
1825
1826 #[test]
1827 fn oauth_state_registry_rejects_start_at_capacity() {
1828 let registry = OAuthFlowRegistry::new_with_capacity(Duration::from_secs(60), 2);
1829 let first = registry
1830 .start(
1831 target(),
1832 OAuthProviderIdentity::OpenAiChatGpt,
1833 "http://127.0.0.1/callback",
1834 "first",
1835 )
1836 .expect("state generation succeeds");
1837 let second = registry
1838 .start(
1839 target(),
1840 OAuthProviderIdentity::OpenAiChatGpt,
1841 "http://127.0.0.1/callback",
1842 "second",
1843 )
1844 .expect("state generation succeeds");
1845 let third = registry.start(
1846 target(),
1847 OAuthProviderIdentity::OpenAiChatGpt,
1848 "http://127.0.0.1/callback",
1849 "third",
1850 );
1851
1852 assert!(matches!(
1853 third,
1854 Err(OAuthFlowError::CapacityExceeded { max_outstanding: 2 })
1855 ));
1856 assert!(
1857 registry
1858 .consume(
1859 &first,
1860 &target(),
1861 OAuthProviderIdentity::OpenAiChatGpt,
1862 "http://127.0.0.1/callback"
1863 )
1864 .is_ok()
1865 );
1866 assert!(
1867 registry
1868 .consume(
1869 &second,
1870 &target(),
1871 OAuthProviderIdentity::OpenAiChatGpt,
1872 "http://127.0.0.1/callback"
1873 )
1874 .is_ok()
1875 );
1876 }
1877
1878 #[test]
1879 fn oauth_state_cannot_cross_login_lifecycle_authorities() {
1880 let admitting_authority = OAuthFlowRegistry::new(Duration::from_secs(60));
1881 let unrelated_authority = OAuthFlowRegistry::new(Duration::from_secs(60));
1882 let state = admitting_authority
1883 .start(
1884 target(),
1885 OAuthProviderIdentity::OpenAiChatGpt,
1886 "http://127.0.0.1/callback",
1887 "verifier",
1888 )
1889 .expect("state generation succeeds");
1890
1891 assert!(matches!(
1892 unrelated_authority.consume(
1893 &state,
1894 &target(),
1895 OAuthProviderIdentity::OpenAiChatGpt,
1896 "http://127.0.0.1/callback"
1897 ),
1898 Err(OAuthFlowError::Missing)
1899 ));
1900
1901 let record = admitting_authority
1902 .consume(
1903 &state,
1904 &target(),
1905 OAuthProviderIdentity::OpenAiChatGpt,
1906 "http://127.0.0.1/callback",
1907 )
1908 .expect("admitting authority owns the flow");
1909 assert_eq!(record.pkce_verifier, "verifier");
1910 }
1911
1912 #[test]
1913 fn oauth_device_flow_is_retained_until_terminal_consume() {
1914 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1915 registry
1916 .admit_device_code(
1917 target(),
1918 OAuthProviderIdentity::GoogleCodeAssist,
1919 "device-code",
1920 Duration::from_secs(600),
1921 )
1922 .expect("device code admitted");
1923
1924 let observed = registry
1925 .verify_device_code(
1926 "device-code",
1927 &target(),
1928 OAuthProviderIdentity::GoogleCodeAssist,
1929 )
1930 .expect("pending device flow remains visible");
1931 assert_eq!(observed.device_code, "device-code");
1932
1933 let poll = registry
1934 .begin_device_code_poll(
1935 "device-code",
1936 &target(),
1937 OAuthProviderIdentity::GoogleCodeAssist,
1938 )
1939 .expect("poll begins");
1940 let consumed = poll.consume().expect("terminal device flow consumes");
1941 assert_eq!(consumed.provider, OAuthProviderIdentity::GoogleCodeAssist);
1942 assert!(matches!(
1943 registry.verify_device_code(
1944 "device-code",
1945 &target(),
1946 OAuthProviderIdentity::GoogleCodeAssist
1947 ),
1948 Err(OAuthFlowError::Missing)
1949 ));
1950 }
1951
1952 #[test]
1953 fn oauth_device_poll_verify_keeps_terminal_consume_available() {
1954 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1955 registry
1956 .admit_device_code(
1957 target(),
1958 OAuthProviderIdentity::GoogleCodeAssist,
1959 "device-code",
1960 Duration::from_secs(600),
1961 )
1962 .expect("device code admitted");
1963 let poll = registry
1964 .begin_device_code_poll(
1965 "device-code",
1966 &target(),
1967 OAuthProviderIdentity::GoogleCodeAssist,
1968 )
1969 .expect("poll begins");
1970
1971 let verified = poll
1972 .verify()
1973 .expect("terminal preflight verifies the current lease");
1974
1975 assert_eq!(verified.device_code, "device-code");
1976 assert!(matches!(
1977 registry.begin_device_code_poll(
1978 "device-code",
1979 &target(),
1980 OAuthProviderIdentity::GoogleCodeAssist
1981 ),
1982 Err(OAuthFlowError::DevicePollInProgress)
1983 ));
1984 let consumed = poll.consume().expect("verified lease still consumes");
1985 assert_eq!(consumed.device_code, "device-code");
1986 }
1987
1988 #[test]
1989 fn oauth_device_admission_does_not_replace_active_poll_lease() {
1990 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1991 registry
1992 .admit_device_code(
1993 target(),
1994 OAuthProviderIdentity::GoogleCodeAssist,
1995 "device-code",
1996 Duration::from_secs(600),
1997 )
1998 .expect("device code admitted");
1999 let poll = registry
2000 .begin_device_code_poll(
2001 "device-code",
2002 &target(),
2003 OAuthProviderIdentity::GoogleCodeAssist,
2004 )
2005 .expect("poll begins");
2006
2007 let duplicate = registry.admit_device_code(
2008 target(),
2009 OAuthProviderIdentity::GoogleCodeAssist,
2010 "device-code",
2011 Duration::from_secs(600),
2012 );
2013
2014 assert_eq!(duplicate, Err(OAuthFlowError::DeviceCodeAlreadyAdmitted));
2015 let consumed = poll
2016 .consume()
2017 .expect("duplicate admission must not replace the active poll lease");
2018 assert_eq!(consumed.device_code, "device-code");
2019 }
2020
2021 #[test]
2022 fn oauth_device_flow_rejects_provider_mismatch() {
2023 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2024 registry
2025 .admit_device_code(
2026 target(),
2027 OAuthProviderIdentity::GoogleCodeAssist,
2028 "device-code",
2029 Duration::from_secs(600),
2030 )
2031 .expect("device code admitted");
2032
2033 assert!(matches!(
2034 registry.verify_device_code(
2035 "device-code",
2036 &target(),
2037 OAuthProviderIdentity::OpenAiChatGpt
2038 ),
2039 Err(OAuthFlowError::ProviderMismatch { .. })
2040 ));
2041 let poll = registry
2042 .begin_device_code_poll(
2043 "device-code",
2044 &target(),
2045 OAuthProviderIdentity::GoogleCodeAssist,
2046 )
2047 .expect("correct provider poll begins after mismatch");
2048 assert!(poll.consume().is_ok());
2049 }
2050
2051 #[test]
2052 fn oauth_device_flow_rejects_target_mismatch_without_consuming() {
2053 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2054 registry
2055 .admit_device_code(
2056 target(),
2057 OAuthProviderIdentity::GoogleCodeAssist,
2058 "device-code",
2059 Duration::from_secs(600),
2060 )
2061 .expect("device code admitted");
2062
2063 assert!(matches!(
2064 registry.verify_device_code(
2065 "device-code",
2066 &alternate_target(),
2067 OAuthProviderIdentity::GoogleCodeAssist
2068 ),
2069 Err(OAuthFlowError::TargetMismatch { .. })
2070 ));
2071 let poll = registry
2072 .begin_device_code_poll(
2073 "device-code",
2074 &target(),
2075 OAuthProviderIdentity::GoogleCodeAssist,
2076 )
2077 .expect("correct target poll begins after mismatch");
2078 assert!(poll.consume().is_ok());
2079 }
2080
2081 #[test]
2082 fn oauth_device_flow_expired_records_are_pruned() {
2083 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2084 registry.device_flows.lock().insert(
2085 "device-code".to_string(),
2086 OAuthDeviceFlowState {
2087 record: OAuthDeviceFlowRecord {
2088 target: target(),
2089 provider: OAuthProviderIdentity::GoogleCodeAssist,
2090 device_code: "device-code".to_string(),
2091 created_at: Instant::now()
2092 .checked_sub(Duration::from_secs(601))
2093 .expect("test duration is representable"),
2094 expires_at: Instant::now()
2095 .checked_sub(Duration::from_secs(1))
2096 .expect("test duration is representable"),
2097 },
2098 poll_lease: None,
2099 },
2100 );
2101
2102 assert!(matches!(
2103 registry.verify_device_code(
2104 "device-code",
2105 &target(),
2106 OAuthProviderIdentity::GoogleCodeAssist
2107 ),
2108 Err(OAuthFlowError::Missing)
2109 ));
2110 }
2111
2112 #[test]
2113 fn oauth_device_flow_rejects_unrepresentable_expiry() {
2114 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2115 let err = registry
2116 .admit_device_code(
2117 target(),
2118 OAuthProviderIdentity::GoogleCodeAssist,
2119 "device-code",
2120 Duration::MAX,
2121 )
2122 .expect_err("unrepresentable device expiry should be rejected");
2123
2124 assert_eq!(err, OAuthFlowError::DeviceExpiryOutOfRange);
2125 assert!(matches!(
2126 registry.verify_device_code(
2127 "device-code",
2128 &target(),
2129 OAuthProviderIdentity::GoogleCodeAssist
2130 ),
2131 Err(OAuthFlowError::Missing)
2132 ));
2133 }
2134
2135 #[test]
2136 fn oauth_device_terminal_consume_rejects_local_expiry_boundary() {
2137 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2138 registry
2139 .admit_device_code(
2140 target(),
2141 OAuthProviderIdentity::GoogleCodeAssist,
2142 "device-code",
2143 Duration::from_secs(600),
2144 )
2145 .expect("device code admitted");
2146 let poll = registry
2147 .begin_device_code_poll(
2148 "device-code",
2149 &target(),
2150 OAuthProviderIdentity::GoogleCodeAssist,
2151 )
2152 .expect("poll begins before local expiry boundary");
2153 {
2154 let mut flows = registry.device_flows.lock();
2155 flows
2156 .get_mut("device-code")
2157 .expect("device flow exists")
2158 .record
2159 .expires_at = Instant::now()
2160 .checked_sub(Duration::from_secs(1))
2161 .expect("test duration is representable");
2162 }
2163
2164 assert!(matches!(poll.consume(), Err(OAuthFlowError::Missing)));
2165 assert!(matches!(
2166 registry.verify_device_code(
2167 "device-code",
2168 &target(),
2169 OAuthProviderIdentity::GoogleCodeAssist
2170 ),
2171 Err(OAuthFlowError::Missing)
2172 ));
2173 }
2174
2175 #[test]
2176 fn oauth_device_terminal_consume_rejects_intervening_prune_while_polling() {
2177 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2178 registry
2179 .admit_device_code(
2180 target(),
2181 OAuthProviderIdentity::GoogleCodeAssist,
2182 "device-code",
2183 Duration::from_secs(600),
2184 )
2185 .expect("device code admitted");
2186 let poll = registry
2187 .begin_device_code_poll(
2188 "device-code",
2189 &target(),
2190 OAuthProviderIdentity::GoogleCodeAssist,
2191 )
2192 .expect("poll begins");
2193 assert!(matches!(
2194 registry.begin_device_code_poll(
2195 "device-code",
2196 &target(),
2197 OAuthProviderIdentity::GoogleCodeAssist
2198 ),
2199 Err(OAuthFlowError::DevicePollInProgress)
2200 ));
2201 {
2202 let mut flows = registry.device_flows.lock();
2203 flows
2204 .get_mut("device-code")
2205 .expect("device flow exists")
2206 .record
2207 .expires_at = Instant::now()
2208 .checked_sub(Duration::from_secs(1))
2209 .expect("test duration is representable");
2210 }
2211
2212 registry
2213 .start(
2214 target(),
2215 OAuthProviderIdentity::OpenAiChatGpt,
2216 "http://127.0.0.1/callback",
2217 "verifier",
2218 )
2219 .expect("unrelated start prunes expired device flow");
2220 assert!(matches!(poll.consume(), Err(OAuthFlowError::Missing)));
2221 }
2222
2223 #[test]
2224 fn oauth_device_poll_drop_releases_in_progress_lifecycle() {
2225 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2226 registry
2227 .admit_device_code(
2228 target(),
2229 OAuthProviderIdentity::GoogleCodeAssist,
2230 "device-code",
2231 Duration::from_secs(600),
2232 )
2233 .expect("device code admitted");
2234
2235 let poll = registry
2236 .begin_device_code_poll(
2237 "device-code",
2238 &target(),
2239 OAuthProviderIdentity::GoogleCodeAssist,
2240 )
2241 .expect("poll begins");
2242 assert!(matches!(
2243 registry.begin_device_code_poll(
2244 "device-code",
2245 &target(),
2246 OAuthProviderIdentity::GoogleCodeAssist
2247 ),
2248 Err(OAuthFlowError::DevicePollInProgress)
2249 ));
2250
2251 drop(poll);
2252
2253 registry
2254 .begin_device_code_poll(
2255 "device-code",
2256 &target(),
2257 OAuthProviderIdentity::GoogleCodeAssist,
2258 )
2259 .expect("dropped poll lease releases the in-progress lifecycle");
2260 }
2261
2262 #[test]
2263 fn oauth_device_poll_drop_prunes_expired_in_progress_record() {
2264 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2265 registry
2266 .admit_device_code(
2267 target(),
2268 OAuthProviderIdentity::GoogleCodeAssist,
2269 "device-code",
2270 Duration::from_secs(600),
2271 )
2272 .expect("device code admitted");
2273 let poll = registry
2274 .begin_device_code_poll(
2275 "device-code",
2276 &target(),
2277 OAuthProviderIdentity::GoogleCodeAssist,
2278 )
2279 .expect("poll begins");
2280 {
2281 let mut flows = registry.device_flows.lock();
2282 flows
2283 .get_mut("device-code")
2284 .expect("device flow exists")
2285 .record
2286 .expires_at = Instant::now()
2287 .checked_sub(Duration::from_secs(1))
2288 .expect("test duration is representable");
2289 }
2290
2291 drop(poll);
2292
2293 assert!(matches!(
2294 registry.verify_device_code(
2295 "device-code",
2296 &target(),
2297 OAuthProviderIdentity::GoogleCodeAssist
2298 ),
2299 Err(OAuthFlowError::Missing)
2300 ));
2301 }
2302
2303 struct RejectConsumeLifecycle;
2304
2305 impl OAuthDevicePollLifecycle for RejectConsumeLifecycle {
2306 fn finish_device_poll(
2307 &self,
2308 _target: &AuthBindingRef,
2309 _device_code: &str,
2310 ) -> Result<(), OAuthFlowError> {
2311 Ok(())
2312 }
2313
2314 fn consume_device_flow(
2315 &self,
2316 _target: &AuthBindingRef,
2317 _device_code: &str,
2318 _provider: OAuthProviderIdentity,
2319 ) -> Result<(), OAuthFlowError> {
2320 Err(OAuthFlowError::LifecycleRejected {
2321 operation: "consume_oauth_device_flow",
2322 detail: "injected failure".to_string(),
2323 })
2324 }
2325
2326 fn expire_device_flow(
2327 &self,
2328 _target: &AuthBindingRef,
2329 _device_code: &str,
2330 ) -> Result<(), OAuthFlowError> {
2331 Ok(())
2332 }
2333 }
2334
2335 #[test]
2336 fn oauth_device_lifecycle_consume_failure_keeps_flow_retryable() {
2337 let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2338 registry
2339 .admit_device_code(
2340 target(),
2341 OAuthProviderIdentity::GoogleCodeAssist,
2342 "device-code",
2343 Duration::from_secs(600),
2344 )
2345 .expect("device code admitted");
2346 let poll = registry
2347 .begin_device_code_poll(
2348 "device-code",
2349 &target(),
2350 OAuthProviderIdentity::GoogleCodeAssist,
2351 )
2352 .expect("poll begins")
2353 .with_lifecycle(std::sync::Arc::new(RejectConsumeLifecycle));
2354
2355 assert!(matches!(
2356 poll.consume(),
2357 Err(OAuthFlowError::LifecycleRejected {
2358 operation: "consume_oauth_device_flow",
2359 ..
2360 })
2361 ));
2362
2363 let retry = registry
2364 .begin_device_code_poll(
2365 "device-code",
2366 &target(),
2367 OAuthProviderIdentity::GoogleCodeAssist,
2368 )
2369 .expect("failed lifecycle consume keeps device flow retryable");
2370 assert!(retry.consume().is_ok());
2371 }
2372
2373 #[test]
2374 fn oauth_provider_resolution_preserves_aliases() {
2375 let cases = [
2376 (
2377 "anthropic",
2378 OAuthProviderIdentity::AnthropicClaudeAi,
2379 Provider::Anthropic,
2380 PersistedAuthMode::ClaudeAiOauth,
2381 ),
2382 (
2383 "claude",
2384 OAuthProviderIdentity::AnthropicClaudeAi,
2385 Provider::Anthropic,
2386 PersistedAuthMode::ClaudeAiOauth,
2387 ),
2388 (
2389 "claude.ai",
2390 OAuthProviderIdentity::AnthropicClaudeAi,
2391 Provider::Anthropic,
2392 PersistedAuthMode::ClaudeAiOauth,
2393 ),
2394 (
2395 "openai",
2396 OAuthProviderIdentity::OpenAiChatGpt,
2397 Provider::OpenAI,
2398 PersistedAuthMode::ChatgptOauth,
2399 ),
2400 (
2401 "chatgpt",
2402 OAuthProviderIdentity::OpenAiChatGpt,
2403 Provider::OpenAI,
2404 PersistedAuthMode::ChatgptOauth,
2405 ),
2406 (
2407 "google",
2408 OAuthProviderIdentity::GoogleCodeAssist,
2409 Provider::Gemini,
2410 PersistedAuthMode::GoogleOauth,
2411 ),
2412 (
2413 "gemini",
2414 OAuthProviderIdentity::GoogleCodeAssist,
2415 Provider::Gemini,
2416 PersistedAuthMode::GoogleOauth,
2417 ),
2418 (
2419 "code_assist",
2420 OAuthProviderIdentity::GoogleCodeAssist,
2421 Provider::Gemini,
2422 PersistedAuthMode::GoogleOauth,
2423 ),
2424 ];
2425
2426 for (alias, identity, provider, auth_mode) in cases {
2427 let resolved =
2428 resolve_oauth_provider(alias, "http://127.0.0.1/callback").expect("alias resolves");
2429 assert_eq!(resolved.identity, identity);
2430 assert_eq!(resolved.provider, provider);
2431 assert_eq!(resolved.auth_mode, auth_mode);
2432 assert_eq!(resolved.endpoints.redirect_uri, "http://127.0.0.1/callback");
2433 }
2434 }
2435
2436 #[test]
2437 fn oauth_provider_resolution_exposes_google_device_secret() {
2438 let resolved =
2439 resolve_oauth_provider("code_assist", "").expect("google code assist resolves");
2440
2441 assert_eq!(resolved.identity, OAuthProviderIdentity::GoogleCodeAssist);
2442 assert!(resolved.endpoints.device_code_url.is_some());
2443 assert_eq!(resolved.client_secret, Some(GOOGLE_CLIENT_SECRET));
2444 }
2445
2446 #[cfg(feature = "oauth")]
2447 #[test]
2448 fn openai_provider_resolution_matches_codex_authorize_contract() {
2449 let resolved = resolve_oauth_provider("openai", "http://localhost:1455/auth/callback")
2450 .expect("openai resolves");
2451 let pkce = crate::auth_oauth::PkcePair::generate_s256();
2452 let authorize_url = resolved
2453 .endpoints
2454 .authorize_url_with_pkce(&pkce.challenge, "state-abc");
2455
2456 assert_eq!(
2457 resolved.endpoints.redirect_uri,
2458 "http://localhost:1455/auth/callback"
2459 );
2460 assert_eq!(
2461 resolved.endpoints.token_request_format,
2462 OAuthTokenRequestFormat::FormUrlEncoded
2463 );
2464 assert!(
2465 authorize_url.contains("redirect_uri=http%3A%2F%2Flocalhost%3A1455%2Fauth%2Fcallback")
2466 );
2467 assert!(authorize_url.contains("id_token_add_organizations=true"));
2468 assert!(authorize_url.contains("codex_cli_simplified_flow=true"));
2469 assert!(authorize_url.contains("originator=codex_cli_rs"));
2470 }
2471
2472 #[test]
2473 fn anthropic_provider_resolution_matches_claude_code_token_contract() {
2474 let resolved = resolve_oauth_provider("anthropic", "http://localhost:1455/callback")
2475 .expect("anthropic resolves");
2476
2477 assert_eq!(
2478 resolved.endpoints.token_request_format,
2479 OAuthTokenRequestFormat::Json
2480 );
2481 assert!(resolved.endpoints.include_state_in_token_exchange);
2482 assert_eq!(
2483 resolved.endpoints.refresh_scopes,
2484 strings(ANTHROPIC_CLAUDE_AI_SCOPES)
2485 );
2486 assert!(
2487 resolved
2488 .endpoints
2489 .extra_authorize_params
2490 .contains(&("code".to_string(), "true".to_string()))
2491 );
2492 }
2493}