pub mod msk_iam;
pub mod oauthbearer;
pub mod scram;
pub mod tls;
pub use msk_iam::MskIamAuthenticator;
pub use oauthbearer::{OAuthBearerToken, OAuthBearerTokenProvider, OAuthBearerTokenProviderHandle};
pub use scram::{
ChannelBinding, MAX_PBKDF2_ITERATIONS, MIN_PBKDF2_ITERATIONS, ScramClient, ScramMechanism,
ScramState,
};
pub use tls::{
MaybeSecureStream, build_tls_config, build_tls_connector, connect_tls,
extract_tls_server_end_point,
};
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
pub trait AwsMskIamCredentialProvider: Send + Sync {
fn provide_credentials(
&self,
) -> Pin<Box<dyn Future<Output = crate::error::Result<AwsMskIamCredentials>> + Send + '_>>;
}
impl<F, Fut> AwsMskIamCredentialProvider for F
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = crate::error::Result<AwsMskIamCredentials>> + Send + 'static,
{
fn provide_credentials(
&self,
) -> Pin<Box<dyn Future<Output = crate::error::Result<AwsMskIamCredentials>> + Send + '_>> {
Box::pin(self())
}
}
#[derive(Clone)]
pub struct AwsMskIamCredentialProviderHandle(Arc<dyn AwsMskIamCredentialProvider>);
impl AwsMskIamCredentialProviderHandle {
pub fn new(provider: impl AwsMskIamCredentialProvider + 'static) -> Self {
Self(Arc::new(provider))
}
pub async fn provide_credentials(&self) -> crate::error::Result<AwsMskIamCredentials> {
self.0.provide_credentials().await
}
}
impl fmt::Debug for AwsMskIamCredentialProviderHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("[AwsMskIamCredentialProvider]")
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum SecurityProtocol {
#[default]
Plaintext,
Ssl,
SaslPlaintext,
SaslSsl,
}
impl fmt::Display for SecurityProtocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SecurityProtocol::Plaintext => write!(f, "PLAINTEXT"),
SecurityProtocol::Ssl => write!(f, "SSL"),
SecurityProtocol::SaslPlaintext => write!(f, "SASL_PLAINTEXT"),
SecurityProtocol::SaslSsl => write!(f, "SASL_SSL"),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SaslMechanism {
Plain,
ScramSha256,
ScramSha512,
AwsMskIam,
OAuthBearer,
Gssapi,
}
impl fmt::Display for SaslMechanism {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SaslMechanism::Plain => write!(f, "PLAIN"),
SaslMechanism::ScramSha256 => write!(f, "SCRAM-SHA-256"),
SaslMechanism::ScramSha512 => write!(f, "SCRAM-SHA-512"),
SaslMechanism::AwsMskIam => write!(f, "AWS_MSK_IAM"),
SaslMechanism::OAuthBearer => write!(f, "OAUTHBEARER"),
SaslMechanism::Gssapi => write!(f, "GSSAPI"),
}
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct PlainCredentials {
pub username: String,
pub password: String,
}
impl PlainCredentials {
pub fn new(username: impl Into<String>, password: impl Into<String>) -> crate::Result<Self> {
let username = username.into();
let password = password.into();
if username.is_empty() {
return Err(crate::error::KrafkaError::config(
"PLAIN username must not be empty",
));
}
if username.contains('\0') {
return Err(crate::error::KrafkaError::config(
"PLAIN username must not contain null bytes",
));
}
if password.contains('\0') {
return Err(crate::error::KrafkaError::config(
"PLAIN password must not contain null bytes",
));
}
Ok(Self { username, password })
}
pub fn to_auth_bytes(&self) -> Zeroizing<Vec<u8>> {
let mut auth = Vec::new();
auth.push(0);
auth.extend_from_slice(self.username.as_bytes());
auth.push(0);
auth.extend_from_slice(self.password.as_bytes());
Zeroizing::new(auth)
}
}
impl fmt::Debug for PlainCredentials {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PlainCredentials")
.field("username", &self.username)
.field("password", &"[REDACTED]")
.finish()
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct ScramCredentials {
pub username: String,
pub password: String,
}
impl ScramCredentials {
pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
username: username.into(),
password: password.into(),
}
}
}
impl fmt::Debug for ScramCredentials {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ScramCredentials")
.field("username", &self.username)
.field("password", &"[REDACTED]")
.finish()
}
}
#[non_exhaustive]
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct AwsMskIamCredentials {
pub access_key_id: String,
pub secret_access_key: String,
pub session_token: Option<String>,
pub region: String,
}
impl AwsMskIamCredentials {
pub fn new(
access_key_id: impl Into<String>,
secret_access_key: impl Into<String>,
region: impl Into<String>,
) -> Self {
Self {
access_key_id: access_key_id.into(),
secret_access_key: secret_access_key.into(),
session_token: None,
region: region.into(),
}
}
pub fn with_session_token(
access_key_id: impl Into<String>,
secret_access_key: impl Into<String>,
session_token: impl Into<String>,
region: impl Into<String>,
) -> Self {
Self {
access_key_id: access_key_id.into(),
secret_access_key: secret_access_key.into(),
session_token: Some(session_token.into()),
region: region.into(),
}
}
pub fn from_env() -> crate::error::Result<Self> {
let access_key_id = std::env::var("AWS_ACCESS_KEY_ID").map_err(|_| {
crate::error::KrafkaError::config("AWS_ACCESS_KEY_ID environment variable not set")
})?;
let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY").map_err(|_| {
crate::error::KrafkaError::config("AWS_SECRET_ACCESS_KEY environment variable not set")
})?;
let session_token = std::env::var("AWS_SESSION_TOKEN").ok();
let region = std::env::var("AWS_REGION")
.or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
.map_err(|_| {
crate::error::KrafkaError::config(
"AWS_REGION or AWS_DEFAULT_REGION environment variable not set",
)
})?;
Ok(Self {
access_key_id,
secret_access_key,
session_token,
region,
})
}
#[cfg(feature = "aws-msk")]
pub async fn from_default_chain(region: impl Into<String>) -> crate::error::Result<Self> {
use aws_config::BehaviorVersion;
use aws_credential_types::provider::ProvideCredentials;
let region_str = region.into();
let region = aws_config::Region::new(region_str.clone());
let config = aws_config::defaults(BehaviorVersion::latest())
.region(region)
.load()
.await;
let credentials_provider = config.credentials_provider().ok_or_else(|| {
crate::error::KrafkaError::config("No credentials provider available in AWS config")
})?;
let credentials = credentials_provider
.provide_credentials()
.await
.map_err(|e| {
crate::error::KrafkaError::config(format!("Failed to load AWS credentials: {e}"))
})?;
Ok(Self {
access_key_id: credentials.access_key_id().to_string(),
secret_access_key: credentials.secret_access_key().to_string(),
session_token: credentials.session_token().map(|s| s.to_string()),
region: region_str,
})
}
}
impl fmt::Debug for AwsMskIamCredentials {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AwsMskIamCredentials")
.field("access_key_id", &self.access_key_id)
.field("secret_access_key", &"[REDACTED]")
.field(
"session_token",
&self.session_token.as_ref().map(|_| "[REDACTED]"),
)
.field("region", &self.region)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub(crate) ca_cert_path: Option<String>,
pub(crate) client_cert_path: Option<String>,
pub(crate) client_key_path: Option<String>,
pub(crate) use_native_roots: bool,
pub(crate) verify_server_cert: bool,
pub(crate) sni_hostname: Option<String>,
pub(crate) alpn_protocols: Vec<Vec<u8>>,
}
impl Default for TlsConfig {
fn default() -> Self {
Self {
ca_cert_path: None,
client_cert_path: None,
client_key_path: None,
use_native_roots: false,
verify_server_cert: true,
sni_hostname: None,
alpn_protocols: Vec::new(),
}
}
}
impl TlsConfig {
pub fn new() -> Self {
Self::default()
}
#[cfg(feature = "danger-insecure-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "danger-insecure-tls")))]
pub fn insecure() -> Self {
Self {
verify_server_cert: false,
..Default::default()
}
}
pub fn with_ca_cert(mut self, path: impl Into<String>) -> Self {
self.ca_cert_path = Some(path.into());
self
}
#[cfg(feature = "native-tls-roots")]
#[cfg_attr(docsrs, doc(cfg(feature = "native-tls-roots")))]
pub fn with_native_roots(mut self) -> Self {
self.use_native_roots = true;
self
}
pub fn with_client_cert(
mut self,
cert_path: impl Into<String>,
key_path: impl Into<String>,
) -> Self {
self.client_cert_path = Some(cert_path.into());
self.client_key_path = Some(key_path.into());
self
}
pub fn with_sni_hostname(mut self, hostname: impl Into<String>) -> Self {
self.sni_hostname = Some(hostname.into());
self
}
pub fn ca_cert_path(&self) -> Option<&str> {
self.ca_cert_path.as_deref()
}
pub fn client_cert_path(&self) -> Option<&str> {
self.client_cert_path.as_deref()
}
pub fn client_key_path(&self) -> Option<&str> {
self.client_key_path.as_deref()
}
pub fn use_native_roots(&self) -> bool {
self.use_native_roots
}
pub fn verify_server_cert(&self) -> bool {
self.verify_server_cert
}
pub fn sni_hostname(&self) -> Option<&str> {
self.sni_hostname.as_deref()
}
pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
self.alpn_protocols = protocols;
self
}
pub fn with_kafka_alpn(self) -> Self {
self.with_alpn_protocols(vec![b"kafka".to_vec()])
}
pub fn alpn_protocols(&self) -> &[Vec<u8>] {
&self.alpn_protocols
}
}
#[derive(Debug, Clone, Default)]
pub struct AuthConfig {
pub(crate) security_protocol: SecurityProtocol,
pub(crate) sasl_mechanism: Option<SaslMechanism>,
pub(crate) plain_credentials: Option<PlainCredentials>,
pub(crate) scram_credentials: Option<ScramCredentials>,
pub(crate) aws_msk_iam_credentials: Option<AwsMskIamCredentials>,
pub(crate) aws_msk_iam_credential_provider: Option<AwsMskIamCredentialProviderHandle>,
pub(crate) oauthbearer_token: Option<OAuthBearerToken>,
pub(crate) oauthbearer_provider: Option<OAuthBearerTokenProviderHandle>,
pub(crate) tls_config: Option<TlsConfig>,
}
impl AuthConfig {
pub fn plaintext() -> Self {
Self {
security_protocol: SecurityProtocol::Plaintext,
..Default::default()
}
}
pub fn ssl(tls_config: TlsConfig) -> Self {
Self {
security_protocol: SecurityProtocol::Ssl,
tls_config: Some(tls_config),
..Default::default()
}
}
pub fn sasl_plain(
username: impl Into<String>,
password: impl Into<String>,
) -> crate::Result<Self> {
Ok(Self {
security_protocol: SecurityProtocol::SaslPlaintext,
sasl_mechanism: Some(SaslMechanism::Plain),
plain_credentials: Some(PlainCredentials::new(username, password)?),
..Default::default()
})
}
pub fn sasl_plain_ssl(
username: impl Into<String>,
password: impl Into<String>,
tls_config: TlsConfig,
) -> crate::Result<Self> {
Ok(Self {
security_protocol: SecurityProtocol::SaslSsl,
sasl_mechanism: Some(SaslMechanism::Plain),
plain_credentials: Some(PlainCredentials::new(username, password)?),
tls_config: Some(tls_config),
..Default::default()
})
}
pub fn sasl_scram_sha256(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
security_protocol: SecurityProtocol::SaslPlaintext,
sasl_mechanism: Some(SaslMechanism::ScramSha256),
scram_credentials: Some(ScramCredentials::new(username, password)),
..Default::default()
}
}
pub fn sasl_scram_sha512(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
security_protocol: SecurityProtocol::SaslPlaintext,
sasl_mechanism: Some(SaslMechanism::ScramSha512),
scram_credentials: Some(ScramCredentials::new(username, password)),
..Default::default()
}
}
pub fn aws_msk_iam(
access_key_id: impl Into<String>,
secret_access_key: impl Into<String>,
region: impl Into<String>,
) -> Self {
Self {
security_protocol: SecurityProtocol::SaslSsl,
sasl_mechanism: Some(SaslMechanism::AwsMskIam),
aws_msk_iam_credentials: Some(AwsMskIamCredentials::new(
access_key_id,
secret_access_key,
region,
)),
tls_config: Some(TlsConfig::new()),
..Default::default()
}
}
pub fn aws_msk_iam_with_credentials(credentials: AwsMskIamCredentials) -> Self {
Self {
security_protocol: SecurityProtocol::SaslSsl,
sasl_mechanism: Some(SaslMechanism::AwsMskIam),
aws_msk_iam_credentials: Some(credentials),
tls_config: Some(TlsConfig::new()),
..Default::default()
}
}
pub fn aws_msk_iam_provider(provider: impl AwsMskIamCredentialProvider + 'static) -> Self {
Self {
security_protocol: SecurityProtocol::SaslSsl,
sasl_mechanism: Some(SaslMechanism::AwsMskIam),
aws_msk_iam_credential_provider: Some(AwsMskIamCredentialProviderHandle::new(provider)),
tls_config: Some(TlsConfig::new()),
..Default::default()
}
}
pub fn sasl_oauthbearer(token: impl Into<String>) -> Self {
Self {
security_protocol: SecurityProtocol::SaslPlaintext,
sasl_mechanism: Some(SaslMechanism::OAuthBearer),
oauthbearer_token: Some(OAuthBearerToken::new(token)),
..Default::default()
}
}
pub fn sasl_oauthbearer_ssl(token: impl Into<String>, tls_config: TlsConfig) -> Self {
Self {
security_protocol: SecurityProtocol::SaslSsl,
sasl_mechanism: Some(SaslMechanism::OAuthBearer),
oauthbearer_token: Some(OAuthBearerToken::new(token)),
tls_config: Some(tls_config),
..Default::default()
}
}
pub fn sasl_oauthbearer_token(token: OAuthBearerToken) -> Self {
Self {
security_protocol: SecurityProtocol::SaslPlaintext,
sasl_mechanism: Some(SaslMechanism::OAuthBearer),
oauthbearer_token: Some(token),
..Default::default()
}
}
pub fn sasl_oauthbearer_token_ssl(token: OAuthBearerToken, tls_config: TlsConfig) -> Self {
Self {
security_protocol: SecurityProtocol::SaslSsl,
sasl_mechanism: Some(SaslMechanism::OAuthBearer),
oauthbearer_token: Some(token),
tls_config: Some(tls_config),
..Default::default()
}
}
pub fn sasl_oauthbearer_provider(provider: impl OAuthBearerTokenProvider + 'static) -> Self {
Self {
security_protocol: SecurityProtocol::SaslPlaintext,
sasl_mechanism: Some(SaslMechanism::OAuthBearer),
oauthbearer_provider: Some(OAuthBearerTokenProviderHandle::new(provider)),
..Default::default()
}
}
pub fn sasl_oauthbearer_provider_ssl(
provider: impl OAuthBearerTokenProvider + 'static,
tls_config: TlsConfig,
) -> Self {
Self {
security_protocol: SecurityProtocol::SaslSsl,
sasl_mechanism: Some(SaslMechanism::OAuthBearer),
oauthbearer_provider: Some(OAuthBearerTokenProviderHandle::new(provider)),
tls_config: Some(tls_config),
..Default::default()
}
}
pub async fn resolve_provider_to_token(&self) -> crate::error::Result<Option<AuthConfig>> {
if self.sasl_mechanism == Some(SaslMechanism::OAuthBearer)
&& let Some(ref provider) = self.oauthbearer_provider
{
let token = provider.provide_token().await?;
Ok(Some(AuthConfig {
oauthbearer_token: Some(token),
oauthbearer_provider: None,
..self.clone()
}))
} else {
Ok(None)
}
}
pub async fn resolve_msk_iam_provider(&self) -> crate::error::Result<Option<AuthConfig>> {
if self.sasl_mechanism == Some(SaslMechanism::AwsMskIam)
&& let Some(ref provider) = self.aws_msk_iam_credential_provider
{
let credentials = provider.provide_credentials().await?;
Ok(Some(AuthConfig {
aws_msk_iam_credentials: Some(credentials),
aws_msk_iam_credential_provider: None,
..self.clone()
}))
} else {
Ok(None)
}
}
pub fn requires_tls(&self) -> bool {
matches!(
self.security_protocol,
SecurityProtocol::Ssl | SecurityProtocol::SaslSsl
)
}
pub fn requires_sasl(&self) -> bool {
matches!(
self.security_protocol,
SecurityProtocol::SaslPlaintext | SecurityProtocol::SaslSsl
)
}
pub fn security_protocol(&self) -> &SecurityProtocol {
&self.security_protocol
}
pub fn sasl_mechanism(&self) -> Option<&SaslMechanism> {
self.sasl_mechanism.as_ref()
}
pub fn plain_credentials(&self) -> Option<&PlainCredentials> {
self.plain_credentials.as_ref()
}
pub fn scram_credentials(&self) -> Option<&ScramCredentials> {
self.scram_credentials.as_ref()
}
pub fn aws_msk_iam_credentials(&self) -> Option<&AwsMskIamCredentials> {
self.aws_msk_iam_credentials.as_ref()
}
pub fn aws_msk_iam_credential_provider(&self) -> Option<&AwsMskIamCredentialProviderHandle> {
self.aws_msk_iam_credential_provider.as_ref()
}
pub fn oauthbearer_token(&self) -> Option<&OAuthBearerToken> {
self.oauthbearer_token.as_ref()
}
pub fn oauthbearer_provider(&self) -> Option<&OAuthBearerTokenProviderHandle> {
self.oauthbearer_provider.as_ref()
}
pub fn tls_config(&self) -> Option<&TlsConfig> {
self.tls_config.as_ref()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_security_protocol_display() {
assert_eq!(SecurityProtocol::Plaintext.to_string(), "PLAINTEXT");
assert_eq!(SecurityProtocol::Ssl.to_string(), "SSL");
assert_eq!(
SecurityProtocol::SaslPlaintext.to_string(),
"SASL_PLAINTEXT"
);
assert_eq!(SecurityProtocol::SaslSsl.to_string(), "SASL_SSL");
}
#[test]
fn test_sasl_mechanism_display() {
assert_eq!(SaslMechanism::Plain.to_string(), "PLAIN");
assert_eq!(SaslMechanism::ScramSha256.to_string(), "SCRAM-SHA-256");
assert_eq!(SaslMechanism::AwsMskIam.to_string(), "AWS_MSK_IAM");
}
#[test]
fn test_plain_credentials() {
let creds = PlainCredentials::new("user", "pass").unwrap();
let auth_bytes = creds.to_auth_bytes();
assert_eq!(&*auth_bytes, b"\0user\0pass");
}
#[test]
fn test_auth_config_plaintext() {
let config = AuthConfig::plaintext();
assert_eq!(config.security_protocol, SecurityProtocol::Plaintext);
assert!(!config.requires_tls());
assert!(!config.requires_sasl());
}
#[test]
fn test_auth_config_sasl_plain() {
let config = AuthConfig::sasl_plain("user", "pass").unwrap();
assert_eq!(config.security_protocol, SecurityProtocol::SaslPlaintext);
assert_eq!(config.sasl_mechanism, Some(SaslMechanism::Plain));
assert!(config.plain_credentials.is_some());
assert!(!config.requires_tls());
assert!(config.requires_sasl());
}
#[test]
fn test_auth_config_aws_msk_iam() {
let config = AuthConfig::aws_msk_iam("access_key", "secret_key", "us-east-1");
assert_eq!(config.security_protocol, SecurityProtocol::SaslSsl);
assert_eq!(config.sasl_mechanism, Some(SaslMechanism::AwsMskIam));
assert!(config.aws_msk_iam_credentials.is_some());
assert!(config.requires_tls());
assert!(config.requires_sasl());
}
#[test]
#[cfg(feature = "native-tls-roots")]
fn test_tls_config() {
let config = TlsConfig::new()
.with_ca_cert("/path/to/ca.pem")
.with_client_cert("/path/to/client.pem", "/path/to/client.key")
.with_native_roots();
assert!(config.verify_server_cert);
assert!(config.use_native_roots());
assert_eq!(config.ca_cert_path, Some("/path/to/ca.pem".to_string()));
assert_eq!(
config.client_cert_path,
Some("/path/to/client.pem".to_string())
);
}
#[test]
fn test_credentials_debug_redacts_password() {
let creds = PlainCredentials::new("user", "secret").unwrap();
let debug_str = format!("{creds:?}");
assert!(debug_str.contains("user"));
assert!(debug_str.contains("[REDACTED]"));
assert!(!debug_str.contains("secret"));
}
#[test]
fn test_aws_msk_credentials_manual_creation() {
let creds = AwsMskIamCredentials::new("AKID123", "secret123", "us-west-2");
assert_eq!(creds.access_key_id, "AKID123");
assert_eq!(creds.region, "us-west-2");
assert!(creds.session_token.is_none());
}
#[test]
fn test_aws_msk_credentials_with_session_token() {
let creds = AwsMskIamCredentials::with_session_token(
"AKID123",
"secret123",
"token123",
"us-east-1",
);
assert_eq!(creds.access_key_id, "AKID123");
assert_eq!(creds.session_token, Some("token123".to_string()));
}
#[test]
fn test_aws_msk_credentials_debug_redacts() {
let creds = AwsMskIamCredentials::new("AKID123", "supersecret", "us-east-1");
let debug_str = format!("{creds:?}");
assert!(debug_str.contains("AKID123"));
assert!(debug_str.contains("[REDACTED]"));
assert!(!debug_str.contains("supersecret"));
}
#[test]
fn test_auth_config_sasl_oauthbearer() {
let config = AuthConfig::sasl_oauthbearer("my-token");
assert_eq!(config.security_protocol, SecurityProtocol::SaslPlaintext);
assert_eq!(config.sasl_mechanism, Some(SaslMechanism::OAuthBearer));
assert!(config.oauthbearer_token.is_some());
assert!(!config.requires_tls());
assert!(config.requires_sasl());
}
#[test]
fn test_auth_config_sasl_oauthbearer_ssl() {
let config = AuthConfig::sasl_oauthbearer_ssl("my-token", TlsConfig::new());
assert_eq!(config.security_protocol, SecurityProtocol::SaslSsl);
assert_eq!(config.sasl_mechanism, Some(SaslMechanism::OAuthBearer));
assert!(config.oauthbearer_token.is_some());
assert!(config.tls_config.is_some());
assert!(config.requires_tls());
assert!(config.requires_sasl());
}
#[test]
fn test_auth_config_sasl_oauthbearer_token() {
let token = OAuthBearerToken::new("jwt").with_extension("logicalCluster", "lkc-1");
let config = AuthConfig::sasl_oauthbearer_token(token);
assert_eq!(config.sasl_mechanism, Some(SaslMechanism::OAuthBearer));
assert!(config.oauthbearer_token.is_some());
}
#[test]
fn test_auth_config_sasl_oauthbearer_token_ssl() {
let token = OAuthBearerToken::new("jwt");
let config = AuthConfig::sasl_oauthbearer_token_ssl(token, TlsConfig::new());
assert_eq!(config.security_protocol, SecurityProtocol::SaslSsl);
assert_eq!(config.sasl_mechanism, Some(SaslMechanism::OAuthBearer));
assert!(config.oauthbearer_token.is_some());
assert!(config.tls_config.is_some());
}
#[test]
fn test_auth_config_sasl_oauthbearer_provider() {
let config =
AuthConfig::sasl_oauthbearer_provider(|| async { Ok(OAuthBearerToken::new("tok")) });
assert_eq!(config.security_protocol, SecurityProtocol::SaslPlaintext);
assert_eq!(config.sasl_mechanism, Some(SaslMechanism::OAuthBearer));
assert!(config.oauthbearer_provider.is_some());
assert!(config.oauthbearer_token.is_none());
assert!(!config.requires_tls());
assert!(config.requires_sasl());
}
#[test]
fn test_auth_config_sasl_oauthbearer_provider_ssl() {
let config = AuthConfig::sasl_oauthbearer_provider_ssl(
|| async { Ok(OAuthBearerToken::new("tok")) },
TlsConfig::new(),
);
assert_eq!(config.security_protocol, SecurityProtocol::SaslSsl);
assert_eq!(config.sasl_mechanism, Some(SaslMechanism::OAuthBearer));
assert!(config.oauthbearer_provider.is_some());
assert!(config.tls_config.is_some());
assert!(config.requires_tls());
assert!(config.requires_sasl());
}
#[test]
fn test_auth_config_provider_debug_no_secrets() {
let config =
AuthConfig::sasl_oauthbearer_provider(|| async { Ok(OAuthBearerToken::new("secret")) });
let debug = format!("{config:?}");
assert!(!debug.contains("secret"));
assert!(debug.contains("[OAuthBearerTokenProvider]"));
}
#[tokio::test]
async fn test_resolve_provider_to_token_calls_provider() {
let config =
AuthConfig::sasl_oauthbearer_provider(|| async { Ok(OAuthBearerToken::new("fresh")) });
let resolved = config.resolve_provider_to_token().await.unwrap().unwrap();
assert!(resolved.oauthbearer_token.is_some());
assert_eq!(
resolved
.oauthbearer_token
.unwrap()
.to_gs2_initial_response(),
OAuthBearerToken::new("fresh").to_gs2_initial_response()
);
assert!(resolved.oauthbearer_provider.is_none());
assert_eq!(resolved.sasl_mechanism, Some(SaslMechanism::OAuthBearer));
assert_eq!(resolved.security_protocol, SecurityProtocol::SaslPlaintext);
}
#[tokio::test]
async fn test_resolve_provider_to_token_preserves_tls() {
let config = AuthConfig::sasl_oauthbearer_provider_ssl(
|| async { Ok(OAuthBearerToken::new("tok")) },
TlsConfig::new(),
);
let resolved = config.resolve_provider_to_token().await.unwrap().unwrap();
assert!(resolved.tls_config.is_some());
assert_eq!(resolved.security_protocol, SecurityProtocol::SaslSsl);
}
#[tokio::test]
async fn test_resolve_provider_to_token_returns_none_for_static() {
let config = AuthConfig::sasl_oauthbearer("static-tok");
assert!(config.resolve_provider_to_token().await.unwrap().is_none());
}
#[tokio::test]
async fn test_resolve_provider_to_token_returns_none_for_non_oauth() {
let config = AuthConfig::sasl_plain("user", "pass").unwrap();
assert!(config.resolve_provider_to_token().await.unwrap().is_none());
}
#[tokio::test]
async fn test_resolve_provider_to_token_propagates_error() {
let config = AuthConfig::sasl_oauthbearer_provider(|| async {
Err(crate::error::KrafkaError::auth("oauth server down"))
});
let err = config.resolve_provider_to_token().await.unwrap_err();
assert!(err.to_string().contains("oauth server down"));
}
#[test]
fn test_auth_config_aws_msk_iam_provider() {
let config = AuthConfig::aws_msk_iam_provider(|| async {
Ok(AwsMskIamCredentials::new("AKID", "secret", "us-east-1"))
});
assert_eq!(config.security_protocol, SecurityProtocol::SaslSsl);
assert_eq!(config.sasl_mechanism, Some(SaslMechanism::AwsMskIam));
assert!(config.aws_msk_iam_credential_provider.is_some());
assert!(config.aws_msk_iam_credentials.is_none());
assert!(config.tls_config.is_some());
}
#[test]
fn test_msk_iam_provider_debug_no_secrets() {
let config = AuthConfig::aws_msk_iam_provider(|| async {
Ok(AwsMskIamCredentials::new("AKID", "secret", "us-east-1"))
});
let debug = format!("{config:?}");
assert!(!debug.contains("secret"));
assert!(debug.contains("[AwsMskIamCredentialProvider]"));
}
#[tokio::test]
async fn test_resolve_msk_iam_provider_calls_provider() {
let config = AuthConfig::aws_msk_iam_provider(|| async {
Ok(AwsMskIamCredentials::new("AKID", "secret", "us-east-1"))
});
let resolved = config.resolve_msk_iam_provider().await.unwrap().unwrap();
assert!(resolved.aws_msk_iam_credentials.is_some());
assert_eq!(
resolved.aws_msk_iam_credentials.as_ref().unwrap().region,
"us-east-1"
);
assert!(resolved.aws_msk_iam_credential_provider.is_none());
assert_eq!(resolved.sasl_mechanism, Some(SaslMechanism::AwsMskIam));
assert_eq!(resolved.security_protocol, SecurityProtocol::SaslSsl);
}
#[tokio::test]
async fn test_resolve_msk_iam_provider_returns_none_for_static() {
let config = AuthConfig::aws_msk_iam("AKID", "secret", "us-east-1");
assert!(config.resolve_msk_iam_provider().await.unwrap().is_none());
}
#[tokio::test]
async fn test_resolve_msk_iam_provider_returns_none_for_non_msk() {
let config = AuthConfig::sasl_plain("user", "pass").unwrap();
assert!(config.resolve_msk_iam_provider().await.unwrap().is_none());
}
#[tokio::test]
async fn test_resolve_msk_iam_provider_propagates_error() {
let config = AuthConfig::aws_msk_iam_provider(|| async {
Err(crate::error::KrafkaError::auth(
"AWS credential fetch failed",
))
});
let err = config.resolve_msk_iam_provider().await.unwrap_err();
assert!(err.to_string().contains("AWS credential fetch failed"));
}
}