use crate::{
error::ChannelError,
i2np::Message,
primitives::{Lease, RouterId, TunnelId},
tunnel::pool::{context::TunnelMessageRecycle, TunnelMessage, TunnelPoolConfig},
};
use futures::Stream;
use futures_channel::oneshot;
use thingbuf::mpsc;
use alloc::vec::Vec;
use core::{
fmt,
pin::Pin,
task::{Context, Poll},
};
#[derive(Default, Debug, Clone)]
pub enum TunnelPoolEvent {
TunnelPoolShutDown,
InboundTunnelBuilt {
tunnel_id: TunnelId,
lease: Lease,
},
OutboundTunnelBuilt {
tunnel_id: TunnelId,
},
InboundTunnelExpired {
tunnel_id: TunnelId,
},
OutboundTunnelExpired {
tunnel_id: TunnelId,
},
#[allow(unused)]
InboundTunnelExpiring {
tunnel_id: TunnelId,
},
#[allow(unused)]
OutboundTunnelExpiring {
tunnel_id: TunnelId,
},
Message {
message: Message,
},
#[default]
Dummy,
}
impl fmt::Display for TunnelPoolEvent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TunnelPoolShutDown => write!(f, "TunnelPoolEvent::TunnelPoolShutDown"),
Self::InboundTunnelBuilt { .. } => write!(f, "TunnelPoolEvent::InboundTunnelBuilt"),
Self::OutboundTunnelBuilt { .. } => write!(f, "TunnelPoolEvent::OutboundTunnelBuilt"),
Self::InboundTunnelExpired { .. } => write!(f, "TunnelPoolEvent::InboundTunnelExpired"),
Self::OutboundTunnelExpired { .. } => {
write!(f, "TunnelPoolEvent::OutboundTunnelExpired")
}
Self::InboundTunnelExpiring { .. } => {
write!(f, "TunnelPoolEvent::InboundTunnelExpiring")
}
Self::OutboundTunnelExpiring { .. } => {
write!(f, "TunnelPoolEvent::OutboundTunnelExpiring")
}
Self::Message { .. } => write!(f, "TunnelPoolEvent::Message"),
Self::Dummy => write!(f, "TunnelPoolEvent::Dummy"),
}
}
}
#[derive(Clone)]
pub struct TunnelMessageSender(mpsc::Sender<TunnelMessage, TunnelMessageRecycle>);
impl TunnelMessageSender {
pub fn send_message(&self, message: Vec<u8>) -> TunnelSender<'_> {
TunnelSender {
kind: None,
message,
outbound_tunnel: None,
tx: &self.0,
}
}
}
enum DeliveryKind {
TunnelDelivery {
tunnel_id: TunnelId,
router_id: RouterId,
},
RouterDelivery {
router_id: RouterId,
},
}
pub struct TunnelSender<'a> {
kind: Option<DeliveryKind>,
message: Vec<u8>,
outbound_tunnel: Option<TunnelId>,
tx: &'a mpsc::Sender<TunnelMessage, TunnelMessageRecycle>,
}
impl TunnelSender<'_> {
pub fn router_delivery(mut self, router_id: RouterId) -> Self {
self.kind = Some(DeliveryKind::RouterDelivery { router_id });
self
}
pub fn tunnel_delivery(mut self, router_id: RouterId, tunnel_id: TunnelId) -> Self {
self.kind = Some(DeliveryKind::TunnelDelivery {
tunnel_id,
router_id,
});
self
}
pub fn via_outbound_tunnel(mut self, tunnel_id: TunnelId) -> Self {
self.outbound_tunnel = Some(tunnel_id);
self
}
pub fn try_send(self) -> Result<(), ChannelError> {
let message = match self.kind.expect("to exist") {
DeliveryKind::TunnelDelivery {
tunnel_id,
router_id,
} => TunnelMessage::TunnelDeliveryViaRoute {
router_id,
tunnel_id,
outbound_tunnel: self.outbound_tunnel,
message: self.message,
},
DeliveryKind::RouterDelivery { router_id } => TunnelMessage::RouterDeliveryViaRoute {
router_id,
outbound_tunnel: self.outbound_tunnel,
message: self.message,
},
};
self.tx.try_send(message).map_err(From::from)
}
#[allow(unused)]
pub async fn send(self) -> Result<(), ChannelError> {
let message = match self.kind.expect("to exist") {
DeliveryKind::TunnelDelivery {
tunnel_id,
router_id,
} => TunnelMessage::TunnelDeliveryViaRoute {
router_id,
tunnel_id,
outbound_tunnel: self.outbound_tunnel,
message: self.message,
},
DeliveryKind::RouterDelivery { router_id } => TunnelMessage::RouterDeliveryViaRoute {
router_id,
outbound_tunnel: self.outbound_tunnel,
message: self.message,
},
};
self.tx.send(message).await.map_err(|_| ChannelError::Closed)
}
}
pub struct TunnelPoolHandle {
config: TunnelPoolConfig,
event_rx: mpsc::Receiver<TunnelPoolEvent>,
sender: TunnelMessageSender,
#[allow(unused)]
shutdown_tx: Option<oneshot::Sender<()>>,
}
impl TunnelPoolHandle {
pub(super) fn new(
config: TunnelPoolConfig,
message_tx: mpsc::Sender<TunnelMessage, TunnelMessageRecycle>,
) -> (Self, mpsc::Sender<TunnelPoolEvent>, oneshot::Receiver<()>) {
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (event_tx, event_rx) = mpsc::channel(64);
(
Self {
config,
event_rx,
sender: TunnelMessageSender(message_tx),
shutdown_tx: Some(shutdown_tx),
},
event_tx,
shutdown_rx,
)
}
pub fn shutdown(&mut self) {
self.shutdown_tx.take().map(|tx| tx.send(()));
}
pub fn config(&self) -> &TunnelPoolConfig {
&self.config
}
pub fn send_message(&self, message: Vec<u8>) -> TunnelSender<'_> {
self.sender.send_message(message)
}
pub fn sender(&self) -> TunnelMessageSender {
self.sender.clone()
}
#[cfg(test)]
pub fn create() -> (
Self,
mpsc::Receiver<TunnelMessage, TunnelMessageRecycle>,
mpsc::Sender<TunnelPoolEvent>,
oneshot::Receiver<()>,
) {
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (event_tx, event_rx) = mpsc::channel(64);
let (message_tx, message_rx) = mpsc::with_recycle(64, TunnelMessageRecycle::default());
(
Self {
config: Default::default(),
event_rx,
sender: TunnelMessageSender(message_tx),
shutdown_tx: Some(shutdown_tx),
},
message_rx,
event_tx,
shutdown_rx,
)
}
#[cfg(test)]
pub fn from_config(
config: TunnelPoolConfig,
) -> (
Self,
mpsc::Receiver<TunnelMessage, TunnelMessageRecycle>,
mpsc::Sender<TunnelPoolEvent>,
oneshot::Receiver<()>,
) {
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (event_tx, event_rx) = mpsc::channel(64);
let (message_tx, message_rx) = mpsc::with_recycle(64, TunnelMessageRecycle::default());
(
Self {
config,
event_rx,
sender: TunnelMessageSender(message_tx),
shutdown_tx: Some(shutdown_tx),
},
message_rx,
event_tx,
shutdown_rx,
)
}
}
impl Stream for TunnelPoolHandle {
type Item = TunnelPoolEvent;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.event_rx.poll_recv(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn send_to_router_via_any() {
let (tx, rx) = mpsc::with_recycle(64, TunnelMessageRecycle::default());
let sender = TunnelMessageSender(tx);
let remote = RouterId::random();
sender
.send_message(vec![1, 3, 3, 7])
.router_delivery(remote.clone())
.send()
.await
.unwrap();
match rx.recv().await.unwrap() {
TunnelMessage::RouterDeliveryViaRoute {
router_id,
outbound_tunnel,
message,
} => {
assert_eq!(router_id, remote);
assert_eq!(message, vec![1, 3, 3, 7]);
assert!(outbound_tunnel.is_none());
}
_ => panic!("invalid message"),
}
}
#[tokio::test]
async fn send_to_tunnel_via_any() {
let (tx, rx) = mpsc::with_recycle(64, TunnelMessageRecycle::default());
let sender = TunnelMessageSender(tx);
let remote_router = RouterId::random();
let remote_tunnel = TunnelId::random();
sender
.send_message(vec![1, 3, 3, 7])
.tunnel_delivery(remote_router.clone(), remote_tunnel)
.send()
.await
.unwrap();
match rx.recv().await.unwrap() {
TunnelMessage::TunnelDeliveryViaRoute {
router_id,
tunnel_id,
outbound_tunnel,
message,
} => {
assert_eq!(router_id, remote_router);
assert_eq!(tunnel_id, remote_tunnel);
assert_eq!(message, vec![1, 3, 3, 7]);
assert!(outbound_tunnel.is_none());
}
_ => panic!("invalid message"),
}
}
#[tokio::test]
async fn send_to_router_via_route() {
let (tx, rx) = mpsc::with_recycle(64, TunnelMessageRecycle::default());
let sender = TunnelMessageSender(tx);
let remote = RouterId::random();
let obgw = TunnelId::random();
sender
.send_message(vec![1, 3, 3, 7])
.router_delivery(remote.clone())
.via_outbound_tunnel(obgw)
.send()
.await
.unwrap();
match rx.recv().await.unwrap() {
TunnelMessage::RouterDeliveryViaRoute {
router_id,
outbound_tunnel,
message,
} => {
assert_eq!(router_id, remote);
assert_eq!(message, vec![1, 3, 3, 7]);
assert_eq!(outbound_tunnel, Some(obgw));
}
_ => panic!("invalid message"),
}
}
#[tokio::test]
async fn send_to_tunnel_via_route() {
let (tx, rx) = mpsc::with_recycle(64, TunnelMessageRecycle::default());
let sender = TunnelMessageSender(tx);
let remote_router = RouterId::random();
let remote_tunnel = TunnelId::random();
let obgw = TunnelId::random();
sender
.send_message(vec![1, 3, 3, 7])
.tunnel_delivery(remote_router.clone(), remote_tunnel)
.via_outbound_tunnel(obgw)
.send()
.await
.unwrap();
match rx.recv().await.unwrap() {
TunnelMessage::TunnelDeliveryViaRoute {
router_id,
tunnel_id,
outbound_tunnel,
message,
} => {
assert_eq!(router_id, remote_router);
assert_eq!(tunnel_id, remote_tunnel);
assert_eq!(message, vec![1, 3, 3, 7]);
assert_eq!(outbound_tunnel, Some(obgw));
}
_ => panic!("invalid message"),
}
}
}