1use crate::{
5 authentication_error, env::Env, AppServiceManagedIdentityCredential, ImdsId,
6 VirtualMachineManagedIdentityCredential,
7};
8use azure_core::credentials::{AccessToken, TokenCredential, TokenRequestOptions};
9use azure_core::http::ClientOptions;
10use std::sync::Arc;
11use tracing::info;
12
13#[derive(Debug, Clone)]
15pub enum UserAssignedId {
16 ClientId(String),
18 ObjectId(String),
20 ResourceId(String),
22}
23
24#[derive(Debug)]
26pub struct ManagedIdentityCredential {
27 credential: Arc<dyn TokenCredential>,
28}
29
30#[derive(Clone, Debug, Default)]
32pub struct ManagedIdentityCredentialOptions {
33 pub user_assigned_id: Option<UserAssignedId>,
36
37 pub client_options: ClientOptions,
39
40 #[cfg(test)]
41 pub(crate) env: Env,
42}
43
44impl ManagedIdentityCredential {
45 pub fn new(options: Option<ManagedIdentityCredentialOptions>) -> azure_core::Result<Arc<Self>> {
51 let options = options.unwrap_or_default();
52 #[cfg(test)]
53 let env = options.env;
54 #[cfg(not(test))]
55 let env = Env::default();
56 let source = get_source(&env);
57 let id = options
58 .user_assigned_id
59 .clone()
60 .map(Into::into)
61 .unwrap_or(ImdsId::SystemAssigned);
62
63 let credential: Arc<dyn TokenCredential> = match source {
64 ManagedIdentitySource::AppService => {
65 if let ImdsId::MsiResId(_) = id {
68 return Err(azure_core::Error::with_message_fn(
69 azure_core::error::ErrorKind::Credential,
70 || {
71 "User-assigned resource IDs aren't supported for App Service. Use a client or object ID instead.".to_string()
72 },
73 ));
74 }
75 AppServiceManagedIdentityCredential::new(id, options.client_options, env)?
76 }
77 ManagedIdentitySource::Imds => {
78 VirtualMachineManagedIdentityCredential::new(id, options.client_options, env)?
79 }
80 _ => {
81 return Err(azure_core::Error::with_message_fn(
82 azure_core::error::ErrorKind::Credential,
83 || format!("{} managed identity isn't supported", source.as_str()),
84 ));
85 }
86 };
87
88 info!(user_assigned_id = ?options.user_assigned_id, "ManagedIdentityCredential will use {} managed identity", source.as_str());
89
90 Ok(Arc::new(Self { credential }))
91 }
92}
93
94#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
95#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
96impl TokenCredential for ManagedIdentityCredential {
97 async fn get_token(
98 &self,
99 scopes: &[&str],
100 options: Option<TokenRequestOptions<'_>>,
101 ) -> azure_core::Result<AccessToken> {
102 if scopes.len() != 1 {
103 return Err(azure_core::Error::with_message(
104 azure_core::error::ErrorKind::Credential,
105 "ManagedIdentityCredential requires exactly one scope".to_string(),
106 ));
107 }
108 self.credential
109 .get_token(scopes, options)
110 .await
111 .map_err(|err| authentication_error(stringify!(ManagedIdentityCredential), err))
112 }
113}
114
115#[derive(Debug, Copy, Clone)]
116enum ManagedIdentitySource {
117 AzureArc,
118 AzureML,
119 AppService,
120 CloudShell,
121 Imds,
122 ServiceFabric,
123}
124
125impl ManagedIdentitySource {
126 pub fn as_str(&self) -> &'static str {
127 match self {
128 ManagedIdentitySource::AzureArc => "Azure Arc",
129 ManagedIdentitySource::AzureML => "Azure ML",
130 ManagedIdentitySource::AppService => "App Service",
131 ManagedIdentitySource::CloudShell => "CloudShell",
132 ManagedIdentitySource::Imds => "IMDS",
133 ManagedIdentitySource::ServiceFabric => "Service Fabric",
134 }
135 }
136}
137
138const IDENTITY_ENDPOINT: &str = "IDENTITY_ENDPOINT";
139const IDENTITY_HEADER: &str = "IDENTITY_HEADER";
140const IDENTITY_SERVER_THUMBPRINT: &str = "IDENTITY_SERVER_THUMBPRINT";
141const IMDS_ENDPOINT: &str = "IMDS_ENDPOINT";
142const MSI_ENDPOINT: &str = "MSI_ENDPOINT";
143const MSI_SECRET: &str = "MSI_SECRET";
144
145fn get_source(env: &Env) -> ManagedIdentitySource {
146 use ManagedIdentitySource::*;
147 if env.var(IDENTITY_ENDPOINT).is_ok() {
148 if env.var(IDENTITY_HEADER).is_ok() {
149 if env.var(IDENTITY_SERVER_THUMBPRINT).is_ok() {
150 return ServiceFabric;
151 }
152 return AppService;
153 } else if env.var(IMDS_ENDPOINT).is_ok() {
154 return AzureArc;
155 }
156 } else if env.var(MSI_ENDPOINT).is_ok() {
157 if env.var(MSI_SECRET).is_ok() {
158 return AzureML;
159 }
160 return CloudShell;
161 }
162 Imds
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::{
169 env::Env,
170 tests::{LIVE_TEST_RESOURCE, LIVE_TEST_SCOPES},
171 };
172 use azure_core::http::{
173 AsyncRawResponse, Method, RawResponse, Request, StatusCode, Transport, Url,
174 };
175 use azure_core::time::OffsetDateTime;
176 use azure_core::Bytes;
177 use azure_core::{error::ErrorKind, http::headers::Headers};
178 use azure_core_test::{http::MockHttpClient, recorded};
179 use futures::FutureExt;
180 use std::env;
181 use std::sync::atomic::{AtomicUsize, Ordering};
182 use std::time::{SystemTime, UNIX_EPOCH};
183
184 const EXPIRES_ON: &str = "EXPIRES_ON";
185
186 async fn run_deployed_test(
187 authority: &str,
188 storage_name: &str,
189 id: Option<UserAssignedId>,
190 ) -> azure_core::Result<()> {
191 let id_param = id.map_or("".to_string(), |id| match id {
192 UserAssignedId::ClientId(id) => format!("client-id={id}&"),
193 UserAssignedId::ObjectId(id) => format!("object-id={id}&"),
194 UserAssignedId::ResourceId(id) => format!("resource-id={id}&"),
195 });
196 let url = format!(
197 "http://{authority}/api?test=managed-identity&{id_param}storage-name={storage_name}"
198 );
199 let u = Url::parse(&url).expect("invalid URL");
200 let client = azure_core::http::new_http_client();
201 let req = Request::new(u, Method::Get);
202
203 let res = client.execute_request(&req).await.expect("request failed");
204 let status = res.status();
205 let body = res.into_body().collect_string().await?;
206 assert_eq!(StatusCode::Ok, status, "Test app responded with '{body}'");
207
208 Ok(())
209 }
210
211 async fn run_error_response_test(source: ManagedIdentitySource) {
212 let expected_status = StatusCode::ImATeapot;
213 let headers = Headers::default();
214 let content: &str = "is a teapot";
215 let body = Bytes::copy_from_slice(content.as_bytes());
216 let expected_response =
217 RawResponse::from_bytes(expected_status, headers.clone(), body.clone());
218 let mock_headers = headers.clone();
219 let mock_body = body.clone();
220 let mock_client = MockHttpClient::new(move |_| {
221 let headers = mock_headers.clone();
222 let body = mock_body.clone();
223 async move { Ok(AsyncRawResponse::from_bytes(expected_status, headers, body)) }.boxed()
224 });
225 let test_env = match source {
226 ManagedIdentitySource::Imds => Env::from(&[][..]),
227 ManagedIdentitySource::AppService => Env::from(
228 &[
229 (
230 IDENTITY_ENDPOINT,
231 "http://localhost/metadata/identity/oauth2/token",
232 ),
233 (IDENTITY_HEADER, "secret"),
234 ][..],
235 ),
236 other => panic!("unsupported managed identity source {:?}", other),
237 };
238 let options = ManagedIdentityCredentialOptions {
239 client_options: ClientOptions {
240 transport: Some(Transport::new(Arc::new(mock_client))),
241 ..Default::default()
242 },
243 env: test_env,
244 ..Default::default()
245 };
246 let credential = ManagedIdentityCredential::new(Some(options)).expect("credential");
247 let err = credential
248 .get_token(LIVE_TEST_SCOPES, None)
249 .await
250 .expect_err("expected error");
251 assert!(matches!(err.kind(), ErrorKind::Credential));
252 assert_eq!(
253 "ManagedIdentityCredential authentication failed. The request failed: is a teapot\nTo troubleshoot, visit https://aka.ms/azsdk/rust/identity/troubleshoot#managed-id",
254 err.to_string(),
255 );
256 match err
257 .downcast_ref::<azure_core::Error>()
258 .expect("returned error should wrap an azure_core::Error")
259 .kind()
260 {
261 ErrorKind::HttpResponse {
262 error_code: None,
263 raw_response: Some(response),
264 status,
265 } => {
266 assert_eq!(response.as_ref(), &expected_response);
267 assert_eq!(expected_status, *status);
268 }
269 err => panic!("unexpected {:?}", err),
270 };
271 }
272
273 async fn run_supported_source_test(
274 env: Env,
275 options: Option<ManagedIdentityCredentialOptions>,
276 expected_source: ManagedIdentitySource,
277 model_request: Request,
278 response_format: String,
279 ) {
280 let actual_source = get_source(&env);
281 assert_eq!(
282 std::mem::discriminant(&actual_source),
283 std::mem::discriminant(&expected_source)
284 );
285 let token_requests = Arc::new(AtomicUsize::new(0));
286 let token_requests_clone = token_requests.clone();
287 let expires_on = SystemTime::now()
288 .duration_since(UNIX_EPOCH)
289 .unwrap()
290 .as_secs()
291 + 3600;
292 let mock_client = MockHttpClient::new(move |actual| {
293 {
294 token_requests_clone.fetch_add(1, Ordering::SeqCst);
295 let expected = model_request.clone();
296 let response_format = response_format.clone();
297 async move {
298 assert_eq!(expected.method(), actual.method());
299
300 let mut actual_params: Vec<_> =
301 actual.url().query_pairs().into_owned().collect();
302 actual_params.sort();
303 let mut expected_params: Vec<_> =
304 expected.url().query_pairs().into_owned().collect();
305 expected_params.sort();
306 assert_eq!(expected_params, actual_params);
307
308 let mut actual_url = actual.url().clone();
309 actual_url.set_query(None);
310 let mut expected_url = expected.url().clone();
311 expected_url.set_query(None);
312 assert_eq!(actual_url, expected_url);
313
314 expected.headers().iter().for_each(|(k, v)| {
317 assert_eq!(actual.headers().get_str(k).unwrap(), v.as_str())
318 });
319
320 Ok(AsyncRawResponse::from_bytes(
321 StatusCode::Ok,
322 Headers::default(),
323 Bytes::from(response_format.replacen(
324 EXPIRES_ON,
325 &expires_on.to_string(),
326 1,
327 )),
328 ))
329 }
330 }
331 .boxed()
332 });
333 let mut options = options.unwrap_or_default();
334 options.env = env;
335 options.client_options = ClientOptions {
336 transport: Some(Transport::new(Arc::new(mock_client))),
337 ..Default::default()
338 };
339 let cred = ManagedIdentityCredential::new(Some(options)).expect("credential");
340 for _ in 0..4 {
341 let token = cred.get_token(LIVE_TEST_SCOPES, None).await.expect("token");
342 assert_eq!(token.expires_on.unix_timestamp(), expires_on as i64);
343 assert_eq!(token.token.secret(), "*");
344 assert_eq!(token_requests.load(Ordering::SeqCst), 1);
345 }
346 }
347
348 fn run_unsupported_source_test(env: Env, expected_source: ManagedIdentitySource) {
349 let actual_source = get_source(&env);
350 assert_eq!(
351 std::mem::discriminant(&actual_source),
352 std::mem::discriminant(&expected_source)
353 );
354 let result = ManagedIdentityCredential::new(Some(ManagedIdentityCredentialOptions {
355 env,
356 ..Default::default()
357 }));
358 assert!(
359 matches!(result, Err(ref e) if *e.kind() == azure_core::error::ErrorKind::Credential),
360 "Expected constructor error"
361 );
362 }
363
364 #[recorded::test(live)]
365 async fn aci_user_assigned_live() -> azure_core::Result<()> {
366 if env::var("CI_HAS_DEPLOYED_RESOURCES").is_err() {
367 println!("Skipped: ACI live tests require deployed resources");
368 return Ok(());
369 }
370 let ip = env::var("IDENTITY_ACI_IP_USER_ASSIGNED").expect("IDENTITY_ACI_IP_USER_ASSIGNED");
371 let storage_name = env::var("IDENTITY_STORAGE_NAME_USER_ASSIGNED")
372 .expect("IDENTITY_STORAGE_NAME_USER_ASSIGNED");
373 let client_id = env::var("IDENTITY_USER_ASSIGNED_IDENTITY_CLIENT_ID")
374 .expect("IDENTITY_USER_ASSIGNED_IDENTITY_CLIENT_ID");
375 run_deployed_test(
376 &format!("{}:8080", ip),
377 &storage_name,
378 Some(UserAssignedId::ClientId(client_id)),
379 )
380 .await?;
381
382 Ok(())
383 }
384
385 async fn run_app_service_test(options: Option<ManagedIdentityCredentialOptions>) {
386 let endpoint = "http://localhost/metadata/identity/oauth2/token";
387 let x_id_header = "x-id-header";
388 let mut model = Request::new(endpoint.parse().unwrap(), Method::Get);
389 model.insert_header("x-identity-header", x_id_header);
390 let mut params = Vec::from([
391 ("api-version", "2019-08-01"),
392 ("resource", LIVE_TEST_RESOURCE),
393 ]);
394 if let Some(options) = options.as_ref() {
395 if let Some(ref id) = options.user_assigned_id {
396 match id {
397 UserAssignedId::ClientId(client_id) => {
398 params.push(("client_id", client_id));
399 }
400 UserAssignedId::ObjectId(object_id) => {
401 params.push(("object_id", object_id));
402 }
403 UserAssignedId::ResourceId(resource_id) => {
404 params.push(("mi_res_id", resource_id));
405 }
406 }
407 }
408 }
409 model.url_mut().query_pairs_mut().extend_pairs(params);
410 run_supported_source_test(
411 Env::from(
412 &[
413 (IDENTITY_ENDPOINT, endpoint),
414 (IDENTITY_HEADER, x_id_header),
415 ][..],
416 ),
417 options,
418 ManagedIdentitySource::AppService,
419 model,
420 format!(
421 r#"{{"access_token":"*","expires_on":"{}","resource":"{}","token_type":"Bearer"}}"#,
422 EXPIRES_ON, LIVE_TEST_RESOURCE
423 )
424 .to_string(),
425 )
426 .await;
427 }
428
429 #[tokio::test]
430 async fn app_service() {
431 run_app_service_test(None).await;
432 }
433
434 #[tokio::test]
435 async fn app_service_client_id() {
436 run_app_service_test(Some(ManagedIdentityCredentialOptions {
437 user_assigned_id: Some(UserAssignedId::ClientId("expected client ID".to_string())),
438 ..Default::default()
439 }))
440 .await;
441 }
442
443 #[tokio::test]
444 async fn app_service_error_response() {
445 run_error_response_test(ManagedIdentitySource::AppService).await
446 }
447
448 #[tokio::test]
449 async fn app_service_object_id() {
450 run_app_service_test(Some(ManagedIdentityCredentialOptions {
451 user_assigned_id: Some(UserAssignedId::ObjectId("expected object ID".to_string())),
452 ..Default::default()
453 }))
454 .await;
455 }
456
457 #[tokio::test]
458 async fn app_service_resource_id() {
459 let result = ManagedIdentityCredential::new(Some(ManagedIdentityCredentialOptions {
460 env: Env::from(&[(IDENTITY_ENDPOINT, "..."), (IDENTITY_HEADER, "x-id-header")][..]),
461 user_assigned_id: Some(UserAssignedId::ResourceId(
462 "expected resource ID".to_string(),
463 )),
464 ..Default::default()
465 }));
466 assert!(
467 matches!(result, Err(ref e) if *e.kind() == azure_core::error::ErrorKind::Credential),
468 "Expected constructor error"
469 );
470 }
471
472 #[test]
473 fn arc() {
474 run_unsupported_source_test(
475 Env::from(
476 &[
477 (IDENTITY_ENDPOINT, "http://localhost"),
478 (IMDS_ENDPOINT, "..."),
479 ][..],
480 ),
481 ManagedIdentitySource::AzureArc,
482 );
483 }
484
485 #[test]
486 fn azure_ml() {
487 run_unsupported_source_test(
488 Env::from(&[(MSI_ENDPOINT, "..."), (MSI_SECRET, "...")][..]),
489 ManagedIdentitySource::AzureML,
490 );
491 }
492
493 #[test]
494 fn cloudshell() {
495 run_unsupported_source_test(
496 Env::from(&[(MSI_ENDPOINT, "http://localhost")][..]),
497 ManagedIdentitySource::CloudShell,
498 );
499 }
500
501 async fn run_imds_live_test(id: Option<UserAssignedId>) -> azure_core::Result<()> {
502 if std::env::var("IDENTITY_IMDS_AVAILABLE").is_err() {
503 println!("Skipped: IMDS isn't available");
504 return Ok(());
505 }
506
507 let credential = ManagedIdentityCredential::new(Some(ManagedIdentityCredentialOptions {
508 user_assigned_id: id,
509 ..Default::default()
510 }))
511 .expect("valid credential");
512
513 let token = credential.get_token(LIVE_TEST_SCOPES, None).await?;
514
515 assert!(!token.token.secret().is_empty());
516 assert_eq!(time::UtcOffset::UTC, token.expires_on.offset());
517 assert!(token.expires_on.unix_timestamp() > OffsetDateTime::now_utc().unix_timestamp());
518
519 Ok(())
520 }
521
522 async fn run_imds_test(options: Option<ManagedIdentityCredentialOptions>) {
523 let mut model = Request::new(
524 "http://169.254.169.254/metadata/identity/oauth2/token"
525 .parse()
526 .unwrap(),
527 Method::Get,
528 );
529 model.insert_header("metadata", "true");
530
531 let mut params = Vec::from([
532 ("api-version", "2019-08-01"),
533 ("resource", LIVE_TEST_RESOURCE),
534 ]);
535 if let Some(options) = options.as_ref() {
536 if let Some(ref id) = options.user_assigned_id {
537 match id {
538 UserAssignedId::ClientId(client_id) => {
539 params.push(("client_id", client_id));
540 }
541 UserAssignedId::ObjectId(object_id) => {
542 params.push(("object_id", object_id));
543 }
544 UserAssignedId::ResourceId(resource_id) => {
545 params.push(("msi_res_id", resource_id));
546 }
547 }
548 }
549 }
550 model.url_mut().query_pairs_mut().extend_pairs(params);
551
552 run_supported_source_test(
553 Env::from(&[][..]),
554 options,
555 ManagedIdentitySource::Imds,
556 model,
557 format!(r#"{{"token_type":"Bearer","expires_in":"85770","expires_on":"{}","ext_expires_in":86399,"access_token":"*","resource":"{}"}}"#, EXPIRES_ON, LIVE_TEST_RESOURCE).to_string(),
558 ).await;
559 }
560
561 #[tokio::test]
562 async fn imds_client_id() {
563 run_imds_test(Some(ManagedIdentityCredentialOptions {
564 user_assigned_id: Some(UserAssignedId::ClientId("expected client ID".to_string())),
565 ..Default::default()
566 }))
567 .await;
568 }
569
570 #[tokio::test]
571 async fn imds_error_response() {
572 run_error_response_test(ManagedIdentitySource::Imds).await
573 }
574
575 #[tokio::test]
576 async fn imds_object_id() {
577 run_imds_test(Some(ManagedIdentityCredentialOptions {
578 user_assigned_id: Some(UserAssignedId::ObjectId("expected object ID".to_string())),
579 ..Default::default()
580 }))
581 .await;
582 }
583
584 #[tokio::test]
585 async fn imds_resource_id() {
586 run_imds_test(Some(ManagedIdentityCredentialOptions {
587 user_assigned_id: Some(UserAssignedId::ResourceId(
588 "expected resource ID".to_string(),
589 )),
590 ..Default::default()
591 }))
592 .await;
593 }
594
595 #[tokio::test]
596 async fn imds_system_assigned() {
597 run_imds_test(None).await;
598 }
599
600 #[recorded::test(live)]
601 async fn imds_system_assigned_live() -> azure_core::Result<()> {
602 run_imds_live_test(None).await
603 }
604
605 #[tokio::test]
606 async fn requires_one_scope() {
607 let credential = ManagedIdentityCredential::new(None).expect("valid credential");
608 for scopes in [&[][..], &["A", "B"][..]].iter() {
609 credential
610 .get_token(scopes, None)
611 .await
612 .expect_err("expected an error, got");
613 }
614 }
615
616 #[test]
617 fn service_fabric() {
618 run_unsupported_source_test(
619 Env::from(
620 &[
621 (IDENTITY_ENDPOINT, "http://localhost"),
622 (IDENTITY_HEADER, "..."),
623 (IDENTITY_SERVER_THUMBPRINT, "..."),
624 ][..],
625 ),
626 ManagedIdentitySource::ServiceFabric,
627 );
628 }
629}