use std::fmt;
use std::str::FromStr;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tonic::service::interceptor::InterceptedService;
use tonic::transport::Channel;
use tonic::{Request, Status, Streaming};
use tracing::debug;
use uuid::Uuid;
use crate::error::{Error, Result};
use crate::proto::grpc::sasl::{
sasl_authentication_service_client::SaslAuthenticationServiceClient, SaslMessage,
};
use super::sasl_client::{PlainSaslClientHandler, SaslClientHandler};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum AuthType {
NoSasl,
#[default]
Simple,
}
impl AuthType {
pub fn as_str(&self) -> &'static str {
match self {
AuthType::NoSasl => "nosasl",
AuthType::Simple => "simple",
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for AuthType {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"nosasl" | "no_sasl" => Ok(AuthType::NoSasl),
"simple" => Ok(AuthType::Simple),
"custom" => Err(
"CUSTOM authentication is not yet implemented; use NOSASL or SIMPLE".to_string(),
),
"kerberos" => Err(
"KERBEROS authentication is not yet implemented; use NOSASL or SIMPLE".to_string(),
),
_ => Err(format!(
"unknown authentication type '{}'. Currently supported: nosasl, simple",
s
)),
}
}
}
const CHANNEL_ID_METADATA_KEY: &str = "channel-id";
#[derive(Clone)]
pub struct ChannelIdInterceptor {
channel_id: String,
}
impl ChannelIdInterceptor {
pub fn new(channel_id: String) -> Self {
Self { channel_id }
}
}
impl tonic::service::Interceptor for ChannelIdInterceptor {
fn call(&mut self, mut request: Request<()>) -> std::result::Result<Request<()>, Status> {
request.metadata_mut().insert(
CHANNEL_ID_METADATA_KEY,
self.channel_id
.parse()
.map_err(|_| Status::internal("invalid channel-id"))?,
);
Ok(request)
}
}
pub struct ChannelAuthenticator {
auth_type: AuthType,
username: String,
password: String,
impersonation_user: Option<String>,
auth_timeout: Duration,
}
pub struct AuthenticatedChannel {
pub channel: InterceptedService<Channel, ChannelIdInterceptor>,
pub channel_id: String,
_sasl_guard: Option<SaslStreamGuard>,
}
impl AuthenticatedChannel {
pub fn take_sasl_guard(&mut self) -> Option<SaslStreamGuard> {
self._sasl_guard.take()
}
}
pub struct SaslStreamGuard {
_tx: mpsc::Sender<SaslMessage>,
_response_stream: Streaming<SaslMessage>,
}
unsafe impl Send for SaslStreamGuard {}
unsafe impl Sync for SaslStreamGuard {}
impl ChannelAuthenticator {
pub fn new(auth_type: AuthType, username: String, impersonation_user: Option<String>) -> Self {
Self {
auth_type,
username,
password: "noPassword".to_string(),
impersonation_user,
auth_timeout: Duration::from_secs(30),
}
}
pub fn with_auth_timeout(mut self, timeout: Duration) -> Self {
self.auth_timeout = timeout;
self
}
pub fn with_password(mut self, password: String) -> Self {
self.password = password;
self
}
pub fn auth_type(&self) -> AuthType {
self.auth_type
}
pub async fn authenticate(&self, channel: Channel) -> Result<AuthenticatedChannel> {
match self.auth_type {
AuthType::NoSasl => self.authenticate_nosasl(channel),
AuthType::Simple => self.authenticate_simple(channel).await,
}
}
fn authenticate_nosasl(&self, channel: Channel) -> Result<AuthenticatedChannel> {
debug!(auth_type = "NOSASL", "skipping SASL authentication");
let channel_id = Uuid::new_v4().to_string();
let interceptor = ChannelIdInterceptor::new(channel_id.clone());
Ok(AuthenticatedChannel {
channel: InterceptedService::new(channel, interceptor),
channel_id,
_sasl_guard: None,
})
}
async fn authenticate_simple(&self, channel: Channel) -> Result<AuthenticatedChannel> {
let channel_id = Uuid::new_v4().to_string();
let channel_ref = format!("rust-client-{}", &channel_id[..8]);
debug!(
auth_type = "SIMPLE",
username = %self.username,
channel_id = %channel_id,
"starting SASL PLAIN authentication"
);
let sasl_handler = PlainSaslClientHandler::new_simple(
&self.username,
&self.password,
self.impersonation_user.as_deref(),
);
let initial_message = sasl_handler.initial_message(&channel_id, &channel_ref)?;
let (tx, rx) = mpsc::channel::<SaslMessage>(8);
tx.send(initial_message)
.await
.map_err(|_| Error::Internal {
message: "failed to send initial SASL message".to_string(),
source: None,
})?;
let stream = ReceiverStream::new(rx);
let mut sasl_client = SaslAuthenticationServiceClient::new(channel.clone());
let response = tokio::time::timeout(self.auth_timeout, sasl_client.authenticate(stream))
.await
.map_err(|_| Error::Internal {
message: format!(
"SASL authentication timed out ({}ms)",
self.auth_timeout.as_millis()
),
source: None,
})?
.map_err(|status| Error::GrpcError {
message: format!("SASL authentication RPC failed: {}", status),
source: status,
})?;
let mut response_stream = response.into_inner();
let auth_result = tokio::time::timeout(self.auth_timeout, async {
while let Some(server_msg) =
response_stream
.message()
.await
.map_err(|status| Error::GrpcError {
message: format!("SASL authentication response error: {}", status),
source: status,
})?
{
match sasl_handler.handle_message(&server_msg)? {
Some(client_response) => {
tx.send(client_response)
.await
.map_err(|_| Error::Internal {
message: "failed to send SASL response message".to_string(),
source: None,
})?;
}
None => {
debug!(
channel_id = %channel_id,
"SASL PLAIN authentication succeeded"
);
return Ok::<(), Error>(());
}
}
}
Err(Error::Internal {
message: "SASL authentication stream closed unexpectedly without receiving SUCCESS"
.to_string(),
source: None,
})
})
.await
.map_err(|_| Error::Internal {
message: format!(
"timed out waiting for SASL authentication response ({}ms)",
self.auth_timeout.as_millis()
),
source: None,
})?;
auth_result?;
let interceptor = ChannelIdInterceptor::new(channel_id.clone());
Ok(AuthenticatedChannel {
channel: InterceptedService::new(channel, interceptor),
channel_id,
_sasl_guard: Some(SaslStreamGuard {
_tx: tx,
_response_stream: response_stream,
}),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use tonic::service::Interceptor;
#[test]
fn test_auth_type_from_str() {
assert_eq!("nosasl".parse::<AuthType>().unwrap(), AuthType::NoSasl);
assert_eq!("NOSASL".parse::<AuthType>().unwrap(), AuthType::NoSasl);
assert_eq!("no_sasl".parse::<AuthType>().unwrap(), AuthType::NoSasl);
assert_eq!("simple".parse::<AuthType>().unwrap(), AuthType::Simple);
assert_eq!("SIMPLE".parse::<AuthType>().unwrap(), AuthType::Simple);
}
#[test]
fn test_auth_type_from_str_unsupported() {
assert!("custom".parse::<AuthType>().is_err());
assert!("kerberos".parse::<AuthType>().is_err());
assert!("invalid".parse::<AuthType>().is_err());
}
#[test]
fn test_auth_type_default() {
assert_eq!(AuthType::default(), AuthType::Simple);
}
#[test]
fn test_auth_type_display() {
assert_eq!(AuthType::NoSasl.to_string(), "nosasl");
assert_eq!(AuthType::Simple.to_string(), "simple");
}
#[test]
fn test_channel_id_interceptor() {
let mut interceptor = ChannelIdInterceptor::new("test-id-123".to_string());
let request = Request::new(());
let result = interceptor.call(request).unwrap();
let channel_id = result
.metadata()
.get(CHANNEL_ID_METADATA_KEY)
.unwrap()
.to_str()
.unwrap();
assert_eq!(channel_id, "test-id-123");
}
}