use std::net::{SocketAddr, TcpStream};
use std::sync::mpsc;
use comp_cat_rs::effect::io::Io;
use crate::codec;
use crate::error::Error;
use crate::protocol::Envelope;
pub trait Transport: Sized + Send + 'static {
fn send(self, envelope: Envelope) -> Io<Error, Self>;
fn recv(self) -> Io<Error, (Envelope, Self)>;
}
pub struct TcpTransport {
stream: TcpStream,
}
impl TcpTransport {
#[must_use]
pub fn connect(addr: SocketAddr) -> Io<Error, Self> {
Io::suspend(move || {
TcpStream::connect(addr)
.map(|stream| Self { stream })
.map_err(Error::from)
})
}
#[must_use]
pub fn from_stream(stream: TcpStream) -> Self {
Self { stream }
}
}
impl Transport for TcpTransport {
fn send(self, envelope: Envelope) -> Io<Error, Self> {
Io::suspend(move || {
let mut writer: &TcpStream = &self.stream;
codec::encode(&mut writer, &envelope)?;
Ok(self)
})
}
fn recv(self) -> Io<Error, (Envelope, Self)> {
Io::suspend(move || {
let mut reader: &TcpStream = &self.stream;
let envelope: Envelope = codec::decode(&mut reader)?;
Ok((envelope, self))
})
}
}
pub struct ChannelTransport {
sender: mpsc::Sender<Vec<u8>>,
receiver: mpsc::Receiver<Vec<u8>>,
}
impl ChannelTransport {
#[must_use]
pub fn pair() -> (Self, Self) {
let (tx1, rx1) = mpsc::channel();
let (tx2, rx2) = mpsc::channel();
(
Self {
sender: tx1,
receiver: rx2,
},
Self {
sender: tx2,
receiver: rx1,
},
)
}
}
impl Transport for ChannelTransport {
fn send(self, envelope: Envelope) -> Io<Error, Self> {
Io::suspend(move || {
let mut buf = Vec::new();
codec::encode(&mut buf, &envelope)?;
self.sender
.send(buf)
.map_err(|_| Error::ConnectionClosed)?;
Ok(self)
})
}
fn recv(self) -> Io<Error, (Envelope, Self)> {
Io::suspend(move || {
let buf = self
.receiver
.recv()
.map_err(|_| Error::ConnectionClosed)?;
let mut cursor = std::io::Cursor::new(buf);
let envelope: Envelope = codec::decode(&mut cursor)?;
Ok((envelope, self))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::RequestId;
#[test]
fn channel_transport_round_trip() -> Result<(), Error> {
let (a, b) = ChannelTransport::pair();
let envelope = Envelope::Request {
id: RequestId::new(7),
payload: r#""hello""#.to_owned(),
};
let a = a.send(envelope).run()?;
let (received, b) = b.recv().run()?;
match received {
Envelope::Request { id, payload } => {
assert_eq!(id.value(), 7);
assert_eq!(payload, r#""hello""#);
}
_ => return Err(Error::Server {
message: "wrong variant".to_owned(),
}),
}
let reply = Envelope::Response {
id: RequestId::new(7),
payload: r#""world""#.to_owned(),
};
let _b = b.send(reply).run()?;
let (received, _a) = a.recv().run()?;
match received {
Envelope::Response { id, payload } => {
assert_eq!(id.value(), 7);
assert_eq!(payload, r#""world""#);
}
_ => return Err(Error::Server {
message: "wrong variant".to_owned(),
}),
}
Ok(())
}
}