use std::{
borrow::Cow, collections::HashMap, convert::Infallible, sync::Arc
};
use axum::{http::response, response::sse::{Event, KeepAlive, Sse}};
use futures::stream::{Stream, StreamExt};
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::ReceiverStream;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct TopicId(u32);
use crate::{Site, callables, signals};
pub enum ChannelError {
SubscribeError,
PublishError,
UnsubscribeError,
}
#[derive(Debug)]
pub enum Unsubscribe {
User(Cow<'static, str>),
Channel(Cow<'static, str>),
Subscription(uuid::Uuid),
}
#[derive(Debug)]
pub struct ChannelSubscription {
pub id: uuid::Uuid,
pub url: String,
}
pub trait ChannelService {
fn send(
&self,
req: ChannelRequest,
) -> impl Future<Output = Result<(), ChannelError>>;
}
#[derive(Debug)]
pub enum ChannelMessage {
Publish {
topic: String,
message: String,
user_keys: Option<Vec<String>>,
},
Subscribe {
user_key: String,
channel_key: String,
topics: Vec<String>,
},
Unsubscribe {
unsub: Unsubscribe,
},
}
#[derive(Debug)]
pub enum ChannelResponse {
Publish(usize),
Subscription((ChannelSubscription, mpsc::Receiver<Arc<ChannelEvent>>)),
Unsubscribe(usize),
}
#[derive(Debug)]
pub struct ChannelEvent {
topic_id: TopicId,
user_keys: Option<Vec<String>>,
pub topic: String,
pub message: String,
}
impl ChannelEvent {
}
#[derive(Debug)]
pub struct ChannelRequest {
pub message: ChannelMessage,
pub reply: oneshot::Sender<ChannelResponse>,
}
#[derive(Debug)]
struct LocalSubscription {
id: uuid::Uuid,
user_key: String,
channel_key: String,
topics: Vec<TopicId>,
sender: mpsc::Sender<Arc<ChannelEvent>>,
}
pub struct LocalChannelService {
queue_size: usize,
base_url: String,
sender: mpsc::Sender<ChannelRequest>,
}
impl ChannelService for LocalChannelService {
async fn send(
&self,
req: ChannelRequest,
) -> Result<(), ChannelError> {
self.sender
.send(req)
.await
.map_err(|_| ChannelError::PublishError)?;
Ok(())
}
}
impl LocalChannelService {
pub fn new(sender: mpsc::Sender<ChannelRequest>, base_url: String, queue_size: usize) -> Self {
Self {
queue_size,
base_url,
sender,
}
}
pub fn make_absolute_url(&self, sub_id: uuid::Uuid, channel_key: &str) -> String {
format!(
"{}{}",
self.base_url,
format!("/channels/{}/subscriptions/{}", channel_key, sub_id)
)
}
fn handle_request(
&self,
inner: &mut ChannelInner,
request: ChannelRequest,
) -> Result<(), ChannelResponse> {
match request.message {
ChannelMessage::Publish { topic, message, user_keys } => {
let delivered = inner.publish(topic, message, user_keys);
request.reply.send(ChannelResponse::Publish(delivered))
}
ChannelMessage::Subscribe {
user_key,
channel_key,
topics,
} => {
let (sub_id, receiver) =
inner.subscribe(user_key.clone(), channel_key.clone(), topics);
let subscription = ChannelSubscription {
id: sub_id,
url: self.make_absolute_url(sub_id, &channel_key),
};
request
.reply
.send(ChannelResponse::Subscription((subscription, receiver)))
}
ChannelMessage::Unsubscribe { unsub } => match unsub {
Unsubscribe::Subscription(sub_id) => {
let success = inner.unsubscribe(sub_id);
request.reply.send(ChannelResponse::Unsubscribe(success))
}
Unsubscribe::User(user_key) => {
let count = inner.unsubscribe_by_user(&user_key);
request.reply.send(ChannelResponse::Unsubscribe(count))
}
Unsubscribe::Channel(channel_key) => {
let count = inner.unsubscribe_by_channel(&channel_key);
request.reply.send(ChannelResponse::Unsubscribe(count))
}
},
}
}
pub async fn run(self: Arc<Self>, mut rx: mpsc::Receiver<ChannelRequest>, site: Site) {
let capacity = self.queue_size;
let mut inner = ChannelInner::new(capacity);
let shutdown = site.shutdown_notifier();
let mut batch = Vec::with_capacity(32);
loop {
tokio::select! {
Some(request) = rx.recv() => {
batch.push(request);
while let Ok(req) = rx.try_recv() {
batch.push(req);
if batch.len() >= 32 {
break;
}
}
for request in batch.drain(..) {
if let Err(response) = self.handle_request(&mut inner, request) {
tracing::error!("Failed to handle channel request: {:?}", response);
}
}
},
_ = shutdown.notified() => {
tracing::info!("Channel service shutting down due to site shutdown");
break;
},
else => {
tracing::info!("Channel service receiver closed, shutting down");
break;
}
}
}
}
}
struct ChannelInner {
capacity: usize,
subscriptions: HashMap<uuid::Uuid, LocalSubscription>,
topic_map: HashMap<TopicId, Vec<uuid::Uuid>>,
topic_interner: HashMap<String, TopicId>,
user_index: HashMap<String, Vec<uuid::Uuid>>,
channel_index: HashMap<String, Vec<uuid::Uuid>>,
dead_subscriptions: Vec<uuid::Uuid>,
batch_buffer: Vec<uuid::Uuid>,
next_topic_id: u32,
}
impl ChannelInner {
fn new(capacity: usize) -> Self {
Self {
capacity,
topic_map: HashMap::new(),
topic_interner: HashMap::new(),
user_index: HashMap::new(),
channel_index: HashMap::new(),
subscriptions: HashMap::new(),
dead_subscriptions: Vec::with_capacity(16),
batch_buffer: Vec::with_capacity(32),
next_topic_id: 0,
}
}
fn intern_topic(&mut self, topic: &str) -> TopicId {
if let Some(id) = self.topic_interner.get(topic) {
return *id;
}
let id = TopicId(self.next_topic_id);
self.next_topic_id += 1;
self.topic_interner.insert(topic.to_string(), id);
id
}
fn publish(&mut self, topic: String, message: String, user_keys: Option<Vec<String>>) -> usize {
let topic_id = self.intern_topic(&topic);
let event = Arc::new(ChannelEvent {
topic_id,
topic,
message,
user_keys,
});
let mut delivered = 0;
if let Some(sub_ids) = self.topic_map.get(&topic_id) {
for &sub_id in sub_ids {
let sub = match self.subscriptions.get(&sub_id) {
Some(sub) => {
if let Some(ref keys) = event.user_keys {
if keys.iter().all(|u| u != &sub.user_key){
continue;
}
}
sub
}
None => continue,
};
if let Err(err) = sub.sender.try_send(event.clone()) {
match err {
mpsc::error::TrySendError::Full(_) => {
tracing::warn!("Subscriber {} channel full, dropping event", sub.id);
continue;
}
mpsc::error::TrySendError::Closed(_) => {
self.dead_subscriptions.push(sub.id);
}
}
} else {
delivered += 1;
}
}
}
if !self.dead_subscriptions.is_empty() {
Self::remove_subscriptions_from(
&mut self.subscriptions,
&mut self.topic_map,
&mut self.user_index,
&mut self.channel_index,
&mut self.dead_subscriptions,
);
}
delivered
}
fn subscribe(
&mut self,
user_key: String,
channel_key: String,
topics: Vec<String>,
) -> (uuid::Uuid, mpsc::Receiver<Arc<ChannelEvent>>) {
let (tx, rx) = mpsc::channel::<Arc<ChannelEvent>>(self.capacity);
let sub_id = uuid::Uuid::now_v7();
let topic_ids: Vec<TopicId> = topics.iter().map(|t| self.intern_topic(t)).collect();
let subscription = LocalSubscription {
id: sub_id,
user_key: user_key.clone(),
channel_key: channel_key.clone(),
sender: tx,
topics: topic_ids.clone(),
};
for &topic_id in &topic_ids {
self.topic_map
.entry(topic_id)
.or_insert_with(|| Vec::with_capacity(4))
.push(sub_id);
}
self.user_index
.entry(user_key)
.or_insert_with(|| Vec::with_capacity(4))
.push(sub_id);
self.channel_index
.entry(channel_key)
.or_insert_with(|| Vec::with_capacity(4))
.push(sub_id);
self.subscriptions.insert(sub_id, subscription);
(sub_id, rx)
}
fn unsubscribe(&mut self, sub_id: uuid::Uuid) -> usize {
if let Some(sub) = self.subscriptions.remove(&sub_id) {
Self::remove_from_topics(&mut self.topic_map, sub_id, &sub.topics);
Self::remove_from_index(&mut self.user_index, &sub.user_key, sub_id);
Self::remove_from_index(&mut self.channel_index, &sub.channel_key, sub_id);
1
} else {
0
}
}
fn unsubscribe_by_user(&mut self, user_key: &str) -> usize {
self.batch_buffer.clear();
if let Some(ids) = self.user_index.remove(user_key) {
self.batch_buffer.extend(ids);
}
Self::remove_subscriptions_from(
&mut self.subscriptions,
&mut self.topic_map,
&mut self.user_index,
&mut self.channel_index,
&mut self.batch_buffer,
)
}
fn unsubscribe_by_channel(&mut self, channel_key: &str) -> usize {
self.batch_buffer.clear();
if let Some(ids) = self.channel_index.remove(channel_key) {
self.batch_buffer.extend(ids);
}
Self::remove_subscriptions_from(
&mut self.subscriptions,
&mut self.topic_map,
&mut self.user_index,
&mut self.channel_index,
&mut self.batch_buffer,
)
}
fn remove_from_topics(
topic_map: &mut HashMap<TopicId, Vec<uuid::Uuid>>,
sub_id: uuid::Uuid,
topics: &[TopicId],
) {
for &topic_id in topics {
if let Some(vec) = topic_map.get_mut(&topic_id) {
if let Some(pos) = vec.iter().position(|&id| id == sub_id) {
vec.swap_remove(pos);
}
if vec.is_empty() {
topic_map.remove(&topic_id);
}
}
}
}
fn remove_from_index(
index: &mut HashMap<String, Vec<uuid::Uuid>>,
key: &str,
sub_id: uuid::Uuid,
) {
if let Some(vec) = index.get_mut(key) {
if let Some(pos) = vec.iter().position(|&id| id == sub_id) {
vec.swap_remove(pos);
}
if vec.is_empty() {
index.remove(key);
}
}
}
fn remove_subscriptions_from(
subscriptions: &mut HashMap<uuid::Uuid, LocalSubscription>,
topic_map: &mut HashMap<TopicId, Vec<uuid::Uuid>>,
user_index: &mut HashMap<String, Vec<uuid::Uuid>>,
channel_index: &mut HashMap<String, Vec<uuid::Uuid>>,
buffer: &mut Vec<uuid::Uuid>,
) -> usize {
let count = buffer.len();
for sub_id in buffer.drain(..) {
if let Some(sub) = subscriptions.remove(&sub_id) {
Self::remove_from_topics(topic_map, sub_id, &sub.topics);
Self::remove_from_index(user_index, &sub.user_key, sub_id);
Self::remove_from_index(channel_index, &sub.channel_key, sub_id);
}
}
count
}
}
pub struct BeaconHandler{
handler: signals::Signaller
}
pub struct BeaconRegistry {
handlers: Vec<BeaconHandler>,
}
impl BeaconRegistry{
}
pub struct BeaconEngine{
registry: BeaconRegistry,
sender: mpsc::Sender<ChannelEvent>,
service: Arc<LocalChannelService>,
}
impl BeaconEngine {
fn handle_signal<P: callables::Payloadable>(&self, item: &P) {
let data = serde_json::to_string(&item).unwrap_or_default();
let topic = P::schema_name();
let event = ChannelEvent {
topic_id: TopicId(0), topic: topic.to_string(),
message: data,
user_keys: None,
};
if let Err(err) = self.sender.try_send(event) {
tracing::error!("Failed to send beacon event: {}", err);
}
}
fn register_signal_handlers(&mut self, signal_registry: &mut signals::SignalRegistry) {
for handler in self.registry.handlers.drain(..) {
signal_registry.register(handler.handler);
}
}
pub fn run(self){
}
}