rab/provider/oauth/
device_code.rs1use std::future::Future;
7use std::pin::Pin;
8use tokio::time::{Duration, sleep};
9use tokio_util::sync::CancellationToken;
10
11const CANCEL_MESSAGE: &str = "Login cancelled";
12const TIMEOUT_MESSAGE: &str = "Device flow timed out";
13const SLOW_DOWN_TIMEOUT_MESSAGE: &str = "Device flow timed out after one or more slow_down responses. \
14 This is often caused by clock drift in WSL or VM environments. \
15 Please sync or restart the VM clock and try again.";
16const MINIMUM_INTERVAL_MS: u64 = 1000;
17const DEFAULT_POLL_INTERVAL_SECONDS: u64 = 5;
18const SLOW_DOWN_INTERVAL_INCREMENT_MS: u64 = 5000;
19
20pub enum PollStatus<T> {
22 Complete(T),
23 Pending,
24 SlowDown,
25 Failed(String),
26}
27
28pub type PollFn<'a, T> = Box<
30 dyn FnMut() -> Pin<Box<dyn Future<Output = Result<PollStatus<T>, String>> + Send>> + Send + 'a,
31>;
32
33pub struct PollOptions<'a, T> {
35 pub interval_seconds: Option<u32>,
36 pub expires_in_seconds: Option<u32>,
37 pub poll: PollFn<'a, T>,
38 pub cancel: Option<CancellationToken>,
39}
40
41pub async fn poll_device_code_flow<T>(mut options: PollOptions<'_, T>) -> Result<T, String> {
43 let deadline = match options.expires_in_seconds {
44 Some(secs) => std::time::Instant::now() + std::time::Duration::from_secs(secs as u64),
45 None => std::time::Instant::now() + std::time::Duration::from_secs(300), };
47
48 let mut interval_ms = std::cmp::max(
49 MINIMUM_INTERVAL_MS,
50 (options
51 .interval_seconds
52 .unwrap_or(DEFAULT_POLL_INTERVAL_SECONDS as u32) as u64)
53 * 1000,
54 );
55
56 let mut slow_down_responses = 0;
57
58 while std::time::Instant::now() < deadline {
59 if let Some(ref cancel) = options.cancel
60 && cancel.is_cancelled()
61 {
62 return Err(CANCEL_MESSAGE.to_string());
63 }
64
65 let result = (options.poll)().await?;
66 match result {
67 PollStatus::Complete(value) => return Ok(value),
68 PollStatus::Failed(msg) => return Err(msg),
69 PollStatus::SlowDown => {
70 slow_down_responses += 1;
71 interval_ms = std::cmp::max(
72 MINIMUM_INTERVAL_MS,
73 interval_ms + SLOW_DOWN_INTERVAL_INCREMENT_MS,
74 );
75 }
76 PollStatus::Pending => {}
77 }
78
79 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
80 if remaining.as_millis() == 0 {
81 break;
82 }
83
84 let sleep_ms = std::cmp::min(interval_ms, remaining.as_millis() as u64);
85 sleep(Duration::from_millis(sleep_ms)).await;
86 }
87
88 Err(if slow_down_responses > 0 {
89 SLOW_DOWN_TIMEOUT_MESSAGE.to_string()
90 } else {
91 TIMEOUT_MESSAGE.to_string()
92 })
93}