Skip to main content

chalk_client/
auth.rs

1//! Token management — exchanging credentials for JWTs.
2//!
3//! The Chalk API uses OAuth2 client-credentials flow:
4//!
5//! 1. You POST your `client_id` + `client_secret` to `/v1/oauth/token`.
6//! 2. The server returns a JWT (`access_token`) that expires after some time.
7//! 3. You attach this JWT as a `Bearer` token on every subsequent request.
8//!
9//! This module handles the exchange *and* caches the token so we don't hit
10//! the auth endpoint on every single query. The cache is thread-safe (using
11//! `tokio::sync::RwLock`) so multiple async tasks can share a single
12//! `TokenManager`.
13
14use std::sync::Arc;
15
16use chrono::{DateTime, Utc};
17use tokio::sync::RwLock;
18
19use crate::config::ChalkClientConfig;
20use crate::error::{ChalkClientError, Result};
21use crate::types::{TokenExchangeRequest, TokenResponse};
22
23/// How many seconds before actual expiry we consider a token "stale" and
24/// refresh it proactively.
25const TOKEN_REFRESH_BUFFER_SECS: i64 = 60;
26
27/// A cached token plus the timestamp when it expires.
28#[derive(Debug, Clone)]
29struct CachedToken {
30    response: TokenResponse,
31    expires_at: DateTime<Utc>,
32}
33
34/// Manages authentication tokens for the Chalk client.
35///
36/// This struct is cheap to clone (the inner state is behind an `Arc`), so
37/// it can be shared between clients.
38#[derive(Clone)]
39pub struct TokenManager {
40    config: ChalkClientConfig,
41    http_client: reqwest::Client,
42    cache: Arc<RwLock<Option<CachedToken>>>,
43}
44
45impl TokenManager {
46    /// Create a new `TokenManager` from the given config.
47    pub fn new(config: ChalkClientConfig) -> Self {
48        Self {
49            config,
50            http_client: reqwest::Client::new(),
51            cache: Arc::new(RwLock::new(None)),
52        }
53    }
54
55    /// Get a valid token, fetching or refreshing as needed.
56    pub async fn get_token(&self) -> Result<TokenResponse> {
57        {
58            let cache = self.cache.read().await;
59            if let Some(cached) = cache.as_ref() {
60                if is_token_valid(cached) {
61                    return Ok(cached.response.clone());
62                }
63            }
64        }
65        let mut cache = self.cache.write().await;
66
67        if let Some(cached) = cache.as_ref() {
68            if is_token_valid(cached) {
69                return Ok(cached.response.clone());
70            }
71        }
72
73        let response = self.exchange_credentials().await?;
74        let expires_at = parse_expiry(&response);
75
76        *cache = Some(CachedToken {
77            response: response.clone(),
78            expires_at,
79        });
80
81        Ok(response)
82    }
83
84    /// POST to `/v1/oauth/token` to exchange client credentials for a JWT.
85    async fn exchange_credentials(&self) -> Result<TokenResponse> {
86        let url = format!("{}/v1/oauth/token", self.config.api_server);
87
88        let body = TokenExchangeRequest {
89            client_id: self.config.client_id.clone(),
90            client_secret: self.config.client_secret.clone(),
91            grant_type: "client_credentials".into(),
92        };
93
94        tracing::debug!("exchanging credentials at {}", url);
95
96        let resp = self
97            .http_client
98            .post(&url)
99            .json(&body)
100            .header("Content-Type", "application/json")
101            .header("User-Agent", "chalk-rust/0.1.0")
102            .send()
103            .await?;
104
105        let status = resp.status();
106        if !status.is_success() {
107            let body_text = resp.text().await.unwrap_or_default();
108            return Err(ChalkClientError::Auth(format!(
109                "token exchange failed (HTTP {}): {}",
110                status.as_u16(),
111                body_text
112            )));
113        }
114
115        let token: TokenResponse = resp.json().await?;
116        tracing::debug!(
117            "token exchanged successfully, primary_environment={:?}",
118            token.primary_environment
119        );
120
121        Ok(token)
122    }
123
124    /// Returns a reference to the underlying config.
125    pub fn config(&self) -> &ChalkClientConfig {
126        &self.config
127    }
128}
129
130fn is_token_valid(cached: &CachedToken) -> bool {
131    let now = Utc::now();
132    let remaining = cached.expires_at.signed_duration_since(now);
133    remaining.num_seconds() > TOKEN_REFRESH_BUFFER_SECS
134}
135
136fn parse_expiry(response: &TokenResponse) -> DateTime<Utc> {
137    if let Some(ref at) = response.expires_at {
138        if let Ok(parsed) = at.parse::<DateTime<Utc>>() {
139            return parsed;
140        }
141    }
142
143    if let Some(seconds) = response.expires_in {
144        return Utc::now() + chrono::Duration::seconds(seconds);
145    }
146
147    Utc::now() + chrono::Duration::hours(1)
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::config::ChalkClientConfigBuilder;
154    use std::collections::HashMap;
155
156    fn test_config(api_server: &str) -> ChalkClientConfig {
157        ChalkClientConfigBuilder::new()
158            .client_id("test-id")
159            .client_secret("test-secret")
160            .api_server(api_server)
161            .build()
162            .unwrap()
163    }
164
165    #[test]
166    fn test_parse_expiry_from_expires_at() {
167        let response = TokenResponse {
168            access_token: "token".into(),
169            expires_at: Some("2099-12-31T23:59:59Z".into()),
170            expires_in: None,
171            primary_environment: None,
172            engines: HashMap::new(),
173            grpc_engines: HashMap::new(),
174            environment_id_to_name: HashMap::new(),
175            api_server: None,
176        };
177
178        let expiry = parse_expiry(&response);
179        assert!(expiry > Utc::now());
180    }
181
182    #[test]
183    fn test_parse_expiry_from_expires_in() {
184        let response = TokenResponse {
185            access_token: "token".into(),
186            expires_at: None,
187            expires_in: Some(3600),
188            primary_environment: None,
189            engines: HashMap::new(),
190            grpc_engines: HashMap::new(),
191            environment_id_to_name: HashMap::new(),
192            api_server: None,
193        };
194
195        let expiry = parse_expiry(&response);
196        let now = Utc::now();
197        let diff = expiry.signed_duration_since(now).num_seconds();
198        assert!(diff > 3500 && diff <= 3600);
199    }
200
201    #[test]
202    fn test_is_token_valid_expired() {
203        let cached = CachedToken {
204            response: TokenResponse {
205                access_token: "token".into(),
206                expires_at: None,
207                expires_in: None,
208                primary_environment: None,
209                engines: HashMap::new(),
210                grpc_engines: HashMap::new(),
211                environment_id_to_name: HashMap::new(),
212                api_server: None,
213            },
214            expires_at: Utc::now() - chrono::Duration::minutes(10),
215        };
216
217        assert!(!is_token_valid(&cached));
218    }
219
220    #[test]
221    fn test_is_token_valid_fresh() {
222        let cached = CachedToken {
223            response: TokenResponse {
224                access_token: "token".into(),
225                expires_at: None,
226                expires_in: None,
227                primary_environment: None,
228                engines: HashMap::new(),
229                grpc_engines: HashMap::new(),
230                environment_id_to_name: HashMap::new(),
231                api_server: None,
232            },
233            expires_at: Utc::now() + chrono::Duration::minutes(30),
234        };
235
236        assert!(is_token_valid(&cached));
237    }
238
239    #[tokio::test]
240    async fn test_token_exchange_success() {
241        let mut server = mockito::Server::new_async().await;
242
243        let mock = server
244            .mock("POST", "/v1/oauth/token")
245            .with_status(200)
246            .with_header("content-type", "application/json")
247            .with_body(
248                serde_json::json!({
249                    "access_token": "mock-jwt-token",
250                    "expires_in": 3600,
251                    "primary_environment": "env-abc",
252                    "engines": {"env-abc": "https://engine.chalk.ai"},
253                    "grpc_engines": {"env-abc": "https://grpc.chalk.ai"},
254                    "environment_id_to_name": {"env-abc": "production"}
255                })
256                .to_string(),
257            )
258            .create_async()
259            .await;
260
261        let config = test_config(&server.url());
262        let manager = TokenManager::new(config);
263
264        let token = manager.get_token().await.unwrap();
265        assert_eq!(token.access_token, "mock-jwt-token");
266        assert_eq!(token.primary_environment.as_deref(), Some("env-abc"));
267        assert_eq!(
268            token.engines.get("env-abc").map(|s| s.as_str()),
269            Some("https://engine.chalk.ai")
270        );
271
272        mock.assert_async().await;
273    }
274
275    #[tokio::test]
276    async fn test_token_caching() {
277        let mut server = mockito::Server::new_async().await;
278
279        let mock = server
280            .mock("POST", "/v1/oauth/token")
281            .with_status(200)
282            .with_header("content-type", "application/json")
283            .with_body(
284                serde_json::json!({
285                    "access_token": "cached-token",
286                    "expires_in": 3600,
287                    "engines": {},
288                    "grpc_engines": {}
289                })
290                .to_string(),
291            )
292            .expect(1)
293            .create_async()
294            .await;
295
296        let config = test_config(&server.url());
297        let manager = TokenManager::new(config);
298
299        let t1 = manager.get_token().await.unwrap();
300        let t2 = manager.get_token().await.unwrap();
301
302        assert_eq!(t1.access_token, t2.access_token);
303        mock.assert_async().await;
304    }
305
306    #[tokio::test]
307    async fn test_token_exchange_failure() {
308        let mut server = mockito::Server::new_async().await;
309
310        server
311            .mock("POST", "/v1/oauth/token")
312            .with_status(401)
313            .with_body("invalid credentials")
314            .create_async()
315            .await;
316
317        let config = test_config(&server.url());
318        let manager = TokenManager::new(config);
319
320        let result = manager.get_token().await;
321        assert!(result.is_err());
322        let err = result.unwrap_err().to_string();
323        assert!(err.contains("401"));
324        assert!(err.contains("invalid credentials"));
325    }
326}