1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
use std::{collections::HashMap, io, mem::size_of, result, sync::Arc};
use arpy::FnRemote;
use ciborium::de;
use futures::future::BoxFuture;
use thiserror::Error;
use crate::FnRemoteBody;
#[derive(Default)]
pub struct WebSocketRouter(HashMap<Id, RpcHandler>);
impl WebSocketRouter {
pub fn new() -> Self {
Self::default()
}
pub fn handle<F, FSig>(mut self, f: F) -> Self
where
F: FnRemoteBody<FSig> + Send + Sync + 'static,
FSig: FnRemote + Send + Sync + 'static,
{
let id = FSig::ID.as_bytes().to_vec();
let f = Arc::new(f);
self.0.insert(
id,
Box::new(move |body| Box::pin(Self::run(f.clone(), body))),
);
self
}
async fn run<F, FSig>(f: Arc<F>, input: &[u8]) -> Result<Vec<u8>>
where
F: FnRemoteBody<FSig> + Send + Sync + 'static,
FSig: FnRemote + Send + Sync + 'static,
{
let args: FSig = ciborium::de::from_reader(input).map_err(Error::Deserialization)?;
let result = f.run(args).await;
let mut body = Vec::new();
ciborium::ser::into_writer(&result, &mut body).unwrap();
Ok(body)
}
}
pub struct WebSocketHandler(HashMap<Id, RpcHandler>);
impl WebSocketHandler {
pub fn new(router: WebSocketRouter) -> Self {
Self(router.0)
}
pub async fn handle_msg(&self, msg: &[u8]) -> Result<Vec<u8>> {
let (id, msg) = split_message(msg, size_of::<u32>(), "ID len")?;
let id_len = u32::from_le_bytes(id.try_into().unwrap());
let (id, args) = split_message(msg, id_len as usize, "ID")?;
let Some(function) = self.0.get(id)
else { return Err(Error::FunctionNotFound) };
function(args).await
}
}
fn split_message<'a>(msg: &'a [u8], mid: usize, name: &str) -> Result<(&'a [u8], &'a [u8])> {
if mid > msg.len() {
return Err(Error::Protocol(format!("Not enought bytes for {name}")));
}
Ok(msg.split_at(mid))
}
#[derive(Error, Debug)]
pub enum Error {
#[error("Function not found")]
FunctionNotFound,
#[error("Error unpacking message: {0}")]
Protocol(String),
#[error("Deserialization: {0}")]
Deserialization(de::Error<io::Error>),
}
pub type Result<T> = result::Result<T, Error>;
type Id = Vec<u8>;
type RpcHandler =
Box<dyn for<'a> Fn(&'a [u8]) -> BoxFuture<'a, Result<Vec<u8>>> + Send + Sync + 'static>;