use std::sync::Arc;
use chrono::{DateTime, Utc};
use tokio::sync::RwLock;
use crate::config::ChalkClientConfig;
use crate::error::{ChalkClientError, Result};
use crate::types::{TokenExchangeRequest, TokenResponse};
const TOKEN_REFRESH_BUFFER_SECS: i64 = 60;
#[derive(Debug, Clone)]
struct CachedToken {
response: TokenResponse,
expires_at: DateTime<Utc>,
}
#[derive(Clone)]
pub struct TokenManager {
config: ChalkClientConfig,
http_client: reqwest::Client,
cache: Arc<RwLock<Option<CachedToken>>>,
}
impl TokenManager {
pub fn new(config: ChalkClientConfig) -> Self {
Self {
config,
http_client: reqwest::Client::new(),
cache: Arc::new(RwLock::new(None)),
}
}
pub async fn get_token(&self) -> Result<TokenResponse> {
{
let cache = self.cache.read().await;
if let Some(cached) = cache.as_ref() {
if is_token_valid(cached) {
return Ok(cached.response.clone());
}
}
}
let mut cache = self.cache.write().await;
if let Some(cached) = cache.as_ref() {
if is_token_valid(cached) {
return Ok(cached.response.clone());
}
}
let response = self.exchange_credentials().await?;
let expires_at = parse_expiry(&response);
*cache = Some(CachedToken {
response: response.clone(),
expires_at,
});
Ok(response)
}
async fn exchange_credentials(&self) -> Result<TokenResponse> {
let url = format!("{}/v1/oauth/token", self.config.api_server);
let body = TokenExchangeRequest {
client_id: self.config.client_id.clone(),
client_secret: self.config.client_secret.clone(),
grant_type: "client_credentials".into(),
};
tracing::debug!("exchanging credentials at {}", url);
let resp = self
.http_client
.post(&url)
.json(&body)
.header("Content-Type", "application/json")
.header("User-Agent", "chalk-rust/0.1.0")
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let body_text = resp.text().await.unwrap_or_default();
return Err(ChalkClientError::Auth(format!(
"token exchange failed (HTTP {}): {}",
status.as_u16(),
body_text
)));
}
let token: TokenResponse = resp.json().await?;
tracing::debug!(
"token exchanged successfully, primary_environment={:?}",
token.primary_environment
);
Ok(token)
}
pub fn config(&self) -> &ChalkClientConfig {
&self.config
}
}
fn is_token_valid(cached: &CachedToken) -> bool {
let now = Utc::now();
let remaining = cached.expires_at.signed_duration_since(now);
remaining.num_seconds() > TOKEN_REFRESH_BUFFER_SECS
}
fn parse_expiry(response: &TokenResponse) -> DateTime<Utc> {
if let Some(ref at) = response.expires_at {
if let Ok(parsed) = at.parse::<DateTime<Utc>>() {
return parsed;
}
}
if let Some(seconds) = response.expires_in {
return Utc::now() + chrono::Duration::seconds(seconds);
}
Utc::now() + chrono::Duration::hours(1)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ChalkClientConfigBuilder;
use std::collections::HashMap;
fn test_config(api_server: &str) -> ChalkClientConfig {
ChalkClientConfigBuilder::new()
.client_id("test-id")
.client_secret("test-secret")
.api_server(api_server)
.build()
.unwrap()
}
#[test]
fn test_parse_expiry_from_expires_at() {
let response = TokenResponse {
access_token: "token".into(),
expires_at: Some("2099-12-31T23:59:59Z".into()),
expires_in: None,
primary_environment: None,
engines: HashMap::new(),
grpc_engines: HashMap::new(),
environment_id_to_name: HashMap::new(),
api_server: None,
};
let expiry = parse_expiry(&response);
assert!(expiry > Utc::now());
}
#[test]
fn test_parse_expiry_from_expires_in() {
let response = TokenResponse {
access_token: "token".into(),
expires_at: None,
expires_in: Some(3600),
primary_environment: None,
engines: HashMap::new(),
grpc_engines: HashMap::new(),
environment_id_to_name: HashMap::new(),
api_server: None,
};
let expiry = parse_expiry(&response);
let now = Utc::now();
let diff = expiry.signed_duration_since(now).num_seconds();
assert!(diff > 3500 && diff <= 3600);
}
#[test]
fn test_is_token_valid_expired() {
let cached = CachedToken {
response: TokenResponse {
access_token: "token".into(),
expires_at: None,
expires_in: None,
primary_environment: None,
engines: HashMap::new(),
grpc_engines: HashMap::new(),
environment_id_to_name: HashMap::new(),
api_server: None,
},
expires_at: Utc::now() - chrono::Duration::minutes(10),
};
assert!(!is_token_valid(&cached));
}
#[test]
fn test_is_token_valid_fresh() {
let cached = CachedToken {
response: TokenResponse {
access_token: "token".into(),
expires_at: None,
expires_in: None,
primary_environment: None,
engines: HashMap::new(),
grpc_engines: HashMap::new(),
environment_id_to_name: HashMap::new(),
api_server: None,
},
expires_at: Utc::now() + chrono::Duration::minutes(30),
};
assert!(is_token_valid(&cached));
}
#[tokio::test]
async fn test_token_exchange_success() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", "/v1/oauth/token")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
serde_json::json!({
"access_token": "mock-jwt-token",
"expires_in": 3600,
"primary_environment": "env-abc",
"engines": {"env-abc": "https://engine.chalk.ai"},
"grpc_engines": {"env-abc": "https://grpc.chalk.ai"},
"environment_id_to_name": {"env-abc": "production"}
})
.to_string(),
)
.create_async()
.await;
let config = test_config(&server.url());
let manager = TokenManager::new(config);
let token = manager.get_token().await.unwrap();
assert_eq!(token.access_token, "mock-jwt-token");
assert_eq!(token.primary_environment.as_deref(), Some("env-abc"));
assert_eq!(
token.engines.get("env-abc").map(|s| s.as_str()),
Some("https://engine.chalk.ai")
);
mock.assert_async().await;
}
#[tokio::test]
async fn test_token_caching() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", "/v1/oauth/token")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
serde_json::json!({
"access_token": "cached-token",
"expires_in": 3600,
"engines": {},
"grpc_engines": {}
})
.to_string(),
)
.expect(1)
.create_async()
.await;
let config = test_config(&server.url());
let manager = TokenManager::new(config);
let t1 = manager.get_token().await.unwrap();
let t2 = manager.get_token().await.unwrap();
assert_eq!(t1.access_token, t2.access_token);
mock.assert_async().await;
}
#[tokio::test]
async fn test_token_exchange_failure() {
let mut server = mockito::Server::new_async().await;
server
.mock("POST", "/v1/oauth/token")
.with_status(401)
.with_body("invalid credentials")
.create_async()
.await;
let config = test_config(&server.url());
let manager = TokenManager::new(config);
let result = manager.get_token().await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("401"));
assert!(err.contains("invalid credentials"));
}
}