use anyhow::Result;
use binprot::macros::{BinProtRead, BinProtWrite};
use binprot::{BinProtRead, BinProtSize, BinProtWrite};
use std::collections::BTreeMap;
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
#[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)]
struct Handshake(Vec<i64>);
#[derive(BinProtRead, BinProtWrite, Clone, PartialEq)]
enum Sexp {
Atom(String),
List(Vec<Sexp>),
}
impl std::fmt::Debug for Sexp {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Sexp::Atom(atom) => {
if atom.contains(|c: char| !c.is_alphanumeric()) {
fmt.write_str("\"")?;
for c in atom.escape_default() {
std::fmt::Write::write_char(fmt, c)?;
}
fmt.write_str("\"")?;
} else {
fmt.write_str(&atom)?;
}
Ok(())
}
Sexp::List(list) => {
fmt.write_str("(")?;
for (index, sexp) in list.iter().enumerate() {
if index > 0 {
fmt.write_str(" ")?;
}
sexp.fmt(fmt)?;
}
fmt.write_str(")")?;
Ok(())
}
}
}
}
#[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)]
struct Query<T> {
rpc_tag: String,
version: i64,
id: i64,
data: binprot::WithLen<T>,
}
#[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)]
#[polymorphic_variant]
enum Version {
Version(i64),
}
#[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)]
enum RpcError {
BinIoExn(Sexp),
ConnectionClosed,
WriteError(Sexp),
UncaughtExn(Sexp),
UnimplementedRpc((String, Version)),
UnknownQueryId(String),
}
#[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)]
enum RpcResult<T> {
Ok(binprot::WithLen<T>),
Error(RpcError),
}
#[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)]
struct Response<T> {
id: i64,
data: RpcResult<T>,
}
#[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)]
enum Message<Q, R> {
Heartbeat,
Query(Query<Q>),
Response(Response<R>),
}
fn read_bin_prot<T: BinProtRead>(stream: &mut TcpStream, buffer: &mut Vec<u8>) -> Result<T> {
let mut recv_bytes = [0u8; 8];
stream.read_exact(&mut recv_bytes)?;
let recv_len = i64::from_le_bytes(recv_bytes);
buffer.resize(recv_len as usize, 0u8);
stream.read_exact(buffer)?;
let mut slice = buffer.as_slice();
let data = T::binprot_read(&mut slice)?;
Ok(data)
}
fn write_bin_prot<T: BinProtWrite>(stream: &mut TcpStream, v: &T) -> Result<()> {
let len = v.binprot_size() as i64;
stream.write_all(&len.to_le_bytes())?;
v.binprot_write(stream)?;
Ok(())
}
trait JRpc {
const NAME: &'static str;
const VERSION: i64;
type Q;
type R;
}
struct RpcGetUniqueId;
impl JRpc for RpcGetUniqueId {
const NAME: &'static str = "get-unique-id";
const VERSION: i64 = 0;
type Q = ();
type R = i64;
}
struct RpcGetUniqueIdTypo;
impl JRpc for RpcGetUniqueIdTypo {
const NAME: &'static str = "get-unique-id2";
const VERSION: i64 = 0;
type Q = ();
type R = i64;
}
struct RpcSetIdCounter;
impl JRpc for RpcSetIdCounter {
const NAME: &'static str = "set-id-counter";
const VERSION: i64 = 1;
type Q = i64;
type R = ();
}
struct RpcClient {
stream: std::net::TcpStream,
buffer: Vec<u8>,
id: i64,
}
impl RpcClient {
fn connect(address: &str) -> Result<Self> {
let mut stream = TcpStream::connect(address)?;
let mut buffer = vec![0u8; 256];
println!("Successfully connected to {}", address);
let handshake: Handshake = read_bin_prot(&mut stream, &mut buffer)?;
println!("Received {:?}", handshake);
write_bin_prot(&mut stream, &handshake)?;
Ok(RpcClient { stream, buffer, id: 0 })
}
fn dispatch<T: JRpc>(&mut self, query: T::Q) -> Result<Response<T::R>>
where
T::Q: BinProtWrite,
T::R: BinProtRead,
{
self.id = self.id + 1;
let query = Query {
rpc_tag: T::NAME.to_owned(),
version: T::VERSION,
id: self.id,
data: binprot::WithLen(query),
};
write_bin_prot(&mut self.stream, &Message::Query::<T::Q, ()>(query))?;
loop {
let received: Message<(), T::R> =
read_bin_prot(&mut self.stream, &mut self.buffer).unwrap();
match received {
Message::Heartbeat => (),
Message::Response(r) => return Ok(r),
Message::Query(_) => (),
}
}
}
}
trait JRpcImpl {
type Q; type R; type E;
fn rpc_impl(&mut self, q: Self::Q) -> std::result::Result<Self::R, Self::E>;
}
trait ErasedJRpcImpl {
fn erased_rpc_impl(&mut self, stream: &mut TcpStream, id: i64) -> Result<()>;
}
impl<T> ErasedJRpcImpl for T
where
T: JRpcImpl,
T::Q: BinProtRead,
T::R: BinProtWrite,
T::E: std::error::Error,
{
fn erased_rpc_impl(&mut self, stream: &mut TcpStream, id: i64) -> Result<()> {
let query = T::Q::binprot_read(stream)?;
let rpc_result = match self.rpc_impl(query) {
Ok(response) => RpcResult::Ok(binprot::WithLen(response)),
Err(error) => {
let sexp = Sexp::Atom(error.to_string());
RpcResult::Error(RpcError::UncaughtExn(sexp))
}
};
let response = Response { id, data: rpc_result };
write_bin_prot(stream, &Message::Response::<(), T::R>(response))?;
Ok(())
}
}
#[allow(dead_code)]
struct RpcServer {
listener: TcpListener,
buffer: Vec<u8>,
id: i64,
rpc_impls: BTreeMap<String, Box<dyn ErasedJRpcImpl>>,
}
struct GetUniqueIdImpl(i64);
impl JRpcImpl for GetUniqueIdImpl {
type Q = ();
type R = i64;
type E = std::convert::Infallible;
fn rpc_impl(&mut self, _q: Self::Q) -> std::result::Result<Self::R, Self::E> {
let result = self.0;
self.0 += 1;
Ok(result)
}
}
#[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)]
struct ServerQuery {
rpc_tag: String,
version: i64,
id: i64,
data: binprot::Nat0,
}
#[derive(BinProtRead, BinProtWrite, Debug, Clone, PartialEq)]
enum ServerMessage<R> {
Heartbeat,
Query(ServerQuery),
Response(R),
}
impl RpcServer {
fn bind(address: &str) -> Result<Self> {
let listener = TcpListener::bind(address)?;
let buffer = vec![0u8; 256];
println!("Successfully bound to {}", address);
let mut rpc_impls: BTreeMap<String, Box<dyn ErasedJRpcImpl>> = BTreeMap::new();
let get_unique_id_impl: Box<dyn ErasedJRpcImpl> = Box::new(GetUniqueIdImpl(0));
rpc_impls.insert("get-unique-id".to_string(), get_unique_id_impl);
Ok(RpcServer { listener, buffer, id: 0, rpc_impls })
}
fn run(&mut self) -> Result<()> {
for stream in self.listener.incoming() {
let mut stream = stream?;
println!("Got connection {:?}.", stream);
write_bin_prot(&mut stream, &Handshake(vec![4411474, 1]))?;
let handshake: Handshake = read_bin_prot(&mut stream, &mut self.buffer)?;
println!("Received handshake {:?}", handshake);
let mut recv_bytes = [0u8; 8];
loop {
stream.read_exact(&mut recv_bytes)?;
let _recv_len = i64::from_le_bytes(recv_bytes);
let query = ServerMessage::<()>::binprot_read(&mut stream)?;
println!("Received rpc query {:?}", query);
match query {
ServerMessage::Heartbeat => {}
ServerMessage::Query(query) => match self.rpc_impls.get_mut(&query.rpc_tag) {
None => {
let err = RpcError::UnimplementedRpc((
query.rpc_tag,
Version::Version(query.version),
));
let message = ServerMessage::Response(Response::<()> {
id: query.id,
data: RpcResult::Error(err),
});
self.buffer.resize(query.data.0 as usize, 0u8);
stream.read_exact(&mut self.buffer)?;
write_bin_prot(&mut stream, &message)?
}
Some(r) => r.erased_rpc_impl(&mut stream, query.id)?,
},
ServerMessage::Response(()) => unimplemented!(),
};
}
}
Ok(())
}
}
fn main() -> Result<()> {
let arg = std::env::args().skip(1).next();
match arg.as_deref() {
Some("client") => {
let mut client = RpcClient::connect("localhost:8080")?;
let response = client.dispatch::<RpcGetUniqueId>(())?;
println!(">> {:?}", response);
let response = client.dispatch::<RpcGetUniqueId>(())?;
println!(">> {:?}", response);
let response = client.dispatch::<RpcSetIdCounter>(42)?;
println!(">> {:?}", response);
let response = client.dispatch::<RpcGetUniqueId>(())?;
println!(">> {:?}", response);
let response = client.dispatch::<RpcSetIdCounter>(0)?;
println!(">> {:?}", response);
let response = client.dispatch::<RpcGetUniqueIdTypo>(())?;
println!(">> {:?}", response);
}
Some("server") => {
let mut server = RpcServer::bind("localhost:8080")?;
server.run()?
}
Some(_) => {
panic!("unexpected argument, try client or server")
}
None => {
panic!("missing argument")
}
}
Ok(())
}