use std::collections::BTreeMap;
use std::collections::btree_map::Entry;
use tokio::sync::mpsc;
use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
use crate::error::private::InnerError;
use crate::peer::Command;
use crate::{
Error,
Message,
MessageType,
ReceivedMessage,
ReceivedRequestHandle,
SentRequestHandle,
};
use crate::request::RequestHandleCommand;
struct TrackedRequest<Body> {
incoming_tx: mpsc::UnboundedSender<RequestHandleCommand<Body>>,
closed: Arc<AtomicBool>,
}
pub struct RequestTracker<Body> {
next_sent_request_id: u32,
command_tx: mpsc::UnboundedSender<Command<Body>>,
sent_requests: BTreeMap<u32, TrackedRequest<Body>>,
received_requests: BTreeMap<u32, TrackedRequest<Body>>,
}
impl<Body> RequestTracker<Body> {
pub fn new(command_tx: mpsc::UnboundedSender<Command<Body>>) -> Self {
Self {
next_sent_request_id: 0,
command_tx,
sent_requests: BTreeMap::new(),
received_requests: BTreeMap::new(),
}
}
pub fn allocate_sent_request(&mut self, service_id: i32) -> Result<SentRequestHandle<Body>, Error> {
for _ in 0..100 {
let request_id = self.next_sent_request_id;
self.next_sent_request_id = self.next_sent_request_id.wrapping_add(1);
if let Entry::Vacant(entry) = self.sent_requests.entry(request_id) {
let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
let closed = Arc::new(AtomicBool::new(false));
let tracked_request = TrackedRequest {
incoming_tx,
closed: closed.clone(),
};
entry.insert(tracked_request);
return Ok(SentRequestHandle::new(request_id, service_id, closed, incoming_rx, self.command_tx.clone()));
}
}
Err(InnerError::NoFreeRequestIdFound.into())
}
pub fn remove_sent_request(&mut self, request_id: u32) -> Result<(), Error> {
let tracked_request = self.sent_requests.remove(&request_id).ok_or(InnerError::UnknownRequestId { request_id })?;
tracked_request.closed.store(true, Ordering::Release);
let _: Result<_, _> = tracked_request.incoming_tx.send(RequestHandleCommand::Close);
Ok(())
}
pub fn register_received_request(
&mut self,
request_id: u32,
service_id: i32,
body: Body,
) -> Result<(ReceivedRequestHandle<Body>, Body), Error> {
match self.received_requests.entry(request_id) {
Entry::Occupied(_entry) => {
Err(InnerError::DuplicateRequestId { request_id }.into())
},
Entry::Vacant(entry) => {
let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
let closed = Arc::new(AtomicBool::new(false));
let tracked_request = TrackedRequest {
incoming_tx,
closed: closed.clone(),
};
entry.insert(tracked_request);
Ok((ReceivedRequestHandle::new(request_id, service_id, closed, incoming_rx, self.command_tx.clone()), body))
},
}
}
pub fn remove_received_request(&mut self, request_id: u32) -> Result<(), Error> {
let tracked_request = self.received_requests.remove(&request_id).ok_or(InnerError::UnknownRequestId { request_id })?;
tracked_request.closed.store(true, Ordering::Release);
let _: Result<_, _> = tracked_request.incoming_tx.send(RequestHandleCommand::Close);
Ok(())
}
pub async fn process_incoming_message(&mut self, message: Message<Body>) -> Result<Option<ReceivedMessage<Body>>, Error> {
match message.header.message_type {
MessageType::Request => {
let (received_request, body) = self.register_received_request(message.header.request_id, message.header.service_id, message.body)?;
Ok(Some(ReceivedMessage::Request(received_request, body)))
},
MessageType::Response => {
self.process_incoming_response(message).await?;
Ok(None)
},
MessageType::RequesterUpdate => {
self.process_incoming_requester_update(message).await?;
Ok(None)
},
MessageType::ResponderUpdate => {
self.process_incoming_responder_update(message).await?;
Ok(None)
},
MessageType::Stream => Ok(Some(ReceivedMessage::Stream(message))),
}
}
async fn process_incoming_response(&mut self, message: Message<Body>) -> Result<(), Error> {
let request_id = message.header.request_id;
match self.sent_requests.entry(request_id) {
Entry::Vacant(_) => Err(InnerError::UnknownRequestId { request_id }.into()),
Entry::Occupied(entry) => {
let tracked_request = entry.remove();
let _: Result<_, _> = tracked_request.incoming_tx.send(RequestHandleCommand::Message(message));
tracked_request.closed.store(true, Ordering::Release);
let _: Result<_, _> = tracked_request.incoming_tx.send(RequestHandleCommand::Close);
Ok(())
},
}
}
async fn process_incoming_requester_update(&mut self, message: Message<Body>) -> Result<(), Error> {
let request_id = message.header.request_id;
match self.received_requests.entry(request_id) {
Entry::Vacant(_) => Err(InnerError::UnknownRequestId { request_id }.into()),
Entry::Occupied(mut entry) => {
if entry.get_mut().incoming_tx.send(RequestHandleCommand::Message(message)).is_err() {
entry.remove();
Err(InnerError::UnknownRequestId { request_id }.into())
} else {
Ok(())
}
},
}
}
async fn process_incoming_responder_update(&mut self, message: Message<Body>) -> Result<(), Error> {
let request_id = message.header.request_id;
match self.sent_requests.entry(request_id) {
Entry::Vacant(_) => Err(InnerError::UnknownRequestId { request_id }.into()),
Entry::Occupied(mut entry) => {
if entry.get_mut().incoming_tx.send(RequestHandleCommand::Message(message)).is_err() {
entry.remove();
Err(InnerError::UnknownRequestId { request_id }.into())
} else {
Ok(())
}
},
}
}
}
#[cfg(test)]
mod test {
use assert2::assert;
use assert2::let_assert;
use super::*;
use crate::MessageHeader;
struct Body;
impl crate::Body for Body {
fn empty() -> Self {
Self
}
fn from_error(_message: &str) -> Self {
Self
}
fn as_error(&self) -> Result<&str, std::str::Utf8Error> {
Ok("")
}
fn into_error(self) -> Result<String, std::string::FromUtf8Error> {
Ok(String::new())
}
}
#[tokio::test]
async fn test_incoming_request() {
let (command_tx, mut command_rx) = mpsc::unbounded_channel();
let mut tracker = RequestTracker::new(command_tx);
let command_task = tokio::spawn(async move {
let_assert!(Some(Command::SendRawMessage(command)) = command_rx.recv().await);
assert!(command.message.header == MessageHeader::responder_update(1, 3));
assert!(let Ok(()) = command.result_tx.send(Ok(())));
let_assert!(Some(Command::SendRawMessage(command)) = command_rx.recv().await);
assert!(command.message.header == MessageHeader::response(1, 4));
assert!(let Ok(()) = command.result_tx.send(Ok(())));
assert!(let None = command_rx.recv().await);
});
let_assert!(Ok(Some(ReceivedMessage::Request(mut received_request, _body))) = tracker.process_incoming_message(Message::request(1, 2, Body)).await);
assert!(let Ok(None) = tracker.process_incoming_message(Message::requester_update(1, 10, Body)).await);
let_assert!(Some(update) = received_request.recv_update().await);
assert!(update.header == MessageHeader::requester_update(1, 10));
let_assert!(Ok(()) = received_request.send_update(3, Body).await);
let_assert!(Ok(()) = received_request.send_response(4, Body).await);
let_assert!(Ok(()) = tracker.remove_received_request(received_request.request_id()));
assert!(let Err(_) = tracker.process_incoming_message(Message::requester_update(1, 11, Body)).await);
drop(received_request);
drop(tracker);
assert!(let Ok(()) = command_task.await);
}
#[tokio::test]
async fn test_outgoing_request() {
let (command_tx, mut command_rx) = mpsc::unbounded_channel();
let mut tracker = RequestTracker::new(command_tx);
let_assert!(Ok(mut sent_request) = tracker.allocate_sent_request(3));
let request_id = sent_request.request_id();
let command_task = tokio::spawn(async move {
let_assert!(Some(Command::SendRawMessage(command)) = command_rx.recv().await);
assert!(command.message.header == MessageHeader::requester_update(request_id, 13));
assert!(let Ok(()) = command.result_tx.send(Ok(())));
assert!(let None = command_rx.recv().await);
});
assert!(let Ok(None) = tracker.process_incoming_message(Message::responder_update(sent_request.request_id(), 12, Body)).await);
let_assert!(Some(update) = sent_request.recv_update().await);
assert!(update.header == MessageHeader::responder_update(sent_request.request_id(), 12));
let_assert!(Ok(()) = sent_request.send_update(13, Body).await);
assert!(let Ok(None) = tracker.process_incoming_message(Message::response(sent_request.request_id(), 14, Body)).await);
let_assert!(Ok(update) = sent_request.recv_response().await);
assert!(update.header == MessageHeader::response(sent_request.request_id(), 14));
assert!(let Err(_) = tracker
.process_incoming_message(Message::responder_update(sent_request.request_id(), 15, Body))
.await
);
drop(tracker);
drop(sent_request);
assert!(let Ok(()) = command_task.await);
}
}