1use std::{
2 convert::TryInto,
3 mem::size_of,
4 pin::Pin,
5 sync::{
6 atomic::{AtomicU64, Ordering},
7 Arc,
8 },
9};
10
11use bincode::{deserialize, serialize_into, serialized_size};
12use bytes::{BufMut, Bytes, BytesMut};
13use futures::{channel::mpsc, future::ready, Sink, SinkExt, Stream, StreamExt};
14use serde::{Deserialize, Serialize};
15use tokio::io::{split, AsyncRead, AsyncWrite};
16use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
17
18use crate::error::{Error, Result};
19
20#[repr(transparent)]
21#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
22pub struct Id(u64);
23
24impl Id {
25 pub const NULL: Id = Id(0);
26}
27
28impl std::fmt::Display for Id {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 write!(f, "[{:016X}]", self.0)
31 }
32}
33
34#[derive(Clone)]
35pub struct IdGenerator(Arc<AtomicU64>);
36
37impl IdGenerator {
38 pub fn new() -> Self {
39 Self(Arc::new(AtomicU64::new(5)))
40 }
41
42 pub fn next(&self) -> Id {
43 Id(self.0.fetch_add(1, Ordering::SeqCst))
44 }
45}
46
47impl Default for IdGenerator {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53pub struct RpcFrame(Bytes);
54
55impl RpcFrame {
56 pub fn new<T: Serialize>(id: Id, data: T) -> Result<Self> {
57 let cap = size_of::<Id>() + serialized_size(&data)? as usize;
58 let mut buf = BytesMut::with_capacity(cap);
59 buf.put_u64(id.0);
60 let mut writer = buf.writer();
61 serialize_into(&mut writer, &data)?;
62 let buf = writer.into_inner();
63 assert_eq!(cap, buf.capacity());
64 Ok(Self(buf.freeze()))
65 }
66
67 pub fn id(&self) -> Result<Id> {
68 self.0
69 .get(0..size_of::<Id>())
70 .map(|buf| {
71 Id(u64::from_be_bytes(
72 buf.try_into().expect("infallible: hardcode slice size"),
73 ))
74 })
75 .ok_or(Error::Serialize(None))
76 }
77
78 pub fn data<'a, T: Deserialize<'a>>(&'a self) -> Result<T> {
79 Ok(deserialize(
80 self.0
81 .get(size_of::<Id>()..)
82 .ok_or(Error::Serialize(None))?,
83 )?)
84 }
85}
86
87pub type GenericStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'static>>;
88pub type GenericSink<T, E> = Pin<Box<dyn Sink<T, Error = E> + Send + Sync + 'static>>;
89
90pub struct Transport {
91 input: GenericStream<Result<RpcFrame>>,
92 output: GenericSink<RpcFrame, Error>,
93}
94
95impl Transport {
96 pub fn from_streamed<T>(io: T) -> Self
97 where
98 T: AsyncRead + AsyncWrite + Send + Sync + 'static,
99 {
100 let (reader, writer) = split(io);
101 Self::from_streamed_pair(reader, writer)
102 }
103
104 pub fn from_streamed_pair<R, W>(reader: R, writer: W) -> Self
105 where
106 R: AsyncRead + Send + Sync + 'static,
107 W: AsyncWrite + Send + Sync + 'static,
108 {
109 let stream = FramedRead::new(reader, LengthDelimitedCodec::default())
110 .map(|buf| buf.map(BytesMut::freeze).map(RpcFrame).map_err(Error::from));
111 let sink = FramedWrite::new(writer, LengthDelimitedCodec::default())
112 .with(|frame: RpcFrame| ready(Ok(frame.0)));
113 Self::from_framed_pair(stream, sink)
114 }
115
116 pub fn from_framed<T>(io: T) -> Self
117 where
118 T: Stream<Item = Result<RpcFrame>> + Sink<RpcFrame, Error = Error> + Send + Sync + 'static,
119 {
120 let (sink, stream) = io.split();
121 Self::from_framed_pair(stream, sink)
122 }
123
124 pub fn from_framed_pair<T, U>(stream: T, sink: U) -> Self
125 where
126 T: Stream<Item = Result<RpcFrame>> + Send + Sync + 'static,
127 U: Sink<RpcFrame, Error = Error> + Send + Sync + 'static,
128 {
129 Self {
130 input: Box::pin(stream),
131 output: Box::pin(sink),
132 }
133 }
134
135 pub fn new_local() -> (Self, Self) {
136 let (tx1, rx1) = mpsc::unbounded::<RpcFrame>();
137 let (tx2, rx2) = mpsc::unbounded::<RpcFrame>();
138
139 let tx1 = tx1.sink_map_err(|_| Error::Io(std::io::ErrorKind::ConnectionAborted.into()));
140 let tx2 = tx2.sink_map_err(|_| Error::Io(std::io::ErrorKind::ConnectionAborted.into()));
141 let rx1 = rx1.map(Ok);
142 let rx2 = rx2.map(Ok);
143
144 let transport_l = Self::from_framed_pair(rx1, tx2);
145 let transport_r = Self::from_framed_pair(rx2, tx1);
146 (transport_l, transport_r)
147 }
148
149 pub fn split(
150 self,
151 ) -> (
152 GenericStream<Result<RpcFrame>>,
153 GenericSink<RpcFrame, Error>,
154 ) {
155 (self.input, self.output)
156 }
157}