1use super::{OidcError, OidcTokenProvider};
2
3pub struct AzureProvider;
4
5fn required_env(name: &str) -> Result<String, OidcError> {
6 std::env::var(name).map_err(|_| OidcError::MissingEnv(name.to_string()))
7}
8
9#[derive(serde::Deserialize)]
10struct AzureTokenResponse {
11 #[serde(rename = "oidcToken")]
12 oidc_token: String,
13}
14
15#[async_trait::async_trait]
16impl OidcTokenProvider for AzureProvider {
17 async fn fetch_token(&self, _audience: &str) -> Result<String, OidcError> {
18 let request_uri = required_env("SYSTEM_OIDCREQUESTURI")?;
19 let access_token = required_env("SYSTEM_ACCESSTOKEN")?;
20
21 let client = reqwest::Client::new();
24 let response = client
25 .post(&request_uri)
26 .header("Authorization", format!("Bearer {access_token}"))
27 .header("Content-Type", "application/json")
28 .body("{}")
29 .send()
30 .await
31 .map_err(|e| OidcError::Http(e.to_string()))?;
32
33 if !response.status().is_success() {
34 return Err(OidcError::Http(format!(
35 "Azure OIDC request failed with status {}",
36 response.status()
37 )));
38 }
39
40 let body = response
41 .json::<AzureTokenResponse>()
42 .await
43 .map_err(|e| OidcError::Http(e.to_string()))?;
44
45 Ok(body.oidc_token)
46 }
47}
48
49#[cfg(test)]
50#[allow(clippy::await_holding_lock)]
53mod tests {
54 use super::*;
55 use std::sync::Mutex;
56 use wiremock::matchers::{header, method};
57 use wiremock::{Mock, MockServer, ResponseTemplate};
58
59 static ENV_LOCK: Mutex<()> = Mutex::new(());
60
61 #[tokio::test]
62 async fn fetch_token_returns_jwt_from_mock_server() {
63 let _guard = ENV_LOCK.lock().unwrap();
64
65 let server = MockServer::start().await;
66
67 Mock::given(method("POST"))
68 .and(header("Authorization", "Bearer dummy-access-token"))
69 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
70 "oidcToken": "azure-jwt-token"
71 })))
72 .mount(&server)
73 .await;
74
75 std::env::set_var("SYSTEM_OIDCREQUESTURI", server.uri());
76 std::env::set_var("SYSTEM_ACCESSTOKEN", "dummy-access-token");
77
78 let token = AzureProvider
79 .fetch_token("https://api.deslicer.ai")
80 .await
81 .unwrap();
82
83 assert_eq!(token, "azure-jwt-token");
84
85 std::env::remove_var("SYSTEM_OIDCREQUESTURI");
86 std::env::remove_var("SYSTEM_ACCESSTOKEN");
87 }
88
89 #[tokio::test]
90 async fn fetch_token_errors_when_env_missing() {
91 let _guard = ENV_LOCK.lock().unwrap();
92
93 std::env::remove_var("SYSTEM_OIDCREQUESTURI");
94 std::env::remove_var("SYSTEM_ACCESSTOKEN");
95
96 let err = AzureProvider
97 .fetch_token("https://api.deslicer.ai")
98 .await
99 .unwrap_err();
100
101 assert!(matches!(err, OidcError::MissingEnv(_)));
102 }
103}