use std::io::{self, Write};
use std::sync::Arc;
use std::time::Duration;
use sharded_sink::{ShardedSink, SinkAction, SinkConfig};
use crate::BoxError;
use crate::config::Environment;
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(default)]
#[non_exhaustive]
pub struct ShardedSinkConfig {
pub enabled: bool,
pub shards: usize,
pub ring_capacity: usize,
pub drain_batch: usize,
pub idle_sleep_micros: u64,
pub shutdown_timeout_secs: Option<u64>,
}
impl Default for ShardedSinkConfig {
fn default() -> Self {
Self {
enabled: false,
shards: 1,
ring_capacity: 8192,
drain_batch: 256,
idle_sleep_micros: 100,
shutdown_timeout_secs: Some(5),
}
}
}
impl ShardedSinkConfig {
fn to_sink_config(&self, name: &'static str) -> SinkConfig {
let mut cfg = SinkConfig::default();
cfg.name = name;
cfg.shards = self.shards.max(1);
cfg.ring_capacity = self.ring_capacity.max(1);
cfg.drain_batch = self.drain_batch.max(1);
cfg.idle_sleep = Duration::from_micros(self.idle_sleep_micros.max(1));
cfg.shutdown_timeout = self.shutdown_timeout_secs.map(Duration::from_secs);
cfg
}
}
type LogLine = Vec<u8>;
struct WriteDrain<W: Write + Send + 'static> {
writer: std::sync::Mutex<W>,
}
impl<W: Write + Send + 'static> SinkAction<LogLine> for WriteDrain<W> {
async fn drain(&self, batch: &mut Vec<LogLine>) {
let mut writer = match self.writer.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
for line in batch.iter() {
let _written = writer.write_all(line);
}
let _flushed = writer.flush();
}
}
#[derive(Clone)]
struct ShardedMakeWriter {
sink: Arc<ShardedSink<LogLine>>,
}
struct ShardedLineWriter {
sink: Arc<ShardedSink<LogLine>>,
buf: Vec<u8>,
}
impl Write for ShardedLineWriter {
fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
self.buf.extend_from_slice(bytes);
Ok(bytes.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl Drop for ShardedLineWriter {
fn drop(&mut self) {
if !self.buf.is_empty() {
let line = std::mem::take(&mut self.buf);
let _ = self.sink.push(line);
}
}
}
impl<'writer> tracing_subscriber::fmt::MakeWriter<'writer> for ShardedMakeWriter {
type Writer = ShardedLineWriter;
fn make_writer(&'writer self) -> Self::Writer {
ShardedLineWriter {
sink: Arc::clone(&self.sink),
buf: Vec::new(),
}
}
}
#[derive(Debug)]
pub struct ShardedLogGuard {
sink: Option<Arc<ShardedSink<LogLine>>>,
}
impl ShardedLogGuard {
pub async fn shutdown(self) -> Result<(), BoxError> {
if let Some(sink) = self.sink {
sink.shutdown().await.map_err(BoxError::from)?;
}
Ok(())
}
}
fn spawn_log_sink<W: Write + Send + 'static>(
cfg: &ShardedSinkConfig,
writer: W,
) -> (ShardedMakeWriter, ShardedLogGuard) {
let action = Arc::new(WriteDrain {
writer: std::sync::Mutex::new(writer),
});
let sink = Arc::new(ShardedSink::spawn_default_overload(
cfg.to_sink_config("rusty-gasket-log"),
action,
));
(
ShardedMakeWriter {
sink: Arc::clone(&sink),
},
ShardedLogGuard { sink: Some(sink) },
)
}
#[must_use]
pub fn init_tracing_sharded(env: Environment, cfg: &ShardedSinkConfig) -> ShardedLogGuard {
use tracing_subscriber::EnvFilter;
use tracing_subscriber::prelude::*;
if !cfg.enabled {
super::init_tracing(env);
return ShardedLogGuard { sink: None };
}
let (make_writer, guard) = spawn_log_sink(cfg, io::stdout());
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let registry = tracing_subscriber::registry().with(filter);
match env {
Environment::Local => {
let fmt_layer = tracing_subscriber::fmt::layer()
.with_target(true)
.with_thread_ids(false)
.with_writer(make_writer)
.pretty();
registry.with(fmt_layer).init();
}
_ => {
#[cfg(feature = "json-log")]
{
let fmt_layer = tracing_subscriber::fmt::layer()
.json()
.with_target(true)
.with_writer(make_writer);
registry.with(fmt_layer).init();
}
#[cfg(not(feature = "json-log"))]
{
let fmt_layer = tracing_subscriber::fmt::layer()
.with_target(true)
.with_writer(make_writer);
registry.with(fmt_layer).init();
}
}
}
guard
}
#[cfg(feature = "otlp")]
#[must_use]
pub fn init_tracing_with_otel_sharded(
env: Environment,
provider: &opentelemetry_sdk::trace::SdkTracerProvider,
service_name: &'static str,
cfg: &ShardedSinkConfig,
) -> ShardedLogGuard {
use opentelemetry::trace::TracerProvider;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::prelude::*;
if !cfg.enabled {
super::init_tracing_with_otel(env, provider, service_name);
return ShardedLogGuard { sink: None };
}
let (make_writer, guard) = spawn_log_sink(cfg, io::stdout());
let filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info,h2=off,hyper=off,rustls=off,tonic=off"));
let tracer = provider.tracer(service_name);
let otel_layer = tracing_opentelemetry::layer().with_tracer(tracer);
let registry = tracing_subscriber::registry().with(filter).with(otel_layer);
match env {
Environment::Local => {
let fmt_layer = tracing_subscriber::fmt::layer()
.with_target(true)
.with_writer(make_writer)
.pretty();
registry.with(fmt_layer).init();
}
_ => {
#[cfg(feature = "json-log")]
{
let fmt_layer = tracing_subscriber::fmt::layer()
.json()
.with_target(true)
.with_writer(make_writer);
registry.with(fmt_layer).init();
}
#[cfg(not(feature = "json-log"))]
{
let fmt_layer = tracing_subscriber::fmt::layer()
.with_target(true)
.with_writer(make_writer);
registry.with(fmt_layer).init();
}
}
}
guard
}
#[cfg(feature = "auth")]
mod audit_offload {
use super::{Arc, ShardedSinkConfig};
use crate::auth::{AuditLogger, AuditLoggerHandle, AuthAuditEvent};
use sharded_sink::{ShardedSink, SinkAction};
struct AuditDrain {
inner: AuditLoggerHandle,
}
impl SinkAction<AuthAuditEvent> for AuditDrain {
async fn drain(&self, batch: &mut Vec<AuthAuditEvent>) {
for event in batch.iter() {
self.inner.logger().log_auth_event(event);
}
}
}
#[derive(Debug)]
pub struct ShardedAuditLogger {
sink: Arc<ShardedSink<AuthAuditEvent>>,
}
impl ShardedAuditLogger {
#[must_use]
pub fn wrap(inner: AuditLoggerHandle, cfg: &ShardedSinkConfig) -> AuditLoggerHandle {
if !cfg.enabled {
return inner;
}
let sink = Arc::new(ShardedSink::spawn_default_overload(
cfg.to_sink_config("rusty-gasket-audit"),
Arc::new(AuditDrain { inner }),
));
AuditLoggerHandle::shared(Arc::new(ShardedAuditLogger { sink }))
}
}
impl AuditLogger for ShardedAuditLogger {
fn log_auth_event(&self, event: &AuthAuditEvent) {
let _ = self.sink.push(event.clone());
}
}
}
#[cfg(feature = "auth")]
pub use audit_offload::ShardedAuditLogger;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_default_is_disabled_single_shard() {
let cfg = ShardedSinkConfig::default();
assert!(!cfg.enabled);
assert_eq!(cfg.shards, 1);
}
#[test]
fn to_sink_config_maps_and_clamps_fields() {
let cfg = ShardedSinkConfig {
enabled: true,
shards: 0,
ring_capacity: 0,
drain_batch: 0,
idle_sleep_micros: 0,
shutdown_timeout_secs: None,
..Default::default()
};
let sink_cfg = cfg.to_sink_config("test");
assert_eq!(sink_cfg.name, "test");
assert_eq!(sink_cfg.shards, 1, "shards clamped to >= 1");
assert_eq!(sink_cfg.ring_capacity, 1, "ring_capacity clamped to >= 1");
assert_eq!(sink_cfg.drain_batch, 1, "drain_batch clamped to >= 1");
assert_eq!(sink_cfg.shutdown_timeout, None);
}
#[derive(Clone)]
struct BufWriter(Arc<std::sync::Mutex<Vec<u8>>>);
impl Write for BufWriter {
fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
let mut guard = self.0.lock().unwrap_or_else(|p| p.into_inner());
guard.extend_from_slice(bytes);
Ok(bytes.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn formatted_lines_are_written_off_thread() {
let buf = Arc::new(std::sync::Mutex::new(Vec::<u8>::new()));
let cfg = ShardedSinkConfig {
enabled: true,
..Default::default()
};
let action = Arc::new(WriteDrain {
writer: std::sync::Mutex::new(BufWriter(Arc::clone(&buf))),
});
let sink = ShardedSink::spawn_default_overload(cfg.to_sink_config("test-log"), action);
for i in 0..200_u32 {
assert!(sink.push(format!("line {i}\n").into_bytes()));
}
sink.shutdown().await.expect("shutdown");
let written = String::from_utf8(buf.lock().expect("buf lock").clone()).expect("utf8");
assert_eq!(written.lines().count(), 200);
}
#[cfg(feature = "auth")]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn audit_wrap_disabled_forwards_inline() {
use crate::auth::{AuditLogger, AuditLoggerHandle, AuthAuditEvent, AuthAuditOutcome};
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Clone)]
struct Counting(Arc<AtomicUsize>);
impl AuditLogger for Counting {
fn log_auth_event(&self, _event: &AuthAuditEvent) {
self.0.fetch_add(1, Ordering::SeqCst);
}
}
let count = Arc::new(AtomicUsize::new(0));
let inner = AuditLoggerHandle::new(Counting(Arc::clone(&count)));
let cfg = ShardedSinkConfig::default(); let handle = ShardedAuditLogger::wrap(inner, &cfg);
handle.logger().log_auth_event(&AuthAuditEvent::new(
"rid",
"1.2.3.4",
AuthAuditOutcome::Success,
));
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[cfg(feature = "auth")]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn audit_wrap_enabled_forwards_off_path() {
use crate::auth::{AuditLogger, AuditLoggerHandle, AuthAuditEvent, AuthAuditOutcome};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
#[derive(Clone)]
struct Counting(Arc<AtomicUsize>);
impl AuditLogger for Counting {
fn log_auth_event(&self, _event: &AuthAuditEvent) {
self.0.fetch_add(1, Ordering::SeqCst);
}
}
let count = Arc::new(AtomicUsize::new(0));
let inner = AuditLoggerHandle::new(Counting(Arc::clone(&count)));
let cfg = ShardedSinkConfig {
enabled: true,
..Default::default()
};
let handle = ShardedAuditLogger::wrap(inner, &cfg);
for _ in 0..50 {
handle.logger().log_auth_event(&AuthAuditEvent::new(
"rid",
"1.2.3.4",
AuthAuditOutcome::Success,
));
}
let mut forwarded = 0;
for _ in 0..100 {
forwarded = count.load(Ordering::SeqCst);
if forwarded == 50 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(forwarded, 50, "all audit events forwarded to inner logger");
}
}