Skip to main content

azure_identity_helpers/device_code/
mod.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//! Authorize using the device authorization grant flow
5//!
6//! This flow allows users to sign in to input-constrained devices such as a smart TV, `IoT` device, or printer.
7//!
8//! You can learn more about this authorization flow [here](https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-device-code).
9mod device_code_responses;
10
11use azure_core::{
12    error::{Error, ErrorKind},
13    http::{
14        ClientOptions, Context, Method, Pipeline, PipelineSendOptions, RawResponse, Request, Url,
15        headers::{self, content_type},
16    },
17    json::from_json,
18    sleep::sleep,
19};
20pub use device_code_responses::DeviceCodeAuthorization;
21use futures::stream::unfold;
22use serde::Deserialize;
23use std::pin::Pin;
24use time::Duration;
25use url::form_urlencoded;
26
27use crate::oauth_error::OAuthErrorResponse;
28
29/// Re-export of [`crate::oauth_error::OAuthErrorResponse`] under its
30/// previous name.
31///
32/// Use [`crate::oauth_error::OAuthErrorResponse`] directly; both refresh-token
33/// failures and device-code failures surface the same OAuth 2.0 error body
34/// (RFC 6749 §5.2), and the type now lives in its own module to make that
35/// sharing explicit.
36#[deprecated(
37    since = "0.2.0",
38    note = "use `azure_identity_helpers::oauth_error::OAuthErrorResponse` instead"
39)]
40pub type DeviceCodeErrorResponse = OAuthErrorResponse;
41
42/// Start the device authorization grant flow.
43///
44/// The user has only 15 minutes to sign in (the usual value for `expires_in`).
45///
46/// `pipeline` is the HTTP pipeline used to issue this request. The same
47/// pipeline (and the same `tenant_id` / `client_id`) must be passed to
48/// [`DeviceCodePhaseOneResponse::stream`] when polling the token endpoint
49/// afterwards. Callers running the flow from a long-lived credential should
50/// construct a single [`Pipeline`] once and reuse it to keep TLS sessions
51/// and HTTP connections pooled across the polling loop. A pipeline built
52/// with default options
53/// (`Pipeline::new(None, None, ClientOptions::default(), vec![], vec![], None)`)
54/// is sufficient unless custom retry, transport, or policy configuration is
55/// required.
56pub async fn start(
57    pipeline: &Pipeline,
58    tenant_id: &str,
59    client_id: &str,
60    scopes: &[&str],
61) -> azure_core::Result<DeviceCodePhaseOneResponse> {
62    let url = &format!("https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/devicecode");
63
64    let encoded = form_urlencoded::Serializer::new(String::new())
65        .append_pair("client_id", client_id)
66        .append_pair("scope", &scopes.join(" "))
67        .finish();
68
69    let rsp = post_form(pipeline, url, encoded).await?;
70    let rsp_status = rsp.status();
71    if !rsp_status.is_success() {
72        let rsp_body = rsp.into_body().into_string()?;
73        // The device-code endpoint returns a structured error body that
74        // matches `OAuthErrorResponse`. Wrap that as the source of
75        // the returned error so callers (and `format_aggregate_error`)
76        // see both the phase-one context (endpoint + status) and the
77        // AAD error/description/uri; fall back to embedding the raw
78        // body only when the response doesn't parse as the expected
79        // shape.
80        return Err(from_json::<_, OAuthErrorResponse>(&rsp_body).map_or_else(
81            |_| {
82                Error::with_message(
83                    ErrorKind::Credential,
84                    format!("device code endpoint returned status {rsp_status}: {rsp_body}"),
85                )
86            },
87            |parsed| {
88                Error::with_error(
89                    ErrorKind::Credential,
90                    parsed,
91                    format!("device code endpoint returned status {rsp_status}"),
92                )
93            },
94        ));
95    }
96    rsp.into_body().json()
97}
98
99/// Contains the required information to allow a user to sign in.
100///
101/// The struct mirrors only the JSON fields the credential needs downstream
102/// (the device code, polling interval, and pre-formatted user message);
103/// other fields returned by the AAD device-code endpoint are ignored. The
104/// HTTP pipeline, tenant id, and client id are passed back in to
105/// [`Self::stream`] when polling rather than stored on the struct, which
106/// keeps this type's serde layout faithful to the wire format.
107#[derive(Debug, Clone, Deserialize)]
108pub struct DeviceCodePhaseOneResponse {
109    device_code: String,
110    #[serde(deserialize_with = "deserialize_polling_interval")]
111    interval: u64,
112    message: String,
113}
114
115/// Deserialize the polling interval, rejecting values outside `0..=i64::MAX`
116/// and clamping to a minimum of one second.
117///
118/// The wire format reports `interval` as a JSON number; pulling it through
119/// `i64` rejects both negative intervals (semantically nonsensical for a
120/// sleep duration) and values too large to ever pass to
121/// `time::Duration::seconds`. Returning an error here surfaces server
122/// misbehavior at the deserialize boundary rather than silently clamping
123/// inside the polling loop.
124///
125/// `interval = 0` is technically in-range but would turn the polling loop
126/// into a tight busy-loop hammering the AAD token endpoint, so it is
127/// clamped to one second. RFC 8628 §3.5 expects a server-supplied interval
128/// suitable for polling.
129fn deserialize_polling_interval<'de, D>(deserializer: D) -> Result<u64, D::Error>
130where
131    D: serde::Deserializer<'de>,
132{
133    let secs = i64::deserialize(deserializer)?;
134    let interval = u64::try_from(secs).map_err(|_| {
135        serde::de::Error::custom(format!(
136            "device code polling interval must be non-negative, got {secs}"
137        ))
138    })?;
139    Ok(interval.max(1))
140}
141
142pub(crate) fn default_pipeline() -> Pipeline {
143    Pipeline::new(None, None, ClientOptions::default(), vec![], vec![], None)
144}
145
146impl DeviceCodePhaseOneResponse {
147    /// The message containing human readable instructions for the user.
148    #[must_use]
149    pub fn message(&self) -> &str {
150        &self.message
151    }
152
153    /// Polls the token endpoint while the user signs in.
154    ///
155    /// `pipeline`, `tenant_id`, and `client_id` must match what was passed
156    /// to [`start`].
157    ///
158    /// This will continue until either success or a terminal error is
159    /// returned. Per [RFC 8628 §3.5][rfc] the `authorization_pending` and
160    /// `slow_down` server errors keep the poll loop alive; `slow_down`
161    /// additionally requires the client to extend its polling interval by
162    /// 5 seconds.
163    ///
164    /// [rfc]: https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
165    #[must_use]
166    pub fn stream<'a>(
167        &'a self,
168        pipeline: &'a Pipeline,
169        tenant_id: &'a str,
170        client_id: &'a str,
171    ) -> Pin<Box<impl futures::Stream<Item = azure_core::Result<DeviceCodeAuthorization>> + 'a>>
172    {
173        #[derive(Debug, Clone, PartialEq, Eq)]
174        enum NextState {
175            /// Keep polling, sleeping `interval` seconds first.
176            Continue {
177                interval: u64,
178            },
179            Finish,
180        }
181
182        Box::pin(unfold(
183            NextState::Continue {
184                interval: self.interval,
185            },
186            move |state: NextState| async move {
187                let NextState::Continue { interval } = state else {
188                    return None;
189                };
190
191                let url =
192                    &format!("https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token");
193
194                // Throttle as specified by Azure. `slow_down` responses bump
195                // this by 5 seconds for the next iteration (see below).
196                // `time::Duration::seconds` takes an i64; clamp to i64::MAX
197                // for the (effectively impossible) overflow case.
198                let secs = i64::try_from(interval).unwrap_or(i64::MAX);
199                sleep(Duration::seconds(secs)).await;
200
201                let encoded = form_urlencoded::Serializer::new(String::new())
202                    .append_pair("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
203                    .append_pair("client_id", client_id)
204                    .append_pair("device_code", &self.device_code)
205                    .finish();
206
207                match post_form(pipeline, url, encoded).await {
208                    Ok(rsp) => {
209                        let rsp_status = rsp.status();
210                        let rsp_body = match rsp.into_body().into_string() {
211                            Ok(b) => b,
212                            Err(e) => return Some((Err(e), NextState::Finish)),
213                        };
214                        if rsp_status.is_success() {
215                            match from_json::<_, DeviceCodeAuthorization>(&rsp_body) {
216                                Ok(authorization) => Some((Ok(authorization), NextState::Finish)),
217                                Err(error) => Some((Err(error), NextState::Finish)),
218                            }
219                        } else {
220                            from_json::<_, OAuthErrorResponse>(&rsp_body).map_or_else(
221                                |_| {
222                                    Some((
223                                        Err(Error::with_message(
224                                            ErrorKind::Credential,
225                                            format!(
226                                                "device code token endpoint returned status {rsp_status}: {rsp_body}"
227                                            ),
228                                        )),
229                                        NextState::Finish,
230                                    ))
231                                },
232                                |error_rsp| {
233                                    let next_state = match error_rsp.error() {
234                                        "authorization_pending" => NextState::Continue { interval },
235                                        // Per RFC 8628 §3.5 the client must
236                                        // extend its polling interval by 5s.
237                                        "slow_down" => NextState::Continue {
238                                            interval: interval.saturating_add(5),
239                                        },
240                                        _ => NextState::Finish,
241                                    };
242                                    Some((
243                                        Err(Error::new(ErrorKind::Credential, error_rsp)),
244                                        next_state,
245                                    ))
246                                },
247                            )
248                        }
249                    }
250                    Err(error) => Some((Err(error), NextState::Finish)),
251                }
252            },
253        ))
254    }
255}
256
257async fn post_form(
258    pipeline: &Pipeline,
259    url: &str,
260    form_body: String,
261) -> azure_core::Result<RawResponse> {
262    let url = Url::parse(url)?;
263    let mut req = Request::new(url, Method::Post);
264    req.insert_header(
265        headers::CONTENT_TYPE,
266        content_type::APPLICATION_X_WWW_FORM_URLENCODED,
267    );
268    req.set_body(form_body);
269
270    // The device code token endpoint signals normal flow states (notably
271    // `authorization_pending` and `slow_down`) via 4xx responses with a
272    // structured body. Skip the pipeline's automatic success check so that
273    // we can inspect those bodies ourselves instead of having them turned
274    // into opaque transport errors.
275    pipeline
276        .send(
277            &Context::new(),
278            &mut req,
279            Some(PipelineSendOptions {
280                skip_checks: true,
281                ..PipelineSendOptions::default()
282            }),
283        )
284        .await
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    fn require_send<T: Send>(_t: T) {}
292
293    #[test]
294    fn ensure_that_start_is_send() {
295        let pipeline = default_pipeline();
296        require_send(start(&pipeline, "UNUSED", "UNUSED", &[]));
297    }
298
299    #[test]
300    fn interval_deserializes_when_in_range() -> azure_core::Result<()> {
301        let body = r#"{
302            "device_code": "dc",
303            "interval": 5,
304            "message": "go enter the code"
305        }"#;
306        let parsed: DeviceCodePhaseOneResponse = azure_core::json::from_json(body)?;
307        assert_eq!(parsed.interval, 5);
308        Ok(())
309    }
310
311    #[test]
312    fn interval_rejects_negative_values() {
313        let body = r#"{
314            "device_code": "dc",
315            "interval": -1,
316            "message": "go enter the code"
317        }"#;
318        let parsed = azure_core::json::from_json::<&str, DeviceCodePhaseOneResponse>(body);
319        assert!(
320            parsed.is_err(),
321            "negative interval must be rejected at deserialize time",
322        );
323    }
324
325    #[test]
326    fn interval_rejects_values_larger_than_i64_max() {
327        // i64::MAX = 9_223_372_036_854_775_807; add 1 to overflow.
328        let body = r#"{
329            "device_code": "dc",
330            "interval": 9223372036854775808,
331            "message": "go enter the code"
332        }"#;
333        let parsed = azure_core::json::from_json::<&str, DeviceCodePhaseOneResponse>(body);
334        assert!(
335            parsed.is_err(),
336            "interval larger than i64::MAX must be rejected at deserialize time",
337        );
338    }
339
340    #[test]
341    fn interval_clamps_zero_to_one_second() -> azure_core::Result<()> {
342        // A server-supplied interval of 0 would turn the polling loop into a
343        // tight loop on the AAD token endpoint; clamp to a one-second minimum.
344        let body = r#"{
345            "device_code": "dc",
346            "interval": 0,
347            "message": "go enter the code"
348        }"#;
349        let parsed: DeviceCodePhaseOneResponse = azure_core::json::from_json(body)?;
350        assert_eq!(parsed.interval, 1);
351        Ok(())
352    }
353}