rpc_toolkit/server/
socket.rs

1use std::path::Path;
2use std::sync::Arc;
3
4use futures::{Future, Stream, StreamExt, TryStreamExt};
5use imbl_value::Value;
6use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
7use tokio::net::{TcpListener, ToSocketAddrs, UnixListener};
8use tokio::sync::Notify;
9use yajrc::RpcError;
10
11use crate::util::{parse_error, JobRunner, StreamUntil};
12use crate::Server;
13
14#[derive(Clone)]
15pub struct ShutdownHandle(Arc<Notify>);
16impl ShutdownHandle {
17    pub fn shutdown(self) {
18        self.0.notify_one();
19    }
20}
21
22impl<Context: crate::Context> Server<Context> {
23    pub fn run_socket<'a, T: AsyncRead + AsyncWrite + Send>(
24        &'a self,
25        listener: impl Stream<Item = std::io::Result<T>> + 'a,
26        error_handler: impl Fn(std::io::Error) + Sync + 'a,
27    ) -> (ShutdownHandle, impl Future<Output = ()> + 'a) {
28        let shutdown = Arc::new(Notify::new());
29        (ShutdownHandle(shutdown.clone()), async move {
30            let mut runner = JobRunner::<std::io::Result<()>>::new();
31            let jobs = StreamUntil::new(listener, shutdown.notified()).map(|pipe| async {
32                let pipe = pipe?;
33                let (r, mut w) = tokio::io::split(pipe);
34                let stream = self.stream(
35                    tokio_stream::wrappers::LinesStream::new(BufReader::new(r).lines())
36                        .map_err(|e| RpcError {
37                            data: Some(e.to_string().into()),
38                            ..yajrc::INTERNAL_ERROR
39                        })
40                        .try_filter_map(|a| async move {
41                            Ok(if a.is_empty() {
42                                None
43                            } else {
44                                Some(serde_json::from_str::<Value>(&a).map_err(parse_error)?)
45                            })
46                        }),
47                );
48                tokio::pin!(stream);
49                while let Some(res) = stream.next().await {
50                    if let Err(e) = async {
51                        let mut buf = serde_json::to_vec(
52                            &res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
53                        )
54                        .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
55                        buf.push(b'\n');
56                        w.write_all(&buf).await
57                    }
58                    .await
59                    {
60                        error_handler(e)
61                    }
62                }
63                Ok(())
64            });
65            tokio::pin!(jobs);
66            while let Some(res) = runner.next_result(&mut jobs).await {
67                if let Err(e) = res {
68                    error_handler(e)
69                }
70            }
71        })
72    }
73    pub fn run_unix<'a>(
74        &'a self,
75        path: impl AsRef<Path> + 'a,
76        error_handler: impl Fn(std::io::Error) + Sync + 'a,
77    ) -> std::io::Result<(ShutdownHandle, impl Future<Output = ()> + 'a)> {
78        let listener = UnixListener::bind(path)?;
79        Ok(self.run_socket(
80            tokio_stream::wrappers::UnixListenerStream::new(listener),
81            error_handler,
82        ))
83    }
84    pub async fn run_tcp<'a>(
85        &'a self,
86        addr: impl ToSocketAddrs + 'a,
87        error_handler: impl Fn(std::io::Error) + Sync + 'a,
88    ) -> std::io::Result<(ShutdownHandle, impl Future<Output = ()> + 'a)> {
89        let listener = TcpListener::bind(addr).await?;
90        Ok(self.run_socket(
91            tokio_stream::wrappers::TcpListenerStream::new(listener),
92            error_handler,
93        ))
94    }
95}