Skip to main content

himmelblau/
discovery.rs

1/*
2   Unix Azure Entra ID implementation
3   Copyright (C) David Mulder <dmulder@samba.org> 2024
4
5   This program is free software: you can redistribute it and/or modify
6   it under the terms of the GNU Lesser General Public License as published by
7   the Free Software Foundation, either version 3 of the License, or
8   (at your option) any later version.
9
10   This program is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13   GNU Lesser General Public License for more details.
14
15   You should have received a copy of the GNU Lesser General Public License
16   along with this program. If not, see <https://www.gnu.org/licenses/>.
17*/
18
19use std::fs;
20use std::io::Read;
21
22use crate::error::MsalError;
23use base64::engine::general_purpose::STANDARD;
24use base64::Engine;
25use openssl::pkey::Public;
26use openssl::rsa::Rsa;
27use openssl::x509::X509;
28use os_release::OsRelease;
29use reqwest::{header, Client, Url};
30use serde::{Deserialize, Serialize};
31use serde_json::json;
32use serde_json::to_string_pretty;
33use tracing::debug;
34use zeroize::{Zeroize, ZeroizeOnDrop};
35
36pub const DRS_CLIENT_NAME_HEADER_FIELD: &str = "ocp-adrs-client-name";
37pub const DRS_CLIENT_VERSION_HEADER_FIELD: &str = "ocp-adrs-client-version";
38pub const DISCOVERY_URL: &str = "https://enterpriseregistration.windows.net";
39const DRS_PROTOCOL_VERSION: &str = "1.9";
40
41#[cfg(feature = "broker")]
42#[derive(Debug, Deserialize, Zeroize, ZeroizeOnDrop)]
43struct Certificate {
44    #[serde(rename = "RawBody")]
45    raw_body: String,
46}
47
48#[cfg(feature = "broker")]
49#[derive(Debug, Deserialize, Zeroize, ZeroizeOnDrop)]
50struct DRSResponse {
51    #[serde(rename = "Certificate")]
52    certificate: Certificate,
53}
54
55#[cfg(feature = "broker")]
56#[derive(Zeroize, ZeroizeOnDrop)]
57pub(crate) struct BcryptRsaKeyBlob {
58    bit_length: u32,
59    exponent: Vec<u8>,
60    modulus: Vec<u8>,
61}
62
63#[cfg(feature = "broker")]
64impl BcryptRsaKeyBlob {
65    pub(crate) fn new(bit_length: u32, exponent: &[u8], modulus: &[u8]) -> Self {
66        BcryptRsaKeyBlob {
67            bit_length,
68            exponent: exponent.to_vec(),
69            modulus: modulus.to_vec(),
70        }
71    }
72}
73
74#[cfg(feature = "broker")]
75impl TryInto<Vec<u8>> for BcryptRsaKeyBlob {
76    type Error = MsalError;
77
78    fn try_into(self) -> Result<Vec<u8>, Self::Error> {
79        let mut cng_blob = b"RSA1".to_vec(); // Magic
80        cng_blob.extend_from_slice(&self.bit_length.to_le_bytes()); // BitLength
81        let exponent_len: u32 = self.exponent.len().try_into().map_err(|e| {
82            MsalError::GeneralFailure(format!("Exponent len into u32 failed: {:?}", e))
83        })?;
84        cng_blob.extend_from_slice(&exponent_len.to_le_bytes()); // cbPublicExpLength
85        let modulus_len: u32 = self.modulus.len().try_into().map_err(|e| {
86            MsalError::GeneralFailure(format!("Modulus len into u32 failed: {:?}", e))
87        })?;
88        cng_blob.extend_from_slice(&modulus_len.to_le_bytes()); // cbModulusLength
89
90        // MS reserves spots for P and Q lengths, but doesn't permit P and Q in
91        // the blob itself. Requests will be rejected if P and Q are specified.
92        let prime1_len: u32 = 0;
93        cng_blob.extend_from_slice(&prime1_len.to_le_bytes()); // cbPrime1Length
94        let prime2_len: u32 = 0;
95        cng_blob.extend_from_slice(&prime2_len.to_le_bytes()); // cbPrime2Length
96
97        cng_blob.extend_from_slice(self.exponent.as_slice()); // cbPublicExp
98        cng_blob.extend_from_slice(self.modulus.as_slice()); // cbModulus
99        Ok(cng_blob)
100    }
101}
102
103#[derive(Debug, Deserialize)]
104pub struct ServicesService {
105    #[serde(rename = "ServicesEndpoint")]
106    pub endpoint: Option<String>,
107    #[serde(rename = "ServiceVersion")]
108    pub service_version: Option<String>,
109}
110
111#[derive(Debug, Deserialize)]
112pub struct DeviceRegistrationService {
113    #[serde(rename = "RegistrationEndpoint")]
114    pub endpoint: Option<String>,
115    #[serde(rename = "RegistrationResourceId")]
116    pub resource_id: Option<String>,
117    #[serde(rename = "ServiceVersion")]
118    pub service_version: Option<String>,
119}
120
121#[derive(Debug, Deserialize)]
122pub struct OAuth2 {
123    #[serde(rename = "AuthCodeEndpoint")]
124    pub auth_code_endpoint: Option<String>,
125    #[serde(rename = "TokenEndpoint")]
126    pub token_endpoint: Option<String>,
127}
128
129#[derive(Debug, Deserialize)]
130pub struct AuthenticationService {
131    #[serde(rename = "OAuth2")]
132    pub oauth2: Option<OAuth2>,
133}
134
135#[derive(Debug, Deserialize)]
136pub struct IdentityProviderService {
137    #[serde(rename = "Federated")]
138    pub federated: Option<bool>,
139    #[serde(rename = "PassiveAuthEndpoint")]
140    pub passive_auth_endpoint: Option<String>,
141}
142
143#[derive(Debug, Deserialize)]
144pub struct DeviceJoinService {
145    #[serde(rename = "JoinEndpoint")]
146    pub endpoint: Option<String>,
147    #[serde(rename = "JoinResourceId")]
148    pub resource_id: Option<String>,
149    #[serde(rename = "ServiceVersion")]
150    pub service_version: Option<String>,
151}
152
153#[derive(Debug, Deserialize)]
154pub struct KeyProvisioningService {
155    #[serde(rename = "KeyProvisionEndpoint")]
156    pub endpoint: Option<String>,
157    #[serde(rename = "KeyProvisionResourceId")]
158    pub resource_id: Option<String>,
159    #[serde(rename = "ServiceVersion")]
160    pub service_version: Option<String>,
161}
162
163#[derive(Debug, Deserialize)]
164pub struct WebAuthNService {
165    #[serde(rename = "WebAuthNEndpoint")]
166    pub endpoint: Option<String>,
167    #[serde(rename = "WebAuthNResourceId")]
168    pub resource_id: Option<String>,
169    #[serde(rename = "ServiceVersion")]
170    pub service_version: Option<String>,
171}
172
173#[derive(Debug, Deserialize)]
174pub struct DeviceManagementService {
175    #[serde(rename = "DeviceManagementEndpoint")]
176    pub endpoint: Option<String>,
177    #[serde(rename = "DeviceManagementResourceId")]
178    pub resource_id: Option<String>,
179    #[serde(rename = "ServiceVersion")]
180    pub service_version: Option<String>,
181}
182
183#[derive(Debug, Deserialize)]
184pub struct MsaProviderData {
185    #[serde(rename = "SiteId")]
186    pub site_id: Option<String>,
187    #[serde(rename = "SiteUrl")]
188    pub site_url: Option<String>,
189}
190
191#[derive(Debug, Deserialize)]
192pub struct PrecreateService {
193    #[serde(rename = "PrecreateEndpoint")]
194    pub endpoint: Option<String>,
195    #[serde(rename = "PrecreateResourceId")]
196    pub resource_id: Option<String>,
197    #[serde(rename = "ServiceVersion")]
198    pub service_version: Option<String>,
199}
200
201#[derive(Debug, Deserialize)]
202pub struct TenantInfo {
203    #[serde(rename = "TenantId")]
204    pub tenant_id: Option<String>,
205    #[serde(rename = "TenantName")]
206    pub tenant_name: Option<String>,
207    #[serde(rename = "DisplayName")]
208    pub display_name: Option<String>,
209}
210
211#[derive(Debug, Deserialize)]
212pub struct AzureRbacService {
213    #[serde(rename = "RbacPolicyEndpoint")]
214    pub endpoint: Option<String>,
215}
216
217#[derive(Debug, Deserialize)]
218pub struct BPLService {
219    #[serde(rename = "BPLServiceEndpoint")]
220    pub endpoint: Option<String>,
221    #[serde(rename = "BPLResourceId")]
222    pub resource_id: Option<String>,
223    #[serde(rename = "ServiceVersion")]
224    pub service_version: Option<String>,
225    #[serde(rename = "BPLProxyServicePrincipalId")]
226    pub service_principal_id: Option<String>,
227}
228
229#[derive(Debug, Deserialize)]
230pub struct DeviceJoinResourceService {
231    #[serde(rename = "Endpoint")]
232    pub endpoint: Option<String>,
233    #[serde(rename = "ResourceId")]
234    pub resource_id: Option<String>,
235    #[serde(rename = "ServiceVersion")]
236    pub service_version: Option<String>,
237}
238
239#[derive(Debug, Deserialize)]
240pub struct NonceService {
241    #[serde(rename = "Endpoint")]
242    pub endpoint: Option<String>,
243    #[serde(rename = "ResourceId")]
244    pub resource_id: Option<String>,
245    #[serde(rename = "ServiceVersion")]
246    pub service_version: Option<String>,
247}
248
249#[derive(Debug, Deserialize)]
250struct NonceResp {
251    #[serde(rename = "Value")]
252    value: String,
253}
254
255fn get_manufacturer() -> Option<String> {
256    let path = "/sys/class/dmi/id/sys_vendor";
257
258    let mut file = fs::File::open(path).ok()?;
259    let mut manufacturer = String::new();
260    file.read_to_string(&mut manufacturer).ok()?;
261
262    Some(manufacturer.trim().to_string())
263}
264
265#[cfg(feature = "broker")]
266#[derive(Clone, Serialize, Deserialize)]
267pub struct EnrollAttrs {
268    pub(crate) device_display_name: String,
269    pub(crate) device_type: String,
270    join_type: u32,
271    pub(crate) os_version: String,
272    pub(crate) target_domain: String,
273    pub(crate) os_distribution: String,
274    pub(crate) manufacturer: String,
275}
276
277#[cfg(feature = "broker")]
278impl EnrollAttrs {
279    /// Initialize attributes for device enrollment
280    ///
281    /// # Arguments
282    ///
283    /// * `target_domain` - The domain to be enrolled in.
284    ///
285    /// * `device_display_name` - An optional chosen display name for the
286    ///   enrolled device. Defaults to the system hostname.
287    ///
288    /// * `device_type` - An optional device type. Defaults to 'Linux'. This
289    ///   effects which Intune policies are distributed to the client.
290    ///
291    /// * `join_type` - An optional join type. Defaults to 0. Possible values
292    ///   are:
293    ///     - 0: Azure AD join.
294    ///     - 4: Azure AD register only.
295    ///     - 6: Azure AD hybrid join.
296    ///     - 8: Azure AD join.
297    ///
298    /// * `os_version` - An optional OS version. Defaults to the contents of
299    ///   /etc/os-release.
300    ///
301    /// * Success: A new EnrollAttrs for device enrollment.
302    /// * Failure: An MsalError, indicating the failure.
303    pub fn new(
304        target_domain: String,
305        device_display_name: Option<String>,
306        device_type: Option<String>,
307        join_type: Option<u32>,
308        os_version: Option<String>,
309    ) -> Result<Self, MsalError> {
310        let os_release =
311            OsRelease::new().map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?;
312        let os_distribution = os_release.name;
313
314        let device_display_name_int = match device_display_name {
315            Some(device_display_name) => device_display_name,
316            None => match hostname::get()
317                .map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?
318                .to_str()
319            {
320                Some(host) => String::from(host),
321                None => {
322                    return Err(MsalError::GeneralFailure(
323                        "Failed to get machine hostname for enrollment".to_string(),
324                    ))
325                }
326            },
327        };
328        let device_type_int = match device_type {
329            Some(device_type) => device_type,
330            None => "Linux".to_string(),
331        };
332        let join_type_int = join_type.unwrap_or(0);
333        let os_version_int = match os_version {
334            Some(os_version) => os_version,
335            None => {
336                format!("{} {}", os_release.pretty_name, os_release.version_id)
337            }
338        };
339        Ok(EnrollAttrs {
340            device_display_name: device_display_name_int,
341            device_type: device_type_int,
342            join_type: join_type_int,
343            os_version: os_version_int,
344            target_domain,
345            os_distribution,
346            manufacturer: get_manufacturer().unwrap_or_default(),
347        })
348    }
349}
350
351#[derive(Debug, Deserialize)]
352pub struct Services {
353    #[serde(skip_deserializing)]
354    client: Client,
355    #[serde(rename = "ServicesService")]
356    pub discovery_service: Option<ServicesService>,
357    #[serde(rename = "DeviceRegistrationService")]
358    pub device_registration_service: Option<DeviceRegistrationService>,
359    #[serde(rename = "AuthenticationService")]
360    pub authentication_service: Option<AuthenticationService>,
361    #[serde(rename = "IdentityProviderService")]
362    pub identity_provider_service: Option<IdentityProviderService>,
363    #[serde(rename = "DeviceJoinService")]
364    pub device_join_service: Option<DeviceJoinService>,
365    #[serde(rename = "KeyProvisioningService")]
366    pub key_provisioning_service: Option<KeyProvisioningService>,
367    #[serde(rename = "WebAuthNService")]
368    pub web_auth_n_service: Option<WebAuthNService>,
369    #[serde(rename = "DeviceManagementService")]
370    pub device_management_service: Option<DeviceManagementService>,
371    #[serde(rename = "MsaProviderData")]
372    pub msa_provider_data: Option<MsaProviderData>,
373    #[serde(rename = "PrecreateService")]
374    pub precreate_service: Option<PrecreateService>,
375    #[serde(rename = "TenantInfo")]
376    pub tenant_info: Option<TenantInfo>,
377    #[serde(rename = "AzureRbacService")]
378    pub azure_rbac_service: Option<AzureRbacService>,
379    #[serde(rename = "BPLService")]
380    pub bpl_service: Option<BPLService>,
381    #[serde(rename = "DeviceJoinResourceService")]
382    pub device_join_resource_service: Option<DeviceJoinResourceService>,
383    #[serde(rename = "NonceService")]
384    nonce_service: Option<NonceService>,
385}
386
387impl Services {
388    pub async fn new(access_token: &str, domain_name: &str) -> Result<Self, MsalError> {
389        let discovery_url = if cfg!(feature = "custom_oidc_discovery_url") {
390            std::env::var("HIMMELBLAU_DISCOVERY_URL").unwrap_or_else(|_| DISCOVERY_URL.to_string())
391        } else {
392            DISCOVERY_URL.to_string()
393        };
394
395        let url = Url::parse_with_params(
396            &format!("{}/{}/Discover", discovery_url, domain_name),
397            &[("api-version", DRS_PROTOCOL_VERSION), ("managed", "True")],
398        )
399        .map_err(|e| MsalError::URLFormatFailed(format!("{}", e)))?;
400
401        let client = reqwest::Client::new();
402        let resp = client
403            .get(url)
404            .header(header::AUTHORIZATION, format!("Bearer {}", access_token))
405            .header(DRS_CLIENT_NAME_HEADER_FIELD, env!("CARGO_PKG_NAME"))
406            .header(DRS_CLIENT_VERSION_HEADER_FIELD, env!("CARGO_PKG_VERSION"))
407            .header(
408                "User-Agent",
409                format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")),
410            )
411            .header(header::ACCEPT, "application/json, text/plain, */*")
412            .send()
413            .await
414            .map_err(|e| MsalError::RequestFailed(format!("{}", e)))?;
415        if resp.status().is_success() {
416            let mut json_resp: Services = resp
417                .json()
418                .await
419                .map_err(|e| MsalError::InvalidJson(format!("{}", e)))?;
420            json_resp.client = client;
421            Ok(json_resp)
422        } else {
423            Err(MsalError::GeneralFailure(
424                resp.text()
425                    .await
426                    .map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?,
427            ))
428        }
429    }
430
431    pub async fn request_nonce(
432        &self,
433        tenant_id: &str,
434        access_token: &str,
435    ) -> Result<String, MsalError> {
436        let fallback_endpoint = format!("{}/EnrollmentServer/nonce/{}/", DISCOVERY_URL, tenant_id);
437        let url = match &self.nonce_service {
438            Some(nonce_service) => {
439                let endpoint = match &nonce_service.endpoint {
440                    Some(endpoint) => endpoint,
441                    None => &fallback_endpoint,
442                };
443                let service_version = match &nonce_service.service_version {
444                    Some(service_version) => service_version,
445                    None => "1.0",
446                };
447                Url::parse_with_params(endpoint, &[("api-version", &service_version)])
448                    .map_err(|e| MsalError::RequestFailed(format!("{:?}", e)))?
449            }
450            None => Url::parse_with_params(&fallback_endpoint, &[("api-version", "1.0")])
451                .map_err(|e| MsalError::RequestFailed(format!("{:?}", e)))?,
452        };
453
454        let client = reqwest::Client::new();
455        let resp = client
456            .get(url)
457            .header(header::AUTHORIZATION, format!("Bearer {}", access_token))
458            .send()
459            .await
460            .map_err(|e| MsalError::RequestFailed(format!("{:?}", e)))?;
461        if resp.status().is_success() {
462            let json_resp: NonceResp = resp
463                .json()
464                .await
465                .map_err(|e| MsalError::InvalidJson(format!("{:?}", e)))?;
466            Ok(json_resp.value)
467        } else {
468            Err(MsalError::RequestFailed(format!("{}", resp.status())))
469        }
470    }
471
472    #[cfg(feature = "broker")]
473    pub async fn enroll_device(
474        &self,
475        access_token: &str,
476        attrs: EnrollAttrs,
477        transport_key: &Rsa<Public>,
478        csr_der: &Vec<u8>,
479    ) -> Result<(X509, String), MsalError> {
480        let fallback_endpoint = format!("{}/EnrollmentServer/device/", DISCOVERY_URL);
481        let (join_endpoint, service_version) = match &self.device_join_service {
482            Some(device_join_service) => {
483                let join_endpoint = match &device_join_service.endpoint {
484                    Some(join_endpoint) => join_endpoint,
485                    None => &fallback_endpoint,
486                };
487                let service_version = match &device_join_service.service_version {
488                    Some(service_version) => service_version,
489                    None => "2.0",
490                };
491                (join_endpoint, service_version)
492            }
493            None => (&fallback_endpoint, "2.0"),
494        };
495
496        let url = Url::parse_with_params(join_endpoint, &[("api-version", service_version)])
497            .map_err(|e| MsalError::URLFormatFailed(format!("{}", e)))?;
498
499        let transport_key_blob: Vec<u8> = BcryptRsaKeyBlob::new(
500            2048,
501            &transport_key.e().to_vec(),
502            &transport_key.n().to_vec(),
503        )
504        .try_into()?;
505
506        let payload = json!({
507            "CertificateRequest": {
508                "Type": "pkcs10",
509                "Data": STANDARD.encode(csr_der)
510            },
511            "DeviceDisplayName": attrs.device_display_name,
512            "DeviceType": attrs.device_type,
513            "JoinType": attrs.join_type,
514            "OSVersion": attrs.os_version,
515            "TargetDomain": attrs.target_domain,
516            "TransportKey": STANDARD.encode(transport_key_blob),
517            "Attributes": {
518                "ReuseDevice": "true",
519                "ReturnClientSid": "true"
520            }
521        });
522        if let Ok(pretty) = to_string_pretty(&payload) {
523            debug!("POST {}: {}", url, pretty);
524        }
525        let resp = self
526            .client
527            .post(url)
528            .header(header::AUTHORIZATION, format!("Bearer {}", access_token))
529            .header(header::CONTENT_TYPE, "application/json")
530            .header(DRS_CLIENT_NAME_HEADER_FIELD, env!("CARGO_PKG_NAME"))
531            .header(DRS_CLIENT_VERSION_HEADER_FIELD, env!("CARGO_PKG_VERSION"))
532            .header(header::ACCEPT, "application/json, text/plain, */*")
533            .json(&payload)
534            .send()
535            .await
536            .map_err(|e| MsalError::RequestFailed(format!("{}", e)))?;
537        if resp.status().is_success() {
538            let res: DRSResponse = resp
539                .json()
540                .await
541                .map_err(|e| MsalError::InvalidJson(format!("{}", e)))?;
542            let cert = X509::from_pem(
543                format!(
544                    "-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----",
545                    res.certificate.raw_body
546                )
547                .as_bytes(),
548            )
549            .map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?;
550            let subject_name = cert.subject_name();
551            let device_id = match subject_name.entries().next() {
552                Some(entry) => entry
553                    .data()
554                    .as_utf8()
555                    .map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?,
556                None => {
557                    return Err(MsalError::GeneralFailure(
558                        "The device id was missing from the certificate response".to_string(),
559                    ))
560                }
561            };
562            Ok((cert, device_id.to_string()))
563        } else {
564            Err(MsalError::GeneralFailure(
565                resp.text()
566                    .await
567                    .map_err(|e| MsalError::GeneralFailure(format!("{}", e)))?,
568            ))
569        }
570    }
571
572    pub fn key_provisioning_resource_id(&self) -> String {
573        match &self.key_provisioning_service {
574            Some(key_provisioning_service) => match &key_provisioning_service.resource_id {
575                Some(resource_id) => resource_id.clone(),
576                None => "urn:ms-drs:enterpriseregistration.windows.net".to_string(),
577            },
578            None => "urn:ms-drs:enterpriseregistration.windows.net".to_string(),
579        }
580    }
581
582    pub async fn provision_key(
583        &self,
584        access_token: &str,
585        pub_key: &Rsa<Public>,
586    ) -> Result<(), MsalError> {
587        let fallback_endpoint = format!("{}/EnrollmentServer/key/", DISCOVERY_URL);
588        let (endpoint, service_version) = match &self.key_provisioning_service {
589            Some(key_provisioning_service) => {
590                let endpoint = match &key_provisioning_service.endpoint {
591                    Some(endpoint) => endpoint,
592                    None => &fallback_endpoint,
593                };
594                let service_version = match &key_provisioning_service.service_version {
595                    Some(service_version) => service_version,
596                    None => "1.0",
597                };
598                (endpoint, service_version)
599            }
600            None => (&fallback_endpoint, "1.0"),
601        };
602
603        let key_blob: Vec<u8> =
604            BcryptRsaKeyBlob::new(2048, &pub_key.e().to_vec(), &pub_key.n().to_vec()).try_into()?;
605
606        // [MS-KPP] 3.1.5.1.1.1 Request Body
607        // Register the public key
608        let payload = json!({
609            "kngc": STANDARD.encode(key_blob),
610        });
611        let url = Url::parse_with_params(endpoint, &[("api-version", service_version)])
612            .map_err(|e| MsalError::URLFormatFailed(format!("{}", e)))?;
613
614        debug!("POST {}: {{ \"kngc\": <PUBLIC KEY> }}", url);
615
616        let resp = self
617            .client
618            .post(url)
619            .header(header::AUTHORIZATION, format!("Bearer {}", access_token))
620            .header(header::CONTENT_TYPE, "application/json")
621            .header(
622                header::USER_AGENT,
623                format!("Dsreg/10.0 ({})", env!("CARGO_PKG_NAME")),
624            )
625            .header(header::ACCEPT, "application/json")
626            .json(&payload)
627            .send()
628            .await
629            .map_err(|e| MsalError::RequestFailed(format!("{}", e)))?;
630        if resp.status().is_success() {
631            Ok(())
632        } else {
633            Err(MsalError::GeneralFailure(
634                "Failed registering Key".to_string(),
635            ))
636        }
637    }
638}