use crate::error::ZerobusError;
use serde::{Deserialize, Serialize};
use tracing::{debug, info, warn};
#[derive(Debug, Serialize, Deserialize)]
struct TokenResponse {
access_token: String,
token_type: Option<String>,
expires_in: Option<u64>,
scope: Option<String>,
}
pub fn is_token_expired_error(error: &ZerobusError) -> bool {
matches!(error, ZerobusError::AuthenticationError(_))
}
pub async fn refresh_token(
unity_catalog_url: &str,
client_id: &str,
client_secret: &str,
) -> Result<String, ZerobusError> {
info!("Refreshing authentication token from {}", unity_catalog_url);
let token_url = if unity_catalog_url.ends_with('/') {
format!("{}oidc/v1/token", unity_catalog_url)
} else {
format!("{}/oidc/v1/token", unity_catalog_url)
};
debug!("Token endpoint: {}", token_url);
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| {
ZerobusError::TokenRefreshError(format!("Failed to create HTTP client: {}", e))
})?;
let params = [
("grant_type", "client_credentials"),
("client_id", client_id),
("client_secret", client_secret),
];
let response = client
.post(&token_url)
.form(¶ms)
.send()
.await
.map_err(|e| {
ZerobusError::TokenRefreshError(format!("Failed to send token refresh request: {}", e))
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
warn!(
"Token refresh failed with status {}: {}",
status, error_text
);
return Err(ZerobusError::TokenRefreshError(format!(
"Token refresh failed with status {}: {}",
status, error_text
)));
}
let token_response: TokenResponse = response.json().await.map_err(|e| {
ZerobusError::TokenRefreshError(format!("Failed to parse token response: {}", e))
})?;
debug!(
"Token refresh successful, expires_in: {:?}",
token_response.expires_in
);
Ok(token_response.access_token)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_token_expired_error() {
let auth_error = ZerobusError::AuthenticationError("token expired".to_string());
assert!(is_token_expired_error(&auth_error));
let config_error = ZerobusError::ConfigurationError("test".to_string());
assert!(!is_token_expired_error(&config_error));
}
#[tokio::test]
#[ignore] async fn test_refresh_token_integration() {
let result = refresh_token(
"https://test.cloud.databricks.com",
"test_client_id",
"test_client_secret",
)
.await;
assert!(result.is_err());
}
}