oauth_device_flows/
device_flow.rs

1//! Main device flow implementation
2
3use crate::{
4    config::DeviceFlowConfig,
5    error::{DeviceFlowError, Result},
6    provider::Provider,
7    token::TokenManager,
8    types::{
9        AuthorizationResponse, DeviceAuthorizationRequest, DeviceTokenRequest, ErrorResponse,
10        TokenResponse,
11    },
12};
13use reqwest::Client;
14use secrecy::ExposeSecret;
15use std::time::Duration;
16use time::OffsetDateTime;
17use tokio::time::sleep;
18use url::Url;
19
20/// Main device flow implementation
21#[derive(Debug, Clone)]
22pub struct DeviceFlow {
23    /// OAuth provider
24    provider: Provider,
25
26    /// Configuration
27    config: DeviceFlowConfig,
28
29    /// HTTP client
30    client: Client,
31
32    /// Current authorization response (if any)
33    auth_response: Option<AuthorizationResponse>,
34}
35
36impl DeviceFlow {
37    /// Create a new device flow instance
38    pub fn new(provider: Provider, config: DeviceFlowConfig) -> Result<Self> {
39        // Validate configuration
40        config.validate(provider)?;
41
42        let client = Self::build_client(&config)?;
43
44        Ok(Self {
45            provider,
46            config,
47            client,
48            auth_response: None,
49        })
50    }
51
52    /// Initialize the device authorization flow
53    pub async fn initialize(&mut self) -> Result<&AuthorizationResponse> {
54        let auth_endpoint = if let Some(ref config) = self.config.generic_provider_config {
55            config.device_authorization_endpoint.clone()
56        } else {
57            Url::parse(self.provider.device_authorization_endpoint()).map_err(|e| {
58                DeviceFlowError::other(format!("Invalid authorization endpoint: {e}"))
59            })?
60        };
61
62        let scopes = self.config.effective_scopes(self.provider);
63        let scope_string = if scopes.is_empty() {
64            None
65        } else {
66            Some(scopes.join(" "))
67        };
68
69        let request = DeviceAuthorizationRequest {
70            client_id: self.config.client_id.clone(),
71            scope: scope_string,
72        };
73
74        let mut req_builder = self.client.post(auth_endpoint).form(&request);
75
76        // Add client secret if required
77        if let Some(ref client_secret) = self.config.client_secret {
78            req_builder = req_builder.form(&[("client_secret", client_secret.expose_secret())]);
79        }
80
81        // Add provider-specific headers
82        for (key, value) in self.provider.headers() {
83            req_builder = req_builder.header(key, value);
84        }
85
86        // Add additional headers
87        for (key, value) in &self.config.additional_headers {
88            req_builder = req_builder.header(key, value);
89        }
90
91        let response = req_builder.send().await?;
92
93        if !response.status().is_success() {
94            let error_response: ErrorResponse = response.json().await?;
95            return Err(DeviceFlowError::oauth_error(
96                error_response.error,
97                error_response.error_description.unwrap_or_default(),
98            ));
99        }
100
101        let auth_response: AuthorizationResponse = response.json().await?;
102        self.auth_response = Some(auth_response);
103
104        Ok(self.auth_response.as_ref().unwrap())
105    }
106
107    /// Poll for the token
108    pub async fn poll_for_token(&self) -> Result<TokenResponse> {
109        let auth_response = self
110            .auth_response
111            .as_ref()
112            .ok_or_else(|| DeviceFlowError::other("Must call initialize() first"))?;
113
114        let token_endpoint = if let Some(ref config) = self.config.generic_provider_config {
115            config.token_endpoint.clone()
116        } else {
117            Url::parse(self.provider.token_endpoint())
118                .map_err(|e| DeviceFlowError::other(format!("Invalid token endpoint: {e}")))?
119        };
120
121        let request = DeviceTokenRequest {
122            grant_type: "urn:ietf:params:oauth:grant-type:device_code".to_string(),
123            device_code: auth_response.device_code().to_string(),
124            client_id: self.config.client_id.clone(),
125        };
126
127        let mut poll_interval = self.config.effective_poll_interval(self.provider);
128        let mut attempts = 0;
129
130        loop {
131            if attempts >= self.config.max_attempts {
132                return Err(DeviceFlowError::MaxAttemptsExceeded(
133                    self.config.max_attempts,
134                ));
135            }
136
137            attempts += 1;
138
139            // Wait before making the request (except for the first attempt)
140            if attempts > 1 {
141                sleep(poll_interval).await;
142            }
143
144            let mut req_builder = self.client.post(token_endpoint.clone()).form(&request);
145
146            // Add client secret if required
147            if let Some(ref client_secret) = self.config.client_secret {
148                req_builder = req_builder.form(&[("client_secret", client_secret.expose_secret())]);
149            }
150
151            // Add provider-specific headers
152            for (key, value) in self.provider.headers() {
153                req_builder = req_builder.header(key, value);
154            }
155
156            // Add additional headers
157            for (key, value) in &self.config.additional_headers {
158                req_builder = req_builder.header(key, value);
159            }
160
161            let response = req_builder.send().await?;
162
163            if response.status().is_success() {
164                let mut token_response: TokenResponse = response.json().await?;
165
166                // Set the issued_at timestamp
167                token_response.issued_at = OffsetDateTime::now_utc();
168
169                return Ok(token_response);
170            }
171
172            // Handle error responses
173            let error_response: ErrorResponse = response.json().await?;
174
175            match error_response.error.as_str() {
176                "authorization_pending" => {
177                    // Continue polling
178                    continue;
179                }
180                "slow_down" => {
181                    // Increase polling interval
182                    poll_interval = Duration::from_secs(
183                        (poll_interval.as_secs() as f64 * self.config.backoff_multiplier) as u64,
184                    )
185                    .min(self.config.max_poll_interval);
186                    continue;
187                }
188                "access_denied" => {
189                    return Err(DeviceFlowError::AuthorizationDenied);
190                }
191                "expired_token" => {
192                    return Err(DeviceFlowError::ExpiredToken);
193                }
194                _ => {
195                    return Err(DeviceFlowError::oauth_error(
196                        error_response.error,
197                        error_response.error_description.unwrap_or_default(),
198                    ));
199                }
200            }
201        }
202    }
203
204    /// Run the complete device flow and return a token manager
205    pub async fn run(&mut self) -> Result<TokenManager> {
206        let _auth_response = self.initialize().await?;
207        let token_response = self.poll_for_token().await?;
208
209        TokenManager::new(token_response, self.provider, self.config.clone())
210    }
211
212    /// Get the current authorization response
213    pub fn authorization_response(&self) -> Option<&AuthorizationResponse> {
214        self.auth_response.as_ref()
215    }
216
217    /// Get the provider
218    pub fn provider(&self) -> Provider {
219        self.provider
220    }
221
222    /// Get the configuration
223    pub fn config(&self) -> &DeviceFlowConfig {
224        &self.config
225    }
226
227    /// Check if the device flow has been initialized
228    pub fn is_initialized(&self) -> bool {
229        self.auth_response.is_some()
230    }
231
232    /// Reset the device flow (clear authorization response)
233    pub fn reset(&mut self) {
234        self.auth_response = None;
235    }
236
237    /// Create a new device flow for a different provider with the same config
238    pub fn with_provider(self, provider: Provider) -> Result<Self> {
239        Self::new(provider, self.config)
240    }
241
242    /// Update the configuration
243    pub fn with_config(mut self, config: DeviceFlowConfig) -> Result<Self> {
244        config.validate(self.provider)?;
245        self.client = Self::build_client(&config)?;
246        self.config = config;
247        Ok(self)
248    }
249
250    /// Build HTTP client with configuration
251    fn build_client(config: &DeviceFlowConfig) -> Result<Client> {
252        let mut client_builder = Client::builder().timeout(config.request_timeout);
253
254        if let Some(ref user_agent) = config.user_agent {
255            client_builder = client_builder.user_agent(user_agent);
256        }
257
258        client_builder.build().map_err(DeviceFlowError::from)
259    }
260}
261
262/// Convenience function to run a complete device flow
263pub async fn run_device_flow(provider: Provider, config: DeviceFlowConfig) -> Result<TokenManager> {
264    let mut device_flow = DeviceFlow::new(provider, config)?;
265    device_flow.run().await
266}
267
268/// Convenience function to run a device flow with a callback for user interaction
269pub async fn run_device_flow_with_callback<F>(
270    provider: Provider,
271    config: DeviceFlowConfig,
272    callback: F,
273) -> Result<TokenManager>
274where
275    F: FnOnce(&AuthorizationResponse) -> Result<()>,
276{
277    let mut device_flow = DeviceFlow::new(provider, config)?;
278    let auth_response = device_flow.initialize().await?;
279
280    // Call the user-provided callback with the authorization response
281    callback(auth_response)?;
282
283    let token_response = device_flow.poll_for_token().await?;
284    TokenManager::new(token_response, provider, device_flow.config)
285}