use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Semaphore;
use tracing::{debug, error, info, warn};
use rivven_core::{
AuthManager, Config, OffsetManager, ServiceAuthConfig, ServiceAuthManager, TopicManager,
};
#[cfg(feature = "tls")]
use rivven_core::{
tls::{MtlsMode, TlsAcceptor, TlsConfig, TlsIdentity, TlsServerStream},
AuthSession,
};
use crate::auth_handler::{AuthenticatedHandler, ConnectionAuth};
use crate::handler::RequestHandler;
use crate::protocol::{Request, Response, WireFormat};
#[derive(Debug, Clone)]
pub struct SecureServerConfig {
pub bind_addr: SocketAddr,
#[cfg(feature = "tls")]
pub tls_config: Option<TlsConfig>,
pub max_connections: usize,
pub connection_timeout: Duration,
pub idle_timeout: Duration,
pub max_message_size: usize,
pub require_auth: bool,
pub enable_service_auth: bool,
pub service_auth_config: Option<ServiceAuthConfig>,
}
impl Default for SecureServerConfig {
fn default() -> Self {
Self {
bind_addr: "0.0.0.0:9092".parse().unwrap(),
#[cfg(feature = "tls")]
tls_config: None,
max_connections: 10_000,
connection_timeout: Duration::from_secs(30),
idle_timeout: Duration::from_secs(300),
max_message_size: 10 * 1024 * 1024, require_auth: false,
enable_service_auth: false,
service_auth_config: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionSecurityContext {
pub client_addr: SocketAddr,
#[cfg(feature = "tls")]
pub tls_info: Option<TlsConnectionInfo>,
pub auth_state: ConnectionAuth,
#[cfg(feature = "tls")]
pub service_identity: Option<ServiceIdentity>,
}
#[cfg(feature = "tls")]
#[derive(Debug, Clone)]
pub struct TlsConnectionInfo {
pub protocol_version: String,
pub cipher_suite: Option<String>,
pub client_cert: Option<TlsIdentity>,
pub alpn_protocol: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ServiceIdentity {
pub service_id: String,
pub common_name: String,
pub subject: String,
pub fingerprint: String,
pub roles: Vec<String>,
}
pub struct SecureServer {
config: SecureServerConfig,
topic_manager: TopicManager,
offset_manager: OffsetManager,
auth_manager: Arc<AuthManager>,
#[allow(dead_code)]
service_auth_manager: Option<Arc<ServiceAuthManager>>,
#[cfg(feature = "tls")]
tls_acceptor: Option<TlsAcceptor>,
connection_semaphore: Arc<Semaphore>,
}
impl SecureServer {
pub async fn new(
core_config: Config,
server_config: SecureServerConfig,
) -> anyhow::Result<Self> {
Self::with_auth_manager(core_config, server_config, None).await
}
pub async fn with_auth_manager(
core_config: Config,
server_config: SecureServerConfig,
auth_manager: Option<Arc<AuthManager>>,
) -> anyhow::Result<Self> {
let topic_manager = TopicManager::new(core_config.clone());
if let Err(e) = topic_manager.recover().await {
tracing::warn!("Failed to recover topics from disk: {}", e);
}
let offset_manager = OffsetManager::with_persistence(
std::path::PathBuf::from(&core_config.data_dir).join("offsets"),
);
let auth_manager =
auth_manager.unwrap_or_else(|| Arc::new(AuthManager::new(Default::default())));
let service_auth_manager = if server_config.enable_service_auth {
Some(Arc::new(ServiceAuthManager::new()))
} else {
None
};
#[cfg(feature = "tls")]
let tls_acceptor = if let Some(ref tls_config) = server_config.tls_config {
if tls_config.enabled {
Some(TlsAcceptor::new(tls_config)?)
} else {
None
}
} else {
None
};
let connection_semaphore = Arc::new(Semaphore::new(server_config.max_connections));
Ok(Self {
config: server_config,
topic_manager,
offset_manager,
auth_manager,
service_auth_manager,
#[cfg(feature = "tls")]
tls_acceptor,
connection_semaphore,
})
}
pub async fn start(self) -> anyhow::Result<()> {
let listener = TcpListener::bind(self.config.bind_addr).await?;
#[cfg(feature = "tls")]
let mode = if self.tls_acceptor.is_some() {
if let Some(ref cfg) = self.config.tls_config {
match cfg.mtls_mode {
MtlsMode::Required => "mTLS (client cert required)",
MtlsMode::Optional => "TLS (client cert optional)",
MtlsMode::Disabled => "TLS",
}
} else {
"plaintext"
}
} else {
"plaintext"
};
#[cfg(not(feature = "tls"))]
let mode = "plaintext";
info!(
"Secure server listening on {} (mode: {}, auth: {})",
self.config.bind_addr,
mode,
if self.config.require_auth {
"required"
} else {
"optional"
}
);
let auth_handler_inner =
RequestHandler::new(self.topic_manager.clone(), self.offset_manager.clone());
let auth_handler = Arc::new(AuthenticatedHandler::new(
auth_handler_inner,
self.auth_manager.clone(),
self.config.require_auth,
));
let server = Arc::new(self);
loop {
let permit = match server.connection_semaphore.clone().try_acquire_owned() {
Ok(permit) => permit,
Err(_) => {
warn!(
"Max connections reached ({}), rejecting",
server.config.max_connections
);
if let Ok((stream, _)) = listener.accept().await {
drop(stream);
}
continue;
}
};
match listener.accept().await {
Ok((tcp_stream, client_addr)) => {
let server = server.clone();
let auth_handler = auth_handler.clone();
tokio::spawn(async move {
let _permit = permit;
if let Err(e) = server
.handle_connection(tcp_stream, client_addr, auth_handler)
.await
{
debug!("Connection error from {}: {}", client_addr, e);
}
});
}
Err(e) => {
error!("Accept error: {}", e);
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
}
async fn handle_connection(
&self,
tcp_stream: TcpStream,
client_addr: SocketAddr,
auth_handler: Arc<AuthenticatedHandler>,
) -> anyhow::Result<()> {
tcp_stream.set_nodelay(true)?;
let _timeout = self.config.connection_timeout;
#[cfg(feature = "tls")]
if let Some(ref tls_acceptor) = self.tls_acceptor {
let tls_stream =
match tokio::time::timeout(_timeout, tls_acceptor.accept_tcp(tcp_stream)).await {
Ok(Ok(stream)) => stream,
Ok(Err(e)) => {
warn!("TLS handshake failed from {}: {}", client_addr, e);
return Ok(());
}
Err(_) => {
warn!("TLS handshake timeout from {}", client_addr);
return Ok(());
}
};
let security_ctx = self
.build_tls_security_context(client_addr, &tls_stream)
.await?;
return self
.handle_secure_connection(tls_stream, security_ctx, auth_handler)
.await;
}
let security_ctx = ConnectionSecurityContext {
client_addr,
#[cfg(feature = "tls")]
tls_info: None,
auth_state: if self.config.require_auth {
ConnectionAuth::Unauthenticated
} else {
ConnectionAuth::Anonymous
},
#[cfg(feature = "tls")]
service_identity: None,
};
self.handle_secure_connection(tcp_stream, security_ctx, auth_handler)
.await
}
#[cfg(feature = "tls")]
async fn build_tls_security_context(
&self,
client_addr: SocketAddr,
tls_stream: &TlsServerStream<TcpStream>,
) -> anyhow::Result<ConnectionSecurityContext> {
let protocol_version = tls_stream
.protocol_version()
.map(|v| format!("{:?}", v))
.unwrap_or_else(|| "unknown".to_string());
let alpn = tls_stream
.alpn_protocol()
.map(|p| String::from_utf8_lossy(p).to_string());
let client_cert = tls_stream.peer_certificates().and_then(|certs| {
if certs.is_empty() {
None
} else {
Some(TlsIdentity::from_certificate(&certs[0]))
}
});
let tls_info = TlsConnectionInfo {
protocol_version,
cipher_suite: tls_stream.cipher_suite_name(),
client_cert: client_cert.clone(),
alpn_protocol: alpn,
};
let (auth_state, service_identity) = if let Some(ref cert_identity) = client_cert {
if let Some(ref svc_auth) = self.service_auth_manager {
let cert_subject = cert_identity
.subject
.clone()
.unwrap_or_else(|| cert_identity.common_name.clone().unwrap_or_default());
if !cert_subject.is_empty() {
let client_ip_str = client_addr.ip().to_string();
match svc_auth.authenticate_certificate(&cert_subject, &client_ip_str) {
Ok(session) => {
info!(
"mTLS authenticated service '{}' from {} (cert: {})",
session.service_account,
client_addr,
cert_identity.common_name.as_deref().unwrap_or("?")
);
let svc_identity = ServiceIdentity {
service_id: session.service_account.clone(),
common_name: cert_identity.common_name.clone().unwrap_or_default(),
subject: cert_subject,
fingerprint: cert_identity.fingerprint.clone(),
roles: session.permissions.clone(),
};
let auth_session = AuthSession {
id: session.id.clone(),
principal_name: session.service_account.clone(),
principal_type: rivven_core::PrincipalType::ServiceAccount,
permissions: std::collections::HashSet::new(), created_at: std::time::Instant::now(),
expires_at: std::time::Instant::now()
+ session.time_until_expiration(),
client_ip: client_addr.ip().to_string(),
};
(
ConnectionAuth::Authenticated(auth_session),
Some(svc_identity),
)
}
Err(e) => {
warn!(
"mTLS auth failed for cert '{}' from {}: {}",
cert_subject, client_addr, e
);
(ConnectionAuth::Unauthenticated, None)
}
}
} else {
warn!("Client cert has no subject from {}", client_addr);
(ConnectionAuth::Unauthenticated, None)
}
} else {
debug!(
"Client cert provided but service auth not enabled from {}",
client_addr
);
(
if self.config.require_auth {
ConnectionAuth::Unauthenticated
} else {
ConnectionAuth::Anonymous
},
None,
)
}
} else {
(
if self.config.require_auth {
ConnectionAuth::Unauthenticated
} else {
ConnectionAuth::Anonymous
},
None,
)
};
Ok(ConnectionSecurityContext {
client_addr,
tls_info: Some(tls_info),
auth_state,
service_identity,
})
}
async fn handle_secure_connection<S>(
&self,
mut stream: S,
mut security_ctx: ConnectionSecurityContext,
auth_handler: Arc<AuthenticatedHandler>,
) -> anyhow::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mut buffer = BytesMut::with_capacity(8192);
let client_addr = security_ctx.client_addr;
let client_ip = client_addr.ip().to_string();
#[cfg(feature = "tls")]
let has_tls = security_ctx.tls_info.is_some();
#[cfg(not(feature = "tls"))]
let has_tls = false;
debug!(
"Connection established: addr={}, tls={}, auth={:?}",
client_addr,
has_tls,
std::mem::discriminant(&security_ctx.auth_state)
);
loop {
let mut len_buf = [0u8; 4];
match tokio::time::timeout(self.config.idle_timeout, stream.read_exact(&mut len_buf))
.await
{
Ok(Ok(_)) => {}
Ok(Err(e)) if e.kind() == io::ErrorKind::UnexpectedEof => {
debug!("Client {} disconnected gracefully", client_addr);
return Ok(());
}
Ok(Err(e)) => return Err(e.into()),
Err(_) => {
debug!("Idle timeout for {}", client_addr);
return Ok(());
}
}
let msg_len = u32::from_be_bytes(len_buf) as usize;
if msg_len > self.config.max_message_size {
warn!(
"Message too large from {}: {} bytes (max: {})",
client_addr, msg_len, self.config.max_message_size
);
let response = Response::Error {
message: format!("MESSAGE_TOO_LARGE: {} bytes exceeds limit", msg_len),
};
self.send_response_with_format(&mut stream, &response, WireFormat::Postcard)
.await?;
continue;
}
buffer.clear();
buffer.resize(msg_len, 0);
match tokio::time::timeout(self.config.idle_timeout, stream.read_exact(&mut buffer))
.await
{
Ok(Ok(_)) => {}
Ok(Err(e)) => return Err(e.into()),
Err(_) => {
debug!(
"Read timeout during message body from {} - closing connection",
client_addr
);
return Ok(());
}
}
let (request, wire_format) = match Request::from_wire(&buffer) {
Ok((req, fmt)) => (req, fmt),
Err(e) => {
warn!("Invalid request from {}: {}", client_addr, e);
let response = Response::Error {
message: format!("INVALID_REQUEST: {}", e),
};
self.send_response_with_format(&mut stream, &response, WireFormat::Postcard)
.await?;
continue;
}
};
let response = auth_handler
.handle(request, &mut security_ctx.auth_state, &client_ip)
.await;
self.send_response_with_format(&mut stream, &response, wire_format)
.await?;
}
}
async fn send_response_with_format<S>(
&self,
stream: &mut S,
response: &Response,
format: WireFormat,
) -> anyhow::Result<()>
where
S: AsyncWrite + Unpin,
{
let response_bytes = response.to_wire(format)?;
let len = response_bytes.len() as u32;
stream.write_all(&len.to_be_bytes()).await?;
stream.write_all(&response_bytes).await?;
stream.flush().await?;
Ok(())
}
}
pub struct SecureServerBuilder {
core_config: Config,
server_config: SecureServerConfig,
}
impl SecureServerBuilder {
pub fn new(core_config: Config) -> Self {
Self {
core_config,
server_config: SecureServerConfig::default(),
}
}
pub fn bind(mut self, addr: SocketAddr) -> Self {
self.server_config.bind_addr = addr;
self
}
#[cfg(feature = "tls")]
pub fn with_tls(mut self, tls_config: TlsConfig) -> Self {
self.server_config.tls_config = Some(tls_config);
self
}
pub fn max_connections(mut self, max: usize) -> Self {
self.server_config.max_connections = max;
self
}
pub fn connection_timeout(mut self, timeout: Duration) -> Self {
self.server_config.connection_timeout = timeout;
self
}
pub fn idle_timeout(mut self, timeout: Duration) -> Self {
self.server_config.idle_timeout = timeout;
self
}
pub fn max_message_size(mut self, size: usize) -> Self {
self.server_config.max_message_size = size;
self
}
pub fn require_auth(mut self, require: bool) -> Self {
self.server_config.require_auth = require;
self
}
pub fn enable_service_auth(mut self, config: ServiceAuthConfig) -> Self {
self.server_config.enable_service_auth = true;
self.server_config.service_auth_config = Some(config);
self
}
pub async fn build(self) -> anyhow::Result<SecureServer> {
SecureServer::new(self.core_config, self.server_config).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = SecureServerConfig::default();
assert_eq!(config.max_connections, 10_000);
assert_eq!(config.max_message_size, 10 * 1024 * 1024);
assert!(!config.require_auth);
}
#[test]
fn test_builder() {
let core_config = Config::default();
let builder = SecureServerBuilder::new(core_config)
.bind("127.0.0.1:9999".parse().unwrap())
.max_connections(5000)
.require_auth(true);
assert_eq!(
builder.server_config.bind_addr,
"127.0.0.1:9999".parse().unwrap()
);
assert_eq!(builder.server_config.max_connections, 5000);
assert!(builder.server_config.require_auth);
}
}