msal_cert/
lib.rs

1use std::collections::HashMap;
2use std::error::Error;
3
4use base64::Engine;
5use base64::prelude::BASE64_STANDARD;
6use jwt::algorithm::openssl::PKeyWithDigest;
7use jwt::SigningAlgorithm;
8use openssl::hash::MessageDigest;
9use openssl::pkey::PKey;
10
11use token::Header;
12
13use crate::token::{AccessTokenResponse, aud, Payload};
14
15mod token;
16
17pub async fn acquire_token(tenant_id: String, client_id: String, scope: String, private_key_pem: &Vec<u8>, public_key_pem: &Vec<u8>) -> Result<AccessTokenResponse, Box<dyn Error>> {
18    let algorithm = PKeyWithDigest {
19        digest: MessageDigest::sha256(),
20        key: PKey::private_key_from_pem(&private_key_pem)?,
21    };
22
23    let header = Header::new(&public_key_pem)?;
24    let payload = Payload::new(tenant_id.to_owned(), client_id.to_string());
25    let header_json = serde_json::json!(header);
26    let payload_json = serde_json::json!(payload);
27
28    let header_base64 = BASE64_STANDARD.encode(header_json.to_string());
29    let payload_base64 = BASE64_STANDARD.encode(payload_json.to_string());
30    let result = algorithm.sign(&header_base64, &payload_base64).unwrap();
31    let client_assertion = format!("{}.{}.{}", header_base64, payload_base64, result);
32
33    let client = reqwest::Client::new();
34    let mut params = HashMap::new();
35    params.insert("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
36    params.insert("grant_type", "client_credentials");
37    let all_scope = format!("openid profile offline_access {}", scope);
38    params.insert("scope", &all_scope);
39    params.insert("client_assertion", &client_assertion);
40    params.insert("client_id", &client_id);
41
42    let res = client.post(aud(tenant_id.to_owned()))
43        .form(&params)
44        .send()
45        .await?;
46
47    let body_text = match res.text().await {
48        Ok(text) => text,
49        Err(e) => {
50            return Err(e.into());
51        }
52    };
53
54    let x: Result<AccessTokenResponse, _> = serde_json::from_str(&body_text);
55
56    let ret = match x {
57        Ok(token_response) => {
58            Ok(token_response)
59        }
60        Err(e) => {
61            println!("Error while parsing JSON: {}. Response text: {}", e, body_text);
62            Err(e.into())
63        }
64    };
65    ret
66}
67
68#[cfg(test)]
69mod tests {
70    use tokio;
71    use super::*;
72
73    #[tokio::test]
74    #[ignore]
75    async fn test_acquire_token() -> Result<(), Box<dyn Error>> {
76        // Setup test data
77        let tenant_id = "72f988bf-86f1-41af-91ab-2d7cd011db47".to_string();
78        let client_id = "064b969a-ed15-42fa-9044-f08081163a67".to_string();
79        let scope = "https://graph.microsoft.com/.default".to_string();
80        let private_key_pem = include_bytes!("../keys/private_key.pem").to_vec(); // Update with the correct path to your private key
81        let public_key_pem = include_bytes!("../keys/public_key.pem").to_vec(); // Update with the correct path to your public key
82
83        // Call the acquire_token function
84        let token_response = acquire_token(tenant_id, client_id, scope, &private_key_pem, &public_key_pem).await?;
85
86        // Validate the response
87        assert_eq!(token_response.token_type, "Bearer");
88        assert!(token_response.expires_in > 0);
89        assert!(token_response.access_token.len() > 0);
90        Ok(())
91    }
92}