use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::Duration,
};
use objects::object::OperationId;
use prost_types::Timestamp;
use serde::{Deserialize, Serialize};
use tokio::sync::{broadcast, mpsc, oneshot};
use grpc::heddle::v1::HookEvent as ProtoHookEvent;
const BROADCAST_CAPACITY: usize = 256;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HookResponse {
#[serde(default)]
pub abort: String,
#[serde(flatten, default)]
pub extra: serde_json::Value,
}
struct ResponseSlot {
sender: oneshot::Sender<HookResponse>,
}
#[derive(Clone)]
pub struct HookEventBroadcaster {
inner: Arc<HookEventBroadcasterInner>,
}
struct HookEventBroadcasterInner {
sender: broadcast::Sender<ProtoHookEvent>,
pending: Mutex<HashMap<String, ResponseSlot>>,
}
impl Default for HookEventBroadcaster {
fn default() -> Self {
Self::new()
}
}
impl HookEventBroadcaster {
pub fn new() -> Self {
let (sender, _) = broadcast::channel(BROADCAST_CAPACITY);
Self {
inner: Arc::new(HookEventBroadcasterInner {
sender,
pending: Mutex::new(HashMap::new()),
}),
}
}
pub fn subscribe(&self) -> mpsc::Receiver<ProtoHookEvent> {
let mut rx = self.inner.sender.subscribe();
let (tx, out_rx) = mpsc::channel(16);
tokio::spawn(async move {
loop {
match rx.recv().await {
Ok(event) => {
if tx.send(event).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(_)) => {
break;
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
});
out_rx
}
pub fn emit(&self, event_name: impl Into<String>, payload_json: impl Into<String>) -> String {
let hook_event_id = OperationId::new().to_string();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
let event = ProtoHookEvent {
hook_event_id: hook_event_id.clone(),
event_name: event_name.into(),
payload_json: payload_json.into(),
emitted_at: Some(Timestamp {
seconds: now.as_secs() as i64,
nanos: now.subsec_nanos() as i32,
}),
};
let _ = self.inner.sender.send(event);
hook_event_id
}
pub fn emit_and_wait(
&self,
event_name: impl Into<String>,
payload_json: impl Into<String>,
timeout: Duration,
) -> (String, EmitWaiter) {
let (sender, receiver) = oneshot::channel();
let event_name = event_name.into();
let payload_json = payload_json.into();
let hook_event_id = OperationId::new().to_string();
{
let mut pending = self
.inner
.pending
.lock()
.expect("hook broker pending map poisoned");
pending.insert(hook_event_id.clone(), ResponseSlot { sender });
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
let event = ProtoHookEvent {
hook_event_id: hook_event_id.clone(),
event_name,
payload_json,
emitted_at: Some(Timestamp {
seconds: now.as_secs() as i64,
nanos: now.subsec_nanos() as i32,
}),
};
let _ = self.inner.sender.send(event);
let waiter = EmitWaiter {
broker: self.clone(),
hook_event_id: hook_event_id.clone(),
receiver,
timeout,
};
(hook_event_id, waiter)
}
pub async fn await_response(
&self,
hook_event_id: &str,
timeout: Duration,
) -> Option<HookResponse> {
let receiver = {
let mut pending = self
.inner
.pending
.lock()
.expect("hook broker pending map poisoned");
pending.remove(hook_event_id).map(|slot| slot.sender)
};
let receiver = match receiver {
Some(_already_taken) => {
let (sender, receiver) = oneshot::channel();
let mut pending = self
.inner
.pending
.lock()
.expect("hook broker pending map poisoned");
pending.insert(hook_event_id.to_string(), ResponseSlot { sender });
receiver
}
None => {
let (sender, receiver) = oneshot::channel();
let mut pending = self
.inner
.pending
.lock()
.expect("hook broker pending map poisoned");
pending.insert(hook_event_id.to_string(), ResponseSlot { sender });
receiver
}
};
match tokio::time::timeout(timeout, receiver).await {
Ok(Ok(response)) => Some(response),
Ok(Err(_canceled)) => None,
Err(_elapsed) => {
let mut pending = self
.inner
.pending
.lock()
.expect("hook broker pending map poisoned");
pending.remove(hook_event_id);
None
}
}
}
pub fn deliver_response(&self, hook_event_id: &str, response: HookResponse) -> bool {
let slot = {
let mut pending = self
.inner
.pending
.lock()
.expect("hook broker pending map poisoned");
pending.remove(hook_event_id)
};
match slot {
Some(slot) => slot.sender.send(response).is_ok(),
None => false,
}
}
#[cfg(test)]
fn subscriber_count(&self) -> usize {
self.inner.sender.receiver_count()
}
}
pub struct EmitWaiter {
broker: HookEventBroadcaster,
hook_event_id: String,
receiver: oneshot::Receiver<HookResponse>,
timeout: Duration,
}
impl EmitWaiter {
pub async fn wait(self) -> Option<HookResponse> {
let EmitWaiter {
broker,
hook_event_id,
receiver,
timeout,
} = self;
match tokio::time::timeout(timeout, receiver).await {
Ok(Ok(response)) => Some(response),
Ok(Err(_canceled)) => {
broker
.inner
.pending
.lock()
.expect("hook broker pending map poisoned")
.remove(&hook_event_id);
None
}
Err(_elapsed) => {
broker
.inner
.pending
.lock()
.expect("hook broker pending map poisoned")
.remove(&hook_event_id);
None
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn emit_round_trips_to_subscriber() {
let broker = HookEventBroadcaster::new();
let mut sub = broker.subscribe();
tokio::task::yield_now().await;
let id = broker.emit("pre_capture", "{\"thread\":\"t1\"}");
let event = sub.recv().await.expect("event");
assert_eq!(event.hook_event_id, id);
assert_eq!(event.event_name, "pre_capture");
assert!(event.payload_json.contains("t1"));
}
#[tokio::test]
async fn await_response_returns_delivered_reply() {
let broker = HookEventBroadcaster::new();
let _sub = broker.subscribe();
tokio::task::yield_now().await;
let (id, waiter) = broker.emit_and_wait("pre_capture", "{}", Duration::from_secs(1));
let id_for_reply = id.clone();
let broker_clone = broker.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
let _ = broker_clone.deliver_response(
&id_for_reply,
HookResponse {
abort: "veto".into(),
extra: serde_json::Value::Null,
},
);
});
let response = waiter.wait().await.expect("response");
assert_eq!(response.abort, "veto");
}
#[tokio::test]
async fn await_response_times_out_with_no_reply() {
let broker = HookEventBroadcaster::new();
let _sub = broker.subscribe();
let (_id, waiter) = broker.emit_and_wait("pre_capture", "{}", Duration::from_millis(20));
let response = waiter.wait().await;
assert!(response.is_none());
}
#[tokio::test]
async fn deliver_to_unknown_id_returns_false() {
let broker = HookEventBroadcaster::new();
let accepted = broker.deliver_response("no-such-id", HookResponse::default());
assert!(!accepted);
}
#[tokio::test]
async fn subscribers_are_independent() {
let broker = HookEventBroadcaster::new();
let mut a = broker.subscribe();
let mut b = broker.subscribe();
tokio::task::yield_now().await;
assert_eq!(broker.subscriber_count(), 2);
broker.emit("post_capture", "{}");
let event_a = a.recv().await.expect("a");
let event_b = b.recv().await.expect("b");
assert_eq!(event_a.hook_event_id, event_b.hook_event_id);
}
}