use super::event::AuditEvent;
use super::exporter::AuditExporter;
use anyhow::{Context, Result};
use async_trait::async_trait;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_rustls::rustls::pki_types::ServerName;
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::TlsConnector;
enum Connection {
Plain(TcpStream),
Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
}
impl Connection {
async fn write_all(&mut self, data: &[u8]) -> std::io::Result<()> {
match self {
Connection::Plain(stream) => stream.write_all(data).await,
Connection::Tls(stream) => stream.write_all(data).await,
}
}
async fn flush(&mut self) -> std::io::Result<()> {
match self {
Connection::Plain(stream) => stream.flush().await,
Connection::Tls(stream) => stream.flush().await,
}
}
}
pub struct LogstashExporter {
host: String,
port: u16,
connection: Mutex<Option<Connection>>,
reconnect_delay: Duration,
connect_timeout: Duration,
use_tls: bool,
tls_connector: Option<TlsConnector>,
}
impl LogstashExporter {
pub fn new(host: &str, port: u16) -> Result<Self> {
if host.is_empty() {
anyhow::bail!("Logstash host cannot be empty");
}
if !Self::is_valid_host(host) {
anyhow::bail!("Invalid host format: must be a valid hostname or IP address");
}
tracing::warn!(
"Logstash exporter created without TLS encryption. \
For production use, enable TLS with with_tls(true) to protect audit data in transit."
);
Ok(Self {
host: host.to_string(),
port,
connection: Mutex::new(None),
reconnect_delay: Duration::from_secs(5),
connect_timeout: Duration::from_secs(10),
use_tls: false,
tls_connector: None,
})
}
#[must_use]
pub fn with_tls(mut self, enable: bool) -> Self {
self.use_tls = enable;
if enable {
let mut root_store = RootCertStore::empty();
let cert_result = rustls_native_certs::load_native_certs();
for cert in cert_result.certs {
root_store.add(cert).ok();
}
if !cert_result.errors.is_empty() {
tracing::warn!(
"Some errors occurred while loading native certificates: {:?}",
cert_result.errors
);
}
let config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
self.tls_connector = Some(TlsConnector::from(Arc::new(config)));
tracing::info!("TLS encryption enabled for Logstash exporter");
} else {
self.tls_connector = None;
tracing::warn!(
"TLS encryption disabled for Logstash exporter. \
Audit data will be transmitted unencrypted."
);
}
self
}
fn is_valid_host(host: &str) -> bool {
if host.parse::<IpAddr>().is_ok() {
return true;
}
if host.len() > 253 {
return false;
}
let labels: Vec<&str> = host.split('.').collect();
if labels.is_empty() {
return false;
}
for label in labels {
if label.is_empty() || label.len() > 63 {
return false;
}
let chars: Vec<char> = label.chars().collect();
if chars[0] == '-' || chars[chars.len() - 1] == '-' {
return false;
}
if !chars.iter().all(|c| c.is_ascii_alphanumeric() || *c == '-') {
return false;
}
}
true
}
async fn ensure_connected(&self) -> Result<()> {
let mut conn = self.connection.lock().await;
if conn.is_none() {
match self.connect().await {
Ok(stream) => {
*conn = Some(stream);
}
Err(e) => {
tracing::warn!("Failed to connect to Logstash: {}", e);
return Err(e);
}
}
}
Ok(())
}
async fn connect(&self) -> Result<Connection> {
let addr = format!("{}:{}", self.host, self.port);
let tcp_stream = tokio::time::timeout(self.connect_timeout, TcpStream::connect(&addr))
.await
.context("Connection timeout")?
.context("Failed to connect")?;
let connection = if self.use_tls {
let connector = self
.tls_connector
.as_ref()
.ok_or_else(|| anyhow::anyhow!("TLS enabled but connector not initialized"))?
.clone();
let server_name =
ServerName::try_from(self.host.clone()).context("Invalid server name for TLS")?;
let tls_stream = connector
.connect(server_name, tcp_stream)
.await
.context("TLS handshake failed")?;
tracing::info!("Connected to Logstash at {} with TLS", addr);
Connection::Tls(Box::new(tls_stream))
} else {
tracing::info!("Connected to Logstash at {} (unencrypted)", addr);
Connection::Plain(tcp_stream)
};
Ok(connection)
}
async fn send(&self, data: &[u8]) -> Result<()> {
let mut conn = self.connection.lock().await;
if let Some(ref mut stream) = *conn {
match stream.write_all(data).await {
Ok(_) => return Ok(()),
Err(e) => {
tracing::warn!("Logstash write failed, reconnecting: {}", e);
*conn = None;
}
}
}
drop(conn);
tokio::time::sleep(self.reconnect_delay).await;
let mut stream = self.connect().await?;
stream
.write_all(data)
.await
.context("Failed to write after reconnection")?;
let mut conn = self.connection.lock().await;
*conn = Some(stream);
Ok(())
}
fn format_event(&self, event: &AuditEvent) -> Result<String> {
let mut json = serde_json::to_string(event).context("Failed to serialize event")?;
json.push('\n');
Ok(json)
}
}
#[async_trait]
impl AuditExporter for LogstashExporter {
async fn export(&self, event: AuditEvent) -> Result<()> {
self.ensure_connected().await?;
let data = self.format_event(&event)?;
self.send(data.as_bytes()).await
}
async fn export_batch(&self, events: &[AuditEvent]) -> Result<()> {
self.ensure_connected().await?;
let mut batch = String::new();
for event in events {
batch.push_str(&self.format_event(event)?);
}
self.send(batch.as_bytes()).await
}
async fn flush(&self) -> Result<()> {
let mut conn = self.connection.lock().await;
if let Some(ref mut stream) = *conn {
stream
.flush()
.await
.context("Failed to flush Logstash connection")?;
}
Ok(())
}
async fn close(&self) -> Result<()> {
let mut conn = self.connection.lock().await;
if let Some(stream) = conn.take() {
drop(stream);
tracing::info!("Closed Logstash connection");
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::audit::event::{EventResult, EventType};
use std::net::{IpAddr, SocketAddr};
use tokio::io::AsyncReadExt;
use tokio::net::TcpListener;
async fn mock_logstash_server() -> (SocketAddr, tokio::task::JoinHandle<Vec<String>>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
let mut received_lines = Vec::new();
let (mut socket, _) = listener.accept().await.unwrap();
let mut buffer = String::new();
loop {
let mut chunk = [0u8; 1024];
match socket.read(&mut chunk).await {
Ok(0) => break,
Ok(n) => {
buffer.push_str(&String::from_utf8_lossy(&chunk[..n]));
while let Some(pos) = buffer.find('\n') {
let line = buffer[..pos].to_string();
buffer.drain(..=pos);
received_lines.push(line);
}
}
Err(_) => break,
}
}
received_lines
});
(addr, handle)
}
#[tokio::test]
async fn test_logstash_exporter_creation() {
let exporter = LogstashExporter::new("localhost", 5044);
assert!(exporter.is_ok());
}
#[tokio::test]
async fn test_logstash_exporter_invalid_host() {
let exporter = LogstashExporter::new("", 5044);
assert!(exporter.is_err());
}
#[tokio::test]
async fn test_host_validation() {
assert!(LogstashExporter::new("localhost", 5044).is_ok());
assert!(LogstashExporter::new("logstash.example.com", 5044).is_ok());
assert!(LogstashExporter::new("my-server-01.internal.example.com", 5044).is_ok());
assert!(LogstashExporter::new("127.0.0.1", 5044).is_ok());
assert!(LogstashExporter::new("192.168.1.100", 5044).is_ok());
assert!(LogstashExporter::new("::1", 5044).is_ok());
assert!(LogstashExporter::new("2001:db8::1", 5044).is_ok());
assert!(LogstashExporter::new("", 5044).is_err());
assert!(LogstashExporter::new("-invalid", 5044).is_err());
assert!(LogstashExporter::new("invalid-", 5044).is_err());
assert!(LogstashExporter::new("invalid..host", 5044).is_err());
assert!(LogstashExporter::new("invalid host with spaces", 5044).is_err());
assert!(LogstashExporter::new("invalid@host", 5044).is_err());
}
#[tokio::test]
async fn test_with_tls() {
let exporter = LogstashExporter::new("localhost", 5044)
.unwrap()
.with_tls(true);
assert!(exporter.use_tls);
assert!(exporter.tls_connector.is_some());
let exporter = LogstashExporter::new("localhost", 5044)
.unwrap()
.with_tls(false);
assert!(!exporter.use_tls);
assert!(exporter.tls_connector.is_none());
}
#[tokio::test]
async fn test_format_event() {
let exporter = LogstashExporter::new("localhost", 5044).unwrap();
let event = AuditEvent::new(
EventType::AuthSuccess,
"alice".to_string(),
"session-123".to_string(),
);
let formatted = exporter.format_event(&event).unwrap();
assert!(formatted.ends_with('\n'));
let json_part = formatted.trim_end();
assert!(serde_json::from_str::<serde_json::Value>(json_part).is_ok());
}
#[tokio::test]
async fn test_export_single_event() {
let (addr, server_handle) = mock_logstash_server().await;
let exporter = LogstashExporter::new(&addr.ip().to_string(), addr.port()).unwrap();
let event = AuditEvent::new(
EventType::SessionStart,
"bob".to_string(),
"session-456".to_string(),
);
let result = exporter.export(event).await;
assert!(result.is_ok());
exporter.close().await.unwrap();
let received = server_handle.await.unwrap();
assert_eq!(received.len(), 1);
assert!(received[0].contains("session-456"));
assert!(received[0].contains("bob"));
}
#[tokio::test]
async fn test_export_batch() {
let (addr, server_handle) = mock_logstash_server().await;
let exporter = LogstashExporter::new(&addr.ip().to_string(), addr.port()).unwrap();
let events = vec![
AuditEvent::new(
EventType::AuthSuccess,
"user1".to_string(),
"session-1".to_string(),
),
AuditEvent::new(
EventType::FileUploaded,
"user2".to_string(),
"session-2".to_string(),
)
.with_result(EventResult::Success),
AuditEvent::new(
EventType::SessionEnd,
"user3".to_string(),
"session-3".to_string(),
),
];
let result = exporter.export_batch(&events).await;
assert!(result.is_ok());
exporter.close().await.unwrap();
let received = server_handle.await.unwrap();
assert_eq!(received.len(), 3);
assert!(received[0].contains("session-1"));
assert!(received[1].contains("session-2"));
assert!(received[2].contains("session-3"));
}
#[tokio::test]
async fn test_connection_timeout() {
let exporter = LogstashExporter::new("192.0.2.1", 5044).unwrap();
let event = AuditEvent::new(
EventType::AuthSuccess,
"test".to_string(),
"session-test".to_string(),
);
let result = exporter.export(event).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_flush() {
let (addr, server_handle) = mock_logstash_server().await;
let exporter = LogstashExporter::new(&addr.ip().to_string(), addr.port()).unwrap();
let event = AuditEvent::new(
EventType::CommandExecuted,
"charlie".to_string(),
"session-789".to_string(),
);
exporter.export(event).await.unwrap();
let result = exporter.flush().await;
assert!(result.is_ok());
exporter.close().await.unwrap();
server_handle.await.unwrap();
}
#[tokio::test]
async fn test_close() {
let (addr, _server_handle) = mock_logstash_server().await;
let exporter = LogstashExporter::new(&addr.ip().to_string(), addr.port()).unwrap();
let event = AuditEvent::new(
EventType::SessionStart,
"dave".to_string(),
"session-101".to_string(),
);
exporter.export(event).await.unwrap();
let result = exporter.close().await;
assert!(result.is_ok());
let result = exporter.close().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_json_lines_format() {
let exporter = LogstashExporter::new("localhost", 5044).unwrap();
let ip: IpAddr = "192.168.1.100".parse().unwrap();
let event = AuditEvent::new(
EventType::FileDownloaded,
"eve".to_string(),
"session-202".to_string(),
)
.with_client_ip(ip)
.with_bytes(2048);
let formatted = exporter.format_event(&event).unwrap();
assert!(formatted.ends_with('\n'));
let lines: Vec<&str> = formatted.lines().collect();
assert_eq!(lines.len(), 1);
let parsed: serde_json::Value = serde_json::from_str(lines[0]).unwrap();
assert_eq!(parsed["user"], "eve");
assert_eq!(parsed["session_id"], "session-202");
assert_eq!(parsed["bytes"], 2048);
}
}