Skip to main content

meerkat_auth_core/
oauth_flow.rs

1//! Short-lived OAuth login flow authority.
2//!
3//! Runtime surfaces own an explicit authority instance for state -> PKCE
4//! verifier and device-code lifecycle correlation. Start records a flow before
5//! returning it to the client; complete must verify and consume that state
6//! through the same authority before committing terminal login state.
7
8use std::collections::HashMap;
9use std::sync::{
10    Arc, Mutex as StdMutex,
11    atomic::{AtomicU64, Ordering},
12};
13use std::time::{Duration, Instant};
14
15use base64::Engine as _;
16use meerkat_core::{AuthBindingRef, AuthProfile, BackendProfile, CredentialSourceSpec, Provider};
17use parking_lot::Mutex;
18use serde::{Deserialize, Serialize};
19
20use crate::auth_oauth::{OAuthEndpoints, OAuthTokenRequestFormat};
21use crate::auth_store::{
22    PersistedAuthMode, credential_source_uses_persisted_store, persisted_auth_mode_for_auth_method,
23};
24
25const DEFAULT_MAX_OUTSTANDING_FLOWS: usize = 1024;
26
27const ANTHROPIC_CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
28const ANTHROPIC_AUTHORIZE_URL: &str = "https://claude.com/cai/oauth/authorize";
29const ANTHROPIC_CONSOLE_AUTHORIZE_URL: &str = "https://platform.claude.com/oauth/authorize";
30const ANTHROPIC_TOKEN_URL: &str = "https://platform.claude.com/v1/oauth/token";
31const ANTHROPIC_CONSOLE_SCOPES: &[&str] = &["org:create_api_key", "user:profile"];
32const ANTHROPIC_CLAUDE_AI_SCOPES: &[&str] = &[
33    "user:profile",
34    "user:inference",
35    "user:sessions:claude_code",
36    "user:mcp_servers",
37    "user:file_upload",
38];
39
40const OPENAI_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
41const OPENAI_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
42const OPENAI_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
43const OPENAI_SCOPES: &[&str] = &[
44    "openid",
45    "profile",
46    "email",
47    "offline_access",
48    "api.connectors.read",
49    "api.connectors.invoke",
50];
51const OPENAI_ORIGINATOR: &str = "codex_cli_rs";
52
53const GOOGLE_CLIENT_ID: &str = concat!(
54    "6812558",
55    "09395-oo8ft2oprdrnp9e3aqf6av3hmdib135j",
56    ".apps.googleusercontent.com",
57);
58const GOOGLE_CLIENT_SECRET: &str = concat!("GOCSP", "X-4uHgMPm", "-1o7Sk-geV6Cu5clXFsxl");
59const GOOGLE_AUTHORIZE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth";
60const GOOGLE_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
61const GOOGLE_DEVICE_CODE_URL: &str = "https://oauth2.googleapis.com/device/code";
62const GOOGLE_SCOPES: &[&str] = &[
63    "https://www.googleapis.com/auth/cloud-platform",
64    "https://www.googleapis.com/auth/userinfo.email",
65    "https://www.googleapis.com/auth/userinfo.profile",
66];
67
68const TEST_OAUTH_ENDPOINT_OVERRIDE_ENV: &str = "MEERKAT_TEST_OAUTH_ENDPOINT_OVERRIDE";
69const TEST_OAUTH_BASE_URL_ENV: &str = "MEERKAT_TEST_OAUTH_BASE_URL";
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
72pub enum OAuthProviderIdentity {
73    AnthropicClaudeAi,
74    AnthropicConsoleApiKey,
75    OpenAiChatGpt,
76    GoogleCodeAssist,
77}
78
79impl OAuthProviderIdentity {
80    pub fn from_alias(alias: &str) -> Option<Self> {
81        match alias {
82            "anthropic" | "claude" | "claude.ai" => Some(Self::AnthropicClaudeAi),
83            "anthropic_console_api_key" => Some(Self::AnthropicConsoleApiKey),
84            "openai" | "chatgpt" => Some(Self::OpenAiChatGpt),
85            "google" | "gemini" | "code_assist" => Some(Self::GoogleCodeAssist),
86            _ => None,
87        }
88    }
89
90    pub fn canonical_alias(self) -> &'static str {
91        match self {
92            Self::AnthropicClaudeAi => "anthropic",
93            Self::AnthropicConsoleApiKey => "anthropic_console_api_key",
94            Self::OpenAiChatGpt => "openai",
95            Self::GoogleCodeAssist => "google",
96        }
97    }
98
99    pub fn provider(self) -> Provider {
100        match self {
101            Self::AnthropicClaudeAi | Self::AnthropicConsoleApiKey => Provider::Anthropic,
102            Self::OpenAiChatGpt => Provider::OpenAI,
103            Self::GoogleCodeAssist => Provider::Gemini,
104        }
105    }
106
107    pub fn auth_mode(self) -> PersistedAuthMode {
108        match self {
109            Self::AnthropicClaudeAi => PersistedAuthMode::ClaudeAiOauth,
110            Self::AnthropicConsoleApiKey => PersistedAuthMode::OauthToApiKey,
111            Self::OpenAiChatGpt => PersistedAuthMode::ChatgptOauth,
112            Self::GoogleCodeAssist => PersistedAuthMode::GoogleOauth,
113        }
114    }
115
116    pub fn backend_kind(self) -> &'static str {
117        match self {
118            Self::AnthropicClaudeAi | Self::AnthropicConsoleApiKey => "anthropic_api",
119            Self::OpenAiChatGpt => "chatgpt_backend",
120            Self::GoogleCodeAssist => "google_code_assist",
121        }
122    }
123
124    pub fn client_secret(self) -> Option<&'static str> {
125        match self {
126            Self::AnthropicClaudeAi | Self::AnthropicConsoleApiKey | Self::OpenAiChatGpt => None,
127            Self::GoogleCodeAssist => Some(GOOGLE_CLIENT_SECRET),
128        }
129    }
130
131    pub fn endpoints(self, redirect_uri: impl Into<String>) -> OAuthEndpoints {
132        let endpoints = match self {
133            Self::AnthropicClaudeAi => OAuthEndpoints {
134                client_id: ANTHROPIC_CLIENT_ID.into(),
135                authorize_url: ANTHROPIC_AUTHORIZE_URL.into(),
136                token_url: ANTHROPIC_TOKEN_URL.into(),
137                device_code_url: None,
138                redirect_uri: redirect_uri.into(),
139                scopes: strings(ANTHROPIC_CLAUDE_AI_SCOPES),
140                extra_authorize_params: vec![("code".into(), "true".into())],
141                token_request_format: OAuthTokenRequestFormat::Json,
142                include_state_in_token_exchange: true,
143                refresh_scopes: strings(ANTHROPIC_CLAUDE_AI_SCOPES),
144                extra_headers: Vec::new(),
145            },
146            Self::AnthropicConsoleApiKey => OAuthEndpoints {
147                client_id: ANTHROPIC_CLIENT_ID.into(),
148                authorize_url: ANTHROPIC_CONSOLE_AUTHORIZE_URL.into(),
149                token_url: ANTHROPIC_TOKEN_URL.into(),
150                device_code_url: None,
151                redirect_uri: redirect_uri.into(),
152                scopes: strings(ANTHROPIC_CONSOLE_SCOPES),
153                extra_authorize_params: Vec::new(),
154                token_request_format: OAuthTokenRequestFormat::Json,
155                include_state_in_token_exchange: true,
156                refresh_scopes: strings(ANTHROPIC_CONSOLE_SCOPES),
157                extra_headers: Vec::new(),
158            },
159            Self::OpenAiChatGpt => OAuthEndpoints {
160                client_id: OPENAI_CLIENT_ID.into(),
161                authorize_url: OPENAI_AUTHORIZE_URL.into(),
162                token_url: OPENAI_TOKEN_URL.into(),
163                device_code_url: None,
164                redirect_uri: redirect_uri.into(),
165                scopes: strings(OPENAI_SCOPES),
166                extra_authorize_params: vec![
167                    ("id_token_add_organizations".into(), "true".into()),
168                    ("codex_cli_simplified_flow".into(), "true".into()),
169                    ("originator".into(), OPENAI_ORIGINATOR.into()),
170                ],
171                token_request_format: OAuthTokenRequestFormat::FormUrlEncoded,
172                include_state_in_token_exchange: false,
173                refresh_scopes: Vec::new(),
174                extra_headers: Vec::new(),
175            },
176            Self::GoogleCodeAssist => OAuthEndpoints {
177                client_id: GOOGLE_CLIENT_ID.into(),
178                authorize_url: GOOGLE_AUTHORIZE_URL.into(),
179                token_url: GOOGLE_TOKEN_URL.into(),
180                device_code_url: Some(GOOGLE_DEVICE_CODE_URL.into()),
181                redirect_uri: redirect_uri.into(),
182                scopes: strings(GOOGLE_SCOPES),
183                extra_authorize_params: Vec::new(),
184                token_request_format: OAuthTokenRequestFormat::FormUrlEncoded,
185                include_state_in_token_exchange: false,
186                refresh_scopes: Vec::new(),
187                extra_headers: Vec::new(),
188            },
189        };
190        apply_test_oauth_endpoint_override(self, endpoints)
191    }
192}
193
194impl std::fmt::Display for OAuthProviderIdentity {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        f.write_str(self.canonical_alias())
197    }
198}
199
200#[derive(Debug, Clone)]
201pub struct OAuthProviderResolution {
202    pub identity: OAuthProviderIdentity,
203    pub provider: Provider,
204    pub endpoints: OAuthEndpoints,
205    pub auth_mode: PersistedAuthMode,
206    pub client_secret: Option<&'static str>,
207}
208
209#[derive(Debug, thiserror::Error, PartialEq, Eq)]
210#[error("Unknown provider '{provider}'. Supported: anthropic, openai, google.")]
211pub struct OAuthProviderResolutionError {
212    pub provider: String,
213}
214
215pub fn resolve_oauth_provider(
216    provider: &str,
217    redirect_uri: impl Into<String>,
218) -> Result<OAuthProviderResolution, OAuthProviderResolutionError> {
219    let identity = OAuthProviderIdentity::from_alias(provider).ok_or_else(|| {
220        OAuthProviderResolutionError {
221            provider: provider.to_string(),
222        }
223    })?;
224    Ok(OAuthProviderResolution {
225        identity,
226        provider: identity.provider(),
227        endpoints: identity.endpoints(redirect_uri),
228        auth_mode: identity.auth_mode(),
229        client_secret: identity.client_secret(),
230    })
231}
232
233/// Apply the local OAuth fixture endpoint override used by release-grade auth
234/// smoke tests. Production code only observes this when both explicit
235/// `MEERKAT_TEST_*` environment variables are set.
236#[doc(hidden)]
237pub fn apply_test_oauth_endpoint_override(
238    identity: OAuthProviderIdentity,
239    mut endpoints: OAuthEndpoints,
240) -> OAuthEndpoints {
241    #[cfg(not(target_arch = "wasm32"))]
242    {
243        let enabled = std::env::var(TEST_OAUTH_ENDPOINT_OVERRIDE_ENV)
244            .map(|value| {
245                matches!(
246                    value.as_str(),
247                    "1" | "true" | "TRUE" | "yes" | "YES" | "on" | "ON"
248                )
249            })
250            .unwrap_or(false);
251        if !enabled {
252            return endpoints;
253        }
254        let Ok(base_url) = std::env::var(TEST_OAUTH_BASE_URL_ENV) else {
255            return endpoints;
256        };
257        let base_url = base_url.trim_end_matches('/');
258        let provider = identity.canonical_alias();
259        endpoints.authorize_url = format!("{base_url}/{provider}/authorize");
260        endpoints.token_url = format!("{base_url}/{provider}/token");
261        if endpoints.device_code_url.is_some() {
262            endpoints.device_code_url = Some(format!("{base_url}/{provider}/device/code"));
263        }
264    }
265    endpoints
266}
267
268#[derive(Debug, thiserror::Error, PartialEq, Eq)]
269pub enum OAuthTargetValidationError {
270    #[error("OAuth target backend provider mismatch: expected {expected:?}, got {actual:?}")]
271    BackendProviderMismatch {
272        expected: Provider,
273        actual: Provider,
274    },
275    #[error(
276        "OAuth target backend_kind '{backend_kind}' cannot store credential mode {expected_mode:?}; expected backend_kind '{expected_backend_kind}'"
277    )]
278    BackendKindMismatch {
279        backend_kind: String,
280        expected_backend_kind: &'static str,
281        expected_mode: PersistedAuthMode,
282    },
283    #[error("OAuth target provider mismatch: expected {expected:?}, got {actual:?}")]
284    ProviderMismatch {
285        expected: Provider,
286        actual: Provider,
287    },
288    #[error(
289        "OAuth target auth_method '{auth_method}' cannot store credential mode {expected_mode:?}"
290    )]
291    AuthMethodMismatch {
292        auth_method: String,
293        expected_mode: PersistedAuthMode,
294    },
295    #[error(
296        "OAuth target source '{source_kind}' cannot store OAuth credentials; expected source.kind = 'managed_store' or 'platform_default'"
297    )]
298    SourceMismatch { source_kind: &'static str },
299}
300
301pub fn validate_oauth_login_target(
302    auth_profile: &AuthProfile,
303    identity: OAuthProviderIdentity,
304) -> Result<(), OAuthTargetValidationError> {
305    validate_oauth_target_for_auth_mode(auth_profile, identity.provider(), identity.auth_mode())
306}
307
308pub fn validate_oauth_login_binding(
309    backend_profile: &BackendProfile,
310    auth_profile: &AuthProfile,
311    identity: OAuthProviderIdentity,
312) -> Result<(), OAuthTargetValidationError> {
313    validate_oauth_target_binding_for_auth_mode(
314        backend_profile,
315        auth_profile,
316        identity.provider(),
317        identity.auth_mode(),
318        identity.backend_kind(),
319    )
320}
321
322pub fn validate_oauth_target_for_auth_mode(
323    auth_profile: &AuthProfile,
324    expected_provider: Provider,
325    expected_mode: PersistedAuthMode,
326) -> Result<(), OAuthTargetValidationError> {
327    if auth_profile.provider != expected_provider {
328        return Err(OAuthTargetValidationError::ProviderMismatch {
329            expected: expected_provider,
330            actual: auth_profile.provider,
331        });
332    }
333    match persisted_auth_mode_for_auth_method(&auth_profile.auth_method) {
334        Some(actual_mode) if actual_mode == expected_mode => {}
335        _ => {
336            return Err(OAuthTargetValidationError::AuthMethodMismatch {
337                auth_method: auth_profile.auth_method.clone(),
338                expected_mode,
339            });
340        }
341    }
342    if !oauth_source_can_store_flow_credentials(&auth_profile.source) {
343        return Err(OAuthTargetValidationError::SourceMismatch {
344            source_kind: source_kind_label(&auth_profile.source),
345        });
346    }
347    Ok(())
348}
349
350pub fn validate_oauth_target_binding_for_auth_mode(
351    backend_profile: &BackendProfile,
352    auth_profile: &AuthProfile,
353    expected_provider: Provider,
354    expected_mode: PersistedAuthMode,
355    expected_backend_kind: &'static str,
356) -> Result<(), OAuthTargetValidationError> {
357    validate_oauth_target_for_auth_mode(auth_profile, expected_provider, expected_mode)?;
358    if backend_profile.provider != expected_provider {
359        return Err(OAuthTargetValidationError::BackendProviderMismatch {
360            expected: expected_provider,
361            actual: backend_profile.provider,
362        });
363    }
364    if backend_profile.backend_kind != expected_backend_kind {
365        return Err(OAuthTargetValidationError::BackendKindMismatch {
366            backend_kind: backend_profile.backend_kind.clone(),
367            expected_backend_kind,
368            expected_mode,
369        });
370    }
371    Ok(())
372}
373
374fn oauth_source_can_store_flow_credentials(source: &CredentialSourceSpec) -> bool {
375    credential_source_uses_persisted_store(source)
376}
377
378fn source_kind_label(source: &CredentialSourceSpec) -> &'static str {
379    match source {
380        CredentialSourceSpec::InlineSecret { .. } => "inline_secret",
381        CredentialSourceSpec::ManagedStore => "managed_store",
382        CredentialSourceSpec::Env { .. } => "env",
383        CredentialSourceSpec::ExternalResolver { .. } => "external_resolver",
384        CredentialSourceSpec::PlatformDefault => "platform_default",
385        CredentialSourceSpec::Command { .. } => "command",
386        CredentialSourceSpec::FileDescriptor { .. } => "file_descriptor",
387    }
388}
389
390#[derive(Debug, Clone, PartialEq, Eq)]
391pub struct OAuthFlowRecord {
392    pub target: AuthBindingRef,
393    pub provider: OAuthProviderIdentity,
394    pub redirect_uri: String,
395    pub pkce_verifier: String,
396    pub created_at: Instant,
397}
398
399#[derive(Debug, Clone, PartialEq, Eq)]
400pub struct OAuthDeviceFlowRecord {
401    pub target: AuthBindingRef,
402    pub provider: OAuthProviderIdentity,
403    pub device_code: String,
404    pub created_at: Instant,
405    pub expires_at: Instant,
406}
407
408#[derive(Debug, Clone, Default, PartialEq, Eq)]
409pub struct OAuthPrunedFlows {
410    pub browser: Vec<(String, AuthBindingRef)>,
411    pub device: Vec<(String, AuthBindingRef)>,
412}
413
414impl OAuthPrunedFlows {
415    fn from_expired(
416        browser: Vec<(String, OAuthFlowRecord)>,
417        device: Vec<OAuthDeviceFlowRecord>,
418    ) -> Self {
419        Self {
420            browser: browser
421                .into_iter()
422                .map(|(flow_id, record)| (flow_id, record.target))
423                .collect(),
424            device: device
425                .into_iter()
426                .map(|record| (record.device_code, record.target))
427                .collect(),
428        }
429    }
430}
431
432#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
433pub struct OAuthFlowRegistrySnapshot {
434    #[serde(default)]
435    pub browser: Vec<PersistedOAuthBrowserFlow>,
436    #[serde(default)]
437    pub device: Vec<PersistedOAuthDeviceFlow>,
438}
439
440#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
441pub struct PersistedOAuthBrowserFlow {
442    pub state: String,
443    pub target: AuthBindingRef,
444    pub provider: String,
445    pub redirect_uri: String,
446    pub pkce_verifier: String,
447    pub created_at_millis: u64,
448    pub expires_at_millis: u64,
449}
450
451#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
452pub struct PersistedOAuthDeviceFlow {
453    pub target: AuthBindingRef,
454    pub provider: String,
455    pub device_code: String,
456    pub created_at_millis: u64,
457    pub expires_at_millis: u64,
458}
459
460#[derive(Debug, Clone, PartialEq, Eq)]
461struct OAuthDeviceFlowState {
462    record: OAuthDeviceFlowRecord,
463    poll_lease: Option<OAuthDevicePollLeaseState>,
464}
465
466#[derive(Debug, Clone, Copy, PartialEq, Eq)]
467struct OAuthDevicePollLeaseState {
468    id: u64,
469}
470
471#[derive(Debug, thiserror::Error, PartialEq, Eq)]
472pub enum OAuthFlowError {
473    #[error("oauth state is missing or expired")]
474    Missing,
475    #[error("oauth registry projection payload missing after AuthMachine accepted {operation}")]
476    RegistryProjectionMissing { operation: &'static str },
477    #[error("oauth state provider mismatch: expected {expected}, got {actual}")]
478    ProviderMismatch {
479        expected: OAuthProviderIdentity,
480        actual: OAuthProviderIdentity,
481    },
482    #[error("oauth state redirect_uri mismatch")]
483    RedirectUriMismatch,
484    #[error("oauth state target mismatch: expected {expected:?}, got {actual:?}")]
485    TargetMismatch {
486        expected: Box<AuthBindingRef>,
487        actual: Box<AuthBindingRef>,
488    },
489    #[error("failed to generate oauth state token")]
490    StateGenerationFailed,
491    #[error("oauth state registry is at capacity ({max_outstanding} outstanding flows)")]
492    CapacityExceeded { max_outstanding: usize },
493    #[error("oauth device code poll is already in progress")]
494    DevicePollInProgress,
495    #[error("oauth device code is already admitted")]
496    DeviceCodeAlreadyAdmitted,
497    #[error("oauth device code expiry is out of range")]
498    DeviceExpiryOutOfRange,
499    #[error("oauth flow lifecycle transition rejected during {operation}: {detail}")]
500    LifecycleRejected {
501        operation: &'static str,
502        detail: String,
503    },
504    #[error("oauth flow durable persistence failed during {operation}: {detail}")]
505    PersistenceFailed {
506        operation: &'static str,
507        detail: String,
508    },
509}
510
511pub trait OAuthDevicePollLifecycle: Send + Sync {
512    fn device_flow_state_is_authmachine_owned(&self) -> bool {
513        false
514    }
515
516    fn finish_device_poll(
517        &self,
518        target: &AuthBindingRef,
519        device_code: &str,
520    ) -> Result<(), OAuthFlowError>;
521
522    fn consume_device_flow(
523        &self,
524        target: &AuthBindingRef,
525        device_code: &str,
526        provider: OAuthProviderIdentity,
527    ) -> Result<(), OAuthFlowError>;
528
529    fn expire_device_flow(
530        &self,
531        target: &AuthBindingRef,
532        device_code: &str,
533    ) -> Result<(), OAuthFlowError>;
534
535    fn restore_device_flow(&self, _record: &OAuthDeviceFlowRecord) -> Result<(), OAuthFlowError> {
536        Ok(())
537    }
538
539    fn device_flow_payloads_changed(&self) -> Result<(), OAuthFlowError> {
540        Ok(())
541    }
542
543    fn device_flow_payload_removed(
544        &self,
545        _record: &OAuthDeviceFlowRecord,
546    ) -> Result<(), OAuthFlowError> {
547        self.device_flow_payloads_changed()
548    }
549}
550
551pub struct OAuthDevicePollLease {
552    device_flows: Arc<Mutex<HashMap<String, OAuthDeviceFlowState>>>,
553    target: AuthBindingRef,
554    device_code: String,
555    provider: OAuthProviderIdentity,
556    lease_id: u64,
557    lifecycle: Option<Arc<dyn OAuthDevicePollLifecycle>>,
558    operation_lock: Option<Arc<StdMutex<()>>>,
559    active: bool,
560}
561
562impl std::fmt::Debug for OAuthDevicePollLease {
563    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
564        f.debug_struct("OAuthDevicePollLease")
565            .field("target", &self.target)
566            .field("device_code", &self.device_code)
567            .field("provider", &self.provider)
568            .field("lease_id", &self.lease_id)
569            .field("has_lifecycle", &self.lifecycle.is_some())
570            .field("has_operation_lock", &self.operation_lock.is_some())
571            .field("active", &self.active)
572            .finish()
573    }
574}
575
576impl OAuthDevicePollLease {
577    fn new(
578        device_flows: Arc<Mutex<HashMap<String, OAuthDeviceFlowState>>>,
579        target: AuthBindingRef,
580        device_code: String,
581        provider: OAuthProviderIdentity,
582        lease_id: u64,
583    ) -> Self {
584        Self {
585            device_flows,
586            target,
587            device_code,
588            provider,
589            lease_id,
590            lifecycle: None,
591            operation_lock: None,
592            active: true,
593        }
594    }
595
596    pub fn terminal_flow_state_is_authmachine_owned(&self) -> bool {
597        self.lifecycle
598            .as_ref()
599            .map(|lifecycle| lifecycle.device_flow_state_is_authmachine_owned())
600            .unwrap_or(false)
601    }
602
603    fn local_missing_error(&self, operation: &'static str) -> OAuthFlowError {
604        if self.terminal_flow_state_is_authmachine_owned() {
605            OAuthFlowError::RegistryProjectionMissing { operation }
606        } else {
607            OAuthFlowError::Missing
608        }
609    }
610
611    pub fn with_lifecycle(mut self, lifecycle: Arc<dyn OAuthDevicePollLifecycle>) -> Self {
612        self.lifecycle = Some(lifecycle);
613        self
614    }
615
616    pub fn with_operation_lock(mut self, operation_lock: Arc<StdMutex<()>>) -> Self {
617        self.operation_lock = Some(operation_lock);
618        self
619    }
620
621    pub fn finish(mut self) -> Result<(), OAuthFlowError> {
622        let _operation_guard = self.operation_lock.as_ref().map(|lock| {
623            lock.lock()
624                .unwrap_or_else(std::sync::PoisonError::into_inner)
625        });
626        let verify_result = {
627            let mut flows = self.device_flows.lock();
628            prune_expired_device_locked(&mut flows);
629            verify_device_poll_lease_locked(
630                &mut flows,
631                &self.device_code,
632                self.provider,
633                self.lease_id,
634            )
635            .map(|_| ())
636        };
637        if matches!(verify_result, Err(OAuthFlowError::Missing)) {
638            if self.terminal_flow_state_is_authmachine_owned()
639                && let Some(lifecycle) = &self.lifecycle
640            {
641                let _ = lifecycle.finish_device_poll(&self.target, &self.device_code);
642            }
643            return Err(self.local_missing_error("finish_oauth_device_poll"));
644        }
645        verify_result?;
646
647        if let Some(lifecycle) = &self.lifecycle {
648            lifecycle.finish_device_poll(&self.target, &self.device_code)?;
649        }
650
651        let result = {
652            let mut flows = self.device_flows.lock();
653            let result = release_device_poll_lease_locked(
654                &mut flows,
655                &self.device_code,
656                self.provider,
657                self.lease_id,
658            );
659            prune_expired_device_locked(&mut flows);
660            result
661        };
662        match result {
663            Ok(()) => {
664                self.active = false;
665                if let Some(lifecycle) = &self.lifecycle {
666                    lifecycle.device_flow_payloads_changed()?;
667                }
668                Ok(())
669            }
670            Err(OAuthFlowError::Missing) => {
671                Err(self.local_missing_error("finish_oauth_device_poll"))
672            }
673            Err(err) => Err(err),
674        }
675    }
676
677    pub fn verify(&self) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
678        let mut flows = self.device_flows.lock();
679        prune_expired_device_locked(&mut flows);
680        let result = verify_device_poll_lease_locked(
681            &mut flows,
682            &self.device_code,
683            self.provider,
684            self.lease_id,
685        );
686        prune_expired_device_locked(&mut flows);
687        if matches!(result, Err(OAuthFlowError::Missing)) {
688            return Err(self.local_missing_error("verify_oauth_device_poll"));
689        }
690        result
691    }
692
693    pub fn consume(mut self) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
694        let _operation_guard = self.operation_lock.as_ref().map(|lock| {
695            lock.lock()
696                .unwrap_or_else(std::sync::PoisonError::into_inner)
697        });
698        let verified = {
699            let mut flows = self.device_flows.lock();
700            prune_expired_device_locked(&mut flows);
701            verify_device_poll_lease_locked(
702                &mut flows,
703                &self.device_code,
704                self.provider,
705                self.lease_id,
706            )
707        };
708        let verified = match verified {
709            Ok(verified) => verified,
710            Err(OAuthFlowError::Missing) => {
711                if self.terminal_flow_state_is_authmachine_owned()
712                    && let Some(lifecycle) = &self.lifecycle
713                {
714                    let _ = lifecycle.finish_device_poll(&self.target, &self.device_code);
715                }
716                return Err(self.local_missing_error("consume_oauth_device_flow"));
717            }
718            Err(err) => return Err(err),
719        };
720
721        if let Some(lifecycle) = &self.lifecycle {
722            lifecycle.consume_device_flow(&self.target, &self.device_code, self.provider)?;
723        }
724
725        let result = {
726            let mut flows = self.device_flows.lock();
727            let result = consume_device_poll_lease_locked(
728                &mut flows,
729                &self.device_code,
730                self.provider,
731                self.lease_id,
732            );
733            prune_expired_device_locked(&mut flows);
734            result
735        };
736        match result {
737            Ok(record) => {
738                self.active = false;
739                if let Some(lifecycle) = &self.lifecycle
740                    && let Err(err) = lifecycle.device_flow_payload_removed(&record)
741                {
742                    if matches!(
743                        err,
744                        OAuthFlowError::Missing | OAuthFlowError::RegistryProjectionMissing { .. }
745                    ) {
746                        return Err(err);
747                    }
748                    let _ = lifecycle.restore_device_flow(&verified);
749                    let mut flows = self.device_flows.lock();
750                    flows.insert(
751                        verified.device_code.clone(),
752                        OAuthDeviceFlowState {
753                            record: verified,
754                            poll_lease: None,
755                        },
756                    );
757                    return Err(err);
758                }
759                Ok(record)
760            }
761            Err(err) => {
762                if matches!(err, OAuthFlowError::Missing)
763                    && self.terminal_flow_state_is_authmachine_owned()
764                {
765                    return Err(self.local_missing_error("consume_oauth_device_flow"));
766                }
767                if let Some(lifecycle) = &self.lifecycle {
768                    let _ = lifecycle.restore_device_flow(&verified);
769                    if matches!(err, OAuthFlowError::Missing) {
770                        let _ = lifecycle.expire_device_flow(&self.target, &self.device_code);
771                    }
772                }
773                Err(err)
774            }
775        }
776    }
777}
778
779impl Drop for OAuthDevicePollLease {
780    fn drop(&mut self) {
781        if !self.active {
782            return;
783        }
784        let mut flows = self.device_flows.lock();
785        prune_expired_device_locked(&mut flows);
786        let result = release_device_poll_lease_locked(
787            &mut flows,
788            &self.device_code,
789            self.provider,
790            self.lease_id,
791        );
792        prune_expired_device_locked(&mut flows);
793        if let Some(lifecycle) = &self.lifecycle {
794            if matches!(result, Err(OAuthFlowError::Missing)) {
795                if lifecycle.device_flow_state_is_authmachine_owned() {
796                    let _ = lifecycle.finish_device_poll(&self.target, &self.device_code);
797                } else {
798                    let _ = lifecycle.expire_device_flow(&self.target, &self.device_code);
799                }
800            } else {
801                let _ = lifecycle.finish_device_poll(&self.target, &self.device_code);
802            }
803        }
804    }
805}
806
807pub trait OAuthFlowAuthority: Send + Sync {
808    fn terminal_flow_state_is_authmachine_owned(&self) -> bool {
809        false
810    }
811
812    fn start(
813        &self,
814        target: AuthBindingRef,
815        provider: OAuthProviderIdentity,
816        redirect_uri: String,
817        pkce_verifier: String,
818    ) -> Result<String, OAuthFlowError>;
819
820    fn verify(
821        &self,
822        state: &str,
823        target: &AuthBindingRef,
824        provider: OAuthProviderIdentity,
825        redirect_uri: &str,
826    ) -> Result<OAuthFlowRecord, OAuthFlowError>;
827
828    fn consume(
829        &self,
830        state: &str,
831        target: &AuthBindingRef,
832        provider: OAuthProviderIdentity,
833        redirect_uri: &str,
834    ) -> Result<OAuthFlowRecord, OAuthFlowError>;
835
836    fn admit_device_code(
837        &self,
838        target: AuthBindingRef,
839        provider: OAuthProviderIdentity,
840        device_code: String,
841        expires_in: Duration,
842    ) -> Result<(), OAuthFlowError>;
843
844    fn verify_device_code(
845        &self,
846        device_code: &str,
847        target: &AuthBindingRef,
848        provider: OAuthProviderIdentity,
849    ) -> Result<OAuthDeviceFlowRecord, OAuthFlowError>;
850
851    fn begin_device_code_poll(
852        &self,
853        device_code: &str,
854        target: &AuthBindingRef,
855        provider: OAuthProviderIdentity,
856    ) -> Result<OAuthDevicePollLease, OAuthFlowError>;
857}
858
859#[derive(Debug)]
860pub struct OAuthFlowRegistry {
861    ttl: Duration,
862    max_outstanding: usize,
863    flows: Mutex<HashMap<String, OAuthFlowRecord>>,
864    device_flows: Arc<Mutex<HashMap<String, OAuthDeviceFlowState>>>,
865    next_device_poll_lease_id: AtomicU64,
866}
867
868impl OAuthFlowRegistry {
869    pub fn new(ttl: Duration) -> Self {
870        Self::new_with_capacity(ttl, DEFAULT_MAX_OUTSTANDING_FLOWS)
871    }
872
873    pub fn new_with_capacity(ttl: Duration, max_outstanding: usize) -> Self {
874        Self {
875            ttl,
876            max_outstanding: max_outstanding.max(1),
877            flows: Mutex::new(HashMap::new()),
878            device_flows: Arc::new(Mutex::new(HashMap::new())),
879            next_device_poll_lease_id: AtomicU64::new(1),
880        }
881    }
882
883    pub fn max_outstanding(&self) -> usize {
884        self.max_outstanding
885    }
886
887    pub fn ttl(&self) -> Duration {
888        self.ttl
889    }
890
891    pub fn new_state() -> Result<String, OAuthFlowError> {
892        new_state_token()
893    }
894
895    pub fn start(
896        &self,
897        target: AuthBindingRef,
898        provider: OAuthProviderIdentity,
899        redirect_uri: impl Into<String>,
900        pkce_verifier: impl Into<String>,
901    ) -> Result<String, OAuthFlowError> {
902        <Self as OAuthFlowAuthority>::start(
903            self,
904            target,
905            provider,
906            redirect_uri.into(),
907            pkce_verifier.into(),
908        )
909    }
910
911    pub fn verify(
912        &self,
913        state: &str,
914        target: &AuthBindingRef,
915        provider: OAuthProviderIdentity,
916        redirect_uri: &str,
917    ) -> Result<OAuthFlowRecord, OAuthFlowError> {
918        <Self as OAuthFlowAuthority>::verify(self, state, target, provider, redirect_uri)
919    }
920
921    pub fn consume(
922        &self,
923        state: &str,
924        target: &AuthBindingRef,
925        provider: OAuthProviderIdentity,
926        redirect_uri: &str,
927    ) -> Result<OAuthFlowRecord, OAuthFlowError> {
928        <Self as OAuthFlowAuthority>::consume(self, state, target, provider, redirect_uri)
929    }
930
931    pub fn admit_device_code(
932        &self,
933        target: AuthBindingRef,
934        provider: OAuthProviderIdentity,
935        device_code: impl Into<String>,
936        expires_in: Duration,
937    ) -> Result<(), OAuthFlowError> {
938        <Self as OAuthFlowAuthority>::admit_device_code(
939            self,
940            target,
941            provider,
942            device_code.into(),
943            expires_in,
944        )
945    }
946
947    pub fn verify_device_code(
948        &self,
949        device_code: &str,
950        target: &AuthBindingRef,
951        provider: OAuthProviderIdentity,
952    ) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
953        <Self as OAuthFlowAuthority>::verify_device_code(self, device_code, target, provider)
954    }
955
956    pub fn begin_device_code_poll(
957        &self,
958        device_code: &str,
959        target: &AuthBindingRef,
960        provider: OAuthProviderIdentity,
961    ) -> Result<OAuthDevicePollLease, OAuthFlowError> {
962        <Self as OAuthFlowAuthority>::begin_device_code_poll(self, device_code, target, provider)
963    }
964
965    pub fn expire_device_code(
966        &self,
967        device_code: &str,
968        target: &AuthBindingRef,
969        provider: OAuthProviderIdentity,
970    ) -> Result<(), OAuthFlowError> {
971        let mut flows = self.device_flows.lock();
972        prune_expired_device_locked(&mut flows);
973        let Some(state) = flows.get(device_code) else {
974            return Err(OAuthFlowError::Missing);
975        };
976        verify_device_record(&state.record, target, provider)?;
977        flows.remove(device_code);
978        Ok(())
979    }
980
981    pub fn prune_expired_browser_flows(&self) -> Vec<(String, AuthBindingRef)> {
982        let mut flows = self.flows.lock();
983        take_expired_locked(&mut flows, self.ttl)
984            .into_iter()
985            .map(|(flow_id, record)| (flow_id, record.target))
986            .collect()
987    }
988
989    pub fn prune_expired_device_flows(&self) -> Vec<(String, AuthBindingRef)> {
990        let mut flows = self.device_flows.lock();
991        take_expired_device_locked(&mut flows)
992            .into_iter()
993            .map(|record| (record.device_code, record.target))
994            .collect()
995    }
996
997    pub fn retain_flows_with_lifecycle(
998        &self,
999        mut browser_active: impl FnMut(&AuthBindingRef, &str) -> bool,
1000        mut device_active: impl FnMut(&AuthBindingRef, &str) -> bool,
1001    ) -> OAuthPrunedFlows {
1002        let mut flows = self.flows.lock();
1003        let mut browser = Vec::new();
1004        flows.retain(|flow_id, record| {
1005            let keep = browser_active(&record.target, flow_id);
1006            if !keep {
1007                browser.push((flow_id.clone(), record.target.clone()));
1008            }
1009            keep
1010        });
1011
1012        let mut device_flows = self.device_flows.lock();
1013        let mut device = Vec::new();
1014        device_flows.retain(|device_code, state| {
1015            let keep = device_active(&state.record.target, device_code);
1016            if !keep {
1017                device.push((device_code.clone(), state.record.target.clone()));
1018            }
1019            keep
1020        });
1021        OAuthPrunedFlows { browser, device }
1022    }
1023
1024    pub fn snapshot_for_persistence(&self, now_millis: u64) -> OAuthFlowRegistrySnapshot {
1025        let flows = self.flows.lock();
1026        let browser = flows
1027            .iter()
1028            .filter_map(|(state, record)| {
1029                let elapsed = record.created_at.elapsed();
1030                if elapsed > self.ttl {
1031                    return None;
1032                }
1033                let elapsed_millis = duration_millis_u64(elapsed);
1034                let ttl_millis = duration_millis_u64(self.ttl);
1035                Some(PersistedOAuthBrowserFlow {
1036                    state: state.clone(),
1037                    target: record.target.clone(),
1038                    provider: record.provider.canonical_alias().to_string(),
1039                    redirect_uri: record.redirect_uri.clone(),
1040                    pkce_verifier: record.pkce_verifier.clone(),
1041                    created_at_millis: now_millis.saturating_sub(elapsed_millis),
1042                    expires_at_millis: now_millis
1043                        .saturating_add(ttl_millis.saturating_sub(elapsed_millis)),
1044                })
1045            })
1046            .collect::<Vec<_>>();
1047        drop(flows);
1048
1049        let device_flows = self.device_flows.lock();
1050        let now = Instant::now();
1051        let device = device_flows
1052            .values()
1053            .filter_map(|state| {
1054                let remaining = state.record.expires_at.checked_duration_since(now)?;
1055                let created_elapsed = state.record.created_at.elapsed();
1056                Some(PersistedOAuthDeviceFlow {
1057                    target: state.record.target.clone(),
1058                    provider: state.record.provider.canonical_alias().to_string(),
1059                    device_code: state.record.device_code.clone(),
1060                    created_at_millis: now_millis
1061                        .saturating_sub(duration_millis_u64(created_elapsed)),
1062                    expires_at_millis: now_millis.saturating_add(duration_millis_u64(remaining)),
1063                })
1064            })
1065            .collect();
1066
1067        OAuthFlowRegistrySnapshot { browser, device }
1068    }
1069
1070    pub fn insert_restored_browser_flow(
1071        &self,
1072        state: String,
1073        target: AuthBindingRef,
1074        provider: OAuthProviderIdentity,
1075        redirect_uri: String,
1076        pkce_verifier: String,
1077        created_at: Instant,
1078    ) -> Result<(), OAuthFlowError> {
1079        let mut flows = self.flows.lock();
1080        let device_flows = self.device_flows.lock();
1081        if flows.len() + device_flows.len() >= self.max_outstanding {
1082            return Err(OAuthFlowError::CapacityExceeded {
1083                max_outstanding: self.max_outstanding,
1084            });
1085        }
1086        flows.insert(
1087            state,
1088            OAuthFlowRecord {
1089                target,
1090                provider,
1091                redirect_uri,
1092                pkce_verifier,
1093                created_at,
1094            },
1095        );
1096        Ok(())
1097    }
1098
1099    pub fn insert_restored_device_flow(
1100        &self,
1101        target: AuthBindingRef,
1102        provider: OAuthProviderIdentity,
1103        device_code: String,
1104        created_at: Instant,
1105        expires_at: Instant,
1106    ) -> Result<(), OAuthFlowError> {
1107        let flows = self.flows.lock();
1108        let mut device_flows = self.device_flows.lock();
1109        if device_flows.contains_key(&device_code) {
1110            return Err(OAuthFlowError::DeviceCodeAlreadyAdmitted);
1111        }
1112        if flows.len() + device_flows.len() >= self.max_outstanding {
1113            return Err(OAuthFlowError::CapacityExceeded {
1114                max_outstanding: self.max_outstanding,
1115            });
1116        }
1117        let record = OAuthDeviceFlowRecord {
1118            target,
1119            provider,
1120            device_code: device_code.clone(),
1121            created_at,
1122            expires_at,
1123        };
1124        device_flows.insert(
1125            device_code,
1126            OAuthDeviceFlowState {
1127                record,
1128                poll_lease: None,
1129            },
1130        );
1131        Ok(())
1132    }
1133
1134    pub fn start_with_pruned(
1135        &self,
1136        target: AuthBindingRef,
1137        provider: OAuthProviderIdentity,
1138        redirect_uri: String,
1139        pkce_verifier: String,
1140    ) -> Result<(String, OAuthPrunedFlows), OAuthFlowError> {
1141        let state = new_state_token()?;
1142        let record = OAuthFlowRecord {
1143            target,
1144            provider,
1145            redirect_uri,
1146            pkce_verifier,
1147            created_at: Instant::now(),
1148        };
1149        let mut flows = self.flows.lock();
1150        let mut device_flows = self.device_flows.lock();
1151        let expired_browser = take_expired_locked(&mut flows, self.ttl);
1152        let expired_device = take_expired_device_locked(&mut device_flows);
1153        if flows.len() + device_flows.len() >= self.max_outstanding {
1154            return Err(OAuthFlowError::CapacityExceeded {
1155                max_outstanding: self.max_outstanding,
1156            });
1157        }
1158        flows.insert(state.clone(), record);
1159        Ok((
1160            state,
1161            OAuthPrunedFlows::from_expired(expired_browser, expired_device),
1162        ))
1163    }
1164
1165    pub fn insert_browser_flow_with_pruned(
1166        &self,
1167        state: String,
1168        target: AuthBindingRef,
1169        provider: OAuthProviderIdentity,
1170        redirect_uri: String,
1171        pkce_verifier: String,
1172    ) -> Result<OAuthPrunedFlows, OAuthFlowError> {
1173        let record = OAuthFlowRecord {
1174            target,
1175            provider,
1176            redirect_uri,
1177            pkce_verifier,
1178            created_at: Instant::now(),
1179        };
1180        let mut flows = self.flows.lock();
1181        let mut device_flows = self.device_flows.lock();
1182        let expired_browser = take_expired_locked(&mut flows, self.ttl);
1183        let expired_device = take_expired_device_locked(&mut device_flows);
1184        if flows.len() + device_flows.len() >= self.max_outstanding {
1185            return Err(OAuthFlowError::CapacityExceeded {
1186                max_outstanding: self.max_outstanding,
1187            });
1188        }
1189        flows.insert(state, record);
1190        Ok(OAuthPrunedFlows::from_expired(
1191            expired_browser,
1192            expired_device,
1193        ))
1194    }
1195
1196    pub fn admit_device_code_with_pruned(
1197        &self,
1198        target: AuthBindingRef,
1199        provider: OAuthProviderIdentity,
1200        device_code: String,
1201        expires_in: Duration,
1202    ) -> Result<OAuthPrunedFlows, OAuthFlowError> {
1203        let mut flows = self.flows.lock();
1204        let mut device_flows = self.device_flows.lock();
1205        let expired_browser = take_expired_locked(&mut flows, self.ttl);
1206        let expired_device = take_expired_device_locked(&mut device_flows);
1207        if device_flows.contains_key(&device_code) {
1208            return Err(OAuthFlowError::DeviceCodeAlreadyAdmitted);
1209        }
1210        if flows.len() + device_flows.len() >= self.max_outstanding {
1211            return Err(OAuthFlowError::CapacityExceeded {
1212                max_outstanding: self.max_outstanding,
1213            });
1214        }
1215        let now = Instant::now();
1216        let expires_at = now
1217            .checked_add(expires_in)
1218            .ok_or(OAuthFlowError::DeviceExpiryOutOfRange)?;
1219        let record = OAuthDeviceFlowRecord {
1220            target,
1221            provider,
1222            device_code: device_code.clone(),
1223            created_at: now,
1224            expires_at,
1225        };
1226        device_flows.insert(
1227            device_code,
1228            OAuthDeviceFlowState {
1229                record,
1230                poll_lease: None,
1231            },
1232        );
1233        Ok(OAuthPrunedFlows::from_expired(
1234            expired_browser,
1235            expired_device,
1236        ))
1237    }
1238
1239    pub fn admit_device_code_with_pruned_without_capacity(
1240        &self,
1241        target: AuthBindingRef,
1242        provider: OAuthProviderIdentity,
1243        device_code: String,
1244        expires_in: Duration,
1245    ) -> Result<OAuthPrunedFlows, OAuthFlowError> {
1246        let mut flows = self.flows.lock();
1247        let mut device_flows = self.device_flows.lock();
1248        let expired_browser = take_expired_locked(&mut flows, self.ttl);
1249        let expired_device = take_expired_device_locked(&mut device_flows);
1250        if device_flows.contains_key(&device_code) {
1251            return Err(OAuthFlowError::DeviceCodeAlreadyAdmitted);
1252        }
1253        let now = Instant::now();
1254        let expires_at = now
1255            .checked_add(expires_in)
1256            .ok_or(OAuthFlowError::DeviceExpiryOutOfRange)?;
1257        let record = OAuthDeviceFlowRecord {
1258            target,
1259            provider,
1260            device_code: device_code.clone(),
1261            created_at: now,
1262            expires_at,
1263        };
1264        device_flows.insert(
1265            device_code,
1266            OAuthDeviceFlowState {
1267                record,
1268                poll_lease: None,
1269            },
1270        );
1271        Ok(OAuthPrunedFlows::from_expired(
1272            expired_browser,
1273            expired_device,
1274        ))
1275    }
1276}
1277
1278impl Default for OAuthFlowRegistry {
1279    fn default() -> Self {
1280        Self::new(Duration::from_secs(10 * 60))
1281    }
1282}
1283
1284impl OAuthFlowAuthority for OAuthFlowRegistry {
1285    fn start(
1286        &self,
1287        target: AuthBindingRef,
1288        provider: OAuthProviderIdentity,
1289        redirect_uri: String,
1290        pkce_verifier: String,
1291    ) -> Result<String, OAuthFlowError> {
1292        self.start_with_pruned(target, provider, redirect_uri, pkce_verifier)
1293            .map(|(state, _)| state)
1294    }
1295
1296    fn verify(
1297        &self,
1298        state: &str,
1299        target: &AuthBindingRef,
1300        provider: OAuthProviderIdentity,
1301        redirect_uri: &str,
1302    ) -> Result<OAuthFlowRecord, OAuthFlowError> {
1303        let mut flows = self.flows.lock();
1304        prune_expired_locked(&mut flows, self.ttl);
1305        let Some(record) = flows.get(state) else {
1306            return Err(OAuthFlowError::Missing);
1307        };
1308        verify_browser_record(record, target, provider, redirect_uri)?;
1309        Ok(record.clone())
1310    }
1311
1312    fn consume(
1313        &self,
1314        state: &str,
1315        target: &AuthBindingRef,
1316        provider: OAuthProviderIdentity,
1317        redirect_uri: &str,
1318    ) -> Result<OAuthFlowRecord, OAuthFlowError> {
1319        let mut flows = self.flows.lock();
1320        prune_expired_locked(&mut flows, self.ttl);
1321        let Some(record) = flows.get(state) else {
1322            return Err(OAuthFlowError::Missing);
1323        };
1324        verify_browser_record(record, target, provider, redirect_uri)?;
1325        flows.remove(state).ok_or(OAuthFlowError::Missing)
1326    }
1327
1328    fn admit_device_code(
1329        &self,
1330        target: AuthBindingRef,
1331        provider: OAuthProviderIdentity,
1332        device_code: String,
1333        expires_in: Duration,
1334    ) -> Result<(), OAuthFlowError> {
1335        self.admit_device_code_with_pruned(target, provider, device_code, expires_in)
1336            .map(|_| ())
1337    }
1338
1339    fn verify_device_code(
1340        &self,
1341        device_code: &str,
1342        target: &AuthBindingRef,
1343        provider: OAuthProviderIdentity,
1344    ) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
1345        let mut flows = self.device_flows.lock();
1346        prune_expired_device_locked(&mut flows);
1347        let Some(record) = flows.get(device_code) else {
1348            return Err(OAuthFlowError::Missing);
1349        };
1350        verify_device_record(&record.record, target, provider)?;
1351        Ok(record.record.clone())
1352    }
1353
1354    fn begin_device_code_poll(
1355        &self,
1356        device_code: &str,
1357        target: &AuthBindingRef,
1358        provider: OAuthProviderIdentity,
1359    ) -> Result<OAuthDevicePollLease, OAuthFlowError> {
1360        let mut flows = self.device_flows.lock();
1361        prune_expired_device_locked(&mut flows);
1362        let Some(state) = flows.get_mut(device_code) else {
1363            return Err(OAuthFlowError::Missing);
1364        };
1365        verify_device_record(&state.record, target, provider)?;
1366        if state.poll_lease.is_some() {
1367            return Err(OAuthFlowError::DevicePollInProgress);
1368        }
1369        let lease_id = self
1370            .next_device_poll_lease_id
1371            .fetch_add(1, Ordering::Relaxed);
1372        state.poll_lease = Some(OAuthDevicePollLeaseState { id: lease_id });
1373        Ok(OAuthDevicePollLease::new(
1374            Arc::clone(&self.device_flows),
1375            target.clone(),
1376            device_code.to_string(),
1377            provider,
1378            lease_id,
1379        ))
1380    }
1381}
1382
1383fn new_state_token() -> Result<String, OAuthFlowError> {
1384    let mut bytes = [0_u8; 32];
1385    getrandom::fill(&mut bytes).map_err(|_| OAuthFlowError::StateGenerationFailed)?;
1386    Ok(format!(
1387        "st-{}",
1388        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
1389    ))
1390}
1391
1392fn duration_millis_u64(duration: Duration) -> u64 {
1393    u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
1394}
1395
1396fn prune_expired_locked(flows: &mut HashMap<String, OAuthFlowRecord>, ttl: Duration) {
1397    let _ = take_expired_locked(flows, ttl);
1398}
1399
1400fn take_expired_locked(
1401    flows: &mut HashMap<String, OAuthFlowRecord>,
1402    ttl: Duration,
1403) -> Vec<(String, OAuthFlowRecord)> {
1404    let expired = flows
1405        .iter()
1406        .filter(|(_, record)| record.created_at.elapsed() > ttl)
1407        .map(|(flow_id, _)| flow_id.clone())
1408        .collect::<Vec<_>>();
1409    expired
1410        .into_iter()
1411        .filter_map(|flow_id| flows.remove(&flow_id).map(|record| (flow_id, record)))
1412        .collect()
1413}
1414
1415fn release_device_poll_lease_locked(
1416    flows: &mut HashMap<String, OAuthDeviceFlowState>,
1417    device_code: &str,
1418    provider: OAuthProviderIdentity,
1419    lease_id: u64,
1420) -> Result<(), OAuthFlowError> {
1421    let Some(state) = flows.get_mut(device_code) else {
1422        return Err(OAuthFlowError::Missing);
1423    };
1424    if state.record.provider != provider {
1425        return Err(OAuthFlowError::ProviderMismatch {
1426            expected: state.record.provider,
1427            actual: provider,
1428        });
1429    }
1430    match state.poll_lease {
1431        Some(lease) if lease.id == lease_id => {
1432            state.poll_lease = None;
1433            Ok(())
1434        }
1435        Some(_) => Err(OAuthFlowError::DevicePollInProgress),
1436        None => Err(OAuthFlowError::Missing),
1437    }
1438}
1439
1440fn verify_browser_record(
1441    record: &OAuthFlowRecord,
1442    target: &AuthBindingRef,
1443    provider: OAuthProviderIdentity,
1444    redirect_uri: &str,
1445) -> Result<(), OAuthFlowError> {
1446    if &record.target != target {
1447        return Err(OAuthFlowError::TargetMismatch {
1448            expected: Box::new(record.target.clone()),
1449            actual: Box::new(target.clone()),
1450        });
1451    }
1452    if record.provider != provider {
1453        return Err(OAuthFlowError::ProviderMismatch {
1454            expected: record.provider,
1455            actual: provider,
1456        });
1457    }
1458    if record.redirect_uri != redirect_uri {
1459        return Err(OAuthFlowError::RedirectUriMismatch);
1460    }
1461    Ok(())
1462}
1463
1464fn verify_device_record(
1465    record: &OAuthDeviceFlowRecord,
1466    target: &AuthBindingRef,
1467    provider: OAuthProviderIdentity,
1468) -> Result<(), OAuthFlowError> {
1469    if &record.target != target {
1470        return Err(OAuthFlowError::TargetMismatch {
1471            expected: Box::new(record.target.clone()),
1472            actual: Box::new(target.clone()),
1473        });
1474    }
1475    if record.provider != provider {
1476        return Err(OAuthFlowError::ProviderMismatch {
1477            expected: record.provider,
1478            actual: provider,
1479        });
1480    }
1481    Ok(())
1482}
1483
1484fn verify_device_poll_lease_locked(
1485    flows: &mut HashMap<String, OAuthDeviceFlowState>,
1486    device_code: &str,
1487    provider: OAuthProviderIdentity,
1488    lease_id: u64,
1489) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
1490    let Some(state) = flows.get(device_code) else {
1491        return Err(OAuthFlowError::Missing);
1492    };
1493    if state.record.provider != provider {
1494        return Err(OAuthFlowError::ProviderMismatch {
1495            expected: state.record.provider,
1496            actual: provider,
1497        });
1498    }
1499    match state.poll_lease {
1500        Some(lease) if lease.id == lease_id => {}
1501        Some(_) => return Err(OAuthFlowError::DevicePollInProgress),
1502        None => return Err(OAuthFlowError::Missing),
1503    }
1504    Ok(state.record.clone())
1505}
1506
1507fn consume_device_poll_lease_locked(
1508    flows: &mut HashMap<String, OAuthDeviceFlowState>,
1509    device_code: &str,
1510    provider: OAuthProviderIdentity,
1511    lease_id: u64,
1512) -> Result<OAuthDeviceFlowRecord, OAuthFlowError> {
1513    verify_device_poll_lease_locked(flows, device_code, provider, lease_id)?;
1514    flows
1515        .remove(device_code)
1516        .map(|state| state.record)
1517        .ok_or(OAuthFlowError::Missing)
1518}
1519
1520fn prune_expired_device_locked(flows: &mut HashMap<String, OAuthDeviceFlowState>) {
1521    let _ = take_expired_device_locked(flows);
1522}
1523
1524fn take_expired_device_locked(
1525    flows: &mut HashMap<String, OAuthDeviceFlowState>,
1526) -> Vec<OAuthDeviceFlowRecord> {
1527    let now = Instant::now();
1528    let expired = flows
1529        .iter()
1530        .filter(|(_, state)| state.record.expires_at < now)
1531        .map(|(device_code, _)| device_code.clone())
1532        .collect::<Vec<_>>();
1533    expired
1534        .into_iter()
1535        .filter_map(|device_code| flows.remove(&device_code).map(|state| state.record))
1536        .collect()
1537}
1538
1539fn strings(values: &[&str]) -> Vec<String> {
1540    values.iter().map(|value| (*value).to_string()).collect()
1541}
1542
1543#[cfg(test)]
1544#[allow(clippy::expect_used)]
1545mod tests {
1546    use super::*;
1547
1548    fn target() -> AuthBindingRef {
1549        AuthBindingRef {
1550            realm: meerkat_core::RealmId::parse("dev").expect("valid realm"),
1551            binding: meerkat_core::BindingId::parse("default_openai").expect("valid binding"),
1552            profile: None,
1553        }
1554    }
1555
1556    fn alternate_target() -> AuthBindingRef {
1557        AuthBindingRef {
1558            realm: meerkat_core::RealmId::parse("dev").expect("valid realm"),
1559            binding: meerkat_core::BindingId::parse("alternate_openai").expect("valid binding"),
1560            profile: None,
1561        }
1562    }
1563
1564    #[test]
1565    fn oauth_state_pkce_round_trip() {
1566        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1567        let state = registry
1568            .start(
1569                target(),
1570                OAuthProviderIdentity::OpenAiChatGpt,
1571                "http://127.0.0.1/callback",
1572                "verifier",
1573            )
1574            .expect("state generation succeeds");
1575        let record = registry.consume(
1576            &state,
1577            &target(),
1578            OAuthProviderIdentity::OpenAiChatGpt,
1579            "http://127.0.0.1/callback",
1580        );
1581        assert!(
1582            record.is_ok(),
1583            "state should resolve once: {:?}",
1584            record.err()
1585        );
1586        if let Ok(record) = record {
1587            assert_eq!(record.pkce_verifier, "verifier");
1588        }
1589        assert!(matches!(
1590            registry.consume(
1591                &state,
1592                &target(),
1593                OAuthProviderIdentity::OpenAiChatGpt,
1594                "http://127.0.0.1/callback"
1595            ),
1596            Err(OAuthFlowError::Missing)
1597        ));
1598    }
1599
1600    #[test]
1601    fn oauth_state_rejects_mismatch() {
1602        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1603        let state = registry
1604            .start(
1605                target(),
1606                OAuthProviderIdentity::OpenAiChatGpt,
1607                "http://127.0.0.1/callback",
1608                "verifier",
1609            )
1610            .expect("state generation succeeds");
1611        assert!(matches!(
1612            registry.consume(
1613                &state,
1614                &target(),
1615                OAuthProviderIdentity::AnthropicClaudeAi,
1616                "http://127.0.0.1/callback"
1617            ),
1618            Err(OAuthFlowError::ProviderMismatch { .. })
1619        ));
1620    }
1621
1622    #[test]
1623    fn oauth_state_provider_mismatch_does_not_consume_state() {
1624        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1625        let state = registry
1626            .start(
1627                target(),
1628                OAuthProviderIdentity::OpenAiChatGpt,
1629                "http://127.0.0.1/callback",
1630                "verifier",
1631            )
1632            .expect("state generation succeeds");
1633        assert!(matches!(
1634            registry.consume(
1635                &state,
1636                &target(),
1637                OAuthProviderIdentity::AnthropicClaudeAi,
1638                "http://127.0.0.1/callback"
1639            ),
1640            Err(OAuthFlowError::ProviderMismatch { .. })
1641        ));
1642
1643        let record = registry
1644            .consume(
1645                &state,
1646                &target(),
1647                OAuthProviderIdentity::OpenAiChatGpt,
1648                "http://127.0.0.1/callback",
1649            )
1650            .expect("provider mismatch must leave state retryable");
1651        assert_eq!(record.pkce_verifier, "verifier");
1652    }
1653
1654    #[test]
1655    fn oauth_state_rejects_redirect_uri_mismatch() {
1656        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1657        let state = registry
1658            .start(
1659                target(),
1660                OAuthProviderIdentity::OpenAiChatGpt,
1661                "http://127.0.0.1/callback",
1662                "verifier",
1663            )
1664            .expect("state generation succeeds");
1665        assert!(matches!(
1666            registry.consume(
1667                &state,
1668                &target(),
1669                OAuthProviderIdentity::OpenAiChatGpt,
1670                "http://127.0.0.1/other"
1671            ),
1672            Err(OAuthFlowError::RedirectUriMismatch)
1673        ));
1674    }
1675
1676    #[test]
1677    fn oauth_state_redirect_uri_mismatch_does_not_consume_state() {
1678        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1679        let state = registry
1680            .start(
1681                target(),
1682                OAuthProviderIdentity::OpenAiChatGpt,
1683                "http://127.0.0.1/callback",
1684                "verifier",
1685            )
1686            .expect("state generation succeeds");
1687        assert!(matches!(
1688            registry.consume(
1689                &state,
1690                &target(),
1691                OAuthProviderIdentity::OpenAiChatGpt,
1692                "http://127.0.0.1/other"
1693            ),
1694            Err(OAuthFlowError::RedirectUriMismatch)
1695        ));
1696
1697        let record = registry
1698            .consume(
1699                &state,
1700                &target(),
1701                OAuthProviderIdentity::OpenAiChatGpt,
1702                "http://127.0.0.1/callback",
1703            )
1704            .expect("redirect mismatch must leave state retryable");
1705        assert_eq!(record.pkce_verifier, "verifier");
1706    }
1707
1708    #[test]
1709    fn oauth_state_target_mismatch_does_not_consume_state() {
1710        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1711        let state = registry
1712            .start(
1713                target(),
1714                OAuthProviderIdentity::OpenAiChatGpt,
1715                "http://127.0.0.1/callback",
1716                "verifier",
1717            )
1718            .expect("state generation succeeds");
1719        assert!(matches!(
1720            registry.consume(
1721                &state,
1722                &alternate_target(),
1723                OAuthProviderIdentity::OpenAiChatGpt,
1724                "http://127.0.0.1/callback"
1725            ),
1726            Err(OAuthFlowError::TargetMismatch { .. })
1727        ));
1728
1729        let record = registry
1730            .consume(
1731                &state,
1732                &target(),
1733                OAuthProviderIdentity::OpenAiChatGpt,
1734                "http://127.0.0.1/callback",
1735            )
1736            .expect("target mismatch must leave state retryable");
1737        assert_eq!(record.pkce_verifier, "verifier");
1738    }
1739
1740    #[test]
1741    fn oauth_state_verify_does_not_consume_state() {
1742        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1743        let state = registry
1744            .start(
1745                target(),
1746                OAuthProviderIdentity::OpenAiChatGpt,
1747                "http://127.0.0.1/callback",
1748                "verifier",
1749            )
1750            .expect("state generation succeeds");
1751
1752        let verified = registry
1753            .verify(
1754                &state,
1755                &target(),
1756                OAuthProviderIdentity::OpenAiChatGpt,
1757                "http://127.0.0.1/callback",
1758            )
1759            .expect("verify should read state before terminal commit");
1760
1761        assert_eq!(verified.pkce_verifier, "verifier");
1762        registry
1763            .consume(
1764                &state,
1765                &target(),
1766                OAuthProviderIdentity::OpenAiChatGpt,
1767                "http://127.0.0.1/callback",
1768            )
1769            .expect("verified browser state remains available for terminal consume");
1770    }
1771
1772    #[test]
1773    fn oauth_state_random_tokens_are_urlsafe_and_unique() {
1774        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1775        let first = registry
1776            .start(
1777                target(),
1778                OAuthProviderIdentity::OpenAiChatGpt,
1779                "http://127.0.0.1/callback",
1780                "verifier-a",
1781            )
1782            .expect("state generation succeeds");
1783        let second = registry
1784            .start(
1785                target(),
1786                OAuthProviderIdentity::OpenAiChatGpt,
1787                "http://127.0.0.1/callback",
1788                "verifier-b",
1789            )
1790            .expect("state generation succeeds");
1791        assert_ne!(first, second);
1792        assert!(first.starts_with("st-"));
1793        assert!(
1794            first[3..]
1795                .chars()
1796                .all(|ch| { ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' })
1797        );
1798    }
1799
1800    #[test]
1801    fn oauth_state_expired_records_are_pruned() {
1802        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1803        registry.flows.lock().insert(
1804            "st-old".to_string(),
1805            OAuthFlowRecord {
1806                target: target(),
1807                provider: OAuthProviderIdentity::OpenAiChatGpt,
1808                redirect_uri: "http://127.0.0.1/callback".to_string(),
1809                pkce_verifier: "verifier".to_string(),
1810                created_at: Instant::now()
1811                    .checked_sub(Duration::from_secs(61))
1812                    .expect("test duration is representable"),
1813            },
1814        );
1815        assert!(matches!(
1816            registry.consume(
1817                "st-old",
1818                &target(),
1819                OAuthProviderIdentity::OpenAiChatGpt,
1820                "http://127.0.0.1/callback"
1821            ),
1822            Err(OAuthFlowError::Missing)
1823        ));
1824    }
1825
1826    #[test]
1827    fn oauth_state_registry_rejects_start_at_capacity() {
1828        let registry = OAuthFlowRegistry::new_with_capacity(Duration::from_secs(60), 2);
1829        let first = registry
1830            .start(
1831                target(),
1832                OAuthProviderIdentity::OpenAiChatGpt,
1833                "http://127.0.0.1/callback",
1834                "first",
1835            )
1836            .expect("state generation succeeds");
1837        let second = registry
1838            .start(
1839                target(),
1840                OAuthProviderIdentity::OpenAiChatGpt,
1841                "http://127.0.0.1/callback",
1842                "second",
1843            )
1844            .expect("state generation succeeds");
1845        let third = registry.start(
1846            target(),
1847            OAuthProviderIdentity::OpenAiChatGpt,
1848            "http://127.0.0.1/callback",
1849            "third",
1850        );
1851
1852        assert!(matches!(
1853            third,
1854            Err(OAuthFlowError::CapacityExceeded { max_outstanding: 2 })
1855        ));
1856        assert!(
1857            registry
1858                .consume(
1859                    &first,
1860                    &target(),
1861                    OAuthProviderIdentity::OpenAiChatGpt,
1862                    "http://127.0.0.1/callback"
1863                )
1864                .is_ok()
1865        );
1866        assert!(
1867            registry
1868                .consume(
1869                    &second,
1870                    &target(),
1871                    OAuthProviderIdentity::OpenAiChatGpt,
1872                    "http://127.0.0.1/callback"
1873                )
1874                .is_ok()
1875        );
1876    }
1877
1878    #[test]
1879    fn oauth_state_cannot_cross_login_lifecycle_authorities() {
1880        let admitting_authority = OAuthFlowRegistry::new(Duration::from_secs(60));
1881        let unrelated_authority = OAuthFlowRegistry::new(Duration::from_secs(60));
1882        let state = admitting_authority
1883            .start(
1884                target(),
1885                OAuthProviderIdentity::OpenAiChatGpt,
1886                "http://127.0.0.1/callback",
1887                "verifier",
1888            )
1889            .expect("state generation succeeds");
1890
1891        assert!(matches!(
1892            unrelated_authority.consume(
1893                &state,
1894                &target(),
1895                OAuthProviderIdentity::OpenAiChatGpt,
1896                "http://127.0.0.1/callback"
1897            ),
1898            Err(OAuthFlowError::Missing)
1899        ));
1900
1901        let record = admitting_authority
1902            .consume(
1903                &state,
1904                &target(),
1905                OAuthProviderIdentity::OpenAiChatGpt,
1906                "http://127.0.0.1/callback",
1907            )
1908            .expect("admitting authority owns the flow");
1909        assert_eq!(record.pkce_verifier, "verifier");
1910    }
1911
1912    #[test]
1913    fn oauth_device_flow_is_retained_until_terminal_consume() {
1914        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1915        registry
1916            .admit_device_code(
1917                target(),
1918                OAuthProviderIdentity::GoogleCodeAssist,
1919                "device-code",
1920                Duration::from_secs(600),
1921            )
1922            .expect("device code admitted");
1923
1924        let observed = registry
1925            .verify_device_code(
1926                "device-code",
1927                &target(),
1928                OAuthProviderIdentity::GoogleCodeAssist,
1929            )
1930            .expect("pending device flow remains visible");
1931        assert_eq!(observed.device_code, "device-code");
1932
1933        let poll = registry
1934            .begin_device_code_poll(
1935                "device-code",
1936                &target(),
1937                OAuthProviderIdentity::GoogleCodeAssist,
1938            )
1939            .expect("poll begins");
1940        let consumed = poll.consume().expect("terminal device flow consumes");
1941        assert_eq!(consumed.provider, OAuthProviderIdentity::GoogleCodeAssist);
1942        assert!(matches!(
1943            registry.verify_device_code(
1944                "device-code",
1945                &target(),
1946                OAuthProviderIdentity::GoogleCodeAssist
1947            ),
1948            Err(OAuthFlowError::Missing)
1949        ));
1950    }
1951
1952    #[test]
1953    fn oauth_device_poll_verify_keeps_terminal_consume_available() {
1954        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1955        registry
1956            .admit_device_code(
1957                target(),
1958                OAuthProviderIdentity::GoogleCodeAssist,
1959                "device-code",
1960                Duration::from_secs(600),
1961            )
1962            .expect("device code admitted");
1963        let poll = registry
1964            .begin_device_code_poll(
1965                "device-code",
1966                &target(),
1967                OAuthProviderIdentity::GoogleCodeAssist,
1968            )
1969            .expect("poll begins");
1970
1971        let verified = poll
1972            .verify()
1973            .expect("terminal preflight verifies the current lease");
1974
1975        assert_eq!(verified.device_code, "device-code");
1976        assert!(matches!(
1977            registry.begin_device_code_poll(
1978                "device-code",
1979                &target(),
1980                OAuthProviderIdentity::GoogleCodeAssist
1981            ),
1982            Err(OAuthFlowError::DevicePollInProgress)
1983        ));
1984        let consumed = poll.consume().expect("verified lease still consumes");
1985        assert_eq!(consumed.device_code, "device-code");
1986    }
1987
1988    #[test]
1989    fn oauth_device_admission_does_not_replace_active_poll_lease() {
1990        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
1991        registry
1992            .admit_device_code(
1993                target(),
1994                OAuthProviderIdentity::GoogleCodeAssist,
1995                "device-code",
1996                Duration::from_secs(600),
1997            )
1998            .expect("device code admitted");
1999        let poll = registry
2000            .begin_device_code_poll(
2001                "device-code",
2002                &target(),
2003                OAuthProviderIdentity::GoogleCodeAssist,
2004            )
2005            .expect("poll begins");
2006
2007        let duplicate = registry.admit_device_code(
2008            target(),
2009            OAuthProviderIdentity::GoogleCodeAssist,
2010            "device-code",
2011            Duration::from_secs(600),
2012        );
2013
2014        assert_eq!(duplicate, Err(OAuthFlowError::DeviceCodeAlreadyAdmitted));
2015        let consumed = poll
2016            .consume()
2017            .expect("duplicate admission must not replace the active poll lease");
2018        assert_eq!(consumed.device_code, "device-code");
2019    }
2020
2021    #[test]
2022    fn oauth_device_flow_rejects_provider_mismatch() {
2023        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2024        registry
2025            .admit_device_code(
2026                target(),
2027                OAuthProviderIdentity::GoogleCodeAssist,
2028                "device-code",
2029                Duration::from_secs(600),
2030            )
2031            .expect("device code admitted");
2032
2033        assert!(matches!(
2034            registry.verify_device_code(
2035                "device-code",
2036                &target(),
2037                OAuthProviderIdentity::OpenAiChatGpt
2038            ),
2039            Err(OAuthFlowError::ProviderMismatch { .. })
2040        ));
2041        let poll = registry
2042            .begin_device_code_poll(
2043                "device-code",
2044                &target(),
2045                OAuthProviderIdentity::GoogleCodeAssist,
2046            )
2047            .expect("correct provider poll begins after mismatch");
2048        assert!(poll.consume().is_ok());
2049    }
2050
2051    #[test]
2052    fn oauth_device_flow_rejects_target_mismatch_without_consuming() {
2053        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2054        registry
2055            .admit_device_code(
2056                target(),
2057                OAuthProviderIdentity::GoogleCodeAssist,
2058                "device-code",
2059                Duration::from_secs(600),
2060            )
2061            .expect("device code admitted");
2062
2063        assert!(matches!(
2064            registry.verify_device_code(
2065                "device-code",
2066                &alternate_target(),
2067                OAuthProviderIdentity::GoogleCodeAssist
2068            ),
2069            Err(OAuthFlowError::TargetMismatch { .. })
2070        ));
2071        let poll = registry
2072            .begin_device_code_poll(
2073                "device-code",
2074                &target(),
2075                OAuthProviderIdentity::GoogleCodeAssist,
2076            )
2077            .expect("correct target poll begins after mismatch");
2078        assert!(poll.consume().is_ok());
2079    }
2080
2081    #[test]
2082    fn oauth_device_flow_expired_records_are_pruned() {
2083        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2084        registry.device_flows.lock().insert(
2085            "device-code".to_string(),
2086            OAuthDeviceFlowState {
2087                record: OAuthDeviceFlowRecord {
2088                    target: target(),
2089                    provider: OAuthProviderIdentity::GoogleCodeAssist,
2090                    device_code: "device-code".to_string(),
2091                    created_at: Instant::now()
2092                        .checked_sub(Duration::from_secs(601))
2093                        .expect("test duration is representable"),
2094                    expires_at: Instant::now()
2095                        .checked_sub(Duration::from_secs(1))
2096                        .expect("test duration is representable"),
2097                },
2098                poll_lease: None,
2099            },
2100        );
2101
2102        assert!(matches!(
2103            registry.verify_device_code(
2104                "device-code",
2105                &target(),
2106                OAuthProviderIdentity::GoogleCodeAssist
2107            ),
2108            Err(OAuthFlowError::Missing)
2109        ));
2110    }
2111
2112    #[test]
2113    fn oauth_device_flow_rejects_unrepresentable_expiry() {
2114        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2115        let err = registry
2116            .admit_device_code(
2117                target(),
2118                OAuthProviderIdentity::GoogleCodeAssist,
2119                "device-code",
2120                Duration::MAX,
2121            )
2122            .expect_err("unrepresentable device expiry should be rejected");
2123
2124        assert_eq!(err, OAuthFlowError::DeviceExpiryOutOfRange);
2125        assert!(matches!(
2126            registry.verify_device_code(
2127                "device-code",
2128                &target(),
2129                OAuthProviderIdentity::GoogleCodeAssist
2130            ),
2131            Err(OAuthFlowError::Missing)
2132        ));
2133    }
2134
2135    #[test]
2136    fn oauth_device_terminal_consume_rejects_local_expiry_boundary() {
2137        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2138        registry
2139            .admit_device_code(
2140                target(),
2141                OAuthProviderIdentity::GoogleCodeAssist,
2142                "device-code",
2143                Duration::from_secs(600),
2144            )
2145            .expect("device code admitted");
2146        let poll = registry
2147            .begin_device_code_poll(
2148                "device-code",
2149                &target(),
2150                OAuthProviderIdentity::GoogleCodeAssist,
2151            )
2152            .expect("poll begins before local expiry boundary");
2153        {
2154            let mut flows = registry.device_flows.lock();
2155            flows
2156                .get_mut("device-code")
2157                .expect("device flow exists")
2158                .record
2159                .expires_at = Instant::now()
2160                .checked_sub(Duration::from_secs(1))
2161                .expect("test duration is representable");
2162        }
2163
2164        assert!(matches!(poll.consume(), Err(OAuthFlowError::Missing)));
2165        assert!(matches!(
2166            registry.verify_device_code(
2167                "device-code",
2168                &target(),
2169                OAuthProviderIdentity::GoogleCodeAssist
2170            ),
2171            Err(OAuthFlowError::Missing)
2172        ));
2173    }
2174
2175    #[test]
2176    fn oauth_device_terminal_consume_rejects_intervening_prune_while_polling() {
2177        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2178        registry
2179            .admit_device_code(
2180                target(),
2181                OAuthProviderIdentity::GoogleCodeAssist,
2182                "device-code",
2183                Duration::from_secs(600),
2184            )
2185            .expect("device code admitted");
2186        let poll = registry
2187            .begin_device_code_poll(
2188                "device-code",
2189                &target(),
2190                OAuthProviderIdentity::GoogleCodeAssist,
2191            )
2192            .expect("poll begins");
2193        assert!(matches!(
2194            registry.begin_device_code_poll(
2195                "device-code",
2196                &target(),
2197                OAuthProviderIdentity::GoogleCodeAssist
2198            ),
2199            Err(OAuthFlowError::DevicePollInProgress)
2200        ));
2201        {
2202            let mut flows = registry.device_flows.lock();
2203            flows
2204                .get_mut("device-code")
2205                .expect("device flow exists")
2206                .record
2207                .expires_at = Instant::now()
2208                .checked_sub(Duration::from_secs(1))
2209                .expect("test duration is representable");
2210        }
2211
2212        registry
2213            .start(
2214                target(),
2215                OAuthProviderIdentity::OpenAiChatGpt,
2216                "http://127.0.0.1/callback",
2217                "verifier",
2218            )
2219            .expect("unrelated start prunes expired device flow");
2220        assert!(matches!(poll.consume(), Err(OAuthFlowError::Missing)));
2221    }
2222
2223    #[test]
2224    fn oauth_device_poll_drop_releases_in_progress_lifecycle() {
2225        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2226        registry
2227            .admit_device_code(
2228                target(),
2229                OAuthProviderIdentity::GoogleCodeAssist,
2230                "device-code",
2231                Duration::from_secs(600),
2232            )
2233            .expect("device code admitted");
2234
2235        let poll = registry
2236            .begin_device_code_poll(
2237                "device-code",
2238                &target(),
2239                OAuthProviderIdentity::GoogleCodeAssist,
2240            )
2241            .expect("poll begins");
2242        assert!(matches!(
2243            registry.begin_device_code_poll(
2244                "device-code",
2245                &target(),
2246                OAuthProviderIdentity::GoogleCodeAssist
2247            ),
2248            Err(OAuthFlowError::DevicePollInProgress)
2249        ));
2250
2251        drop(poll);
2252
2253        registry
2254            .begin_device_code_poll(
2255                "device-code",
2256                &target(),
2257                OAuthProviderIdentity::GoogleCodeAssist,
2258            )
2259            .expect("dropped poll lease releases the in-progress lifecycle");
2260    }
2261
2262    #[test]
2263    fn oauth_device_poll_drop_prunes_expired_in_progress_record() {
2264        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2265        registry
2266            .admit_device_code(
2267                target(),
2268                OAuthProviderIdentity::GoogleCodeAssist,
2269                "device-code",
2270                Duration::from_secs(600),
2271            )
2272            .expect("device code admitted");
2273        let poll = registry
2274            .begin_device_code_poll(
2275                "device-code",
2276                &target(),
2277                OAuthProviderIdentity::GoogleCodeAssist,
2278            )
2279            .expect("poll begins");
2280        {
2281            let mut flows = registry.device_flows.lock();
2282            flows
2283                .get_mut("device-code")
2284                .expect("device flow exists")
2285                .record
2286                .expires_at = Instant::now()
2287                .checked_sub(Duration::from_secs(1))
2288                .expect("test duration is representable");
2289        }
2290
2291        drop(poll);
2292
2293        assert!(matches!(
2294            registry.verify_device_code(
2295                "device-code",
2296                &target(),
2297                OAuthProviderIdentity::GoogleCodeAssist
2298            ),
2299            Err(OAuthFlowError::Missing)
2300        ));
2301    }
2302
2303    struct RejectConsumeLifecycle;
2304
2305    impl OAuthDevicePollLifecycle for RejectConsumeLifecycle {
2306        fn finish_device_poll(
2307            &self,
2308            _target: &AuthBindingRef,
2309            _device_code: &str,
2310        ) -> Result<(), OAuthFlowError> {
2311            Ok(())
2312        }
2313
2314        fn consume_device_flow(
2315            &self,
2316            _target: &AuthBindingRef,
2317            _device_code: &str,
2318            _provider: OAuthProviderIdentity,
2319        ) -> Result<(), OAuthFlowError> {
2320            Err(OAuthFlowError::LifecycleRejected {
2321                operation: "consume_oauth_device_flow",
2322                detail: "injected failure".to_string(),
2323            })
2324        }
2325
2326        fn expire_device_flow(
2327            &self,
2328            _target: &AuthBindingRef,
2329            _device_code: &str,
2330        ) -> Result<(), OAuthFlowError> {
2331            Ok(())
2332        }
2333    }
2334
2335    #[test]
2336    fn oauth_device_lifecycle_consume_failure_keeps_flow_retryable() {
2337        let registry = OAuthFlowRegistry::new(Duration::from_secs(60));
2338        registry
2339            .admit_device_code(
2340                target(),
2341                OAuthProviderIdentity::GoogleCodeAssist,
2342                "device-code",
2343                Duration::from_secs(600),
2344            )
2345            .expect("device code admitted");
2346        let poll = registry
2347            .begin_device_code_poll(
2348                "device-code",
2349                &target(),
2350                OAuthProviderIdentity::GoogleCodeAssist,
2351            )
2352            .expect("poll begins")
2353            .with_lifecycle(std::sync::Arc::new(RejectConsumeLifecycle));
2354
2355        assert!(matches!(
2356            poll.consume(),
2357            Err(OAuthFlowError::LifecycleRejected {
2358                operation: "consume_oauth_device_flow",
2359                ..
2360            })
2361        ));
2362
2363        let retry = registry
2364            .begin_device_code_poll(
2365                "device-code",
2366                &target(),
2367                OAuthProviderIdentity::GoogleCodeAssist,
2368            )
2369            .expect("failed lifecycle consume keeps device flow retryable");
2370        assert!(retry.consume().is_ok());
2371    }
2372
2373    #[test]
2374    fn oauth_provider_resolution_preserves_aliases() {
2375        let cases = [
2376            (
2377                "anthropic",
2378                OAuthProviderIdentity::AnthropicClaudeAi,
2379                Provider::Anthropic,
2380                PersistedAuthMode::ClaudeAiOauth,
2381            ),
2382            (
2383                "claude",
2384                OAuthProviderIdentity::AnthropicClaudeAi,
2385                Provider::Anthropic,
2386                PersistedAuthMode::ClaudeAiOauth,
2387            ),
2388            (
2389                "claude.ai",
2390                OAuthProviderIdentity::AnthropicClaudeAi,
2391                Provider::Anthropic,
2392                PersistedAuthMode::ClaudeAiOauth,
2393            ),
2394            (
2395                "openai",
2396                OAuthProviderIdentity::OpenAiChatGpt,
2397                Provider::OpenAI,
2398                PersistedAuthMode::ChatgptOauth,
2399            ),
2400            (
2401                "chatgpt",
2402                OAuthProviderIdentity::OpenAiChatGpt,
2403                Provider::OpenAI,
2404                PersistedAuthMode::ChatgptOauth,
2405            ),
2406            (
2407                "google",
2408                OAuthProviderIdentity::GoogleCodeAssist,
2409                Provider::Gemini,
2410                PersistedAuthMode::GoogleOauth,
2411            ),
2412            (
2413                "gemini",
2414                OAuthProviderIdentity::GoogleCodeAssist,
2415                Provider::Gemini,
2416                PersistedAuthMode::GoogleOauth,
2417            ),
2418            (
2419                "code_assist",
2420                OAuthProviderIdentity::GoogleCodeAssist,
2421                Provider::Gemini,
2422                PersistedAuthMode::GoogleOauth,
2423            ),
2424        ];
2425
2426        for (alias, identity, provider, auth_mode) in cases {
2427            let resolved =
2428                resolve_oauth_provider(alias, "http://127.0.0.1/callback").expect("alias resolves");
2429            assert_eq!(resolved.identity, identity);
2430            assert_eq!(resolved.provider, provider);
2431            assert_eq!(resolved.auth_mode, auth_mode);
2432            assert_eq!(resolved.endpoints.redirect_uri, "http://127.0.0.1/callback");
2433        }
2434    }
2435
2436    #[test]
2437    fn oauth_provider_resolution_exposes_google_device_secret() {
2438        let resolved =
2439            resolve_oauth_provider("code_assist", "").expect("google code assist resolves");
2440
2441        assert_eq!(resolved.identity, OAuthProviderIdentity::GoogleCodeAssist);
2442        assert!(resolved.endpoints.device_code_url.is_some());
2443        assert_eq!(resolved.client_secret, Some(GOOGLE_CLIENT_SECRET));
2444    }
2445
2446    #[cfg(feature = "oauth")]
2447    #[test]
2448    fn openai_provider_resolution_matches_codex_authorize_contract() {
2449        let resolved = resolve_oauth_provider("openai", "http://localhost:1455/auth/callback")
2450            .expect("openai resolves");
2451        let pkce = crate::auth_oauth::PkcePair::generate_s256();
2452        let authorize_url = resolved
2453            .endpoints
2454            .authorize_url_with_pkce(&pkce.challenge, "state-abc");
2455
2456        assert_eq!(
2457            resolved.endpoints.redirect_uri,
2458            "http://localhost:1455/auth/callback"
2459        );
2460        assert_eq!(
2461            resolved.endpoints.token_request_format,
2462            OAuthTokenRequestFormat::FormUrlEncoded
2463        );
2464        assert!(
2465            authorize_url.contains("redirect_uri=http%3A%2F%2Flocalhost%3A1455%2Fauth%2Fcallback")
2466        );
2467        assert!(authorize_url.contains("id_token_add_organizations=true"));
2468        assert!(authorize_url.contains("codex_cli_simplified_flow=true"));
2469        assert!(authorize_url.contains("originator=codex_cli_rs"));
2470    }
2471
2472    #[test]
2473    fn anthropic_provider_resolution_matches_claude_code_token_contract() {
2474        let resolved = resolve_oauth_provider("anthropic", "http://localhost:1455/callback")
2475            .expect("anthropic resolves");
2476
2477        assert_eq!(
2478            resolved.endpoints.token_request_format,
2479            OAuthTokenRequestFormat::Json
2480        );
2481        assert!(resolved.endpoints.include_state_in_token_exchange);
2482        assert_eq!(
2483            resolved.endpoints.refresh_scopes,
2484            strings(ANTHROPIC_CLAUDE_AI_SCOPES)
2485        );
2486        assert!(
2487            resolved
2488                .endpoints
2489                .extra_authorize_params
2490                .contains(&("code".to_string(), "true".to_string()))
2491        );
2492    }
2493}