1use std::fs;
20use std::io::Read;
21
22use crate::error::MsalError;
23use base64::engine::general_purpose::STANDARD;
24use base64::Engine;
25use openssl::pkey::Public;
26use openssl::rsa::Rsa;
27use openssl::x509::X509;
28use os_release::OsRelease;
29use reqwest::{header, Client, Url};
30use serde::{Deserialize, Serialize};
31use serde_json::json;
32use serde_json::to_string_pretty;
33use tracing::debug;
34use zeroize::{Zeroize, ZeroizeOnDrop};
35
36pub const DRS_CLIENT_NAME_HEADER_FIELD: &str = "ocp-adrs-client-name";
37pub const DRS_CLIENT_VERSION_HEADER_FIELD: &str = "ocp-adrs-client-version";
38pub const DISCOVERY_URL: &str = "https://enterpriseregistration.windows.net";
39const DRS_PROTOCOL_VERSION: &str = "1.9";
40
41#[cfg(feature = "broker")]
42#[derive(Debug, Deserialize, Zeroize, ZeroizeOnDrop)]
43struct Certificate {
44 #[serde(rename = "RawBody")]
45 raw_body: String,
46}
47
48#[cfg(feature = "broker")]
49#[derive(Debug, Deserialize, Zeroize, ZeroizeOnDrop)]
50struct DRSResponse {
51 #[serde(rename = "Certificate")]
52 certificate: Certificate,
53}
54
55#[cfg(feature = "broker")]
56#[derive(Zeroize, ZeroizeOnDrop)]
57pub(crate) struct BcryptRsaKeyBlob {
58 bit_length: u32,
59 exponent: Vec<u8>,
60 modulus: Vec<u8>,
61}
62
63#[cfg(feature = "broker")]
64impl BcryptRsaKeyBlob {
65 pub(crate) fn new(bit_length: u32, exponent: &[u8], modulus: &[u8]) -> Self {
66 BcryptRsaKeyBlob {
67 bit_length,
68 exponent: exponent.to_vec(),
69 modulus: modulus.to_vec(),
70 }
71 }
72}
73
74#[cfg(feature = "broker")]
75impl TryInto<Vec<u8>> for BcryptRsaKeyBlob {
76 type Error = MsalError;
77
78 fn try_into(self) -> Result<Vec<u8>, Self::Error> {
79 let mut cng_blob = b"RSA1".to_vec(); cng_blob.extend_from_slice(&self.bit_length.to_le_bytes()); let exponent_len: u32 = self.exponent.len().try_into().map_err(|e| {
82 MsalError::GeneralFailure(format!("Exponent len into u32 failed: {:?}", e))
83 })?;
84 cng_blob.extend_from_slice(&exponent_len.to_le_bytes()); let modulus_len: u32 = self.modulus.len().try_into().map_err(|e| {
86 MsalError::GeneralFailure(format!("Modulus len into u32 failed: {:?}", e))
87 })?;
88 cng_blob.extend_from_slice(&modulus_len.to_le_bytes()); let prime1_len: u32 = 0;
93 cng_blob.extend_from_slice(&prime1_len.to_le_bytes()); let prime2_len: u32 = 0;
95 cng_blob.extend_from_slice(&prime2_len.to_le_bytes()); cng_blob.extend_from_slice(self.exponent.as_slice()); cng_blob.extend_from_slice(self.modulus.as_slice()); Ok(cng_blob)
100 }
101}
102
103#[derive(Debug, Deserialize)]
104pub struct ServicesService {
105 #[serde(rename = "ServicesEndpoint")]
106 pub endpoint: Option<String>,
107 #[serde(rename = "ServiceVersion")]
108 pub service_version: Option<String>,
109}
110
111#[derive(Debug, Deserialize)]
112pub struct DeviceRegistrationService {
113 #[serde(rename = "RegistrationEndpoint")]
114 pub endpoint: Option<String>,
115 #[serde(rename = "RegistrationResourceId")]
116 pub resource_id: Option<String>,
117 #[serde(rename = "ServiceVersion")]
118 pub service_version: Option<String>,
119}
120
121#[derive(Debug, Deserialize)]
122pub struct OAuth2 {
123 #[serde(rename = "AuthCodeEndpoint")]
124 pub auth_code_endpoint: Option<String>,
125 #[serde(rename = "TokenEndpoint")]
126 pub token_endpoint: Option<String>,
127}
128
129#[derive(Debug, Deserialize)]
130pub struct AuthenticationService {
131 #[serde(rename = "OAuth2")]
132 pub oauth2: Option<OAuth2>,
133}
134
135#[derive(Debug, Deserialize)]
136pub struct IdentityProviderService {
137 #[serde(rename = "Federated")]
138 pub federated: Option<bool>,
139 #[serde(rename = "PassiveAuthEndpoint")]
140 pub passive_auth_endpoint: Option<String>,
141}
142
143#[derive(Debug, Deserialize)]
144pub struct DeviceJoinService {
145 #[serde(rename = "JoinEndpoint")]
146 pub endpoint: Option<String>,
147 #[serde(rename = "JoinResourceId")]
148 pub resource_id: Option<String>,
149 #[serde(rename = "ServiceVersion")]
150 pub service_version: Option<String>,
151}
152
153#[derive(Debug, Deserialize)]
154pub struct KeyProvisioningService {
155 #[serde(rename = "KeyProvisionEndpoint")]
156 pub endpoint: Option<String>,
157 #[serde(rename = "KeyProvisionResourceId")]
158 pub resource_id: Option<String>,
159 #[serde(rename = "ServiceVersion")]
160 pub service_version: Option<String>,
161}
162
163#[derive(Debug, Deserialize)]
164pub struct WebAuthNService {
165 #[serde(rename = "WebAuthNEndpoint")]
166 pub endpoint: Option<String>,
167 #[serde(rename = "WebAuthNResourceId")]
168 pub resource_id: Option<String>,
169 #[serde(rename = "ServiceVersion")]
170 pub service_version: Option<String>,
171}
172
173#[derive(Debug, Deserialize)]
174pub struct DeviceManagementService {
175 #[serde(rename = "DeviceManagementEndpoint")]
176 pub endpoint: Option<String>,
177 #[serde(rename = "DeviceManagementResourceId")]
178 pub resource_id: Option<String>,
179 #[serde(rename = "ServiceVersion")]
180 pub service_version: Option<String>,
181}
182
183#[derive(Debug, Deserialize)]
184pub struct MsaProviderData {
185 #[serde(rename = "SiteId")]
186 pub site_id: Option<String>,
187 #[serde(rename = "SiteUrl")]
188 pub site_url: Option<String>,
189}
190
191#[derive(Debug, Deserialize)]
192pub struct PrecreateService {
193 #[serde(rename = "PrecreateEndpoint")]
194 pub endpoint: Option<String>,
195 #[serde(rename = "PrecreateResourceId")]
196 pub resource_id: Option<String>,
197 #[serde(rename = "ServiceVersion")]
198 pub service_version: Option<String>,
199}
200
201#[derive(Debug, Deserialize)]
202pub struct TenantInfo {
203 #[serde(rename = "TenantId")]
204 pub tenant_id: Option<String>,
205 #[serde(rename = "TenantName")]
206 pub tenant_name: Option<String>,
207 #[serde(rename = "DisplayName")]
208 pub display_name: Option<String>,
209}
210
211#[derive(Debug, Deserialize)]
212pub struct AzureRbacService {
213 #[serde(rename = "RbacPolicyEndpoint")]
214 pub endpoint: Option<String>,
215}
216
217#[derive(Debug, Deserialize)]
218pub struct BPLService {
219 #[serde(rename = "BPLServiceEndpoint")]
220 pub endpoint: Option<String>,
221 #[serde(rename = "BPLResourceId")]
222 pub resource_id: Option<String>,
223 #[serde(rename = "ServiceVersion")]
224 pub service_version: Option<String>,
225 #[serde(rename = "BPLProxyServicePrincipalId")]
226 pub service_principal_id: Option<String>,
227}
228
229#[derive(Debug, Deserialize)]
230pub struct DeviceJoinResourceService {
231 #[serde(rename = "Endpoint")]
232 pub endpoint: Option<String>,
233 #[serde(rename = "ResourceId")]
234 pub resource_id: Option<String>,
235 #[serde(rename = "ServiceVersion")]
236 pub service_version: Option<String>,
237}
238
239#[derive(Debug, Deserialize)]
240pub struct NonceService {
241 #[serde(rename = "Endpoint")]
242 pub endpoint: Option<String>,
243 #[serde(rename = "ResourceId")]
244 pub resource_id: Option<String>,
245 #[serde(rename = "ServiceVersion")]
246 pub service_version: Option<String>,
247}
248
249#[derive(Debug, Deserialize)]
250struct NonceResp {
251 #[serde(rename = "Value")]
252 value: String,
253}
254
255fn get_manufacturer() -> Option<String> {
256 let path = "/sys/class/dmi/id/sys_vendor";
257
258 let mut file = fs::File::open(path).ok()?;
259 let mut manufacturer = String::new();
260 file.read_to_string(&mut manufacturer).ok()?;
261
262 Some(manufacturer.trim().to_string())
263}
264
265#[cfg(feature = "broker")]
266#[derive(Clone, Serialize, Deserialize)]
267pub struct EnrollAttrs {
268 pub(crate) device_display_name: String,
269 pub(crate) device_type: String,
270 join_type: u32,
271 pub(crate) os_version: String,
272 pub(crate) target_domain: String,
273 pub(crate) os_distribution: String,
274 pub(crate) manufacturer: String,
275}
276
277#[cfg(feature = "broker")]
278impl EnrollAttrs {
279 pub fn new(
304 target_domain: String,
305 device_display_name: Option<String>,
306 device_type: Option<String>,
307 join_type: Option<u32>,
308 os_version: Option<String>,
309 ) -> Result<Self, MsalError> {
310 let os_release =
311 OsRelease::new().map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?;
312 let os_distribution = os_release.name;
313
314 let device_display_name_int = match device_display_name {
315 Some(device_display_name) => device_display_name,
316 None => match hostname::get()
317 .map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?
318 .to_str()
319 {
320 Some(host) => String::from(host),
321 None => {
322 return Err(MsalError::GeneralFailure(
323 "Failed to get machine hostname for enrollment".to_string(),
324 ))
325 }
326 },
327 };
328 let device_type_int = match device_type {
329 Some(device_type) => device_type,
330 None => "Linux".to_string(),
331 };
332 let join_type_int = join_type.unwrap_or(0);
333 let os_version_int = match os_version {
334 Some(os_version) => os_version,
335 None => {
336 format!("{} {}", os_release.pretty_name, os_release.version_id)
337 }
338 };
339 Ok(EnrollAttrs {
340 device_display_name: device_display_name_int,
341 device_type: device_type_int,
342 join_type: join_type_int,
343 os_version: os_version_int,
344 target_domain,
345 os_distribution,
346 manufacturer: get_manufacturer().unwrap_or_default(),
347 })
348 }
349}
350
351#[derive(Debug, Deserialize)]
352pub struct Services {
353 #[serde(skip_deserializing)]
354 client: Client,
355 #[serde(rename = "ServicesService")]
356 pub discovery_service: Option<ServicesService>,
357 #[serde(rename = "DeviceRegistrationService")]
358 pub device_registration_service: Option<DeviceRegistrationService>,
359 #[serde(rename = "AuthenticationService")]
360 pub authentication_service: Option<AuthenticationService>,
361 #[serde(rename = "IdentityProviderService")]
362 pub identity_provider_service: Option<IdentityProviderService>,
363 #[serde(rename = "DeviceJoinService")]
364 pub device_join_service: Option<DeviceJoinService>,
365 #[serde(rename = "KeyProvisioningService")]
366 pub key_provisioning_service: Option<KeyProvisioningService>,
367 #[serde(rename = "WebAuthNService")]
368 pub web_auth_n_service: Option<WebAuthNService>,
369 #[serde(rename = "DeviceManagementService")]
370 pub device_management_service: Option<DeviceManagementService>,
371 #[serde(rename = "MsaProviderData")]
372 pub msa_provider_data: Option<MsaProviderData>,
373 #[serde(rename = "PrecreateService")]
374 pub precreate_service: Option<PrecreateService>,
375 #[serde(rename = "TenantInfo")]
376 pub tenant_info: Option<TenantInfo>,
377 #[serde(rename = "AzureRbacService")]
378 pub azure_rbac_service: Option<AzureRbacService>,
379 #[serde(rename = "BPLService")]
380 pub bpl_service: Option<BPLService>,
381 #[serde(rename = "DeviceJoinResourceService")]
382 pub device_join_resource_service: Option<DeviceJoinResourceService>,
383 #[serde(rename = "NonceService")]
384 nonce_service: Option<NonceService>,
385}
386
387impl Services {
388 pub async fn new(access_token: &str, domain_name: &str) -> Result<Self, MsalError> {
389 let discovery_url = if cfg!(feature = "custom_oidc_discovery_url") {
390 std::env::var("HIMMELBLAU_DISCOVERY_URL").unwrap_or_else(|_| DISCOVERY_URL.to_string())
391 } else {
392 DISCOVERY_URL.to_string()
393 };
394
395 let url = Url::parse_with_params(
396 &format!("{}/{}/Discover", discovery_url, domain_name),
397 &[("api-version", DRS_PROTOCOL_VERSION), ("managed", "True")],
398 )
399 .map_err(|e| MsalError::URLFormatFailed(format!("{}", e)))?;
400
401 let client = reqwest::Client::new();
402 let resp = client
403 .get(url)
404 .header(header::AUTHORIZATION, format!("Bearer {}", access_token))
405 .header(DRS_CLIENT_NAME_HEADER_FIELD, env!("CARGO_PKG_NAME"))
406 .header(DRS_CLIENT_VERSION_HEADER_FIELD, env!("CARGO_PKG_VERSION"))
407 .header(
408 "User-Agent",
409 format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")),
410 )
411 .header(header::ACCEPT, "application/json, text/plain, */*")
412 .send()
413 .await
414 .map_err(|e| MsalError::RequestFailed(format!("{}", e)))?;
415 if resp.status().is_success() {
416 let mut json_resp: Services = resp
417 .json()
418 .await
419 .map_err(|e| MsalError::InvalidJson(format!("{}", e)))?;
420 json_resp.client = client;
421 Ok(json_resp)
422 } else {
423 Err(MsalError::GeneralFailure(
424 resp.text()
425 .await
426 .map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?,
427 ))
428 }
429 }
430
431 pub async fn request_nonce(
432 &self,
433 tenant_id: &str,
434 access_token: &str,
435 ) -> Result<String, MsalError> {
436 let fallback_endpoint = format!("{}/EnrollmentServer/nonce/{}/", DISCOVERY_URL, tenant_id);
437 let url = match &self.nonce_service {
438 Some(nonce_service) => {
439 let endpoint = match &nonce_service.endpoint {
440 Some(endpoint) => endpoint,
441 None => &fallback_endpoint,
442 };
443 let service_version = match &nonce_service.service_version {
444 Some(service_version) => service_version,
445 None => "1.0",
446 };
447 Url::parse_with_params(endpoint, &[("api-version", &service_version)])
448 .map_err(|e| MsalError::RequestFailed(format!("{:?}", e)))?
449 }
450 None => Url::parse_with_params(&fallback_endpoint, &[("api-version", "1.0")])
451 .map_err(|e| MsalError::RequestFailed(format!("{:?}", e)))?,
452 };
453
454 let client = reqwest::Client::new();
455 let resp = client
456 .get(url)
457 .header(header::AUTHORIZATION, format!("Bearer {}", access_token))
458 .send()
459 .await
460 .map_err(|e| MsalError::RequestFailed(format!("{:?}", e)))?;
461 if resp.status().is_success() {
462 let json_resp: NonceResp = resp
463 .json()
464 .await
465 .map_err(|e| MsalError::InvalidJson(format!("{:?}", e)))?;
466 Ok(json_resp.value)
467 } else {
468 Err(MsalError::RequestFailed(format!("{}", resp.status())))
469 }
470 }
471
472 #[cfg(feature = "broker")]
473 pub async fn enroll_device(
474 &self,
475 access_token: &str,
476 attrs: EnrollAttrs,
477 transport_key: &Rsa<Public>,
478 csr_der: &Vec<u8>,
479 ) -> Result<(X509, String), MsalError> {
480 let fallback_endpoint = format!("{}/EnrollmentServer/device/", DISCOVERY_URL);
481 let (join_endpoint, service_version) = match &self.device_join_service {
482 Some(device_join_service) => {
483 let join_endpoint = match &device_join_service.endpoint {
484 Some(join_endpoint) => join_endpoint,
485 None => &fallback_endpoint,
486 };
487 let service_version = match &device_join_service.service_version {
488 Some(service_version) => service_version,
489 None => "2.0",
490 };
491 (join_endpoint, service_version)
492 }
493 None => (&fallback_endpoint, "2.0"),
494 };
495
496 let url = Url::parse_with_params(join_endpoint, &[("api-version", service_version)])
497 .map_err(|e| MsalError::URLFormatFailed(format!("{}", e)))?;
498
499 let transport_key_blob: Vec<u8> = BcryptRsaKeyBlob::new(
500 2048,
501 &transport_key.e().to_vec(),
502 &transport_key.n().to_vec(),
503 )
504 .try_into()?;
505
506 let payload = json!({
507 "CertificateRequest": {
508 "Type": "pkcs10",
509 "Data": STANDARD.encode(csr_der)
510 },
511 "DeviceDisplayName": attrs.device_display_name,
512 "DeviceType": attrs.device_type,
513 "JoinType": attrs.join_type,
514 "OSVersion": attrs.os_version,
515 "TargetDomain": attrs.target_domain,
516 "TransportKey": STANDARD.encode(transport_key_blob),
517 "Attributes": {
518 "ReuseDevice": "true",
519 "ReturnClientSid": "true"
520 }
521 });
522 if let Ok(pretty) = to_string_pretty(&payload) {
523 debug!("POST {}: {}", url, pretty);
524 }
525 let resp = self
526 .client
527 .post(url)
528 .header(header::AUTHORIZATION, format!("Bearer {}", access_token))
529 .header(header::CONTENT_TYPE, "application/json")
530 .header(DRS_CLIENT_NAME_HEADER_FIELD, env!("CARGO_PKG_NAME"))
531 .header(DRS_CLIENT_VERSION_HEADER_FIELD, env!("CARGO_PKG_VERSION"))
532 .header(header::ACCEPT, "application/json, text/plain, */*")
533 .json(&payload)
534 .send()
535 .await
536 .map_err(|e| MsalError::RequestFailed(format!("{}", e)))?;
537 if resp.status().is_success() {
538 let res: DRSResponse = resp
539 .json()
540 .await
541 .map_err(|e| MsalError::InvalidJson(format!("{}", e)))?;
542 let cert = X509::from_pem(
543 format!(
544 "-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----",
545 res.certificate.raw_body
546 )
547 .as_bytes(),
548 )
549 .map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?;
550 let subject_name = cert.subject_name();
551 let device_id = match subject_name.entries().next() {
552 Some(entry) => entry
553 .data()
554 .as_utf8()
555 .map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?,
556 None => {
557 return Err(MsalError::GeneralFailure(
558 "The device id was missing from the certificate response".to_string(),
559 ))
560 }
561 };
562 Ok((cert, device_id.to_string()))
563 } else {
564 Err(MsalError::GeneralFailure(
565 resp.text()
566 .await
567 .map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?,
568 ))
569 }
570 }
571
572 pub fn key_provisioning_resource_id(&self) -> String {
573 match &self.key_provisioning_service {
574 Some(key_provisioning_service) => match &key_provisioning_service.resource_id {
575 Some(resource_id) => resource_id.clone(),
576 None => "urn:ms-drs:enterpriseregistration.windows.net".to_string(),
577 },
578 None => "urn:ms-drs:enterpriseregistration.windows.net".to_string(),
579 }
580 }
581
582 pub async fn provision_key(
583 &self,
584 access_token: &str,
585 pub_key: &Rsa<Public>,
586 ) -> Result<(), MsalError> {
587 let fallback_endpoint = format!("{}/EnrollmentServer/key/", DISCOVERY_URL);
588 let (endpoint, service_version) = match &self.key_provisioning_service {
589 Some(key_provisioning_service) => {
590 let endpoint = match &key_provisioning_service.endpoint {
591 Some(endpoint) => endpoint,
592 None => &fallback_endpoint,
593 };
594 let service_version = match &key_provisioning_service.service_version {
595 Some(service_version) => service_version,
596 None => "1.0",
597 };
598 (endpoint, service_version)
599 }
600 None => (&fallback_endpoint, "1.0"),
601 };
602
603 let key_blob: Vec<u8> =
604 BcryptRsaKeyBlob::new(2048, &pub_key.e().to_vec(), &pub_key.n().to_vec()).try_into()?;
605
606 let payload = json!({
609 "kngc": STANDARD.encode(key_blob),
610 });
611 let url = Url::parse_with_params(endpoint, &[("api-version", service_version)])
612 .map_err(|e| MsalError::URLFormatFailed(format!("{}", e)))?;
613
614 debug!("POST {}: {{ \"kngc\": <PUBLIC KEY> }}", url);
615
616 let resp = self
617 .client
618 .post(url)
619 .header(header::AUTHORIZATION, format!("Bearer {}", access_token))
620 .header(header::CONTENT_TYPE, "application/json")
621 .header(
622 header::USER_AGENT,
623 format!("Dsreg/10.0 ({})", env!("CARGO_PKG_NAME")),
624 )
625 .header(header::ACCEPT, "application/json")
626 .json(&payload)
627 .send()
628 .await
629 .map_err(|e| MsalError::RequestFailed(format!("{}", e)))?;
630 if resp.status().is_success() {
631 Ok(())
632 } else {
633 Err(MsalError::GeneralFailure(
634 "Failed registering Key".to_string(),
635 ))
636 }
637 }
638}