Skip to main content

authly_flow/
device_flow.rs

1use authly_core::{AuthError, OAuthToken};
2use serde::{Deserialize, Serialize};
3use std::time::Duration;
4use tokio::time::sleep;
5
6/// Represents the response from the device authorization endpoint.
7/// Defined in RFC 8628 Section 3.2.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct DeviceAuthorizationResponse {
10    /// The device verification code.
11    pub device_code: String,
12    /// The end-user verification code.
13    pub user_code: String,
14    /// The end-user verification URI on the authorization server.
15    pub verification_uri: String,
16    /// A verification URI that includes the "user_code" (or other information)
17    /// to optimize the end-user interaction.
18    pub verification_uri_complete: Option<String>,
19    /// The lifetime in seconds of the "device_code" and "user_code".
20    pub expires_in: u64,
21    /// The minimum amount of time in seconds that the client SHOULD wait
22    /// between polling requests to the token endpoint.
23    pub interval: Option<u64>,
24}
25
26/// Orchestrates the Device Authorization Flow (RFC 8628).
27pub struct DeviceFlow {
28    client_id: String,
29    device_authorization_url: String,
30    token_url: String,
31    http_client: reqwest::Client,
32}
33
34impl DeviceFlow {
35    /// Creates a new `DeviceFlow` instance.
36    pub fn new(client_id: String, device_authorization_url: String, token_url: String) -> Self {
37        Self {
38            client_id,
39            device_authorization_url,
40            token_url,
41            http_client: reqwest::Client::new(),
42        }
43    }
44
45    /// Initiates the device authorization request.
46    /// Returns a `DeviceAuthorizationResponse` which contains the codes and URIs
47    /// to be displayed to the user.
48    pub async fn initiate_device_authorization(
49        &self,
50        scopes: &[&str],
51    ) -> Result<DeviceAuthorizationResponse, AuthError> {
52        let scope_param = scopes.join(" ");
53
54        let response = self
55            .http_client
56            .post(&self.device_authorization_url)
57            .header("Accept", "application/json")
58            .form(&[("client_id", &self.client_id), ("scope", &scope_param)])
59            .send()
60            .await
61            .map_err(|_| AuthError::Network)?;
62
63        if !response.status().is_success() {
64            let error_text = response.text().await.unwrap_or_default();
65            return Err(AuthError::Provider(format!(
66                "Device authorization request failed: {}",
67                error_text
68            )));
69        }
70
71        response
72            .json::<DeviceAuthorizationResponse>()
73            .await
74            .map_err(|e| {
75                AuthError::Provider(format!(
76                    "Failed to parse device authorization response: {}",
77                    e
78                ))
79            })
80    }
81
82    /// Polls the token endpoint until an access token is granted or an error occurs.
83    /// This function respects the `interval` specified by the provider and handles
84    /// common device flow errors like `authorization_pending` and `slow_down`.
85    pub async fn poll_for_token(
86        &self,
87        device_code: &str,
88        interval: Option<u64>,
89    ) -> Result<OAuthToken, AuthError> {
90        let mut current_interval = interval.unwrap_or(5);
91
92        loop {
93            let response = self
94                .http_client
95                .post(&self.token_url)
96                .header("Accept", "application/json")
97                .form(&[
98                    ("client_id", &self.client_id),
99                    ("device_code", &device_code.to_string()),
100                    (
101                        "grant_type",
102                        &"urn:ietf:params:oauth:grant-type:device_code".to_string(),
103                    ),
104                ])
105                .send()
106                .await
107                .map_err(|_| AuthError::Network)?;
108
109            let status = response.status();
110
111            if status.is_success() {
112                return response.json::<OAuthToken>().await.map_err(|e| {
113                    AuthError::Provider(format!("Failed to parse token response: {}", e))
114                });
115            } else {
116                let error_resp: serde_json::Value = response
117                    .json()
118                    .await
119                    .map_err(|_| AuthError::Provider("Failed to parse error response".into()))?;
120
121                let error = error_resp["error"].as_str().unwrap_or("unknown_error");
122
123                match error {
124                    "authorization_pending" => {
125                        // Keep polling
126                    }
127                    "slow_down" => {
128                        current_interval += 5;
129                    }
130                    "access_denied" => {
131                        return Err(AuthError::Provider("Access denied by user".into()));
132                    }
133                    "expired_token" => {
134                        return Err(AuthError::Provider("Device code expired".into()));
135                    }
136                    _ => {
137                        return Err(AuthError::Provider(format!(
138                            "Token polling failed: {}",
139                            error
140                        )));
141                    }
142                }
143            }
144
145            sleep(Duration::from_secs(current_interval)).await;
146        }
147    }
148}