1use std::collections::HashMap;
2use std::error::Error;
3
4use base64::Engine;
5use base64::prelude::BASE64_STANDARD;
6use jwt::algorithm::openssl::PKeyWithDigest;
7use jwt::SigningAlgorithm;
8use openssl::hash::MessageDigest;
9use openssl::pkey::PKey;
10
11use token::Header;
12
13use crate::token::{AccessTokenResponse, aud, Payload};
14
15mod token;
16
17pub async fn acquire_token(tenant_id: String, client_id: String, scope: String, private_key_pem: &Vec<u8>, public_key_pem: &Vec<u8>) -> Result<AccessTokenResponse, Box<dyn Error>> {
18 let algorithm = PKeyWithDigest {
19 digest: MessageDigest::sha256(),
20 key: PKey::private_key_from_pem(&private_key_pem)?,
21 };
22
23 let header = Header::new(&public_key_pem)?;
24 let payload = Payload::new(tenant_id.to_owned(), client_id.to_string());
25 let header_json = serde_json::json!(header);
26 let payload_json = serde_json::json!(payload);
27
28 let header_base64 = BASE64_STANDARD.encode(header_json.to_string());
29 let payload_base64 = BASE64_STANDARD.encode(payload_json.to_string());
30 let result = algorithm.sign(&header_base64, &payload_base64).unwrap();
31 let client_assertion = format!("{}.{}.{}", header_base64, payload_base64, result);
32
33 let client = reqwest::Client::new();
34 let mut params = HashMap::new();
35 params.insert("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
36 params.insert("grant_type", "client_credentials");
37 let all_scope = format!("openid profile offline_access {}", scope);
38 params.insert("scope", &all_scope);
39 params.insert("client_assertion", &client_assertion);
40 params.insert("client_id", &client_id);
41
42 let res = client.post(aud(tenant_id.to_owned()))
43 .form(¶ms)
44 .send()
45 .await?;
46
47 let body_text = match res.text().await {
48 Ok(text) => text,
49 Err(e) => {
50 return Err(e.into());
51 }
52 };
53
54 let x: Result<AccessTokenResponse, _> = serde_json::from_str(&body_text);
55
56 let ret = match x {
57 Ok(token_response) => {
58 Ok(token_response)
59 }
60 Err(e) => {
61 println!("Error while parsing JSON: {}. Response text: {}", e, body_text);
62 Err(e.into())
63 }
64 };
65 ret
66}
67
68#[cfg(test)]
69mod tests {
70 use tokio;
71 use super::*;
72
73 #[tokio::test]
74 #[ignore]
75 async fn test_acquire_token() -> Result<(), Box<dyn Error>> {
76 let tenant_id = "72f988bf-86f1-41af-91ab-2d7cd011db47".to_string();
78 let client_id = "064b969a-ed15-42fa-9044-f08081163a67".to_string();
79 let scope = "https://graph.microsoft.com/.default".to_string();
80 let private_key_pem = include_bytes!("../keys/private_key.pem").to_vec(); let public_key_pem = include_bytes!("../keys/public_key.pem").to_vec(); let token_response = acquire_token(tenant_id, client_id, scope, &private_key_pem, &public_key_pem).await?;
85
86 assert_eq!(token_response.token_type, "Bearer");
88 assert!(token_response.expires_in > 0);
89 assert!(token_response.access_token.len() > 0);
90 Ok(())
91 }
92}