kontext_dev_sdk/management/
auth.rs1use std::sync::Arc;
2use std::time::{Duration, SystemTime};
3
4use serde::{Deserialize, Serialize};
5use tokio::sync::Mutex;
6
7use crate::KontextDevError;
8
9#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
10#[serde(rename_all = "camelCase")]
11pub struct ServiceAccountCredentials {
12 pub client_id: String,
13 pub client_secret: String,
14}
15
16#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
17pub struct AccessToken {
18 pub access_token: String,
19 #[serde(default = "default_token_type")]
20 pub token_type: String,
21 #[serde(default)]
22 pub expires_in: Option<i64>,
23 #[serde(default)]
24 pub scope: Option<String>,
25}
26
27fn default_token_type() -> String {
28 "Bearer".to_string()
29}
30
31#[derive(Clone, Debug)]
32struct CachedToken {
33 token: AccessToken,
34 expires_at: Option<SystemTime>,
35}
36
37impl CachedToken {
38 fn is_valid(&self) -> bool {
39 match self.expires_at {
40 Some(expires_at) => {
41 let now = SystemTime::now();
42 let refresh_margin = Duration::from_secs(60);
43 now + refresh_margin < expires_at
44 }
45 None => true,
46 }
47 }
48}
49
50#[derive(Clone)]
51pub struct TokenManager {
52 token_url: String,
53 credentials: ServiceAccountCredentials,
54 scopes: Vec<String>,
55 audience: String,
56 http: reqwest::Client,
57 cache: Arc<Mutex<Option<CachedToken>>>,
58}
59
60impl TokenManager {
61 pub fn new(
62 token_url: String,
63 credentials: ServiceAccountCredentials,
64 scopes: Vec<String>,
65 audience: String,
66 ) -> Self {
67 Self {
68 token_url,
69 credentials,
70 scopes,
71 audience,
72 http: reqwest::Client::new(),
73 cache: Arc::new(Mutex::new(None)),
74 }
75 }
76
77 pub async fn token(&self) -> Result<String, KontextDevError> {
78 let mut guard = self.cache.lock().await;
79 if let Some(cached) = guard.as_ref()
80 && cached.is_valid()
81 {
82 return Ok(cached.token.access_token.clone());
83 }
84
85 let token = authenticate_service_account(
86 &self.http,
87 &self.token_url,
88 &self.credentials,
89 &self.scopes,
90 &self.audience,
91 )
92 .await?;
93
94 let expires_at = token
95 .expires_in
96 .and_then(|s| u64::try_from(s).ok())
97 .map(|s| SystemTime::now() + Duration::from_secs(s));
98
99 *guard = Some(CachedToken {
100 token: token.clone(),
101 expires_at,
102 });
103
104 Ok(token.access_token)
105 }
106
107 pub async fn refresh(&self) -> Result<(), KontextDevError> {
108 let token = authenticate_service_account(
109 &self.http,
110 &self.token_url,
111 &self.credentials,
112 &self.scopes,
113 &self.audience,
114 )
115 .await?;
116
117 let expires_at = token
118 .expires_in
119 .and_then(|s| u64::try_from(s).ok())
120 .map(|s| SystemTime::now() + Duration::from_secs(s));
121
122 let mut guard = self.cache.lock().await;
123 *guard = Some(CachedToken { token, expires_at });
124 Ok(())
125 }
126
127 pub async fn clear(&self) {
128 let mut guard = self.cache.lock().await;
129 *guard = None;
130 }
131}
132
133pub async fn authenticate_service_account(
134 http: &reqwest::Client,
135 token_url: &str,
136 credentials: &ServiceAccountCredentials,
137 scopes: &[String],
138 audience: &str,
139) -> Result<AccessToken, KontextDevError> {
140 let scope = scopes.join(" ");
141 let response = http
142 .post(token_url)
143 .basic_auth(
144 credentials.client_id.as_str(),
145 Some(credentials.client_secret.as_str()),
146 )
147 .form(&[
148 ("grant_type", "client_credentials"),
149 ("scope", scope.as_str()),
150 ("audience", audience),
151 ])
152 .send()
153 .await
154 .map_err(|err| KontextDevError::TokenRequest {
155 token_url: token_url.to_string(),
156 message: err.to_string(),
157 })?;
158
159 if !response.status().is_success() {
160 let status = response.status();
161 let body = response.text().await.unwrap_or_default();
162 return Err(KontextDevError::TokenRequest {
163 token_url: token_url.to_string(),
164 message: format!("{status}: {body}"),
165 });
166 }
167
168 response
169 .json::<AccessToken>()
170 .await
171 .map_err(|err| KontextDevError::TokenRequest {
172 token_url: token_url.to_string(),
173 message: err.to_string(),
174 })
175}