azure_identity/
managed_identity_credential.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4use crate::{
5    authentication_error, env::Env, AppServiceManagedIdentityCredential, ImdsId,
6    VirtualMachineManagedIdentityCredential,
7};
8use azure_core::credentials::{AccessToken, TokenCredential, TokenRequestOptions};
9use azure_core::http::ClientOptions;
10use std::sync::Arc;
11use tracing::info;
12
13/// Identifies a specific user-assigned identity for [`ManagedIdentityCredential`] to authenticate.
14#[derive(Debug, Clone)]
15pub enum UserAssignedId {
16    /// The client ID of a user-assigned identity
17    ClientId(String),
18    /// The object or principal ID of a user-assigned identity
19    ObjectId(String),
20    /// The Azure resource ID of a user-assigned identity
21    ResourceId(String),
22}
23
24/// Authenticates a managed identity from Azure App Service or an Azure Virtual Machine.
25#[derive(Debug)]
26pub struct ManagedIdentityCredential {
27    credential: Arc<dyn TokenCredential>,
28}
29
30/// Options for constructing a new [`ManagedIdentityCredential`].
31#[derive(Clone, Debug, Default)]
32pub struct ManagedIdentityCredentialOptions {
33    /// Specifies a user-assigned identity the credential should authenticate.
34    /// When `None`, the credential will authenticate a system-assigned identity, if any.
35    pub user_assigned_id: Option<UserAssignedId>,
36
37    /// The [`ClientOptions`] to use for the credential's pipeline.
38    pub client_options: ClientOptions,
39
40    #[cfg(test)]
41    pub(crate) env: Env,
42}
43
44impl ManagedIdentityCredential {
45    /// Creates a new instance of `ManagedIdentityCredential`.
46    ///
47    /// # Arguments
48    /// * `options`: Options for configuring the credential. If `None`, the credential uses its default options.
49    ///
50    pub fn new(options: Option<ManagedIdentityCredentialOptions>) -> azure_core::Result<Arc<Self>> {
51        let options = options.unwrap_or_default();
52        #[cfg(test)]
53        let env = options.env;
54        #[cfg(not(test))]
55        let env = Env::default();
56        let source = get_source(&env);
57        let id = options
58            .user_assigned_id
59            .clone()
60            .map(Into::into)
61            .unwrap_or(ImdsId::SystemAssigned);
62
63        let credential: Arc<dyn TokenCredential> = match source {
64            ManagedIdentitySource::AppService => {
65                // App Service does accept resource IDs, however this crate's current implementation sends
66                // them in the wrong query parameter: https://github.com/Azure/azure-sdk-for-rust/issues/2407
67                if let ImdsId::MsiResId(_) = id {
68                    return Err(azure_core::Error::with_message_fn(
69                        azure_core::error::ErrorKind::Credential,
70                        || {
71                            "User-assigned resource IDs aren't supported for App Service. Use a client or object ID instead.".to_string()
72                        },
73                    ));
74                }
75                AppServiceManagedIdentityCredential::new(id, options.client_options, env)?
76            }
77            ManagedIdentitySource::Imds => {
78                VirtualMachineManagedIdentityCredential::new(id, options.client_options, env)?
79            }
80            _ => {
81                return Err(azure_core::Error::with_message_fn(
82                    azure_core::error::ErrorKind::Credential,
83                    || format!("{} managed identity isn't supported", source.as_str()),
84                ));
85            }
86        };
87
88        info!(user_assigned_id = ?options.user_assigned_id, "ManagedIdentityCredential will use {} managed identity", source.as_str());
89
90        Ok(Arc::new(Self { credential }))
91    }
92}
93
94#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
95#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
96impl TokenCredential for ManagedIdentityCredential {
97    async fn get_token(
98        &self,
99        scopes: &[&str],
100        options: Option<TokenRequestOptions<'_>>,
101    ) -> azure_core::Result<AccessToken> {
102        if scopes.len() != 1 {
103            return Err(azure_core::Error::with_message(
104                azure_core::error::ErrorKind::Credential,
105                "ManagedIdentityCredential requires exactly one scope".to_string(),
106            ));
107        }
108        self.credential
109            .get_token(scopes, options)
110            .await
111            .map_err(|err| authentication_error(stringify!(ManagedIdentityCredential), err))
112    }
113}
114
115#[derive(Debug, Copy, Clone)]
116enum ManagedIdentitySource {
117    AzureArc,
118    AzureML,
119    AppService,
120    CloudShell,
121    Imds,
122    ServiceFabric,
123}
124
125impl ManagedIdentitySource {
126    pub fn as_str(&self) -> &'static str {
127        match self {
128            ManagedIdentitySource::AzureArc => "Azure Arc",
129            ManagedIdentitySource::AzureML => "Azure ML",
130            ManagedIdentitySource::AppService => "App Service",
131            ManagedIdentitySource::CloudShell => "CloudShell",
132            ManagedIdentitySource::Imds => "IMDS",
133            ManagedIdentitySource::ServiceFabric => "Service Fabric",
134        }
135    }
136}
137
138const IDENTITY_ENDPOINT: &str = "IDENTITY_ENDPOINT";
139const IDENTITY_HEADER: &str = "IDENTITY_HEADER";
140const IDENTITY_SERVER_THUMBPRINT: &str = "IDENTITY_SERVER_THUMBPRINT";
141const IMDS_ENDPOINT: &str = "IMDS_ENDPOINT";
142const MSI_ENDPOINT: &str = "MSI_ENDPOINT";
143const MSI_SECRET: &str = "MSI_SECRET";
144
145fn get_source(env: &Env) -> ManagedIdentitySource {
146    use ManagedIdentitySource::*;
147    if env.var(IDENTITY_ENDPOINT).is_ok() {
148        if env.var(IDENTITY_HEADER).is_ok() {
149            if env.var(IDENTITY_SERVER_THUMBPRINT).is_ok() {
150                return ServiceFabric;
151            }
152            return AppService;
153        } else if env.var(IMDS_ENDPOINT).is_ok() {
154            return AzureArc;
155        }
156    } else if env.var(MSI_ENDPOINT).is_ok() {
157        if env.var(MSI_SECRET).is_ok() {
158            return AzureML;
159        }
160        return CloudShell;
161    }
162    Imds
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::{
169        env::Env,
170        tests::{LIVE_TEST_RESOURCE, LIVE_TEST_SCOPES},
171    };
172    use azure_core::http::{
173        AsyncRawResponse, Method, RawResponse, Request, StatusCode, Transport, Url,
174    };
175    use azure_core::time::OffsetDateTime;
176    use azure_core::Bytes;
177    use azure_core::{error::ErrorKind, http::headers::Headers};
178    use azure_core_test::{http::MockHttpClient, recorded};
179    use futures::FutureExt;
180    use std::env;
181    use std::sync::atomic::{AtomicUsize, Ordering};
182    use std::time::{SystemTime, UNIX_EPOCH};
183
184    const EXPIRES_ON: &str = "EXPIRES_ON";
185
186    async fn run_deployed_test(
187        authority: &str,
188        storage_name: &str,
189        id: Option<UserAssignedId>,
190    ) -> azure_core::Result<()> {
191        let id_param = id.map_or("".to_string(), |id| match id {
192            UserAssignedId::ClientId(id) => format!("client-id={id}&"),
193            UserAssignedId::ObjectId(id) => format!("object-id={id}&"),
194            UserAssignedId::ResourceId(id) => format!("resource-id={id}&"),
195        });
196        let url = format!(
197            "http://{authority}/api?test=managed-identity&{id_param}storage-name={storage_name}"
198        );
199        let u = Url::parse(&url).expect("invalid URL");
200        let client = azure_core::http::new_http_client();
201        let req = Request::new(u, Method::Get);
202
203        let res = client.execute_request(&req).await.expect("request failed");
204        let status = res.status();
205        let body = res.into_body().collect_string().await?;
206        assert_eq!(StatusCode::Ok, status, "Test app responded with '{body}'");
207
208        Ok(())
209    }
210
211    async fn run_error_response_test(source: ManagedIdentitySource) {
212        let expected_status = StatusCode::ImATeapot;
213        let headers = Headers::default();
214        let content: &str = "is a teapot";
215        let body = Bytes::copy_from_slice(content.as_bytes());
216        let expected_response =
217            RawResponse::from_bytes(expected_status, headers.clone(), body.clone());
218        let mock_headers = headers.clone();
219        let mock_body = body.clone();
220        let mock_client = MockHttpClient::new(move |_| {
221            let headers = mock_headers.clone();
222            let body = mock_body.clone();
223            async move { Ok(AsyncRawResponse::from_bytes(expected_status, headers, body)) }.boxed()
224        });
225        let test_env = match source {
226            ManagedIdentitySource::Imds => Env::from(&[][..]),
227            ManagedIdentitySource::AppService => Env::from(
228                &[
229                    (
230                        IDENTITY_ENDPOINT,
231                        "http://localhost/metadata/identity/oauth2/token",
232                    ),
233                    (IDENTITY_HEADER, "secret"),
234                ][..],
235            ),
236            other => panic!("unsupported managed identity source {:?}", other),
237        };
238        let options = ManagedIdentityCredentialOptions {
239            client_options: ClientOptions {
240                transport: Some(Transport::new(Arc::new(mock_client))),
241                ..Default::default()
242            },
243            env: test_env,
244            ..Default::default()
245        };
246        let credential = ManagedIdentityCredential::new(Some(options)).expect("credential");
247        let err = credential
248            .get_token(LIVE_TEST_SCOPES, None)
249            .await
250            .expect_err("expected error");
251        assert!(matches!(err.kind(), ErrorKind::Credential));
252        assert_eq!(
253            "ManagedIdentityCredential authentication failed. The request failed: is a teapot\nTo troubleshoot, visit https://aka.ms/azsdk/rust/identity/troubleshoot#managed-id",
254            err.to_string(),
255        );
256        match err
257            .downcast_ref::<azure_core::Error>()
258            .expect("returned error should wrap an azure_core::Error")
259            .kind()
260        {
261            ErrorKind::HttpResponse {
262                error_code: None,
263                raw_response: Some(response),
264                status,
265            } => {
266                assert_eq!(response.as_ref(), &expected_response);
267                assert_eq!(expected_status, *status);
268            }
269            err => panic!("unexpected {:?}", err),
270        };
271    }
272
273    async fn run_supported_source_test(
274        env: Env,
275        options: Option<ManagedIdentityCredentialOptions>,
276        expected_source: ManagedIdentitySource,
277        model_request: Request,
278        response_format: String,
279    ) {
280        let actual_source = get_source(&env);
281        assert_eq!(
282            std::mem::discriminant(&actual_source),
283            std::mem::discriminant(&expected_source)
284        );
285        let token_requests = Arc::new(AtomicUsize::new(0));
286        let token_requests_clone = token_requests.clone();
287        let expires_on = SystemTime::now()
288            .duration_since(UNIX_EPOCH)
289            .unwrap()
290            .as_secs()
291            + 3600;
292        let mock_client = MockHttpClient::new(move |actual| {
293            {
294                token_requests_clone.fetch_add(1, Ordering::SeqCst);
295                let expected = model_request.clone();
296                let response_format = response_format.clone();
297                async move {
298                    assert_eq!(expected.method(), actual.method());
299
300                    let mut actual_params: Vec<_> =
301                        actual.url().query_pairs().into_owned().collect();
302                    actual_params.sort();
303                    let mut expected_params: Vec<_> =
304                        expected.url().query_pairs().into_owned().collect();
305                    expected_params.sort();
306                    assert_eq!(expected_params, actual_params);
307
308                    let mut actual_url = actual.url().clone();
309                    actual_url.set_query(None);
310                    let mut expected_url = expected.url().clone();
311                    expected_url.set_query(None);
312                    assert_eq!(actual_url, expected_url);
313
314                    // allow additional headers in the actual request so changing
315                    // the underlying client in the future won't break tests
316                    expected.headers().iter().for_each(|(k, v)| {
317                        assert_eq!(actual.headers().get_str(k).unwrap(), v.as_str())
318                    });
319
320                    Ok(AsyncRawResponse::from_bytes(
321                        StatusCode::Ok,
322                        Headers::default(),
323                        Bytes::from(response_format.replacen(
324                            EXPIRES_ON,
325                            &expires_on.to_string(),
326                            1,
327                        )),
328                    ))
329                }
330            }
331            .boxed()
332        });
333        let mut options = options.unwrap_or_default();
334        options.env = env;
335        options.client_options = ClientOptions {
336            transport: Some(Transport::new(Arc::new(mock_client))),
337            ..Default::default()
338        };
339        let cred = ManagedIdentityCredential::new(Some(options)).expect("credential");
340        for _ in 0..4 {
341            let token = cred.get_token(LIVE_TEST_SCOPES, None).await.expect("token");
342            assert_eq!(token.expires_on.unix_timestamp(), expires_on as i64);
343            assert_eq!(token.token.secret(), "*");
344            assert_eq!(token_requests.load(Ordering::SeqCst), 1);
345        }
346    }
347
348    fn run_unsupported_source_test(env: Env, expected_source: ManagedIdentitySource) {
349        let actual_source = get_source(&env);
350        assert_eq!(
351            std::mem::discriminant(&actual_source),
352            std::mem::discriminant(&expected_source)
353        );
354        let result = ManagedIdentityCredential::new(Some(ManagedIdentityCredentialOptions {
355            env,
356            ..Default::default()
357        }));
358        assert!(
359            matches!(result, Err(ref e) if *e.kind() == azure_core::error::ErrorKind::Credential),
360            "Expected constructor error"
361        );
362    }
363
364    #[recorded::test(live)]
365    async fn aci_user_assigned_live() -> azure_core::Result<()> {
366        if env::var("CI_HAS_DEPLOYED_RESOURCES").is_err() {
367            println!("Skipped: ACI live tests require deployed resources");
368            return Ok(());
369        }
370        let ip = env::var("IDENTITY_ACI_IP_USER_ASSIGNED").expect("IDENTITY_ACI_IP_USER_ASSIGNED");
371        let storage_name = env::var("IDENTITY_STORAGE_NAME_USER_ASSIGNED")
372            .expect("IDENTITY_STORAGE_NAME_USER_ASSIGNED");
373        let client_id = env::var("IDENTITY_USER_ASSIGNED_IDENTITY_CLIENT_ID")
374            .expect("IDENTITY_USER_ASSIGNED_IDENTITY_CLIENT_ID");
375        run_deployed_test(
376            &format!("{}:8080", ip),
377            &storage_name,
378            Some(UserAssignedId::ClientId(client_id)),
379        )
380        .await?;
381
382        Ok(())
383    }
384
385    async fn run_app_service_test(options: Option<ManagedIdentityCredentialOptions>) {
386        let endpoint = "http://localhost/metadata/identity/oauth2/token";
387        let x_id_header = "x-id-header";
388        let mut model = Request::new(endpoint.parse().unwrap(), Method::Get);
389        model.insert_header("x-identity-header", x_id_header);
390        let mut params = Vec::from([
391            ("api-version", "2019-08-01"),
392            ("resource", LIVE_TEST_RESOURCE),
393        ]);
394        if let Some(options) = options.as_ref() {
395            if let Some(ref id) = options.user_assigned_id {
396                match id {
397                    UserAssignedId::ClientId(client_id) => {
398                        params.push(("client_id", client_id));
399                    }
400                    UserAssignedId::ObjectId(object_id) => {
401                        params.push(("object_id", object_id));
402                    }
403                    UserAssignedId::ResourceId(resource_id) => {
404                        params.push(("mi_res_id", resource_id));
405                    }
406                }
407            }
408        }
409        model.url_mut().query_pairs_mut().extend_pairs(params);
410        run_supported_source_test(
411            Env::from(
412                &[
413                    (IDENTITY_ENDPOINT, endpoint),
414                    (IDENTITY_HEADER, x_id_header),
415                ][..],
416            ),
417            options,
418            ManagedIdentitySource::AppService,
419            model,
420            format!(
421                r#"{{"access_token":"*","expires_on":"{}","resource":"{}","token_type":"Bearer"}}"#,
422                EXPIRES_ON, LIVE_TEST_RESOURCE
423            )
424            .to_string(),
425        )
426        .await;
427    }
428
429    #[tokio::test]
430    async fn app_service() {
431        run_app_service_test(None).await;
432    }
433
434    #[tokio::test]
435    async fn app_service_client_id() {
436        run_app_service_test(Some(ManagedIdentityCredentialOptions {
437            user_assigned_id: Some(UserAssignedId::ClientId("expected client ID".to_string())),
438            ..Default::default()
439        }))
440        .await;
441    }
442
443    #[tokio::test]
444    async fn app_service_error_response() {
445        run_error_response_test(ManagedIdentitySource::AppService).await
446    }
447
448    #[tokio::test]
449    async fn app_service_object_id() {
450        run_app_service_test(Some(ManagedIdentityCredentialOptions {
451            user_assigned_id: Some(UserAssignedId::ObjectId("expected object ID".to_string())),
452            ..Default::default()
453        }))
454        .await;
455    }
456
457    #[tokio::test]
458    async fn app_service_resource_id() {
459        let result = ManagedIdentityCredential::new(Some(ManagedIdentityCredentialOptions {
460            env: Env::from(&[(IDENTITY_ENDPOINT, "..."), (IDENTITY_HEADER, "x-id-header")][..]),
461            user_assigned_id: Some(UserAssignedId::ResourceId(
462                "expected resource ID".to_string(),
463            )),
464            ..Default::default()
465        }));
466        assert!(
467            matches!(result, Err(ref e) if *e.kind() == azure_core::error::ErrorKind::Credential),
468            "Expected constructor error"
469        );
470    }
471
472    #[test]
473    fn arc() {
474        run_unsupported_source_test(
475            Env::from(
476                &[
477                    (IDENTITY_ENDPOINT, "http://localhost"),
478                    (IMDS_ENDPOINT, "..."),
479                ][..],
480            ),
481            ManagedIdentitySource::AzureArc,
482        );
483    }
484
485    #[test]
486    fn azure_ml() {
487        run_unsupported_source_test(
488            Env::from(&[(MSI_ENDPOINT, "..."), (MSI_SECRET, "...")][..]),
489            ManagedIdentitySource::AzureML,
490        );
491    }
492
493    #[test]
494    fn cloudshell() {
495        run_unsupported_source_test(
496            Env::from(&[(MSI_ENDPOINT, "http://localhost")][..]),
497            ManagedIdentitySource::CloudShell,
498        );
499    }
500
501    async fn run_imds_live_test(id: Option<UserAssignedId>) -> azure_core::Result<()> {
502        if std::env::var("IDENTITY_IMDS_AVAILABLE").is_err() {
503            println!("Skipped: IMDS isn't available");
504            return Ok(());
505        }
506
507        let credential = ManagedIdentityCredential::new(Some(ManagedIdentityCredentialOptions {
508            user_assigned_id: id,
509            ..Default::default()
510        }))
511        .expect("valid credential");
512
513        let token = credential.get_token(LIVE_TEST_SCOPES, None).await?;
514
515        assert!(!token.token.secret().is_empty());
516        assert_eq!(time::UtcOffset::UTC, token.expires_on.offset());
517        assert!(token.expires_on.unix_timestamp() > OffsetDateTime::now_utc().unix_timestamp());
518
519        Ok(())
520    }
521
522    async fn run_imds_test(options: Option<ManagedIdentityCredentialOptions>) {
523        let mut model = Request::new(
524            "http://169.254.169.254/metadata/identity/oauth2/token"
525                .parse()
526                .unwrap(),
527            Method::Get,
528        );
529        model.insert_header("metadata", "true");
530
531        let mut params = Vec::from([
532            ("api-version", "2019-08-01"),
533            ("resource", LIVE_TEST_RESOURCE),
534        ]);
535        if let Some(options) = options.as_ref() {
536            if let Some(ref id) = options.user_assigned_id {
537                match id {
538                    UserAssignedId::ClientId(client_id) => {
539                        params.push(("client_id", client_id));
540                    }
541                    UserAssignedId::ObjectId(object_id) => {
542                        params.push(("object_id", object_id));
543                    }
544                    UserAssignedId::ResourceId(resource_id) => {
545                        params.push(("msi_res_id", resource_id));
546                    }
547                }
548            }
549        }
550        model.url_mut().query_pairs_mut().extend_pairs(params);
551
552        run_supported_source_test(
553            Env::from(&[][..]),
554            options,
555            ManagedIdentitySource::Imds,
556            model,
557            format!(r#"{{"token_type":"Bearer","expires_in":"85770","expires_on":"{}","ext_expires_in":86399,"access_token":"*","resource":"{}"}}"#, EXPIRES_ON, LIVE_TEST_RESOURCE).to_string(),
558        ).await;
559    }
560
561    #[tokio::test]
562    async fn imds_client_id() {
563        run_imds_test(Some(ManagedIdentityCredentialOptions {
564            user_assigned_id: Some(UserAssignedId::ClientId("expected client ID".to_string())),
565            ..Default::default()
566        }))
567        .await;
568    }
569
570    #[tokio::test]
571    async fn imds_error_response() {
572        run_error_response_test(ManagedIdentitySource::Imds).await
573    }
574
575    #[tokio::test]
576    async fn imds_object_id() {
577        run_imds_test(Some(ManagedIdentityCredentialOptions {
578            user_assigned_id: Some(UserAssignedId::ObjectId("expected object ID".to_string())),
579            ..Default::default()
580        }))
581        .await;
582    }
583
584    #[tokio::test]
585    async fn imds_resource_id() {
586        run_imds_test(Some(ManagedIdentityCredentialOptions {
587            user_assigned_id: Some(UserAssignedId::ResourceId(
588                "expected resource ID".to_string(),
589            )),
590            ..Default::default()
591        }))
592        .await;
593    }
594
595    #[tokio::test]
596    async fn imds_system_assigned() {
597        run_imds_test(None).await;
598    }
599
600    #[recorded::test(live)]
601    async fn imds_system_assigned_live() -> azure_core::Result<()> {
602        run_imds_live_test(None).await
603    }
604
605    #[tokio::test]
606    async fn requires_one_scope() {
607        let credential = ManagedIdentityCredential::new(None).expect("valid credential");
608        for scopes in [&[][..], &["A", "B"][..]].iter() {
609            credential
610                .get_token(scopes, None)
611                .await
612                .expect_err("expected an error, got");
613        }
614    }
615
616    #[test]
617    fn service_fabric() {
618        run_unsupported_source_test(
619            Env::from(
620                &[
621                    (IDENTITY_ENDPOINT, "http://localhost"),
622                    (IDENTITY_HEADER, "..."),
623                    (IDENTITY_SERVER_THUMBPRINT, "..."),
624                ][..],
625            ),
626            ManagedIdentitySource::ServiceFabric,
627        );
628    }
629}