use std::time::Duration;
use crate::auth::{
AuthConfig, AwsMskIamCredentialProvider, ChannelBinding, MskIamAuthenticator, OAuthBearerToken,
OAuthBearerTokenProvider, PlainCredentials, SaslMechanism, ScramClient, ScramMechanism,
SecurityProtocol, TlsConfig,
};
use crate::error::{KrafkaError, Result};
use zeroize::Zeroizing;
use super::connection::ConnectionConfig;
#[derive(Debug, Clone)]
pub struct SecureConnectionConfig {
pub connection: ConnectionConfig,
pub auth: AuthConfig,
}
impl Default for SecureConnectionConfig {
fn default() -> Self {
Self {
connection: ConnectionConfig::default(),
auth: AuthConfig::plaintext(),
}
}
}
impl SecureConnectionConfig {
pub fn builder() -> SecureConnectionConfigBuilder {
SecureConnectionConfigBuilder::default()
}
}
#[must_use = "builders do nothing until .build() is called"]
#[derive(Debug, Default)]
pub struct SecureConnectionConfigBuilder {
connection: ConnectionConfig,
auth: AuthConfig,
}
impl SecureConnectionConfigBuilder {
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connection.connect_timeout = timeout;
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.connection.request_timeout = timeout;
self
}
pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
self.connection.client_id = client_id.into();
self
}
pub fn nodelay(mut self, nodelay: bool) -> Self {
self.connection.nodelay = nodelay;
self
}
pub fn auth(mut self, auth: AuthConfig) -> Self {
self.auth = auth;
self
}
pub fn sasl_plain(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> crate::Result<Self> {
self.auth = AuthConfig::sasl_plain(username, password)?;
Ok(self)
}
pub fn sasl_scram_sha256(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.auth = AuthConfig::sasl_scram_sha256(username, password);
self
}
pub fn sasl_scram_sha512(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.auth = AuthConfig::sasl_scram_sha512(username, password);
self
}
pub fn aws_msk_iam(
mut self,
access_key_id: impl Into<String>,
secret_access_key: impl Into<String>,
region: impl Into<String>,
) -> Self {
self.auth = AuthConfig::aws_msk_iam(access_key_id, secret_access_key, region);
self
}
pub fn aws_msk_iam_provider(
mut self,
provider: impl AwsMskIamCredentialProvider + 'static,
) -> Self {
self.auth = AuthConfig::aws_msk_iam_provider(provider);
self
}
pub fn sasl_oauthbearer(mut self, token: impl Into<String>) -> Self {
self.auth = AuthConfig::sasl_oauthbearer(token);
self
}
pub fn sasl_oauthbearer_token(mut self, token: OAuthBearerToken) -> Self {
self.auth = AuthConfig::sasl_oauthbearer_token(token);
self
}
pub fn sasl_oauthbearer_provider(
mut self,
provider: impl OAuthBearerTokenProvider + 'static,
) -> Self {
self.auth = AuthConfig::sasl_oauthbearer_provider(provider);
self
}
pub fn tls(mut self, tls_config: TlsConfig) -> Self {
self.auth.tls_config = Some(tls_config);
if self.auth.security_protocol == SecurityProtocol::Plaintext {
self.auth.security_protocol = SecurityProtocol::Ssl;
} else if self.auth.security_protocol == SecurityProtocol::SaslPlaintext {
self.auth.security_protocol = SecurityProtocol::SaslSsl;
}
self
}
pub fn build(self) -> SecureConnectionConfig {
SecureConnectionConfig {
connection: self.connection,
auth: self.auth,
}
}
}
#[non_exhaustive]
#[derive(Debug)]
pub enum ChallengeResponse {
Continue(Zeroizing<Vec<u8>>),
AckThenFail {
ack: Vec<u8>,
error: KrafkaError,
},
Done,
}
pub struct SaslAuthenticator {
mechanism: SaslMechanism,
plain_credentials: Option<PlainCredentials>,
scram_client: Option<ScramClient>,
msk_iam_authenticator: Option<MskIamAuthenticator>,
msk_iam_complete: bool,
oauthbearer_token: Option<OAuthBearerToken>,
oauthbearer_complete: bool,
}
impl SaslAuthenticator {
pub fn new(auth: &AuthConfig, channel_binding: ChannelBinding) -> Result<Option<Self>> {
let Some(mechanism) = auth.sasl_mechanism.as_ref() else {
return Ok(None);
};
match mechanism {
SaslMechanism::Plain => Ok(Some(Self {
mechanism: SaslMechanism::Plain,
plain_credentials: auth.plain_credentials.clone(),
scram_client: None,
msk_iam_authenticator: None,
msk_iam_complete: false,
oauthbearer_token: None,
oauthbearer_complete: false,
})),
SaslMechanism::ScramSha256 => {
let creds = auth
.scram_credentials
.as_ref()
.ok_or_else(|| KrafkaError::auth("SCRAM-SHA-256 credentials not configured"))?;
Ok(Some(Self {
mechanism: SaslMechanism::ScramSha256,
plain_credentials: None,
scram_client: Some(ScramClient::new(
&creds.username,
&creds.password,
ScramMechanism::Sha256,
channel_binding,
)),
msk_iam_authenticator: None,
msk_iam_complete: false,
oauthbearer_token: None,
oauthbearer_complete: false,
}))
}
SaslMechanism::ScramSha512 => {
let creds = auth
.scram_credentials
.as_ref()
.ok_or_else(|| KrafkaError::auth("SCRAM-SHA-512 credentials not configured"))?;
Ok(Some(Self {
mechanism: SaslMechanism::ScramSha512,
plain_credentials: None,
scram_client: Some(ScramClient::new(
&creds.username,
&creds.password,
ScramMechanism::Sha512,
channel_binding,
)),
msk_iam_authenticator: None,
msk_iam_complete: false,
oauthbearer_token: None,
oauthbearer_complete: false,
}))
}
SaslMechanism::AwsMskIam => {
Ok(Some(Self {
mechanism: SaslMechanism::AwsMskIam,
plain_credentials: None,
scram_client: None,
msk_iam_authenticator: None,
msk_iam_complete: false,
oauthbearer_token: None,
oauthbearer_complete: false,
}))
}
SaslMechanism::OAuthBearer => {
let token = auth.oauthbearer_token.as_ref().cloned().ok_or_else(|| {
KrafkaError::auth(
"OAUTHBEARER mechanism requires an OAuth bearer token; \
if using a token provider, call resolve_provider_to_token() first",
)
})?;
Ok(Some(Self {
mechanism: SaslMechanism::OAuthBearer,
plain_credentials: None,
scram_client: None,
msk_iam_authenticator: None,
msk_iam_complete: false,
oauthbearer_token: Some(token),
oauthbearer_complete: false,
}))
}
SaslMechanism::Gssapi => Err(KrafkaError::auth(
"SASL/GSSAPI (Kerberos) is not available in the pure-Rust build; \
use OAUTHBEARER for token-based authentication or SCRAM-SHA-256/512 \
for password-based authentication",
)),
}
}
pub fn new_msk_iam(
auth: &AuthConfig,
host: &str,
clock_offset_secs: i64,
) -> Result<Option<Self>> {
if !matches!(auth.sasl_mechanism, Some(SaslMechanism::AwsMskIam)) {
return Ok(None);
}
let Some(creds) = auth.aws_msk_iam_credentials.as_ref() else {
return Ok(None);
};
let authenticator =
MskIamAuthenticator::new_with_clock_offset(creds, host, clock_offset_secs)?;
Ok(Some(Self {
mechanism: SaslMechanism::AwsMskIam,
plain_credentials: None,
scram_client: None,
msk_iam_authenticator: Some(authenticator),
msk_iam_complete: false,
oauthbearer_token: None,
oauthbearer_complete: false,
}))
}
pub fn set_msk_host(
&mut self,
auth: &AuthConfig,
host: &str,
clock_offset_secs: i64,
) -> Result<()> {
if self.mechanism == SaslMechanism::AwsMskIam {
let creds = auth.aws_msk_iam_credentials.as_ref().ok_or_else(|| {
KrafkaError::auth(
"AWS MSK IAM mechanism selected but no credentials available; \
if using a credential provider, ensure resolve_msk_iam_provider() \
is called before creating the authenticator",
)
})?;
self.msk_iam_authenticator = Some(MskIamAuthenticator::new_with_clock_offset(
creds,
host,
clock_offset_secs,
)?);
}
Ok(())
}
pub fn mechanism_name(&self) -> &str {
match self.mechanism {
SaslMechanism::Plain => "PLAIN",
SaslMechanism::ScramSha256 => "SCRAM-SHA-256",
SaslMechanism::ScramSha512 => "SCRAM-SHA-512",
SaslMechanism::AwsMskIam => "AWS_MSK_IAM",
SaslMechanism::OAuthBearer => "OAUTHBEARER",
SaslMechanism::Gssapi => "GSSAPI",
}
}
pub fn initial_response(&mut self) -> Result<Zeroizing<Vec<u8>>> {
match self.mechanism {
SaslMechanism::Plain => Ok(self
.plain_credentials
.as_ref()
.map(|c| c.to_auth_bytes())
.unwrap_or_default()),
SaslMechanism::ScramSha256 | SaslMechanism::ScramSha512 => Ok(Zeroizing::new(
self.scram_client
.as_mut()
.map(|c| c.client_first_message())
.unwrap_or_default(),
)),
SaslMechanism::AwsMskIam => Ok(Zeroizing::new(
self.msk_iam_authenticator
.as_ref()
.map(|a| a.create_auth_payload())
.unwrap_or_default(),
)),
SaslMechanism::OAuthBearer => {
if let Some(token) = &self.oauthbearer_token {
if token.needs_refresh() {
return Err(KrafkaError::auth(
"OAuthBearer token is expired or too close to expiry; obtain a fresh token before connecting",
));
}
Ok(Zeroizing::new(token.to_gs2_initial_response()))
} else {
Ok(Zeroizing::new(Vec::new()))
}
}
SaslMechanism::Gssapi => Ok(Zeroizing::new(Vec::new())),
}
}
pub fn process_challenge(&mut self, challenge: &[u8]) -> Result<ChallengeResponse> {
match self.mechanism {
SaslMechanism::Plain => {
Ok(ChallengeResponse::Done)
}
SaslMechanism::ScramSha256 | SaslMechanism::ScramSha512 => {
let scram = self
.scram_client
.as_mut()
.ok_or_else(|| KrafkaError::auth("SCRAM client not initialized"))?;
match scram.state() {
crate::auth::ScramState::WaitingServerFirst => {
let response = scram.process_server_first(challenge)?;
Ok(ChallengeResponse::Continue(Zeroizing::new(response)))
}
crate::auth::ScramState::WaitingServerFinal => {
scram.verify_server_final(challenge)?;
Ok(ChallengeResponse::Done)
}
_ => Err(KrafkaError::auth("Unexpected SCRAM state")),
}
}
SaslMechanism::AwsMskIam => {
self.msk_iam_complete = true;
Ok(ChallengeResponse::Done)
}
SaslMechanism::OAuthBearer => {
let token = self
.oauthbearer_token
.as_ref()
.ok_or_else(|| KrafkaError::auth("OAuthBearer token not configured"))?;
match token.process_server_response(challenge) {
Ok(()) => {
self.oauthbearer_complete = true;
Ok(ChallengeResponse::Done)
}
Err(e) => {
Ok(ChallengeResponse::AckThenFail {
ack: vec![0x01],
error: e,
})
}
}
}
SaslMechanism::Gssapi => Err(KrafkaError::auth(
"SASL/GSSAPI (Kerberos) is not available in the pure-Rust build",
)),
}
}
pub fn is_complete(&self) -> bool {
match self.mechanism {
SaslMechanism::Plain => true, SaslMechanism::ScramSha256 | SaslMechanism::ScramSha512 => self
.scram_client
.as_ref()
.is_some_and(|c| *c.state() == crate::auth::ScramState::Complete),
SaslMechanism::AwsMskIam => self.msk_iam_complete,
SaslMechanism::OAuthBearer => self.oauthbearer_complete,
SaslMechanism::Gssapi => false,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_secure_connection_config_default() {
let config = SecureConnectionConfig::default();
assert_eq!(config.auth.security_protocol, SecurityProtocol::Plaintext);
assert!(!config.auth.requires_tls());
assert!(!config.auth.requires_sasl());
}
#[test]
fn test_secure_connection_config_builder() {
let config = SecureConnectionConfig::builder()
.client_id("test-client")
.connect_timeout(Duration::from_secs(5))
.sasl_plain("user", "pass")
.unwrap()
.build();
assert_eq!(config.connection.client_id, "test-client");
assert_eq!(config.connection.connect_timeout, Duration::from_secs(5));
assert!(config.auth.requires_sasl());
}
#[test]
fn test_secure_connection_config_with_tls() {
let config = SecureConnectionConfig::builder()
.tls(TlsConfig::new())
.build();
assert!(config.auth.requires_tls());
}
#[test]
fn test_sasl_authenticator_plain() {
let auth = AuthConfig::sasl_plain("user", "pass").unwrap();
let mut authenticator = SaslAuthenticator::new(&auth, ChannelBinding::None)
.unwrap()
.unwrap();
assert_eq!(authenticator.mechanism_name(), "PLAIN");
let initial = authenticator.initial_response().unwrap();
assert_eq!(&*initial, b"\0user\0pass");
assert!(authenticator.is_complete());
}
#[test]
fn test_sasl_authenticator_scram() {
let auth = AuthConfig::sasl_scram_sha256("user", "pass");
let mut authenticator = SaslAuthenticator::new(&auth, ChannelBinding::None)
.unwrap()
.unwrap();
assert_eq!(authenticator.mechanism_name(), "SCRAM-SHA-256");
let initial = authenticator.initial_response().unwrap();
assert!(initial.starts_with(b"n,,n=user,r="));
assert!(!authenticator.is_complete());
}
#[test]
fn test_sasl_authenticator_msk_iam() {
let auth = AuthConfig::aws_msk_iam("AKIAIOSFODNN7EXAMPLE", "secret", "us-east-1");
let mut authenticator =
SaslAuthenticator::new_msk_iam(&auth, "broker.kafka.us-east-1.amazonaws.com", 0)
.unwrap()
.unwrap();
assert_eq!(authenticator.mechanism_name(), "AWS_MSK_IAM");
let initial = authenticator.initial_response().unwrap();
let payload_str = String::from_utf8(initial.to_vec()).unwrap();
assert!(payload_str.contains("\"version\":\"2020_10_22\""));
assert!(payload_str.contains("\"host\":\"broker.kafka.us-east-1.amazonaws.com\""));
assert!(payload_str.contains("\"action\":\"kafka-cluster:Connect\""));
assert!(payload_str.contains("\"x-amz-signature\":"));
assert!(!authenticator.is_complete());
authenticator.process_challenge(&[]).unwrap();
assert!(authenticator.is_complete());
}
#[test]
fn test_secure_connection_config_builder_msk_iam() {
let config = SecureConnectionConfig::builder()
.aws_msk_iam("AKID", "secret", "us-east-1")
.build();
assert!(config.auth.requires_tls());
assert!(config.auth.requires_sasl());
assert_eq!(config.auth.sasl_mechanism, Some(SaslMechanism::AwsMskIam));
}
#[test]
fn test_sasl_authenticator_oauthbearer() {
let auth = AuthConfig::sasl_oauthbearer("my-jwt-token");
let mut authenticator = SaslAuthenticator::new(&auth, ChannelBinding::None)
.unwrap()
.unwrap();
assert_eq!(authenticator.mechanism_name(), "OAUTHBEARER");
let initial = authenticator.initial_response().unwrap();
assert_eq!(&*initial, b"n,,\x01auth=Bearer my-jwt-token\x01\x01");
assert!(!authenticator.is_complete());
authenticator.process_challenge(&[]).unwrap();
assert!(authenticator.is_complete());
}
#[test]
fn test_sasl_authenticator_oauthbearer_with_extensions() {
let token = OAuthBearerToken::new("tok").with_extension("logicalCluster", "lkc-123");
let auth = AuthConfig::sasl_oauthbearer_token(token);
let mut authenticator = SaslAuthenticator::new(&auth, ChannelBinding::None)
.unwrap()
.unwrap();
let initial = authenticator.initial_response().unwrap();
let initial_str = String::from_utf8_lossy(&initial);
assert!(initial_str.starts_with("n,,\x01auth=Bearer tok"));
assert!(initial_str.contains("logicalCluster=lkc-123"));
assert!(initial_str.ends_with("\x01\x01"));
}
#[test]
fn test_sasl_authenticator_oauthbearer_server_error() {
let auth = AuthConfig::sasl_oauthbearer("bad-token");
let mut authenticator = SaslAuthenticator::new(&auth, ChannelBinding::None)
.unwrap()
.unwrap();
let _ = authenticator.initial_response().unwrap();
let result = authenticator
.process_challenge(br#"{"status":"invalid_token"}"#)
.unwrap();
match result {
ChallengeResponse::AckThenFail { ack, error } => {
assert_eq!(ack, vec![0x01]);
assert!(error.to_string().contains("invalid_token"));
}
other => panic!("expected AckThenFail, got {other:?}"),
}
assert!(!authenticator.is_complete());
}
#[test]
fn test_sasl_authenticator_oauthbearer_missing_token() {
let auth = AuthConfig {
security_protocol: SecurityProtocol::SaslPlaintext,
sasl_mechanism: Some(SaslMechanism::OAuthBearer),
oauthbearer_token: None,
..Default::default()
};
assert!(SaslAuthenticator::new(&auth, ChannelBinding::None).is_err());
}
#[test]
fn test_sasl_authenticator_gssapi_fails_gracefully() {
let auth = AuthConfig {
security_protocol: SecurityProtocol::SaslPlaintext,
sasl_mechanism: Some(SaslMechanism::Gssapi),
..Default::default()
};
assert!(SaslAuthenticator::new(&auth, ChannelBinding::None).is_err());
}
#[test]
fn test_secure_connection_config_builder_oauthbearer() {
let config = SecureConnectionConfig::builder()
.sasl_oauthbearer("my-token")
.build();
assert!(config.auth.requires_sasl());
assert_eq!(config.auth.sasl_mechanism, Some(SaslMechanism::OAuthBearer));
assert!(config.auth.oauthbearer_token.is_some());
}
#[test]
fn test_secure_connection_config_builder_oauthbearer_token() {
let token = OAuthBearerToken::new("tok").with_extension("key", "val");
let config = SecureConnectionConfig::builder()
.sasl_oauthbearer_token(token)
.build();
assert!(config.auth.requires_sasl());
assert_eq!(config.auth.sasl_mechanism, Some(SaslMechanism::OAuthBearer));
}
#[test]
fn test_secure_connection_config_builder_oauthbearer_provider() {
let config = SecureConnectionConfig::builder()
.sasl_oauthbearer_provider(|| async { Ok(OAuthBearerToken::new("provider-token")) })
.build();
assert!(config.auth.requires_sasl());
assert_eq!(config.auth.sasl_mechanism, Some(SaslMechanism::OAuthBearer));
assert!(config.auth.oauthbearer_provider.is_some());
assert!(config.auth.oauthbearer_token.is_none());
}
#[test]
fn test_sasl_authenticator_oauthbearer_expired_token_rejected() {
use std::time::{SystemTime, UNIX_EPOCH};
let past_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64
- 3_600_000;
let token = OAuthBearerToken::new("expired-jwt").with_lifetime_ms(past_ms);
let auth = AuthConfig::sasl_oauthbearer_token(token);
let mut authenticator = SaslAuthenticator::new(&auth, ChannelBinding::None)
.unwrap()
.unwrap();
let result = authenticator.initial_response();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("expired"));
}
#[test]
fn test_sasl_authenticator_oauthbearer_valid_token_accepted() {
use std::time::{SystemTime, UNIX_EPOCH};
let future_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64
+ 3_600_000;
let token = OAuthBearerToken::new("valid-jwt").with_lifetime_ms(future_ms);
let auth = AuthConfig::sasl_oauthbearer_token(token);
let mut authenticator = SaslAuthenticator::new(&auth, ChannelBinding::None)
.unwrap()
.unwrap();
let result = authenticator.initial_response();
assert!(result.is_ok());
}
#[test]
fn test_sasl_authenticator_oauthbearer_near_expiry_token_rejected() {
use std::time::{SystemTime, UNIX_EPOCH};
let near_future_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64
+ 10_000;
let token = OAuthBearerToken::new("near-expiry-jwt").with_lifetime_ms(near_future_ms);
let auth = AuthConfig::sasl_oauthbearer_token(token);
let mut authenticator = SaslAuthenticator::new(&auth, ChannelBinding::None)
.unwrap()
.unwrap();
let result = authenticator.initial_response();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("too close to expiry")
);
}
}