oauth_device_flows/
device_flow.rs1use crate::{
4 config::DeviceFlowConfig,
5 error::{DeviceFlowError, Result},
6 provider::Provider,
7 token::TokenManager,
8 types::{
9 AuthorizationResponse, DeviceAuthorizationRequest, DeviceTokenRequest, ErrorResponse,
10 TokenResponse,
11 },
12};
13use reqwest::Client;
14use secrecy::ExposeSecret;
15use std::time::Duration;
16use time::OffsetDateTime;
17use tokio::time::sleep;
18use url::Url;
19
20#[derive(Debug, Clone)]
22pub struct DeviceFlow {
23 provider: Provider,
25
26 config: DeviceFlowConfig,
28
29 client: Client,
31
32 auth_response: Option<AuthorizationResponse>,
34}
35
36impl DeviceFlow {
37 pub fn new(provider: Provider, config: DeviceFlowConfig) -> Result<Self> {
39 config.validate(provider)?;
41
42 let client = Self::build_client(&config)?;
43
44 Ok(Self {
45 provider,
46 config,
47 client,
48 auth_response: None,
49 })
50 }
51
52 pub async fn initialize(&mut self) -> Result<&AuthorizationResponse> {
54 let auth_endpoint = if let Some(ref config) = self.config.generic_provider_config {
55 config.device_authorization_endpoint.clone()
56 } else {
57 Url::parse(self.provider.device_authorization_endpoint()).map_err(|e| {
58 DeviceFlowError::other(format!("Invalid authorization endpoint: {e}"))
59 })?
60 };
61
62 let scopes = self.config.effective_scopes(self.provider);
63 let scope_string = if scopes.is_empty() {
64 None
65 } else {
66 Some(scopes.join(" "))
67 };
68
69 let request = DeviceAuthorizationRequest {
70 client_id: self.config.client_id.clone(),
71 scope: scope_string,
72 };
73
74 let mut req_builder = self.client.post(auth_endpoint).form(&request);
75
76 if let Some(ref client_secret) = self.config.client_secret {
78 req_builder = req_builder.form(&[("client_secret", client_secret.expose_secret())]);
79 }
80
81 for (key, value) in self.provider.headers() {
83 req_builder = req_builder.header(key, value);
84 }
85
86 for (key, value) in &self.config.additional_headers {
88 req_builder = req_builder.header(key, value);
89 }
90
91 let response = req_builder.send().await?;
92
93 if !response.status().is_success() {
94 let error_response: ErrorResponse = response.json().await?;
95 return Err(DeviceFlowError::oauth_error(
96 error_response.error,
97 error_response.error_description.unwrap_or_default(),
98 ));
99 }
100
101 let auth_response: AuthorizationResponse = response.json().await?;
102 self.auth_response = Some(auth_response);
103
104 Ok(self.auth_response.as_ref().unwrap())
105 }
106
107 pub async fn poll_for_token(&self) -> Result<TokenResponse> {
109 let auth_response = self
110 .auth_response
111 .as_ref()
112 .ok_or_else(|| DeviceFlowError::other("Must call initialize() first"))?;
113
114 let token_endpoint = if let Some(ref config) = self.config.generic_provider_config {
115 config.token_endpoint.clone()
116 } else {
117 Url::parse(self.provider.token_endpoint())
118 .map_err(|e| DeviceFlowError::other(format!("Invalid token endpoint: {e}")))?
119 };
120
121 let request = DeviceTokenRequest {
122 grant_type: "urn:ietf:params:oauth:grant-type:device_code".to_string(),
123 device_code: auth_response.device_code().to_string(),
124 client_id: self.config.client_id.clone(),
125 };
126
127 let mut poll_interval = self.config.effective_poll_interval(self.provider);
128 let mut attempts = 0;
129
130 loop {
131 if attempts >= self.config.max_attempts {
132 return Err(DeviceFlowError::MaxAttemptsExceeded(
133 self.config.max_attempts,
134 ));
135 }
136
137 attempts += 1;
138
139 if attempts > 1 {
141 sleep(poll_interval).await;
142 }
143
144 let mut req_builder = self.client.post(token_endpoint.clone()).form(&request);
145
146 if let Some(ref client_secret) = self.config.client_secret {
148 req_builder = req_builder.form(&[("client_secret", client_secret.expose_secret())]);
149 }
150
151 for (key, value) in self.provider.headers() {
153 req_builder = req_builder.header(key, value);
154 }
155
156 for (key, value) in &self.config.additional_headers {
158 req_builder = req_builder.header(key, value);
159 }
160
161 let response = req_builder.send().await?;
162
163 if response.status().is_success() {
164 let mut token_response: TokenResponse = response.json().await?;
165
166 token_response.issued_at = OffsetDateTime::now_utc();
168
169 return Ok(token_response);
170 }
171
172 let error_response: ErrorResponse = response.json().await?;
174
175 match error_response.error.as_str() {
176 "authorization_pending" => {
177 continue;
179 }
180 "slow_down" => {
181 poll_interval = Duration::from_secs(
183 (poll_interval.as_secs() as f64 * self.config.backoff_multiplier) as u64,
184 )
185 .min(self.config.max_poll_interval);
186 continue;
187 }
188 "access_denied" => {
189 return Err(DeviceFlowError::AuthorizationDenied);
190 }
191 "expired_token" => {
192 return Err(DeviceFlowError::ExpiredToken);
193 }
194 _ => {
195 return Err(DeviceFlowError::oauth_error(
196 error_response.error,
197 error_response.error_description.unwrap_or_default(),
198 ));
199 }
200 }
201 }
202 }
203
204 pub async fn run(&mut self) -> Result<TokenManager> {
206 let _auth_response = self.initialize().await?;
207 let token_response = self.poll_for_token().await?;
208
209 TokenManager::new(token_response, self.provider, self.config.clone())
210 }
211
212 pub fn authorization_response(&self) -> Option<&AuthorizationResponse> {
214 self.auth_response.as_ref()
215 }
216
217 pub fn provider(&self) -> Provider {
219 self.provider
220 }
221
222 pub fn config(&self) -> &DeviceFlowConfig {
224 &self.config
225 }
226
227 pub fn is_initialized(&self) -> bool {
229 self.auth_response.is_some()
230 }
231
232 pub fn reset(&mut self) {
234 self.auth_response = None;
235 }
236
237 pub fn with_provider(self, provider: Provider) -> Result<Self> {
239 Self::new(provider, self.config)
240 }
241
242 pub fn with_config(mut self, config: DeviceFlowConfig) -> Result<Self> {
244 config.validate(self.provider)?;
245 self.client = Self::build_client(&config)?;
246 self.config = config;
247 Ok(self)
248 }
249
250 fn build_client(config: &DeviceFlowConfig) -> Result<Client> {
252 let mut client_builder = Client::builder().timeout(config.request_timeout);
253
254 if let Some(ref user_agent) = config.user_agent {
255 client_builder = client_builder.user_agent(user_agent);
256 }
257
258 client_builder.build().map_err(DeviceFlowError::from)
259 }
260}
261
262pub async fn run_device_flow(provider: Provider, config: DeviceFlowConfig) -> Result<TokenManager> {
264 let mut device_flow = DeviceFlow::new(provider, config)?;
265 device_flow.run().await
266}
267
268pub async fn run_device_flow_with_callback<F>(
270 provider: Provider,
271 config: DeviceFlowConfig,
272 callback: F,
273) -> Result<TokenManager>
274where
275 F: FnOnce(&AuthorizationResponse) -> Result<()>,
276{
277 let mut device_flow = DeviceFlow::new(provider, config)?;
278 let auth_response = device_flow.initialize().await?;
279
280 callback(auth_response)?;
282
283 let token_response = device_flow.poll_for_token().await?;
284 TokenManager::new(token_response, provider, device_flow.config)
285}