use anyhow::Result;
use async_trait::async_trait;
use oxios_gateway::channel::Channel;
use oxios_gateway::message::{IncomingMessage, OutgoingMessage};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, oneshot, Mutex, RwLock};
pub struct WebChannel {
incoming_rx: Mutex<mpsc::Receiver<IncomingMessage>>,
incoming_tx: mpsc::Sender<IncomingMessage>,
outgoing_tx: broadcast::Sender<OutgoingMessage>,
responses: Arc<RwLock<HashMap<uuid::Uuid, oneshot::Sender<OutgoingMessage>>>>,
}
impl WebChannel {
pub fn new(buffer: usize) -> Self {
let (incoming_tx, incoming_rx) = mpsc::channel(buffer);
let (outgoing_tx, _) = broadcast::channel(buffer);
Self {
incoming_rx: Mutex::new(incoming_rx),
incoming_tx,
outgoing_tx,
responses: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn sender(&self) -> mpsc::Sender<IncomingMessage> {
self.incoming_tx.clone()
}
pub fn subscribe_outgoing(&self) -> broadcast::Receiver<OutgoingMessage> {
self.outgoing_tx.subscribe()
}
pub fn broadcast_outgoing(&self, msg: OutgoingMessage) -> Result<()> {
let _ = self.outgoing_tx.send(msg);
Ok(())
}
pub async fn deliver_response(&self, msg: OutgoingMessage) -> Result<()> {
let msg_id = msg.id;
{
let mut responses = self.responses.write().await;
if let Some(sender) = responses.remove(&msg_id) {
let _ = sender.send(msg.clone());
}
}
let _ = self.outgoing_tx.send(msg);
tracing::debug!(msg_id = %msg_id, "Delivering response");
Ok(())
}
}
#[async_trait]
impl Channel for WebChannel {
fn name(&self) -> &str {
"web"
}
async fn receive(&self) -> Result<Option<IncomingMessage>> {
let mut rx = self.incoming_rx.lock().await;
Ok(rx.recv().await)
}
async fn send(&self, msg: OutgoingMessage) -> Result<()> {
{
let mut responses = self.responses.write().await;
if let Some(sender) = responses.remove(&msg.id) {
let _ = sender.send(msg.clone());
tracing::debug!(msg_id = %msg.id, "Correlated response to HTTP handler");
}
}
let _ = self.outgoing_tx.send(msg);
Ok(())
}
}
impl std::fmt::Debug for WebChannel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebChannel").finish()
}
}
#[derive(Debug, Clone)]
pub struct WebChannelHandle {
pub incoming_tx: mpsc::Sender<IncomingMessage>,
pub outgoing_tx: broadcast::Sender<OutgoingMessage>,
responses: Arc<RwLock<HashMap<uuid::Uuid, oneshot::Sender<OutgoingMessage>>>>,
}
impl WebChannelHandle {
pub fn from_channel(channel: &WebChannel) -> Self {
Self {
incoming_tx: channel.sender(),
outgoing_tx: channel.outgoing_tx.clone(),
responses: channel.responses.clone(),
}
}
pub fn subscribe(&self) -> broadcast::Receiver<OutgoingMessage> {
self.outgoing_tx.subscribe()
}
pub async fn send_incoming(&self, msg: IncomingMessage) -> Result<()> {
self.incoming_tx
.send(msg)
.await
.map_err(|e| anyhow::anyhow!("{e}"))
}
pub async fn send_and_wait(&self, msg: IncomingMessage) -> Result<OutgoingMessage> {
let (tx, rx) = oneshot::channel::<OutgoingMessage>();
let msg_id = msg.id;
{
let mut responses = self.responses.write().await;
responses.insert(msg_id, tx);
}
self.incoming_tx
.send(msg)
.await
.map_err(|e| anyhow::anyhow!("{e}"))?;
rx.await
.map_err(|e| anyhow::anyhow!("Response channel dropped: {e}"))
}
}