use std::sync::{
Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
};
use bytes::BytesMut;
use obs_proto::obs::v1::{ObsEnvelope, Tier};
use tokio::{
runtime::Handle,
sync::{Mutex as AsyncMutex, mpsc, oneshot},
task::JoinHandle,
};
use crate::{
config::QueuesConfig,
registry::{SchemaRegistry, ScrubbedEnvelope},
sink::Sink,
};
#[derive(Debug, Default)]
pub struct WorkerCounters {
pub channel_full_log: AtomicU64,
pub channel_full_metric: AtomicU64,
pub channel_full_trace: AtomicU64,
pub channel_full_audit: AtomicU64,
pub delivered: AtomicU64,
}
pub struct TierWorker {
sender: parking_lot::Mutex<Option<mpsc::Sender<WorkerMsg>>>,
join: AsyncMutex<Option<JoinHandle<()>>>,
shutdown: Arc<AtomicBool>,
sink: Arc<dyn Sink>,
}
impl std::fmt::Debug for TierWorker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TierWorker")
.field("alive", &self.sender.lock().is_some())
.finish()
}
}
impl TierWorker {
pub fn spawn(
capacity: usize,
sink: Arc<dyn Sink>,
registry: Arc<SchemaRegistry>,
counters: Arc<WorkerCounters>,
tier: Tier,
) -> Self {
let (tx, mut rx) = mpsc::channel::<WorkerMsg>(capacity.max(1));
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_in = Arc::clone(&shutdown);
let sink_in = Arc::clone(&sink);
let registry_in = registry;
let counters_in = counters;
let join = tokio::spawn(async move {
let mut scratch = BytesMut::with_capacity(4096);
while let Some(msg) = rx.recv().await {
if let Some(env) = msg.envelope.as_ref() {
deliver_one(env, ®istry_in, &mut scratch, &sink_in);
counters_in.delivered.fetch_add(1, Ordering::Relaxed);
if shutdown_in.load(Ordering::Relaxed) && rx.is_empty() {
break;
}
} else if let Some(ack) = msg.flush {
sink_in.flush().await;
let _ = ack.send(());
}
}
while let Ok(msg) = rx.try_recv() {
if let Some(env) = msg.envelope.as_ref() {
deliver_one(env, ®istry_in, &mut scratch, &sink_in);
counters_in.delivered.fetch_add(1, Ordering::Relaxed);
} else if let Some(ack) = msg.flush {
sink_in.flush().await;
let _ = ack.send(());
}
}
sink_in.flush().await;
let _ = tier;
});
Self {
sender: parking_lot::Mutex::new(Some(tx)),
join: AsyncMutex::new(Some(join)),
shutdown,
sink,
}
}
#[allow(clippy::result_large_err)]
pub fn try_send(&self, env: ObsEnvelope) -> Result<(), ObsEnvelope> {
let guard = self.sender.lock();
let Some(sender) = guard.as_ref() else {
return Err(env);
};
match sender.try_send(WorkerMsg::envelope(env)) {
Ok(()) => Ok(()),
Err(mpsc::error::TrySendError::Full(msg) | mpsc::error::TrySendError::Closed(msg)) => {
if let Some(env) = msg.into_envelope() {
Err(env)
} else {
Ok(())
}
}
}
}
#[allow(clippy::result_large_err, dead_code)]
pub async fn send_with_timeout(
&self,
env: ObsEnvelope,
timeout: std::time::Duration,
) -> Result<(), ObsEnvelope> {
let sender = match self.sender.lock().as_ref() {
Some(s) => s.clone(),
None => return Err(env),
};
let cloned = env.clone();
match tokio::time::timeout(timeout, sender.send(WorkerMsg::envelope(env))).await {
Ok(Ok(())) => Ok(()),
Ok(Err(mpsc::error::SendError(msg))) => {
if let Some(env) = msg.into_envelope() {
Err(env)
} else {
Ok(())
}
}
Err(_) => Err(cloned),
}
}
pub async fn flush(&self) {
let sender = {
let guard = self.sender.lock();
guard.as_ref().cloned()
};
let Some(sender) = sender else {
self.sink.flush().await;
return;
};
let (tx, rx) = oneshot::channel();
if sender.send(WorkerMsg::flush(tx)).await.is_ok() {
let _ = rx.await;
} else {
self.sink.flush().await;
}
}
pub async fn shutdown(&self) {
self.shutdown.store(true, Ordering::SeqCst);
self.sender.lock().take();
let mut guard = self.join.lock().await;
if let Some(join) = guard.take() {
let _ = join.await;
}
self.sink.shutdown().await;
}
#[allow(dead_code)]
pub fn sink(&self) -> &Arc<dyn Sink> {
&self.sink
}
}
struct WorkerMsg {
envelope: Option<ObsEnvelope>,
flush: Option<oneshot::Sender<()>>,
}
impl WorkerMsg {
fn envelope(env: ObsEnvelope) -> Self {
Self {
envelope: Some(env),
flush: None,
}
}
fn flush(ack: oneshot::Sender<()>) -> Self {
Self {
envelope: None,
flush: Some(ack),
}
}
fn into_envelope(self) -> Option<ObsEnvelope> {
self.envelope
}
}
fn deliver_one(
env: &ObsEnvelope,
registry: &Arc<SchemaRegistry>,
scratch: &mut BytesMut,
sink: &Arc<dyn Sink>,
) {
scratch.clear();
let scrubbed = match ScrubbedEnvelope::scrub(env, registry, scratch) {
Ok(s) => s,
Err(_) => {
return;
}
};
sink.deliver(scrubbed);
}
pub fn spawn_tier_worker(
tier: Tier,
cfg: &QueuesConfig,
sink: Arc<dyn Sink>,
registry: Arc<SchemaRegistry>,
counters: Arc<WorkerCounters>,
) -> Option<TierWorker> {
let cap = match tier {
Tier::Log => cfg.log,
Tier::Metric => cfg.metric,
Tier::Trace => cfg.trace,
Tier::Audit => cfg.log,
_ => return None,
} as usize;
if Handle::try_current().is_err() {
return None;
}
Some(TierWorker::spawn(cap, sink, registry, counters, tier))
}
pub fn note_channel_full(counters: &WorkerCounters, tier: Tier) {
let target = match tier {
Tier::Log => &counters.channel_full_log,
Tier::Metric => &counters.channel_full_metric,
Tier::Trace => &counters.channel_full_trace,
Tier::Audit => &counters.channel_full_audit,
_ => return,
};
target.fetch_add(1, Ordering::Relaxed);
}