azure_identity_helpers/device_code/mod.rs
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//! Authorize using the device authorization grant flow
5//!
6//! This flow allows users to sign in to input-constrained devices such as a smart TV, `IoT` device, or printer.
7//!
8//! You can learn more about this authorization flow [here](https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-device-code).
9mod device_code_responses;
10
11use azure_core::{
12 error::{Error, ErrorKind},
13 http::{
14 ClientOptions, Context, Method, Pipeline, PipelineSendOptions, RawResponse, Request, Url,
15 headers::{self, content_type},
16 },
17 json::from_json,
18 sleep::sleep,
19};
20pub use device_code_responses::DeviceCodeAuthorization;
21use futures::stream::unfold;
22use serde::Deserialize;
23use std::pin::Pin;
24use time::Duration;
25use url::form_urlencoded;
26
27use crate::oauth_error::OAuthErrorResponse;
28
29/// Re-export of [`crate::oauth_error::OAuthErrorResponse`] under its
30/// previous name.
31///
32/// Use [`crate::oauth_error::OAuthErrorResponse`] directly; both refresh-token
33/// failures and device-code failures surface the same OAuth 2.0 error body
34/// (RFC 6749 §5.2), and the type now lives in its own module to make that
35/// sharing explicit.
36#[deprecated(
37 since = "0.2.0",
38 note = "use `azure_identity_helpers::oauth_error::OAuthErrorResponse` instead"
39)]
40pub type DeviceCodeErrorResponse = OAuthErrorResponse;
41
42/// Start the device authorization grant flow.
43///
44/// The user has only 15 minutes to sign in (the usual value for `expires_in`).
45///
46/// `pipeline` is the HTTP pipeline used to issue this request. The same
47/// pipeline (and the same `tenant_id` / `client_id`) must be passed to
48/// [`DeviceCodePhaseOneResponse::stream`] when polling the token endpoint
49/// afterwards. Callers running the flow from a long-lived credential should
50/// construct a single [`Pipeline`] once and reuse it to keep TLS sessions
51/// and HTTP connections pooled across the polling loop. A pipeline built
52/// with default options
53/// (`Pipeline::new(None, None, ClientOptions::default(), vec![], vec![], None)`)
54/// is sufficient unless custom retry, transport, or policy configuration is
55/// required.
56pub async fn start(
57 pipeline: &Pipeline,
58 tenant_id: &str,
59 client_id: &str,
60 scopes: &[&str],
61) -> azure_core::Result<DeviceCodePhaseOneResponse> {
62 let url = &format!("https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/devicecode");
63
64 let encoded = form_urlencoded::Serializer::new(String::new())
65 .append_pair("client_id", client_id)
66 .append_pair("scope", &scopes.join(" "))
67 .finish();
68
69 let rsp = post_form(pipeline, url, encoded).await?;
70 let rsp_status = rsp.status();
71 if !rsp_status.is_success() {
72 let rsp_body = rsp.into_body().into_string()?;
73 // The device-code endpoint returns a structured error body that
74 // matches `OAuthErrorResponse`. Wrap that as the source of
75 // the returned error so callers (and `format_aggregate_error`)
76 // see both the phase-one context (endpoint + status) and the
77 // AAD error/description/uri; fall back to embedding the raw
78 // body only when the response doesn't parse as the expected
79 // shape.
80 return Err(from_json::<_, OAuthErrorResponse>(&rsp_body).map_or_else(
81 |_| {
82 Error::with_message(
83 ErrorKind::Credential,
84 format!("device code endpoint returned status {rsp_status}: {rsp_body}"),
85 )
86 },
87 |parsed| {
88 Error::with_error(
89 ErrorKind::Credential,
90 parsed,
91 format!("device code endpoint returned status {rsp_status}"),
92 )
93 },
94 ));
95 }
96 rsp.into_body().json()
97}
98
99/// Contains the required information to allow a user to sign in.
100///
101/// The struct mirrors only the JSON fields the credential needs downstream
102/// (the device code, polling interval, and pre-formatted user message);
103/// other fields returned by the AAD device-code endpoint are ignored. The
104/// HTTP pipeline, tenant id, and client id are passed back in to
105/// [`Self::stream`] when polling rather than stored on the struct, which
106/// keeps this type's serde layout faithful to the wire format.
107#[derive(Debug, Clone, Deserialize)]
108pub struct DeviceCodePhaseOneResponse {
109 device_code: String,
110 #[serde(deserialize_with = "deserialize_polling_interval")]
111 interval: u64,
112 message: String,
113}
114
115/// Deserialize the polling interval, rejecting values outside `0..=i64::MAX`
116/// and clamping to a minimum of one second.
117///
118/// The wire format reports `interval` as a JSON number; pulling it through
119/// `i64` rejects both negative intervals (semantically nonsensical for a
120/// sleep duration) and values too large to ever pass to
121/// `time::Duration::seconds`. Returning an error here surfaces server
122/// misbehavior at the deserialize boundary rather than silently clamping
123/// inside the polling loop.
124///
125/// `interval = 0` is technically in-range but would turn the polling loop
126/// into a tight busy-loop hammering the AAD token endpoint, so it is
127/// clamped to one second. RFC 8628 §3.5 expects a server-supplied interval
128/// suitable for polling.
129fn deserialize_polling_interval<'de, D>(deserializer: D) -> Result<u64, D::Error>
130where
131 D: serde::Deserializer<'de>,
132{
133 let secs = i64::deserialize(deserializer)?;
134 let interval = u64::try_from(secs).map_err(|_| {
135 serde::de::Error::custom(format!(
136 "device code polling interval must be non-negative, got {secs}"
137 ))
138 })?;
139 Ok(interval.max(1))
140}
141
142pub(crate) fn default_pipeline() -> Pipeline {
143 Pipeline::new(None, None, ClientOptions::default(), vec![], vec![], None)
144}
145
146impl DeviceCodePhaseOneResponse {
147 /// The message containing human readable instructions for the user.
148 #[must_use]
149 pub fn message(&self) -> &str {
150 &self.message
151 }
152
153 /// Polls the token endpoint while the user signs in.
154 ///
155 /// `pipeline`, `tenant_id`, and `client_id` must match what was passed
156 /// to [`start`].
157 ///
158 /// This will continue until either success or a terminal error is
159 /// returned. Per [RFC 8628 §3.5][rfc] the `authorization_pending` and
160 /// `slow_down` server errors keep the poll loop alive; `slow_down`
161 /// additionally requires the client to extend its polling interval by
162 /// 5 seconds.
163 ///
164 /// [rfc]: https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
165 #[must_use]
166 pub fn stream<'a>(
167 &'a self,
168 pipeline: &'a Pipeline,
169 tenant_id: &'a str,
170 client_id: &'a str,
171 ) -> Pin<Box<impl futures::Stream<Item = azure_core::Result<DeviceCodeAuthorization>> + 'a>>
172 {
173 #[derive(Debug, Clone, PartialEq, Eq)]
174 enum NextState {
175 /// Keep polling, sleeping `interval` seconds first.
176 Continue {
177 interval: u64,
178 },
179 Finish,
180 }
181
182 Box::pin(unfold(
183 NextState::Continue {
184 interval: self.interval,
185 },
186 move |state: NextState| async move {
187 let NextState::Continue { interval } = state else {
188 return None;
189 };
190
191 let url =
192 &format!("https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token");
193
194 // Throttle as specified by Azure. `slow_down` responses bump
195 // this by 5 seconds for the next iteration (see below).
196 // `time::Duration::seconds` takes an i64; clamp to i64::MAX
197 // for the (effectively impossible) overflow case.
198 let secs = i64::try_from(interval).unwrap_or(i64::MAX);
199 sleep(Duration::seconds(secs)).await;
200
201 let encoded = form_urlencoded::Serializer::new(String::new())
202 .append_pair("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
203 .append_pair("client_id", client_id)
204 .append_pair("device_code", &self.device_code)
205 .finish();
206
207 match post_form(pipeline, url, encoded).await {
208 Ok(rsp) => {
209 let rsp_status = rsp.status();
210 let rsp_body = match rsp.into_body().into_string() {
211 Ok(b) => b,
212 Err(e) => return Some((Err(e), NextState::Finish)),
213 };
214 if rsp_status.is_success() {
215 match from_json::<_, DeviceCodeAuthorization>(&rsp_body) {
216 Ok(authorization) => Some((Ok(authorization), NextState::Finish)),
217 Err(error) => Some((Err(error), NextState::Finish)),
218 }
219 } else {
220 from_json::<_, OAuthErrorResponse>(&rsp_body).map_or_else(
221 |_| {
222 Some((
223 Err(Error::with_message(
224 ErrorKind::Credential,
225 format!(
226 "device code token endpoint returned status {rsp_status}: {rsp_body}"
227 ),
228 )),
229 NextState::Finish,
230 ))
231 },
232 |error_rsp| {
233 let next_state = match error_rsp.error() {
234 "authorization_pending" => NextState::Continue { interval },
235 // Per RFC 8628 §3.5 the client must
236 // extend its polling interval by 5s.
237 "slow_down" => NextState::Continue {
238 interval: interval.saturating_add(5),
239 },
240 _ => NextState::Finish,
241 };
242 Some((
243 Err(Error::new(ErrorKind::Credential, error_rsp)),
244 next_state,
245 ))
246 },
247 )
248 }
249 }
250 Err(error) => Some((Err(error), NextState::Finish)),
251 }
252 },
253 ))
254 }
255}
256
257async fn post_form(
258 pipeline: &Pipeline,
259 url: &str,
260 form_body: String,
261) -> azure_core::Result<RawResponse> {
262 let url = Url::parse(url)?;
263 let mut req = Request::new(url, Method::Post);
264 req.insert_header(
265 headers::CONTENT_TYPE,
266 content_type::APPLICATION_X_WWW_FORM_URLENCODED,
267 );
268 req.set_body(form_body);
269
270 // The device code token endpoint signals normal flow states (notably
271 // `authorization_pending` and `slow_down`) via 4xx responses with a
272 // structured body. Skip the pipeline's automatic success check so that
273 // we can inspect those bodies ourselves instead of having them turned
274 // into opaque transport errors.
275 pipeline
276 .send(
277 &Context::new(),
278 &mut req,
279 Some(PipelineSendOptions {
280 skip_checks: true,
281 ..PipelineSendOptions::default()
282 }),
283 )
284 .await
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 fn require_send<T: Send>(_t: T) {}
292
293 #[test]
294 fn ensure_that_start_is_send() {
295 let pipeline = default_pipeline();
296 require_send(start(&pipeline, "UNUSED", "UNUSED", &[]));
297 }
298
299 #[test]
300 fn interval_deserializes_when_in_range() -> azure_core::Result<()> {
301 let body = r#"{
302 "device_code": "dc",
303 "interval": 5,
304 "message": "go enter the code"
305 }"#;
306 let parsed: DeviceCodePhaseOneResponse = azure_core::json::from_json(body)?;
307 assert_eq!(parsed.interval, 5);
308 Ok(())
309 }
310
311 #[test]
312 fn interval_rejects_negative_values() {
313 let body = r#"{
314 "device_code": "dc",
315 "interval": -1,
316 "message": "go enter the code"
317 }"#;
318 let parsed = azure_core::json::from_json::<&str, DeviceCodePhaseOneResponse>(body);
319 assert!(
320 parsed.is_err(),
321 "negative interval must be rejected at deserialize time",
322 );
323 }
324
325 #[test]
326 fn interval_rejects_values_larger_than_i64_max() {
327 // i64::MAX = 9_223_372_036_854_775_807; add 1 to overflow.
328 let body = r#"{
329 "device_code": "dc",
330 "interval": 9223372036854775808,
331 "message": "go enter the code"
332 }"#;
333 let parsed = azure_core::json::from_json::<&str, DeviceCodePhaseOneResponse>(body);
334 assert!(
335 parsed.is_err(),
336 "interval larger than i64::MAX must be rejected at deserialize time",
337 );
338 }
339
340 #[test]
341 fn interval_clamps_zero_to_one_second() -> azure_core::Result<()> {
342 // A server-supplied interval of 0 would turn the polling loop into a
343 // tight loop on the AAD token endpoint; clamp to a one-second minimum.
344 let body = r#"{
345 "device_code": "dc",
346 "interval": 0,
347 "message": "go enter the code"
348 }"#;
349 let parsed: DeviceCodePhaseOneResponse = azure_core::json::from_json(body)?;
350 assert_eq!(parsed.interval, 1);
351 Ok(())
352 }
353}