1use std::sync::Arc;
15
16use chrono::{DateTime, Utc};
17use tokio::sync::RwLock;
18
19use crate::config::ChalkClientConfig;
20use crate::error::{ChalkClientError, Result};
21use crate::types::{TokenExchangeRequest, TokenResponse};
22
23const TOKEN_REFRESH_BUFFER_SECS: i64 = 60;
26
27#[derive(Debug, Clone)]
29struct CachedToken {
30 response: TokenResponse,
31 expires_at: DateTime<Utc>,
32}
33
34#[derive(Clone)]
39pub struct TokenManager {
40 config: ChalkClientConfig,
41 http_client: reqwest::Client,
42 cache: Arc<RwLock<Option<CachedToken>>>,
43}
44
45impl TokenManager {
46 pub fn new(config: ChalkClientConfig) -> Self {
48 Self {
49 config,
50 http_client: reqwest::Client::new(),
51 cache: Arc::new(RwLock::new(None)),
52 }
53 }
54
55 pub async fn get_token(&self) -> Result<TokenResponse> {
57 {
58 let cache = self.cache.read().await;
59 if let Some(cached) = cache.as_ref() {
60 if is_token_valid(cached) {
61 return Ok(cached.response.clone());
62 }
63 }
64 }
65 let mut cache = self.cache.write().await;
66
67 if let Some(cached) = cache.as_ref() {
68 if is_token_valid(cached) {
69 return Ok(cached.response.clone());
70 }
71 }
72
73 let response = self.exchange_credentials().await?;
74 let expires_at = parse_expiry(&response);
75
76 *cache = Some(CachedToken {
77 response: response.clone(),
78 expires_at,
79 });
80
81 Ok(response)
82 }
83
84 async fn exchange_credentials(&self) -> Result<TokenResponse> {
86 let url = format!("{}/v1/oauth/token", self.config.api_server);
87
88 let body = TokenExchangeRequest {
89 client_id: self.config.client_id.clone(),
90 client_secret: self.config.client_secret.clone(),
91 grant_type: "client_credentials".into(),
92 };
93
94 tracing::debug!("exchanging credentials at {}", url);
95
96 let resp = self
97 .http_client
98 .post(&url)
99 .json(&body)
100 .header("Content-Type", "application/json")
101 .header("User-Agent", "chalk-rust/0.1.0")
102 .send()
103 .await?;
104
105 let status = resp.status();
106 if !status.is_success() {
107 let body_text = resp.text().await.unwrap_or_default();
108 return Err(ChalkClientError::Auth(format!(
109 "token exchange failed (HTTP {}): {}",
110 status.as_u16(),
111 body_text
112 )));
113 }
114
115 let token: TokenResponse = resp.json().await?;
116 tracing::debug!(
117 "token exchanged successfully, primary_environment={:?}",
118 token.primary_environment
119 );
120
121 Ok(token)
122 }
123
124 pub fn config(&self) -> &ChalkClientConfig {
126 &self.config
127 }
128}
129
130fn is_token_valid(cached: &CachedToken) -> bool {
131 let now = Utc::now();
132 let remaining = cached.expires_at.signed_duration_since(now);
133 remaining.num_seconds() > TOKEN_REFRESH_BUFFER_SECS
134}
135
136fn parse_expiry(response: &TokenResponse) -> DateTime<Utc> {
137 if let Some(ref at) = response.expires_at {
138 if let Ok(parsed) = at.parse::<DateTime<Utc>>() {
139 return parsed;
140 }
141 }
142
143 if let Some(seconds) = response.expires_in {
144 return Utc::now() + chrono::Duration::seconds(seconds);
145 }
146
147 Utc::now() + chrono::Duration::hours(1)
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::config::ChalkClientConfigBuilder;
154 use std::collections::HashMap;
155
156 fn test_config(api_server: &str) -> ChalkClientConfig {
157 ChalkClientConfigBuilder::new()
158 .client_id("test-id")
159 .client_secret("test-secret")
160 .api_server(api_server)
161 .build()
162 .unwrap()
163 }
164
165 #[test]
166 fn test_parse_expiry_from_expires_at() {
167 let response = TokenResponse {
168 access_token: "token".into(),
169 expires_at: Some("2099-12-31T23:59:59Z".into()),
170 expires_in: None,
171 primary_environment: None,
172 engines: HashMap::new(),
173 grpc_engines: HashMap::new(),
174 environment_id_to_name: HashMap::new(),
175 api_server: None,
176 };
177
178 let expiry = parse_expiry(&response);
179 assert!(expiry > Utc::now());
180 }
181
182 #[test]
183 fn test_parse_expiry_from_expires_in() {
184 let response = TokenResponse {
185 access_token: "token".into(),
186 expires_at: None,
187 expires_in: Some(3600),
188 primary_environment: None,
189 engines: HashMap::new(),
190 grpc_engines: HashMap::new(),
191 environment_id_to_name: HashMap::new(),
192 api_server: None,
193 };
194
195 let expiry = parse_expiry(&response);
196 let now = Utc::now();
197 let diff = expiry.signed_duration_since(now).num_seconds();
198 assert!(diff > 3500 && diff <= 3600);
199 }
200
201 #[test]
202 fn test_is_token_valid_expired() {
203 let cached = CachedToken {
204 response: TokenResponse {
205 access_token: "token".into(),
206 expires_at: None,
207 expires_in: None,
208 primary_environment: None,
209 engines: HashMap::new(),
210 grpc_engines: HashMap::new(),
211 environment_id_to_name: HashMap::new(),
212 api_server: None,
213 },
214 expires_at: Utc::now() - chrono::Duration::minutes(10),
215 };
216
217 assert!(!is_token_valid(&cached));
218 }
219
220 #[test]
221 fn test_is_token_valid_fresh() {
222 let cached = CachedToken {
223 response: TokenResponse {
224 access_token: "token".into(),
225 expires_at: None,
226 expires_in: None,
227 primary_environment: None,
228 engines: HashMap::new(),
229 grpc_engines: HashMap::new(),
230 environment_id_to_name: HashMap::new(),
231 api_server: None,
232 },
233 expires_at: Utc::now() + chrono::Duration::minutes(30),
234 };
235
236 assert!(is_token_valid(&cached));
237 }
238
239 #[tokio::test]
240 async fn test_token_exchange_success() {
241 let mut server = mockito::Server::new_async().await;
242
243 let mock = server
244 .mock("POST", "/v1/oauth/token")
245 .with_status(200)
246 .with_header("content-type", "application/json")
247 .with_body(
248 serde_json::json!({
249 "access_token": "mock-jwt-token",
250 "expires_in": 3600,
251 "primary_environment": "env-abc",
252 "engines": {"env-abc": "https://engine.chalk.ai"},
253 "grpc_engines": {"env-abc": "https://grpc.chalk.ai"},
254 "environment_id_to_name": {"env-abc": "production"}
255 })
256 .to_string(),
257 )
258 .create_async()
259 .await;
260
261 let config = test_config(&server.url());
262 let manager = TokenManager::new(config);
263
264 let token = manager.get_token().await.unwrap();
265 assert_eq!(token.access_token, "mock-jwt-token");
266 assert_eq!(token.primary_environment.as_deref(), Some("env-abc"));
267 assert_eq!(
268 token.engines.get("env-abc").map(|s| s.as_str()),
269 Some("https://engine.chalk.ai")
270 );
271
272 mock.assert_async().await;
273 }
274
275 #[tokio::test]
276 async fn test_token_caching() {
277 let mut server = mockito::Server::new_async().await;
278
279 let mock = server
280 .mock("POST", "/v1/oauth/token")
281 .with_status(200)
282 .with_header("content-type", "application/json")
283 .with_body(
284 serde_json::json!({
285 "access_token": "cached-token",
286 "expires_in": 3600,
287 "engines": {},
288 "grpc_engines": {}
289 })
290 .to_string(),
291 )
292 .expect(1)
293 .create_async()
294 .await;
295
296 let config = test_config(&server.url());
297 let manager = TokenManager::new(config);
298
299 let t1 = manager.get_token().await.unwrap();
300 let t2 = manager.get_token().await.unwrap();
301
302 assert_eq!(t1.access_token, t2.access_token);
303 mock.assert_async().await;
304 }
305
306 #[tokio::test]
307 async fn test_token_exchange_failure() {
308 let mut server = mockito::Server::new_async().await;
309
310 server
311 .mock("POST", "/v1/oauth/token")
312 .with_status(401)
313 .with_body("invalid credentials")
314 .create_async()
315 .await;
316
317 let config = test_config(&server.url());
318 let manager = TokenManager::new(config);
319
320 let result = manager.get_token().await;
321 assert!(result.is_err());
322 let err = result.unwrap_err().to_string();
323 assert!(err.contains("401"));
324 assert!(err.contains("invalid credentials"));
325 }
326}