Skip to main content

rab/provider/oauth/
device_code.rs

1//! OAuth device code flow poller — matching pi's pollOAuthDeviceCodeFlow.
2//!
3//! Polls the token endpoint at the configured interval until the user
4//! completes the login in their browser, the flow times out, or is cancelled.
5
6use std::future::Future;
7use std::pin::Pin;
8use tokio::time::{Duration, sleep};
9use tokio_util::sync::CancellationToken;
10
11const CANCEL_MESSAGE: &str = "Login cancelled";
12const TIMEOUT_MESSAGE: &str = "Device flow timed out";
13const SLOW_DOWN_TIMEOUT_MESSAGE: &str = "Device flow timed out after one or more slow_down responses. \
14     This is often caused by clock drift in WSL or VM environments. \
15     Please sync or restart the VM clock and try again.";
16const MINIMUM_INTERVAL_MS: u64 = 1000;
17const DEFAULT_POLL_INTERVAL_SECONDS: u64 = 5;
18const SLOW_DOWN_INTERVAL_INCREMENT_MS: u64 = 5000;
19
20/// Result from a single poll attempt.
21pub enum PollStatus<T> {
22    Complete(T),
23    Pending,
24    SlowDown,
25    Failed(String),
26}
27
28/// Async poll function type for device code flow.
29pub type PollFn<'a, T> = Box<
30    dyn FnMut() -> Pin<Box<dyn Future<Output = Result<PollStatus<T>, String>> + Send>> + Send + 'a,
31>;
32
33/// Options for the device code poller.
34pub struct PollOptions<'a, T> {
35    pub interval_seconds: Option<u32>,
36    pub expires_in_seconds: Option<u32>,
37    pub poll: PollFn<'a, T>,
38    pub cancel: Option<CancellationToken>,
39}
40
41/// Poll the token endpoint until the user completes login or the flow fails.
42pub async fn poll_device_code_flow<T>(mut options: PollOptions<'_, T>) -> Result<T, String> {
43    let deadline = match options.expires_in_seconds {
44        Some(secs) => std::time::Instant::now() + std::time::Duration::from_secs(secs as u64),
45        None => std::time::Instant::now() + std::time::Duration::from_secs(300), // 5 min default
46    };
47
48    let mut interval_ms = std::cmp::max(
49        MINIMUM_INTERVAL_MS,
50        (options
51            .interval_seconds
52            .unwrap_or(DEFAULT_POLL_INTERVAL_SECONDS as u32) as u64)
53            * 1000,
54    );
55
56    let mut slow_down_responses = 0;
57
58    while std::time::Instant::now() < deadline {
59        if let Some(ref cancel) = options.cancel
60            && cancel.is_cancelled()
61        {
62            return Err(CANCEL_MESSAGE.to_string());
63        }
64
65        let result = (options.poll)().await?;
66        match result {
67            PollStatus::Complete(value) => return Ok(value),
68            PollStatus::Failed(msg) => return Err(msg),
69            PollStatus::SlowDown => {
70                slow_down_responses += 1;
71                interval_ms = std::cmp::max(
72                    MINIMUM_INTERVAL_MS,
73                    interval_ms + SLOW_DOWN_INTERVAL_INCREMENT_MS,
74                );
75            }
76            PollStatus::Pending => {}
77        }
78
79        let remaining = deadline.saturating_duration_since(std::time::Instant::now());
80        if remaining.as_millis() == 0 {
81            break;
82        }
83
84        let sleep_ms = std::cmp::min(interval_ms, remaining.as_millis() as u64);
85        sleep(Duration::from_millis(sleep_ms)).await;
86    }
87
88    Err(if slow_down_responses > 0 {
89        SLOW_DOWN_TIMEOUT_MESSAGE.to_string()
90    } else {
91        TIMEOUT_MESSAGE.to_string()
92    })
93}