#![allow(dead_code)]
use std::collections::HashMap;
use std::sync::Arc;
use crabka_protocol::owned::sasl_authenticate_request::SaslAuthenticateRequest;
use crabka_protocol::owned::sasl_handshake_request::SaslHandshakeRequest;
use crabka_protocol::{Decode, Encode};
use crabka_raft::{ControllerHandle, DuplexStream, RaftHandshakeError, RaftListenerHandshake};
use crabka_security::{ListenerProtocol, SaslMechanism};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::OnceCell;
use tokio_rustls::TlsAcceptor;
use crate::network::auth::{
ConnectionAuth, SaslExchange, handle_authenticate_plain, handle_authenticate_scram,
handle_handshake, is_pre_auth_allowed,
};
pub type ControllerHandleArc = Arc<OnceCell<Arc<ControllerHandle>>>;
const API_KEY_SASL_HANDSHAKE: i16 = 17;
const API_KEY_SASL_AUTHENTICATE: i16 = 36;
const API_KEY_API_VERSIONS: i16 = 18;
pub struct BrokerRaftHandshake {
pub tls_acceptor: Option<TlsAcceptor>,
pub plain_credentials: HashMap<String, String>,
pub enabled_sasl_mechanisms: Vec<SaslMechanism>,
pub protocol: ListenerProtocol,
pub controller: ControllerHandleArc,
pub authorizer: Arc<dyn crate::authorizer::Authorizer>,
}
fn pre_auth_state() -> ConnectionAuth {
ConnectionAuth::Anonymous
}
impl BrokerRaftHandshake {
fn authorize_cluster_action(
&self,
principal: &crabka_security::Principal,
peer: &std::net::SocketAddr,
) -> Result<(), RaftHandshakeError> {
use crate::authorizer::{AuthorizationRequest, AuthorizationResult};
use crabka_metadata::{AclOperation, ResourceType};
let controller = self.controller.get().ok_or_else(|| {
RaftHandshakeError::Sasl(
"controller handle not initialised for CLUSTER_ACTION authorization".into(),
)
})?;
let image = controller.current_image();
let decision = self.authorizer.authorize(
&*image,
&AuthorizationRequest {
principal,
host: peer,
resource_type: ResourceType::Cluster,
resource_name: "kafka-cluster",
operation: AclOperation::ClusterAction,
},
);
if decision == AuthorizationResult::Deny {
tracing::warn!(
principal = %principal.name,
peer = %peer,
"denying controller-listener peer: principal lacks CLUSTER_ACTION on kafka-cluster"
);
return Err(RaftHandshakeError::Sasl(
"principal not authorized for CLUSTER_ACTION on the controller listener".into(),
));
}
Ok(())
}
}
#[async_trait::async_trait]
impl RaftListenerHandshake for BrokerRaftHandshake {
async fn upgrade(
&self,
stream: TcpStream,
) -> Result<Box<dyn DuplexStream>, RaftHandshakeError> {
let peer = stream
.peer_addr()
.map_err(|e| RaftHandshakeError::Tls(e.to_string()))?;
let mut stream: Box<dyn DuplexStream> = if self.protocol.requires_tls() {
let acceptor = self.tls_acceptor.clone().ok_or_else(|| {
RaftHandshakeError::Tls("tls_config required for TLS controller listener".into())
})?;
let tls = acceptor
.accept(stream)
.await
.map_err(|e| RaftHandshakeError::Tls(e.to_string()))?;
Box::new(tls)
} else {
Box::new(stream)
};
if self.protocol.requires_sasl() {
let principal = run_inbound_sasl(&mut *stream, self).await?;
self.authorize_cluster_action(&principal, &peer)?;
}
Ok(stream)
}
}
async fn run_inbound_sasl(
stream: &mut dyn DuplexStream,
cfg: &BrokerRaftHandshake,
) -> Result<crabka_security::Principal, RaftHandshakeError> {
let mut auth = pre_auth_state();
loop {
let (api_key, api_version, corr_id, body) = read_kafka_request(stream).await?;
if !is_pre_auth_allowed(api_key) && !auth.is_authenticated() {
return Err(RaftHandshakeError::Sasl(format!(
"pre-auth request api_key={api_key} rejected"
)));
}
match api_key {
API_KEY_API_VERSIONS => {
let resp_bytes = build_api_versions_response(corr_id);
stream.write_all(&resp_bytes).await?;
}
API_KEY_SASL_HANDSHAKE => {
let mut cur = body.as_slice();
let req = SaslHandshakeRequest::decode(&mut cur, api_version)
.map_err(|e| RaftHandshakeError::Protocol(e.to_string()))?;
let resp = handle_handshake(&req, &mut auth, &cfg.enabled_sasl_mechanisms);
let error_code = resp.error_code;
write_response(stream, api_key, api_version, corr_id, &resp).await?;
if error_code != 0 {
return Err(RaftHandshakeError::Sasl(format!(
"handshake error_code={error_code}"
)));
}
}
API_KEY_SASL_AUTHENTICATE => {
let mut cur = body.as_slice();
let req = SaslAuthenticateRequest::decode(&mut cur, api_version)
.map_err(|e| RaftHandshakeError::Protocol(e.to_string()))?;
let mech = match &auth {
ConnectionAuth::Negotiating { mechanism, .. } => *mechanism,
_ => {
return Err(RaftHandshakeError::Sasl(
"authenticate before handshake".into(),
));
}
};
let resp = match mech {
SaslMechanism::Plain => {
handle_authenticate_plain(&req, &mut auth, &cfg.plain_credentials)
}
SaslMechanism::ScramSha256 | SaslMechanism::ScramSha512 => {
let controller = cfg.controller.get().ok_or_else(|| {
RaftHandshakeError::Sasl(
"controller handle not initialised for SCRAM lookup".into(),
)
})?;
handle_authenticate_scram(&req, &mut auth, controller.as_ref())
}
SaslMechanism::OAuthBearer => {
return Err(RaftHandshakeError::Sasl(
"OAUTHBEARER is not supported on the controller listener".into(),
));
}
SaslMechanism::Gssapi => {
return Err(RaftHandshakeError::Sasl(
"GSSAPI is not yet wired on the controller listener".into(),
));
}
};
let error_code = resp.error_code;
write_response(stream, api_key, api_version, corr_id, &resp).await?;
if error_code != 0 {
return Err(RaftHandshakeError::Sasl(format!(
"authenticate error_code={error_code}"
)));
}
if auth.is_authenticated() {
let principal = auth.principal().cloned().ok_or_else(|| {
RaftHandshakeError::Sasl(
"authenticated connection missing principal".into(),
)
})?;
return Ok(principal);
}
debug_assert!(
matches!(
auth,
ConnectionAuth::Negotiating {
exchange: SaslExchange::Scram(_),
..
}
),
"expected SCRAM continuation after non-authenticated success"
);
}
other => {
return Err(RaftHandshakeError::Protocol(format!(
"unexpected api_key={other} during handshake"
)));
}
}
}
}
async fn read_kafka_request(
stream: &mut dyn DuplexStream,
) -> Result<(i16, i16, i32, Vec<u8>), RaftHandshakeError> {
let mut size_buf = [0u8; 4];
stream.read_exact(&mut size_buf).await?;
let size = u32::from_be_bytes(size_buf) as usize;
let mut frame = vec![0u8; size];
stream.read_exact(&mut frame).await?;
if frame.len() < 10 {
return Err(RaftHandshakeError::Protocol("short request header".into()));
}
let api_key = i16::from_be_bytes([frame[0], frame[1]]);
let api_version = i16::from_be_bytes([frame[2], frame[3]]);
let corr_id = i32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]);
let client_id_len = i16::from_be_bytes([frame[8], frame[9]]);
let mut cursor: usize = 10;
if client_id_len >= 0 {
let cid_len = usize::try_from(client_id_len)
.map_err(|_| RaftHandshakeError::Protocol("client_id_len overflow".into()))?;
let cid_end = cursor
.checked_add(cid_len)
.ok_or_else(|| RaftHandshakeError::Protocol("client_id_len overflow".into()))?;
if cid_end > frame.len() {
return Err(RaftHandshakeError::Protocol(
"client_id extends past frame".into(),
));
}
cursor = cid_end;
}
if is_request_header_flexible(api_key, api_version) {
if cursor >= frame.len() {
return Err(RaftHandshakeError::Protocol(
"missing tagged-fields byte in flexible request header".into(),
));
}
cursor += 1;
}
let body = frame[cursor..].to_vec();
Ok((api_key, api_version, corr_id, body))
}
async fn write_response<R: Encode>(
stream: &mut dyn DuplexStream,
api_key: i16,
api_version: i16,
corr_id: i32,
resp: &R,
) -> Result<(), RaftHandshakeError> {
let flexible = is_response_header_flexible(api_key, api_version);
let body_len = resp.encoded_len(api_version);
let header_len = 4 + usize::from(flexible);
let total = header_len + body_len;
let total_u32 = u32::try_from(total)
.map_err(|_| RaftHandshakeError::Protocol("response frame exceeds u32".into()))?;
let mut out = Vec::with_capacity(4 + total);
out.extend_from_slice(&total_u32.to_be_bytes());
out.extend_from_slice(&corr_id.to_be_bytes());
if flexible {
out.push(0); }
resp.encode(&mut out, api_version)
.map_err(|e| RaftHandshakeError::Protocol(e.to_string()))?;
stream.write_all(&out).await?;
Ok(())
}
fn is_request_header_flexible(api_key: i16, api_version: i16) -> bool {
match api_key {
API_KEY_SASL_AUTHENTICATE => api_version >= 2,
_ => false,
}
}
fn is_response_header_flexible(api_key: i16, api_version: i16) -> bool {
match api_key {
API_KEY_SASL_AUTHENTICATE => api_version >= 2,
_ => false,
}
}
fn build_api_versions_response(corr_id: i32) -> Vec<u8> {
let mut body = Vec::with_capacity(2 + 4 + 3 * 6 + 4);
body.extend_from_slice(&0i16.to_be_bytes()); body.extend_from_slice(&3i32.to_be_bytes()); for k in [
API_KEY_SASL_HANDSHAKE,
API_KEY_SASL_AUTHENTICATE,
API_KEY_API_VERSIONS,
] {
body.extend_from_slice(&k.to_be_bytes());
body.extend_from_slice(&0i16.to_be_bytes()); body.extend_from_slice(&2i16.to_be_bytes()); }
body.extend_from_slice(&0i32.to_be_bytes());
let total = 4 + body.len();
let total_u32 = u32::try_from(total).expect("ApiVersions response fits in u32");
let mut out = Vec::with_capacity(4 + total);
out.extend_from_slice(&total_u32.to_be_bytes());
out.extend_from_slice(&corr_id.to_be_bytes());
out.extend_from_slice(&body);
out
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
#[test]
fn plaintext_passthrough_short_circuits() {
let cfg = BrokerRaftHandshake {
tls_acceptor: None,
plain_credentials: HashMap::new(),
enabled_sasl_mechanisms: vec![],
protocol: ListenerProtocol::Plaintext,
controller: Arc::new(OnceCell::new()),
authorizer: Arc::new(crate::authorizer::AllowAllAuthorizer),
};
assert!(!cfg.protocol.requires_tls());
assert!(!cfg.protocol.requires_sasl());
}
#[test]
fn header_flexibility_table_matches_outbound_encoder() {
assert!(!is_request_header_flexible(API_KEY_SASL_HANDSHAKE, 0));
assert!(!is_request_header_flexible(API_KEY_SASL_HANDSHAKE, 1));
assert!(!is_response_header_flexible(API_KEY_SASL_HANDSHAKE, 0));
assert!(!is_response_header_flexible(API_KEY_SASL_HANDSHAKE, 1));
assert!(!is_request_header_flexible(API_KEY_SASL_AUTHENTICATE, 1));
assert!(is_request_header_flexible(API_KEY_SASL_AUTHENTICATE, 2));
assert!(!is_response_header_flexible(API_KEY_SASL_AUTHENTICATE, 1));
assert!(is_response_header_flexible(API_KEY_SASL_AUTHENTICATE, 2));
assert!(!is_response_header_flexible(API_KEY_API_VERSIONS, 0));
assert!(!is_response_header_flexible(API_KEY_API_VERSIONS, 3));
}
}