rpc_toolkit/server/
socket.rs1use std::path::Path;
2use std::sync::Arc;
3
4use futures::{Future, Stream, StreamExt, TryStreamExt};
5use imbl_value::Value;
6use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
7use tokio::net::{TcpListener, ToSocketAddrs, UnixListener};
8use tokio::sync::Notify;
9use yajrc::RpcError;
10
11use crate::util::{parse_error, JobRunner, StreamUntil};
12use crate::Server;
13
14#[derive(Clone)]
15pub struct ShutdownHandle(Arc<Notify>);
16impl ShutdownHandle {
17 pub fn shutdown(self) {
18 self.0.notify_one();
19 }
20}
21
22impl<Context: crate::Context> Server<Context> {
23 pub fn run_socket<'a, T: AsyncRead + AsyncWrite + Send>(
24 &'a self,
25 listener: impl Stream<Item = std::io::Result<T>> + 'a,
26 error_handler: impl Fn(std::io::Error) + Sync + 'a,
27 ) -> (ShutdownHandle, impl Future<Output = ()> + 'a) {
28 let shutdown = Arc::new(Notify::new());
29 (ShutdownHandle(shutdown.clone()), async move {
30 let mut runner = JobRunner::<std::io::Result<()>>::new();
31 let jobs = StreamUntil::new(listener, shutdown.notified()).map(|pipe| async {
32 let pipe = pipe?;
33 let (r, mut w) = tokio::io::split(pipe);
34 let stream = self.stream(
35 tokio_stream::wrappers::LinesStream::new(BufReader::new(r).lines())
36 .map_err(|e| RpcError {
37 data: Some(e.to_string().into()),
38 ..yajrc::INTERNAL_ERROR
39 })
40 .try_filter_map(|a| async move {
41 Ok(if a.is_empty() {
42 None
43 } else {
44 Some(serde_json::from_str::<Value>(&a).map_err(parse_error)?)
45 })
46 }),
47 );
48 tokio::pin!(stream);
49 while let Some(res) = stream.next().await {
50 if let Err(e) = async {
51 let mut buf = serde_json::to_vec(
52 &res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
53 )
54 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
55 buf.push(b'\n');
56 w.write_all(&buf).await
57 }
58 .await
59 {
60 error_handler(e)
61 }
62 }
63 Ok(())
64 });
65 tokio::pin!(jobs);
66 while let Some(res) = runner.next_result(&mut jobs).await {
67 if let Err(e) = res {
68 error_handler(e)
69 }
70 }
71 })
72 }
73 pub fn run_unix<'a>(
74 &'a self,
75 path: impl AsRef<Path> + 'a,
76 error_handler: impl Fn(std::io::Error) + Sync + 'a,
77 ) -> std::io::Result<(ShutdownHandle, impl Future<Output = ()> + 'a)> {
78 let listener = UnixListener::bind(path)?;
79 Ok(self.run_socket(
80 tokio_stream::wrappers::UnixListenerStream::new(listener),
81 error_handler,
82 ))
83 }
84 pub async fn run_tcp<'a>(
85 &'a self,
86 addr: impl ToSocketAddrs + 'a,
87 error_handler: impl Fn(std::io::Error) + Sync + 'a,
88 ) -> std::io::Result<(ShutdownHandle, impl Future<Output = ()> + 'a)> {
89 let listener = TcpListener::bind(addr).await?;
90 Ok(self.run_socket(
91 tokio_stream::wrappers::TcpListenerStream::new(listener),
92 error_handler,
93 ))
94 }
95}