use std::sync::Arc;
use xet_client::cas_client::auth::{DirectRefreshRouteTokenRefresher, TokenRefresher};
use xet_client::common::http_client::build_http_client;
use xet_data::processing::configurations::TranslatorConfig;
use super::XetSession;
use super::auth_group_builder::AuthOptions;
use crate::error::XetError;
pub(super) async fn create_translator_config(
session: &XetSession,
auth_options: AuthOptions,
) -> Result<TranslatorConfig, XetError> {
let session_id = session.inner.id.to_string();
let AuthOptions {
mut endpoint,
custom_headers,
mut token_info,
token_refresh,
} = auth_options;
let token_refresher: Option<Arc<dyn TokenRefresher>> = if let Some((url, token_refresh_headers)) = token_refresh {
let client = build_http_client(&session_id, None, Some(Arc::new(token_refresh_headers)))?;
let direct_route_refresher = DirectRefreshRouteTokenRefresher::new(url, client, None);
if endpoint.is_none() {
let jwt_info = direct_route_refresher.get_cas_jwt().await?;
if token_info.is_none() {
token_info = Some((jwt_info.access_token, jwt_info.exp));
}
endpoint = Some(jwt_info.cas_url);
}
Some(Arc::new(direct_route_refresher))
} else {
None
};
let endpoint = endpoint.unwrap_or_else(|| session.inner.config.data.default_cas_endpoint.clone());
let mut config = xet_data::processing::data_client::default_config(
endpoint,
token_info,
token_refresher,
custom_headers.map(Arc::new),
)?;
if !session_id.is_empty() {
config.session.session_id = Some(session_id);
}
Ok(config)
}
#[cfg(test)]
mod tests {
use http::HeaderMap;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use super::*;
use crate::xet_session::XetSessionBuilder;
use crate::xet_session::auth_group_builder::AuthOptions;
#[tokio::test]
async fn test_endpoint_provided_skips_eager_refresh() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"casUrl": "https://should-not-be-used.example.com",
"exp": 9_999_999_999u64,
"accessToken": "should-not-be-fetched",
})))
.expect(0)
.mount(&server)
.await;
let refresh_url = format!("{}/token", server.uri());
let session = XetSessionBuilder::new().build().unwrap();
let auth_options = AuthOptions {
endpoint: Some("https://cas.example.com".to_string()),
custom_headers: None,
token_info: None,
token_refresh: Some((refresh_url, HeaderMap::new())),
};
let config = create_translator_config(&session, auth_options).await.unwrap();
assert_eq!(config.session.endpoint, "https://cas.example.com");
}
#[tokio::test]
async fn test_eager_refresh_sets_endpoint_and_token() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"casUrl": "https://cas.example.com",
"exp": 9_999_999_999u64,
"accessToken": "eagerly-fetched-token",
})))
.expect(1)
.mount(&server)
.await;
let refresh_url = format!("{}/token", server.uri());
let session = XetSessionBuilder::new().build().unwrap();
let auth_options = AuthOptions {
endpoint: None,
custom_headers: None,
token_info: None,
token_refresh: Some((refresh_url, HeaderMap::new())),
};
let config = create_translator_config(&session, auth_options).await.unwrap();
assert_eq!(config.session.endpoint, "https://cas.example.com");
let auth = config.session.auth.expect("auth config should be set");
assert_eq!(auth.token, "eagerly-fetched-token");
assert_eq!(auth.token_expiration, 9_999_999_999);
}
#[tokio::test]
async fn test_eager_refresh_preserves_existing_token_info() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"casUrl": "https://cas.example.com",
"exp": 9_999_999_999u64,
"accessToken": "eagerly-fetched-token",
})))
.expect(1)
.mount(&server)
.await;
let refresh_url = format!("{}/token", server.uri());
let session = XetSessionBuilder::new().build().unwrap();
let auth_options = AuthOptions {
endpoint: None,
custom_headers: None,
token_info: Some(("pre-supplied-token".to_string(), 1_000_000_000)),
token_refresh: Some((refresh_url, HeaderMap::new())),
};
let config = create_translator_config(&session, auth_options).await.unwrap();
assert_eq!(config.session.endpoint, "https://cas.example.com");
let auth = config.session.auth.expect("auth config should be set");
assert_eq!(auth.token, "pre-supplied-token");
assert_eq!(auth.token_expiration, 1_000_000_000);
}
}