use std::{
collections::HashSet,
fs::remove_file,
path::{Path, PathBuf},
str::FromStr,
sync::{Arc, atomic::AtomicBool},
};
use anyhow::Result;
use dashmap::DashMap;
use log::warn;
use shared_mem_queue::{
byte_queue::ByteQueue,
msg_queue::{MqError, MsgQueue},
};
use shared_memory::{Shmem, ShmemConf};
use tokio::{
io::{AsyncReadExt as _, AsyncWriteExt},
net::{UnixListener, UnixStream},
task::{JoinHandle, JoinSet},
};
use uuid::Uuid;
use super::{ReceptionistFacade, tokenizer::parse_pid};
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
struct SubscriberRootPid(i32);
pub struct Receptionist {
address: PathBuf,
subscribers: Arc<DashMap<SubscriberRootPid, DedicatedOperator>>,
shutdown_request_tx: tokio::sync::oneshot::Sender<()>,
shutdown_complete_rx: tokio::sync::oneshot::Receiver<()>,
}
impl Receptionist {
pub fn get_receptionist_address(trace_path: &Path) -> PathBuf {
let mut path = trace_path.with_extension("sock");
if path.as_os_str().len() > 107 {
path.pop();
#[allow(clippy::string_slice)] path.push(&Uuid::new_v4().to_string()[0..5]);
}
path
}
pub fn startup(address: PathBuf) -> Result<Self> {
let (shutdown_request_tx, mut shutdown_request_rx) = tokio::sync::oneshot::channel::<()>();
let (shutdown_complete_tx, shutdown_complete_rx) = tokio::sync::oneshot::channel::<()>();
let retval = Self {
address,
subscribers: Arc::new(DashMap::new()),
shutdown_request_tx,
shutdown_complete_rx,
};
let listener = UnixListener::bind(&retval.address)?;
let subscribers = retval.subscribers.clone();
tokio::task::spawn(async move {
loop {
tokio::select! {
shutdown_event = (&mut shutdown_request_rx) => {
let mut remove = vec![];
for subscriber in subscribers.iter() {
remove.push(subscriber.key().clone());
}
let mut set = JoinSet::new();
for rem in remove {
let (_pid, subscriber) = subscribers.remove(&rem).unwrap();
set.spawn(subscriber.shutdown());
}
set.join_all().await;
if shutdown_event.is_ok() {
let _ = shutdown_complete_tx.send(());
}
return;
},
stream = listener.accept() => {
let (mut stream, _addr) = match stream {
Ok(inner) => inner,
Err(err) => {
warn!("Receptionist listener.accept had error: {err:?}");
continue;
}
};
let subscribe_pid = SubscriberRootPid(
match stream.read_i32().await {
Ok(pid) => pid,
Err(err) => {
warn!("Receptionist stream.read_i32 had error: {err:?}");
continue;
}
}
);
let operator = match DedicatedOperator::startup(&subscribe_pid) {
Ok(operator) => operator,
Err(err) => {
warn!("DedicatedOperator startup failed: {err:?}");
continue;
},
};
let unique_id = String::from(operator.unique_id());
let subscribe_pid_copy = subscribe_pid.clone();
subscribers.insert(subscribe_pid, operator);
if let Err(err) = stream.write_all(unique_id.as_bytes()).await {
warn!("Receptionist stream.write_all had error: {err:?}");
if let Some((_pid, operator)) = subscribers.remove(&subscribe_pid_copy) {
tokio::task::spawn(operator.shutdown());
}
}
}
}
}
});
Ok(retval)
}
}
impl ReceptionistFacade for Receptionist {
async fn peek_trace(&self, trace_line: &str) {
if self.subscribers.is_empty() {
return;
}
let mut tokenizer_input: &str = trace_line;
let Ok(pid) = parse_pid(&mut tokenizer_input) else {
warn!("peek_trace: parse_pid failed on input: {trace_line:?}");
return;
};
let pid = i32::from_str(pid).unwrap(); for subscribers in self.subscribers.iter() {
subscribers.value().peek_trace(pid, trace_line).await;
}
}
fn add_subprocess(&self, parent_pid: i32, child_pid: i32) {
for mut subscriber in self.subscribers.iter_mut() {
subscriber.value_mut().add_subprocess(parent_pid, child_pid);
}
}
fn remove_process(&self, pid: i32) {
let mut remove = vec![];
for mut subscriber in self.subscribers.iter_mut() {
if subscriber.value_mut().remove_process(pid) {
remove.push(subscriber.key().clone());
}
}
for rem in remove {
if let Some((_pid, subscriber)) = self.subscribers.remove(&rem) {
tokio::spawn(subscriber.shutdown());
}
}
}
async fn shutdown(self) {
self.shutdown_request_tx.send(()).unwrap(); let _ = self.shutdown_complete_rx.await; let _ = remove_file(self.address); }
}
struct DedicatedOperator {
pid_set: HashSet<i32>,
_shmem: ShmemWrapper, unique_id: String,
preshutdown_state: Option<DedicatedOperatorPreshutdownState>,
}
struct DedicatedOperatorPreshutdownState {
shutdown_request_tx: tokio::sync::oneshot::Sender<()>,
shutdown_complete_rx: tokio::sync::oneshot::Receiver<()>,
trace_tx: tokio::sync::mpsc::Sender<String>,
spawned_task: JoinHandle<()>,
}
struct ShmemWrapper {
_inner_shmem: Shmem,
}
unsafe impl Send for ShmemWrapper {}
unsafe impl Sync for ShmemWrapper {}
impl DedicatedOperator {
pub fn startup(pid: &SubscriberRootPid) -> Result<Self> {
let (shutdown_request_tx, mut shutdown_request_rx) = tokio::sync::oneshot::channel::<()>();
let (shutdown_complete_tx, shutdown_complete_rx) = tokio::sync::oneshot::channel::<()>();
let (trace_tx, mut trace_rx) = tokio::sync::mpsc::channel::<String>(64);
let unique_id = format!("testtrim-{}", Uuid::new_v4());
let shmem = ShmemConf::new().size(65536).os_id(&unique_id).create()?;
let bytequeue = unsafe { ByteQueue::create(shmem.as_ptr(), shmem.len()) };
let spawned_task = tokio::task::spawn(async move {
let mut msg_queue = MsgQueue::new(bytequeue, b"TT", [0u8; 0]);
loop {
tokio::select! {
shutdown_event = (&mut shutdown_request_rx) => {
while let Some(trace) = trace_rx.recv().await {
if let Err(err) = msg_queue.write_blocking(trace.as_ref()) {
warn!("msg_queue write error while flushing: {err:?}");
break;
}
}
if let Err(err) = msg_queue.write_blocking(&[0u8]) {
warn!("msg_queue write error while writing EOF: {err:?}");
}
if shutdown_event.is_ok() {
let _ = shutdown_complete_tx.send(());
}
return;
},
trace = trace_rx.recv() => {
let Some(trace) = trace else { continue; };
if let Err(err) = msg_queue.write_blocking(trace.as_ref()) {
warn!("msg_queue write error: {err:?}");
break;
}
},
}
}
});
let mut pid_set = HashSet::new();
pid_set.insert(pid.0);
Ok(DedicatedOperator {
pid_set,
_shmem: ShmemWrapper {
_inner_shmem: shmem,
},
unique_id,
preshutdown_state: Some(DedicatedOperatorPreshutdownState {
shutdown_request_tx,
shutdown_complete_rx,
trace_tx,
spawned_task,
}),
})
}
pub async fn peek_trace(&self, pid: i32, trace_line: &str) {
if self.pid_set.contains(&pid) {
if let Some(ref state) = self.preshutdown_state {
if let Err(err) = state.trace_tx.send(String::from(trace_line)).await {
warn!("trace_tx write error: {err:?}");
}
} else {
warn!("peek_trace: dropped interesting message due to shutdown in-progress");
}
}
}
pub fn add_subprocess(&mut self, parent_pid: i32, child_pid: i32) {
if self.pid_set.contains(&parent_pid) {
self.pid_set.insert(child_pid);
}
}
pub fn remove_process(&mut self, pid: i32) -> bool {
self.pid_set.remove(&pid);
self.pid_set.is_empty()
}
pub async fn shutdown(mut self) {
if let Some(state) = self.preshutdown_state.take() {
drop(state.trace_tx); state.shutdown_request_tx.send(()).unwrap(); let _ = state.shutdown_complete_rx.await; }
}
pub fn unique_id(&self) -> &str {
&self.unique_id
}
}
impl Drop for DedicatedOperator {
fn drop(&mut self) {
if let Some(state) = self.preshutdown_state.take() {
warn!(
"Drop of DedicatedOperator without a `shutdown`; an incomplete trace may be sent to a subscriber"
);
drop(state.trace_tx);
state.shutdown_request_tx.send(()).unwrap(); state.spawned_task.abort();
}
}
}
pub struct TraceClient {
trace_rx: tokio::sync::mpsc::Receiver<String>,
read_join_handle: Option<JoinHandle<Result<()>>>,
shutdown_signal: Arc<AtomicBool>,
}
impl TraceClient {
pub async fn try_create(receptionist_address: &Path, pid: i32) -> Result<Self> {
let mut stream = UnixStream::connect(receptionist_address).await?;
stream.write_i32(pid).await?;
let mut unique_id = [0u8; 64];
let unique_id_len = stream.read(&mut unique_id).await?;
let unique_id = String::from_utf8(unique_id[..unique_id_len].to_vec())?;
let (trace_tx, trace_rx) = tokio::sync::mpsc::channel::<String>(64);
let shutdown_signal = Arc::new(AtomicBool::new(false));
let shutdown_signal_inner = shutdown_signal.clone();
let read_join_handle = tokio::task::spawn_blocking(move || {
let shmem = ShmemConf::new().os_id(&unique_id).open()?;
let bytequeue = unsafe { ByteQueue::attach(shmem.as_ptr(), shmem.len()) };
let read_buffer = [0u8; 4096];
let mut msgqueue = MsgQueue::new(bytequeue, b"TT", read_buffer);
loop {
if shutdown_signal_inner.fetch_or(false, std::sync::atomic::Ordering::Relaxed) {
break;
}
match msgqueue.read_or_fail() {
Ok(msg) => {
if msg.len() == 1 && msg[0] == 0 {
break;
}
let vec = msg.to_vec();
trace_tx.blocking_send(unsafe { String::from_utf8_unchecked(vec) })?;
}
Err(MqError::MqEmpty) => {
std::thread::sleep(std::time::Duration::from_millis(1));
}
Err(other) => {
return Err(other.into());
}
}
}
anyhow::Result::<()>::Ok(())
});
Ok(TraceClient {
trace_rx,
read_join_handle: Some(read_join_handle),
shutdown_signal,
})
}
pub async fn next_line(&mut self) -> Result<Option<String>> {
Ok(self.trace_rx.recv().await)
}
pub async fn shutdown(mut self) {
self.shutdown_signal
.store(true, std::sync::atomic::Ordering::Relaxed);
if let Some(read_join_handle) = self.read_join_handle.take() {
if let Err(err) = read_join_handle.await.unwrap() {
warn!("failure in TraceClient spawn process: {err:?}");
}
}
}
}
impl Drop for TraceClient {
fn drop(&mut self) {
self.shutdown_signal
.store(true, std::sync::atomic::Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use crate::sys_trace::strace::{
ReceptionistFacade as _,
shmem::{Receptionist, TraceClient},
};
use anyhow::Result;
use tokio::time::timeout;
use uuid::Uuid;
#[tokio::test]
async fn receptionist_basic_read_write() -> Result<()> {
let address = PathBuf::from(format!("receptionist.sock-{}", Uuid::new_v4()));
let receptionist = Receptionist::startup(address.clone())?;
let mut client = TraceClient::try_create(&address, 100).await?;
receptionist.peek_trace("100 close(3) = 0").await;
assert_eq!(
Some(String::from("100 close(3) = 0")),
client.next_line().await?
);
receptionist.peek_trace("100 close(5) = 0").await;
assert_eq!(
Some(String::from("100 close(5) = 0")),
client.next_line().await?
);
client.shutdown().await;
receptionist.shutdown().await;
Ok(())
}
#[tokio::test]
async fn receptionist_filter_pid() -> Result<()> {
let address = PathBuf::from(format!("receptionist.sock-{}", Uuid::new_v4()));
let receptionist = Receptionist::startup(address.clone())?;
let mut client = TraceClient::try_create(&address, 100).await?;
receptionist.peek_trace("101 close(3) = 0").await;
receptionist.peek_trace("100 close(3) = 0").await;
let next_line = timeout(std::time::Duration::from_millis(100), client.next_line())
.await
.expect("timeout on client.next_line()")?;
assert_eq!(Some(String::from("100 close(3) = 0")), next_line);
client.shutdown().await;
receptionist.shutdown().await;
Ok(())
}
#[tokio::test]
async fn receptionist_filter_subprocess() -> Result<()> {
let address = PathBuf::from(format!("receptionist.sock-{}", Uuid::new_v4()));
let receptionist = Receptionist::startup(address.clone())?;
let mut client = TraceClient::try_create(&address, 100).await?;
receptionist.add_subprocess(100, 103);
receptionist.peek_trace("101 close(3) = 0").await;
receptionist.peek_trace("100 close(3) = 0").await;
let next_line = timeout(std::time::Duration::from_millis(100), client.next_line())
.await
.expect("timeout on client.next_line()")?;
assert_eq!(Some(String::from("100 close(3) = 0")), next_line);
receptionist.peek_trace("103 close(5) = 0").await;
let next_line = timeout(std::time::Duration::from_millis(100), client.next_line())
.await
.expect("timeout on client.next_line()")?;
assert_eq!(Some(String::from("103 close(5) = 0")), next_line);
client.shutdown().await;
receptionist.shutdown().await;
Ok(())
}
#[tokio::test]
async fn receptionist_auto_shutdown() -> Result<()> {
let address = PathBuf::from(format!("receptionist.sock-{}", Uuid::new_v4()));
let receptionist = Receptionist::startup(address.clone())?;
let mut client = TraceClient::try_create(&address, 100).await?;
receptionist.peek_trace("100 close(3) = 0").await;
assert_eq!(
Some(String::from("100 close(3) = 0")),
client.next_line().await?
);
receptionist.remove_process(100);
let next_line = timeout(std::time::Duration::from_millis(100), client.next_line())
.await
.expect("timeout on client.next_line()")?;
assert_eq!(None, next_line);
client.shutdown().await;
receptionist.shutdown().await;
Ok(())
}
}