use std::path::PathBuf;
use bytes::{Buf, BufMut, BytesMut};
use crabka_protocol::owned::sasl_authenticate_request::SaslAuthenticateRequest;
use crabka_protocol::owned::sasl_authenticate_response::SaslAuthenticateResponse;
use crabka_protocol::owned::sasl_handshake_request::SaslHandshakeRequest;
use crabka_protocol::owned::sasl_handshake_response::SaslHandshakeResponse;
use crabka_protocol::{Decode, Encode};
use crabka_security::{SaslMechanism, ScramClientExchange};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const API_KEY_SASL_HANDSHAKE: i16 = 17;
const API_KEY_SASL_AUTHENTICATE: i16 = 36;
const OUTBOUND_CLIENT_ID: &str = "crabka-client";
#[cfg(feature = "sspi-keytab")]
const GSSAPI_MAX_RECV_SIZE: u32 = 0x1_0000;
#[derive(Debug, Clone)]
pub enum SaslCredentials {
Plain { username: String, password: String },
Scram {
mechanism: SaslMechanism,
username: String,
password: String,
},
Gssapi {
keytab_path: PathBuf,
client_principal: String,
service_name: String,
kdc_url: String,
},
}
impl SaslCredentials {
#[must_use]
pub fn mechanism(&self) -> SaslMechanism {
match self {
Self::Plain { .. } => SaslMechanism::Plain,
Self::Scram { mechanism, .. } => *mechanism,
Self::Gssapi { .. } => SaslMechanism::Gssapi,
}
}
}
#[derive(Debug, Error)]
pub enum OutboundSaslError {
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("sasl: {0}")]
Sasl(String),
#[error("codec: {0}")]
Codec(String),
}
pub async fn outbound_sasl<S>(
stream: &mut S,
creds: &SaslCredentials,
server_name: &str,
) -> Result<(), OutboundSaslError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + ?Sized,
{
let mut corr_id: i32 = 1;
send_sasl_handshake(stream, creds.mechanism(), &mut corr_id).await?;
match creds {
SaslCredentials::Plain { username, password } => {
send_plain_authenticate(stream, username, password, &mut corr_id).await
}
SaslCredentials::Scram {
mechanism,
username,
password,
} => run_scram_client(stream, username, password, *mechanism, &mut corr_id).await,
SaslCredentials::Gssapi {
keytab_path,
client_principal,
service_name,
kdc_url,
} => {
run_gssapi_client(
stream,
keytab_path,
client_principal,
service_name,
server_name,
kdc_url,
&mut corr_id,
)
.await
}
}
}
async fn send_sasl_handshake<S>(
stream: &mut S,
mechanism: SaslMechanism,
corr_id: &mut i32,
) -> Result<(), OutboundSaslError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + ?Sized,
{
let req = SaslHandshakeRequest {
mechanism: mechanism.wire_name().to_string(),
..Default::default()
};
let mut body = BytesMut::new();
req.encode(&mut body, 1)
.map_err(|e| OutboundSaslError::Codec(format!("SaslHandshake encode: {e}")))?;
let resp_bytes = round_trip(stream, API_KEY_SASL_HANDSHAKE, 1, *corr_id, false, &body).await?;
*corr_id += 1;
let mut cur: &[u8] = &resp_bytes;
let resp = SaslHandshakeResponse::decode(&mut cur, 1)
.map_err(|e| OutboundSaslError::Codec(format!("SaslHandshake decode: {e}")))?;
if resp.error_code != 0 {
return Err(OutboundSaslError::Sasl(format!(
"SaslHandshake error_code={} (mechanism={})",
resp.error_code,
mechanism.wire_name()
)));
}
Ok(())
}
async fn send_plain_authenticate<S>(
stream: &mut S,
user: &str,
pass: &str,
corr_id: &mut i32,
) -> Result<(), OutboundSaslError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + ?Sized,
{
let mut payload = Vec::with_capacity(2 + user.len() + pass.len());
payload.push(0); payload.extend_from_slice(user.as_bytes());
payload.push(0);
payload.extend_from_slice(pass.as_bytes());
let resp = send_sasl_authenticate(stream, payload, corr_id).await?;
if resp.error_code != 0 {
return Err(OutboundSaslError::Sasl(format!(
"SaslAuthenticate(PLAIN) error_code={} error_message={:?}",
resp.error_code, resp.error_message
)));
}
Ok(())
}
async fn run_scram_client<S>(
stream: &mut S,
user: &str,
pass: &str,
mechanism: SaslMechanism,
corr_id: &mut i32,
) -> Result<(), OutboundSaslError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + ?Sized,
{
let mut exch = ScramClientExchange::new(user.to_string(), pass.as_bytes().to_vec(), mechanism);
let client_first = exch
.client_first()
.map_err(|e| OutboundSaslError::Sasl(format!("scram client_first: {e:?}")))?;
let resp1 = send_sasl_authenticate(stream, client_first, corr_id).await?;
if resp1.error_code != 0 {
return Err(OutboundSaslError::Sasl(format!(
"SaslAuthenticate(SCRAM round 1) error_code={} error_message={:?}",
resp1.error_code, resp1.error_message
)));
}
let server_first = resp1.auth_bytes.to_vec();
let client_final = exch
.step(&server_first)
.map_err(|e| OutboundSaslError::Sasl(format!("scram client step: {e:?}")))?;
let resp2 = send_sasl_authenticate(stream, client_final, corr_id).await?;
if resp2.error_code != 0 {
return Err(OutboundSaslError::Sasl(format!(
"SaslAuthenticate(SCRAM round 2) error_code={} error_message={:?}",
resp2.error_code, resp2.error_message
)));
}
exch.verify_server_final(&resp2.auth_bytes)
.map_err(|e| OutboundSaslError::Sasl(format!("server-final verify: {e:?}")))?;
Ok(())
}
#[cfg(feature = "sspi-keytab")]
async fn run_gssapi_client<S>(
stream: &mut S,
keytab_path: &std::path::Path,
client_principal: &str,
service_name: &str,
server_name: &str,
kdc_url: &str,
corr_id: &mut i32,
) -> Result<(), OutboundSaslError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + ?Sized,
{
use crabka_security::gssapi::client::{ClientStep, GssapiClientExchange};
use crabka_security::gssapi::provider::SspiInitiator;
let target_spn = format!("{service_name}/{server_name}");
let keytab = keytab_path.to_string_lossy();
let initiator = SspiInitiator::new(&keytab, client_principal, &target_spn, kdc_url)
.map_err(|e| OutboundSaslError::Sasl(format!("GSSAPI initiator init failed: {e}")))?;
let mut exchange = GssapiClientExchange::new(Box::new(initiator), GSSAPI_MAX_RECV_SIZE, None);
let mut step = exchange
.step(None)
.map_err(|e| OutboundSaslError::Sasl(format!("GSSAPI initiate failed: {e}")))?;
loop {
match step {
ClientStep::Token(token) => {
let resp = send_sasl_authenticate(stream, token, corr_id).await?;
if resp.error_code != 0 {
return Err(OutboundSaslError::Sasl(format!(
"SaslAuthenticate(GSSAPI) error_code={} error_message={:?}",
resp.error_code, resp.error_message
)));
}
step = exchange
.step(Some(&resp.auth_bytes))
.map_err(|e| OutboundSaslError::Sasl(format!("GSSAPI step failed: {e}")))?;
}
ClientStep::Done => return Ok(()),
}
}
}
#[cfg(not(feature = "sspi-keytab"))]
#[allow(clippy::unused_async)]
async fn run_gssapi_client<S>(
_stream: &mut S,
_keytab_path: &std::path::Path,
_client_principal: &str,
_service_name: &str,
_server_name: &str,
_kdc_url: &str,
_corr_id: &mut i32,
) -> Result<(), OutboundSaslError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + ?Sized,
{
Err(OutboundSaslError::Sasl(
"GSSAPI client support requires the sspi-keytab feature".to_string(),
))
}
async fn send_sasl_authenticate<S>(
stream: &mut S,
auth_bytes: Vec<u8>,
corr_id: &mut i32,
) -> Result<SaslAuthenticateResponse, OutboundSaslError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + ?Sized,
{
let req = SaslAuthenticateRequest {
auth_bytes: bytes::Bytes::from(auth_bytes),
..Default::default()
};
let mut body = BytesMut::new();
req.encode(&mut body, 2)
.map_err(|e| OutboundSaslError::Codec(format!("SaslAuthenticate encode: {e}")))?;
let resp_bytes =
round_trip(stream, API_KEY_SASL_AUTHENTICATE, 2, *corr_id, true, &body).await?;
*corr_id += 1;
let mut cur: &[u8] = &resp_bytes;
let resp = SaslAuthenticateResponse::decode(&mut cur, 2)
.map_err(|e| OutboundSaslError::Codec(format!("SaslAuthenticate decode: {e}")))?;
Ok(resp)
}
async fn round_trip<S>(
stream: &mut S,
api_key: i16,
api_version: i16,
corr_id: i32,
flexible: bool,
body: &[u8],
) -> Result<Vec<u8>, OutboundSaslError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + ?Sized,
{
let mut frame = BytesMut::with_capacity(16 + body.len());
frame.put_i16(api_key);
frame.put_i16(api_version);
frame.put_i32(corr_id);
frame.put_i16(
i16::try_from(OUTBOUND_CLIENT_ID.len())
.map_err(|_| OutboundSaslError::Codec("client_id too long".into()))?,
);
frame.put_slice(OUTBOUND_CLIENT_ID.as_bytes());
if flexible {
frame.put_u8(0); }
frame.put_slice(body);
stream
.write_u32(
u32::try_from(frame.len())
.map_err(|_| OutboundSaslError::Codec("frame size exceeds u32".into()))?,
)
.await?;
stream.write_all(&frame).await?;
stream.flush().await?;
let resp_len = stream.read_u32().await?;
let mut resp = vec![0u8; resp_len as usize];
stream.read_exact(&mut resp).await?;
let mut cur = &resp[..];
if cur.len() < 4 {
return Err(OutboundSaslError::Codec("response missing corr_id".into()));
}
let _resp_corr_id = cur.get_i32();
let uses_v1_header = flexible && api_key != 18;
if uses_v1_header {
if cur.is_empty() {
return Err(OutboundSaslError::Codec(
"flexible response missing tagged-fields byte".into(),
));
}
let _tagged = cur.get_u8();
}
Ok(cur.to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
use crabka_protocol::owned::sasl_authenticate_response::SaslAuthenticateResponse;
use crabka_protocol::owned::sasl_handshake_response::SaslHandshakeResponse;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
async fn reply_frame<S>(stream: &mut S, body: &[u8], flex_header: bool)
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
let req_len = stream.read_u32().await.unwrap();
let mut req = vec![0u8; req_len as usize];
stream.read_exact(&mut req).await.unwrap();
let corr_id = i32::from_be_bytes([req[4], req[5], req[6], req[7]]);
let mut frame = BytesMut::new();
frame.put_i32(corr_id);
if flex_header {
frame.put_u8(0); }
frame.put_slice(body);
stream
.write_u32(u32::try_from(frame.len()).unwrap())
.await
.unwrap();
stream.write_all(&frame).await.unwrap();
stream.flush().await.unwrap();
}
#[tokio::test]
async fn outbound_plain_completes() {
let (mut client, mut server) = tokio::io::duplex(8192);
let server_task = tokio::spawn(async move {
let mut hs = BytesMut::new();
SaslHandshakeResponse {
error_code: 0,
..Default::default()
}
.encode(&mut hs, 1)
.unwrap();
reply_frame(&mut server, &hs, false).await;
let mut au = BytesMut::new();
SaslAuthenticateResponse {
error_code: 0,
..Default::default()
}
.encode(&mut au, 2)
.unwrap();
reply_frame(&mut server, &au, true).await;
});
let creds = SaslCredentials::Plain {
username: "u".into(),
password: "p".into(),
};
outbound_sasl(&mut client, &creds, "localhost")
.await
.expect("PLAIN outbound handshake completes");
server_task.await.unwrap();
}
}