Skip to main content

kontext_dev_sdk/management/
auth.rs

1use std::sync::Arc;
2use std::time::{Duration, SystemTime};
3
4use serde::{Deserialize, Serialize};
5use tokio::sync::Mutex;
6
7use crate::KontextDevError;
8
9#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
10#[serde(rename_all = "camelCase")]
11pub struct ServiceAccountCredentials {
12    pub client_id: String,
13    pub client_secret: String,
14}
15
16#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
17pub struct AccessToken {
18    pub access_token: String,
19    #[serde(default = "default_token_type")]
20    pub token_type: String,
21    #[serde(default)]
22    pub expires_in: Option<i64>,
23    #[serde(default)]
24    pub scope: Option<String>,
25}
26
27fn default_token_type() -> String {
28    "Bearer".to_string()
29}
30
31#[derive(Clone, Debug)]
32struct CachedToken {
33    token: AccessToken,
34    expires_at: Option<SystemTime>,
35}
36
37impl CachedToken {
38    fn is_valid(&self) -> bool {
39        match self.expires_at {
40            Some(expires_at) => {
41                let now = SystemTime::now();
42                let refresh_margin = Duration::from_secs(60);
43                now + refresh_margin < expires_at
44            }
45            None => true,
46        }
47    }
48}
49
50#[derive(Clone)]
51pub struct TokenManager {
52    token_url: String,
53    credentials: ServiceAccountCredentials,
54    scopes: Vec<String>,
55    audience: String,
56    http: reqwest::Client,
57    cache: Arc<Mutex<Option<CachedToken>>>,
58}
59
60impl TokenManager {
61    pub fn new(
62        token_url: String,
63        credentials: ServiceAccountCredentials,
64        scopes: Vec<String>,
65        audience: String,
66    ) -> Self {
67        Self {
68            token_url,
69            credentials,
70            scopes,
71            audience,
72            http: reqwest::Client::new(),
73            cache: Arc::new(Mutex::new(None)),
74        }
75    }
76
77    pub async fn token(&self) -> Result<String, KontextDevError> {
78        let mut guard = self.cache.lock().await;
79        if let Some(cached) = guard.as_ref()
80            && cached.is_valid()
81        {
82            return Ok(cached.token.access_token.clone());
83        }
84
85        let token = authenticate_service_account(
86            &self.http,
87            &self.token_url,
88            &self.credentials,
89            &self.scopes,
90            &self.audience,
91        )
92        .await?;
93
94        let expires_at = token
95            .expires_in
96            .and_then(|s| u64::try_from(s).ok())
97            .map(|s| SystemTime::now() + Duration::from_secs(s));
98
99        *guard = Some(CachedToken {
100            token: token.clone(),
101            expires_at,
102        });
103
104        Ok(token.access_token)
105    }
106
107    pub async fn refresh(&self) -> Result<(), KontextDevError> {
108        let token = authenticate_service_account(
109            &self.http,
110            &self.token_url,
111            &self.credentials,
112            &self.scopes,
113            &self.audience,
114        )
115        .await?;
116
117        let expires_at = token
118            .expires_in
119            .and_then(|s| u64::try_from(s).ok())
120            .map(|s| SystemTime::now() + Duration::from_secs(s));
121
122        let mut guard = self.cache.lock().await;
123        *guard = Some(CachedToken { token, expires_at });
124        Ok(())
125    }
126
127    pub async fn clear(&self) {
128        let mut guard = self.cache.lock().await;
129        *guard = None;
130    }
131}
132
133pub async fn authenticate_service_account(
134    http: &reqwest::Client,
135    token_url: &str,
136    credentials: &ServiceAccountCredentials,
137    scopes: &[String],
138    audience: &str,
139) -> Result<AccessToken, KontextDevError> {
140    let scope = scopes.join(" ");
141    let response = http
142        .post(token_url)
143        .basic_auth(
144            credentials.client_id.as_str(),
145            Some(credentials.client_secret.as_str()),
146        )
147        .form(&[
148            ("grant_type", "client_credentials"),
149            ("scope", scope.as_str()),
150            ("audience", audience),
151        ])
152        .send()
153        .await
154        .map_err(|err| KontextDevError::TokenRequest {
155            token_url: token_url.to_string(),
156            message: err.to_string(),
157        })?;
158
159    if !response.status().is_success() {
160        let status = response.status();
161        let body = response.text().await.unwrap_or_default();
162        return Err(KontextDevError::TokenRequest {
163            token_url: token_url.to_string(),
164            message: format!("{status}: {body}"),
165        });
166    }
167
168    response
169        .json::<AccessToken>()
170        .await
171        .map_err(|err| KontextDevError::TokenRequest {
172            token_url: token_url.to_string(),
173            message: err.to_string(),
174        })
175}