use std::collections::BTreeMap;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::error::{KrafkaError, Result};
const OAUTHBEARER_EXPIRY_SKEW_MARGIN_MS: i64 = 30_000;
fn current_epoch_ms() -> i64 {
let now_u128 = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
i64::try_from(now_u128).unwrap_or(i64::MAX)
}
pub trait OAuthBearerTokenProvider: Send + Sync {
fn provide_token(&self) -> Pin<Box<dyn Future<Output = Result<OAuthBearerToken>> + Send + '_>>;
}
impl<F, Fut> OAuthBearerTokenProvider for F
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Result<OAuthBearerToken>> + Send + 'static,
{
fn provide_token(&self) -> Pin<Box<dyn Future<Output = Result<OAuthBearerToken>> + Send + '_>> {
Box::pin(self())
}
}
#[derive(Clone)]
pub struct OAuthBearerTokenProviderHandle(Arc<dyn OAuthBearerTokenProvider>);
impl OAuthBearerTokenProviderHandle {
pub fn new(provider: impl OAuthBearerTokenProvider + 'static) -> Self {
Self(Arc::new(provider))
}
pub async fn provide_token(&self) -> Result<OAuthBearerToken> {
self.0.provide_token().await
}
}
impl fmt::Debug for OAuthBearerTokenProviderHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("[OAuthBearerTokenProvider]")
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct OAuthBearerToken {
token_value: String,
#[zeroize(skip)]
extensions: BTreeMap<String, String>,
#[zeroize(skip)]
lifetime_ms: Option<i64>,
}
impl OAuthBearerToken {
pub fn new(token_value: impl Into<String>) -> Self {
Self {
token_value: token_value.into(),
extensions: BTreeMap::new(),
lifetime_ms: None,
}
}
pub fn with_extension(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.extensions.insert(key.into(), value.into());
self
}
pub fn with_lifetime_ms(mut self, lifetime_ms: i64) -> Self {
self.lifetime_ms = Some(lifetime_ms);
self
}
pub fn lifetime_ms(&self) -> Option<i64> {
self.lifetime_ms
}
pub fn is_expired(&self) -> bool {
self.lifetime_ms
.is_some_and(|lifetime_ms| current_epoch_ms() >= lifetime_ms)
}
pub fn needs_refresh(&self) -> bool {
self.lifetime_ms.is_some_and(|lifetime_ms| {
current_epoch_ms() >= lifetime_ms.saturating_sub(OAUTHBEARER_EXPIRY_SKEW_MARGIN_MS)
})
}
pub(crate) fn to_gs2_initial_response(&self) -> Vec<u8> {
let mut capacity = 3 + 1 + 12 + self.token_value.len() + 2; for (k, v) in &self.extensions {
capacity += 1 + k.len() + 1 + v.len(); }
let mut response = Vec::with_capacity(capacity);
response.extend_from_slice(b"n,,");
response.push(0x01);
response.extend_from_slice(b"auth=Bearer ");
response.extend_from_slice(self.token_value.as_bytes());
for (key, value) in &self.extensions {
response.push(0x01);
response.extend_from_slice(key.as_bytes());
response.push(b'=');
response.extend_from_slice(value.as_bytes());
}
response.push(0x01);
response.push(0x01);
response
}
pub(crate) fn process_server_response(&self, challenge: &[u8]) -> Result<()> {
if challenge.is_empty() {
return Ok(());
}
let error_msg = String::from_utf8_lossy(challenge);
Err(KrafkaError::auth(format!(
"OAUTHBEARER authentication failed: {error_msg}"
)))
}
}
impl fmt::Debug for OAuthBearerToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OAuthBearerToken")
.field("token_value", &"[REDACTED]")
.field("extensions", &self.extensions)
.field("lifetime_ms", &self.lifetime_ms)
.finish()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_oauthbearer_token_basic() {
let token = OAuthBearerToken::new("my-jwt-token");
let response = token.to_gs2_initial_response();
let expected = b"n,,\x01auth=Bearer my-jwt-token\x01\x01";
assert_eq!(response, expected);
}
#[test]
fn test_oauthbearer_token_with_single_extension() {
let token = OAuthBearerToken::new("my-token").with_extension("logicalCluster", "lkc-123");
let response = token.to_gs2_initial_response();
let response_str = String::from_utf8_lossy(&response);
assert!(response_str.starts_with("n,,\x01auth=Bearer my-token"));
assert!(response_str.contains("\x01logicalCluster=lkc-123"));
assert!(response_str.ends_with("\x01\x01"));
}
#[test]
fn test_oauthbearer_token_with_multiple_extensions() {
let token = OAuthBearerToken::new("tok")
.with_extension("ext1", "val1")
.with_extension("ext2", "val2");
let response = token.to_gs2_initial_response();
let response_str = String::from_utf8_lossy(&response);
assert!(response_str.starts_with("n,,\x01auth=Bearer tok"));
assert!(response_str.contains("ext1=val1"));
assert!(response_str.contains("ext2=val2"));
assert!(response_str.ends_with("\x01\x01"));
}
#[test]
fn test_oauthbearer_debug_redacts_token() {
let token = OAuthBearerToken::new("secret-token-value");
let debug = format!("{token:?}");
assert!(!debug.contains("secret-token-value"));
assert!(debug.contains("[REDACTED]"));
}
#[test]
fn test_oauthbearer_server_response_success_empty() {
let token = OAuthBearerToken::new("tok");
assert!(token.process_server_response(b"").is_ok());
}
#[test]
fn test_oauthbearer_server_response_error_json() {
let token = OAuthBearerToken::new("tok");
let error_json = br#"{"status":"invalid_token","scope":"openid"}"#;
let result = token.process_server_response(error_json);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("invalid_token"));
}
#[test]
fn test_oauthbearer_gs2_format_compliance() {
let token = OAuthBearerToken::new("eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJhbGljZSJ9.sig");
let response = token.to_gs2_initial_response();
assert_eq!(&response[..3], b"n,,");
assert_eq!(response[3], 0x01);
assert_eq!(&response[4..16], b"auth=Bearer ");
let len = response.len();
assert_eq!(response[len - 2], 0x01);
assert_eq!(response[len - 1], 0x01);
}
#[test]
fn test_oauthbearer_empty_token_produces_valid_gs2() {
let token = OAuthBearerToken::new("");
let response = token.to_gs2_initial_response();
assert_eq!(response, b"n,,\x01auth=Bearer \x01\x01");
}
#[test]
fn test_oauthbearer_token_clone() {
let token = OAuthBearerToken::new("tok").with_extension("k", "v");
let cloned = token.clone();
assert_eq!(
cloned.to_gs2_initial_response(),
token.to_gs2_initial_response()
);
}
#[tokio::test]
async fn test_token_provider_closure_impl() {
let provider = || async { Ok(OAuthBearerToken::new("from-closure")) };
let token = provider.provide_token().await.unwrap();
assert_eq!(
token.to_gs2_initial_response(),
OAuthBearerToken::new("from-closure").to_gs2_initial_response()
);
}
#[tokio::test]
async fn test_token_provider_handle() {
let handle = OAuthBearerTokenProviderHandle::new(|| async {
Ok(OAuthBearerToken::new("handle-token"))
});
let token = handle.provide_token().await.unwrap();
assert_eq!(
token.to_gs2_initial_response(),
OAuthBearerToken::new("handle-token").to_gs2_initial_response()
);
}
#[test]
fn test_token_provider_handle_clone() {
let handle =
OAuthBearerTokenProviderHandle::new(|| async { Ok(OAuthBearerToken::new("tok")) });
let cloned = handle.clone();
assert!(Arc::ptr_eq(&handle.0, &cloned.0));
}
#[test]
fn test_token_provider_handle_debug_no_secrets() {
let handle = OAuthBearerTokenProviderHandle::new(|| async {
Ok(OAuthBearerToken::new("super-secret"))
});
let debug = format!("{handle:?}");
assert_eq!(debug, "[OAuthBearerTokenProvider]");
assert!(!debug.contains("super-secret"));
}
#[tokio::test]
async fn test_token_provider_error_propagation() {
let handle = OAuthBearerTokenProviderHandle::new(|| async {
Err(KrafkaError::auth("token expired"))
});
let result = handle.provide_token().await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("token expired"));
}
#[tokio::test]
async fn test_token_provider_struct_impl() {
struct StaticProvider {
token: String,
}
impl OAuthBearerTokenProvider for StaticProvider {
fn provide_token(
&self,
) -> Pin<Box<dyn Future<Output = Result<OAuthBearerToken>> + Send + '_>> {
let token = self.token.clone();
Box::pin(async move { Ok(OAuthBearerToken::new(token)) })
}
}
let provider = StaticProvider {
token: "struct-token".to_string(),
};
let handle = OAuthBearerTokenProviderHandle::new(provider);
let token = handle.provide_token().await.unwrap();
assert_eq!(
token.to_gs2_initial_response(),
OAuthBearerToken::new("struct-token").to_gs2_initial_response()
);
}
#[test]
fn test_oauthbearer_token_not_expired_without_lifetime() {
let token = OAuthBearerToken::new("tok");
assert!(!token.is_expired());
assert!(token.lifetime_ms().is_none());
}
#[test]
fn test_oauthbearer_token_not_expired_future_lifetime() {
let future_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64
+ 3_600_000;
let token = OAuthBearerToken::new("tok").with_lifetime_ms(future_ms);
assert!(!token.is_expired());
assert_eq!(token.lifetime_ms(), Some(future_ms));
}
#[test]
fn test_oauthbearer_token_expired_past_lifetime() {
let past_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64
- 3_600_000;
let token = OAuthBearerToken::new("tok").with_lifetime_ms(past_ms);
assert!(token.is_expired());
}
#[test]
fn test_oauthbearer_token_needs_refresh_near_expiry() {
let near_future_ms = current_epoch_ms() + 10_000;
let token = OAuthBearerToken::new("tok").with_lifetime_ms(near_future_ms);
assert!(!token.is_expired());
assert!(token.needs_refresh());
}
#[test]
fn test_oauthbearer_token_does_not_need_refresh_with_safe_margin() {
let future_ms = current_epoch_ms() + 60_000;
let token = OAuthBearerToken::new("tok").with_lifetime_ms(future_ms);
assert!(!token.needs_refresh());
}
}