use anyhow::{Result, anyhow};
use async_zmq::{Context, Dealer, Router, Sink, SinkExt, StreamExt};
use bytes::Bytes;
use derive_getters::Dissolve;
use futures::TryStreamExt;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, os::fd::FromRawFd, sync::Arc, time::Duration, vec::IntoIter};
use tokio::{
sync::{Mutex, mpsc},
task::{JoinError, JoinHandle},
};
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Serialize, Deserialize)]
enum ControlMessage {
Cancel { request_id: String },
CancelAck { request_id: String },
Error { request_id: String, error: String },
Complete { request_id: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum MessageType {
Data(Vec<u8>),
Control(ControlMessage),
}
enum StreamAction {
SendEager(usize),
SendDelayed(usize),
Close,
}
struct RouterState {
active_streams: HashMap<String, mpsc::Sender<Bytes>>,
control_channels: HashMap<String, mpsc::Sender<ControlMessage>>,
}
impl RouterState {
fn new() -> Self {
Self {
active_streams: HashMap::new(),
control_channels: HashMap::new(),
}
}
fn register_stream(
&mut self,
request_id: String,
data_tx: mpsc::Sender<Bytes>,
control_tx: mpsc::Sender<ControlMessage>,
) {
self.active_streams.insert(request_id.clone(), data_tx);
self.control_channels.insert(request_id, control_tx);
}
fn remove_stream(&mut self, request_id: &str) {
self.active_streams.remove(request_id);
self.control_channels.remove(request_id);
}
}
#[derive(Clone, Dissolve)]
pub struct Server {
state: Arc<Mutex<RouterState>>,
cancel_token: CancellationToken,
fd: i32,
}
impl Server {
pub async fn new(
context: &Context,
address: &str,
cancel_token: CancellationToken,
) -> Result<(Self, ServerExecutionHandle)> {
let router = async_zmq::router(address)?.with_context(context).bind()?;
let fd = router.as_raw_socket().get_fd()?;
let state = Arc::new(Mutex::new(RouterState::new()));
let child = cancel_token.child_token();
let primary_task = tokio::spawn(Self::run(router, state.clone(), child.child_token()));
let watch_task = tokio::spawn(async move {
let result = primary_task.await.inspect_err(|e| {
tracing::error!("zmq server/router task failed: {}", e);
cancel_token.cancel();
})?;
result.inspect_err(|e| {
tracing::error!("zmq server/router task failed: {}", e);
cancel_token.cancel();
})
});
let handle = ServerExecutionHandle {
task: watch_task,
cancel_token: child.clone(),
};
Ok((
Self {
state,
cancel_token: child,
fd,
},
handle,
))
}
async fn run(
router: Router<IntoIter<Vec<u8>>, Vec<u8>>,
state: Arc<Mutex<RouterState>>,
token: CancellationToken,
) -> Result<()> {
let mut router = router;
loop {
let frames = tokio::select! {
biased;
frames = router.next() => {
match frames {
Some(Ok(frames)) => {
frames
},
Some(Err(e)) => {
tracing::warn!("Error receiving message: {}", e);
continue;
}
None => break,
}
}
_ = token.cancelled() => {
tracing::info!("Server shutting down");
break;
}
};
if frames.len() != 3 {
anyhow::bail!(
"Fatal Error -- Broken contract -- Expected 3 frames, got {}",
frames.len()
);
}
let request_id = String::from_utf8_lossy(&frames[1]).to_string();
let message = frames[2].to_vec();
let message_size = message.len();
if let Some(tx) = state.lock().await.active_streams.get(&request_id) {
let action = match tx.try_send(message.into()) {
Ok(_) => {
tracing::trace!(
request_id,
"response data sent eagerly to stream: {} bytes",
message_size
);
StreamAction::SendEager(message_size)
}
Err(e) => match e {
mpsc::error::TrySendError::Closed(_) => {
tracing::info!(request_id, "response stream was closed");
StreamAction::Close
}
mpsc::error::TrySendError::Full(data) => {
tracing::warn!(
request_id,
"response stream is full; backpressure alert"
);
if (tx.send(data).await).is_err() {
StreamAction::Close
} else {
StreamAction::SendDelayed(message_size)
}
}
},
};
match action {
StreamAction::SendEager(_size) => {
}
StreamAction::SendDelayed(_size) => {
}
StreamAction::Close => {
state.lock().await.active_streams.remove(&request_id);
}
}
} else {
tracing::trace!(request_id, "no active stream for request_id");
}
}
Ok(())
}
}
pub struct ServerExecutionHandle {
task: JoinHandle<Result<()>>,
cancel_token: CancellationToken,
}
impl ServerExecutionHandle {
pub fn is_finished(&self) -> bool {
self.task.is_finished()
}
pub fn is_cancelled(&self) -> bool {
self.cancel_token.is_cancelled()
}
pub fn cancel(&self) {
self.cancel_token.cancel();
}
pub async fn join(self) -> Result<()> {
self.task.await?
}
}
pub struct Client {
dealer: Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
}
impl Client {
fn new(context: &Context, address: &str) -> Result<Self> {
let dealer = async_zmq::dealer(address)?
.with_context(context)
.connect()?;
Ok(Self { dealer })
}
fn dealer(&mut self) -> &mut Dealer<IntoIter<Vec<u8>>, Vec<u8>> {
&mut self.dealer
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::timeout;
#[tokio::test]
async fn test_basic_communication() -> Result<()> {
let context = Context::new();
let address = "tcp://127.0.0.1:1337";
let token = CancellationToken::new();
let (server, handle) = Server::new(&context, address, token.clone()).await?;
let state = server.state.clone();
let id = "test-request".to_string();
let (tx, mut rx) = tokio::sync::mpsc::channel(512);
state.lock().await.active_streams.insert(id.clone(), tx);
let mut client = Client::new(&context, address)?;
client
.dealer()
.send(vec![id.as_bytes().to_vec(), id.as_bytes().to_vec()].into())
.await?;
let receive_result = rx.recv().await;
let received = receive_result.unwrap();
let received_str = String::from_utf8_lossy(&received).to_string();
assert_eq!(received_str, "test-request");
client.dealer().close().await?;
handle.cancel();
handle.join().await?;
println!("done");
Ok(())
}
}