#![doc(
html_logo_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png"
)]
#![doc(
html_favicon_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png"
)]
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
use std::{
io::{self, ErrorKind, Read, Write},
mem,
};
use futures::{
stream::{SplitSink, SplitStream},
Sink, Stream, StreamExt,
};
use jetstream_wireformat::WireFormat;
pub use tokio_util::codec::{Decoder, Encoder, Framed};
#[cfg(not(target_arch = "wasm32"))]
pub trait Message: WireFormat + Sync {}
#[cfg(target_arch = "wasm32")]
pub trait Message: WireFormat {}
#[repr(transparent)]
pub struct Tag(u16);
impl From<u16> for Tag {
fn from(tag: u16) -> Self {
Self(tag)
}
}
pub struct Context<T: WireFormat> {
pub tag: Tag,
pub msg: T,
}
pub trait FromContext<T: WireFormat> {
fn from_context(ctx: Context<T>) -> Self;
}
impl<T: WireFormat> FromContext<T> for T {
fn from_context(ctx: Context<T>) -> Self {
ctx.msg
}
}
impl<T: WireFormat> FromContext<T> for Tag {
fn from_context(ctx: Context<T>) -> Self {
ctx.tag
}
}
pub trait Handler<T: WireFormat> {
fn call(self, context: Context<T>);
}
#[trait_variant::make(Send + Sync + Sized)]
pub trait Protocol: Send + Sync {
type Request: Framer;
type Response: Framer;
type Error: std::error::Error + Send + Sync + 'static;
const VERSION: &'static str;
async fn rpc(
&mut self,
frame: Frame<Self::Request>,
) -> Result<Frame<Self::Response>, Self::Error>;
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("io error: {0}")]
Io(#[from] io::Error),
#[error("generic error: {0}")]
Generic(#[from] Box<dyn std::error::Error + Send + Sync>),
#[error("{0}")]
Custom(String),
#[error("invalid response")]
InvalidResponse,
}
pub struct Frame<T: Framer> {
pub tag: u16,
pub msg: T,
}
impl<T: Framer> From<(u16, T)> for Frame<T> {
fn from((tag, msg): (u16, T)) -> Self {
Self { tag, msg }
}
}
impl<T: Framer> WireFormat for Frame<T> {
fn byte_size(&self) -> u32 {
let msg_size = self.msg.byte_size();
(mem::size_of::<u32>() + mem::size_of::<u8>() + mem::size_of::<u16>())
as u32
+ msg_size
}
fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
self.byte_size().encode(writer)?;
let ty = self.msg.message_type();
ty.encode(writer)?;
self.tag.encode(writer)?;
self.msg.encode(writer)?;
Ok(())
}
fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
let byte_size: u32 = WireFormat::decode(reader)?;
if byte_size < mem::size_of::<u32>() as u32 {
return Err(io::Error::new(
ErrorKind::InvalidData,
format!("byte_size(= {}) is less than 4 bytes", byte_size),
));
}
let reader =
&mut reader.take((byte_size - mem::size_of::<u32>() as u32) as u64);
let mut ty = [0u8];
reader.read_exact(&mut ty)?;
let tag: u16 = WireFormat::decode(reader)?;
let msg = T::decode(reader, ty[0])?;
Ok(Frame { tag, msg })
}
}
pub trait Framer: Sized + Send + Sync {
fn message_type(&self) -> u8;
fn byte_size(&self) -> u32;
fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()>;
fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Self>;
}
pub trait ServiceTransport<P: Protocol>:
Sink<Frame<P::Response>, Error = P::Error>
+ Stream<Item = Result<Frame<P::Request>, P::Error>>
+ Send
+ Sync
+ Unpin
{
}
impl<P: Protocol, T> ServiceTransport<P> for T where
T: Sink<Frame<P::Response>, Error = P::Error>
+ Stream<Item = Result<Frame<P::Request>, P::Error>>
+ Send
+ Sync
+ Unpin
{
}
pub trait ClientTransport<P: Protocol>:
Sink<Frame<P::Request>, Error = std::io::Error>
+ Stream<Item = Result<Frame<P::Response>, std::io::Error>>
+ Send
+ Sync
+ Unpin
{
}
impl<P: Protocol, T> ClientTransport<P> for T
where
Self: Sized,
T: Sink<Frame<P::Request>, Error = std::io::Error>
+ Stream<Item = Result<Frame<P::Response>, std::io::Error>>
+ Send
+ Sync
+ Unpin,
{
}
pub trait Channel<P: Protocol>: Unpin + Sized {
fn split(self) -> (SplitSink<Self, Frame<P::Request>>, SplitStream<Self>);
}
impl<P, T> Channel<P> for T
where
P: Protocol,
T: ClientTransport<P> + Unpin + Sized,
{
fn split(
self,
) -> (
SplitSink<Self, Frame<<P as Protocol>::Request>>,
SplitStream<Self>,
) {
StreamExt::split(self)
}
}