use crate::sinks::MetricSink;
use std::fs;
use std::io::{self, ErrorKind};
use std::os::unix::net::UnixDatagram;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::Duration;
use std::{env, thread};
#[derive(Debug)]
pub struct TempDir {
base: PathBuf,
}
impl TempDir {
pub fn new<P>(prefix: P) -> io::Result<Self>
where
P: AsRef<Path>,
{
let base = env::temp_dir().join(prefix);
fs::create_dir_all(&base)?;
Ok(TempDir { base })
}
pub fn new_path<P>(&self, name: P) -> PathBuf
where
P: AsRef<Path>,
{
self.base.join(name)
}
}
impl Drop for TempDir {
fn drop(&mut self) {
let _ = fs::remove_dir_all(&self.base);
}
}
pub trait DatagramConsumer {
fn accept(&self, datagram: String);
}
impl<F> DatagramConsumer for F
where
F: Fn(String),
{
fn accept(&self, datagram: String) {
(self)(datagram);
}
}
pub struct UnixSocketServer {
ready: AtomicBool,
shutdown: AtomicBool,
path: PathBuf,
consumer: Arc<dyn DatagramConsumer + Send + Sync + 'static>,
interval: Duration,
}
impl UnixSocketServer {
pub fn new<P, C>(path: P, interval: Duration, consumer: C) -> Self
where
P: AsRef<Path>,
C: DatagramConsumer + Send + Sync + 'static,
{
UnixSocketServer {
ready: AtomicBool::new(false),
shutdown: AtomicBool::new(false),
path: path.as_ref().to_path_buf(),
consumer: Arc::new(consumer),
interval,
}
}
pub fn is_ready(&self) -> bool {
self.ready.load(Ordering::Acquire)
}
pub fn run(&self) -> io::Result<()> {
let _ = fs::remove_file(&self.path);
let socket = UnixDatagram::bind(&self.path)?;
socket.set_read_timeout(Some(self.interval))?;
let mut buf = [0u8; 1024];
self.ready.store(true, Ordering::Release);
loop {
match socket.recv(&mut buf) {
Ok(v) => match std::str::from_utf8(&buf[0..v]) {
Ok(s) => self.consumer.accept(s.to_owned()),
Err(e) => eprintln!("Error: Couldn't decode string to utf-8 {}", e),
},
Err(e) => {
if e.kind() == ErrorKind::WouldBlock {
if self.shutdown.load(Ordering::Acquire) {
break;
}
} else {
eprintln!("Error: {} - {:?}", e, e.kind());
}
}
}
}
Ok(())
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
}
}
pub struct UnixServerHarness {
base: PathBuf,
server: Option<Arc<UnixSocketServer>>,
thread: Option<JoinHandle<()>>,
}
impl UnixServerHarness {
pub fn new<P>(prefix: P) -> Self
where
P: AsRef<Path>,
{
UnixServerHarness {
base: prefix.as_ref().to_path_buf(),
server: None,
thread: None,
}
}
pub fn run<C, F>(mut self, consumer: C, body: F)
where
C: DatagramConsumer + Send + Sync + 'static,
F: FnOnce(&Path),
{
let temp = TempDir::new(&self.base).unwrap();
let socket = temp.new_path("cadence.sock");
let server = Arc::new(UnixSocketServer::new(&socket, Duration::from_millis(100), consumer));
let server_local = server.clone();
let t = thread::spawn(move || {
server_local.run().unwrap();
});
while !server.is_ready() {
thread::yield_now();
}
self.server = Some(server);
self.thread = Some(t);
body(&socket);
}
pub fn run_quiet<F>(self, body: F)
where
F: FnOnce(&Path),
{
self.run(|_| (), body)
}
}
impl Drop for UnixServerHarness {
fn drop(&mut self) {
if let Some(s) = self.server.take() {
s.shutdown();
}
if let Some(t) = self.thread.take() {
let _ = t.join();
}
}
}
struct Every {
modulo: u64,
counter: AtomicU64,
}
impl Every {
fn new(modulo: u64) -> Self {
assert_ne!(modulo, 0, "modulo must be >= 1");
Every {
modulo,
counter: AtomicU64::new(1),
}
}
fn allow(&self) -> bool {
self.counter.fetch_add(1, Ordering::SeqCst) % self.modulo == 0
}
}
pub struct PanickingMetricSink {
every: Every,
}
impl PanickingMetricSink {
pub fn every(every: u64) -> Self {
PanickingMetricSink {
every: Every::new(every),
}
}
pub fn always() -> Self {
Self::every(1)
}
}
impl MetricSink for PanickingMetricSink {
fn emit(&self, m: &str) -> io::Result<usize> {
if self.every.allow() {
panic!("This sink is supposed to panic");
} else {
Ok(m.len())
}
}
}
pub struct ErrorMetricSink {
every: Every,
}
impl ErrorMetricSink {
pub fn every(every: u64) -> Self {
ErrorMetricSink {
every: Every::new(every),
}
}
pub fn always() -> Self {
Self::every(1)
}
}
impl MetricSink for ErrorMetricSink {
fn emit(&self, m: &str) -> io::Result<usize> {
if self.every.allow() {
io::Result::Err(io::Error::from(io::ErrorKind::TimedOut))
} else {
Ok(m.len())
}
}
}