hyperlocal_with_windows/
server_windows.rs

1use hyper::{
2    body::{Body, Incoming},
3    service::service_fn,
4    Request, Response,
5};
6use hyper_util::rt::TokioIo;
7use std::{future::Future, io, path::Path};
8use uds_windows::UnixListener;
9
10use crate::windows::convert_unix_stream_to_nb_tcp_stream;
11
12/// A cross-platform wrapper around a [`tokio::net::UnixListener`] or a Windows
13/// equivalent. Using this type allows code using Unix sockets to be written
14/// once and run on both Unix and Windows.
15///
16/// [`tokio::net::UnixListener`]:
17///     https://docs.rs/tokio/1.39.1/tokio/net/struct.UnixListener.html
18#[derive(Debug)]
19pub struct CommonUnixListener(UnixListener);
20
21impl CommonUnixListener {
22    /// Open a Unix socket.
23    ///
24    /// # Errors
25    ///
26    /// This function will return any errors that occur while trying to open the
27    /// provided path.
28    pub fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
29        UnixListener::bind(path).map(Self)
30    }
31}
32
33/// Extension trait for provisioning a hyper HTTP server over a Unix domain
34/// socket.
35///
36/// # Example
37///
38/// ```rust
39/// use hyper::Response;
40/// use hyperlocal_with_windows::{
41///     remove_unix_socket_if_present, CommonUnixListener, UnixListenerExt,
42/// };
43///
44/// let future = async move {
45///     let path = std::env::temp_dir().join("hyperlocal.sock");
46///     remove_unix_socket_if_present(&path)
47///         .await
48///         .expect("removed any existing unix socket");
49///     let listener = CommonUnixListener::bind(path).expect("parsed unix path");
50///
51///     listener
52///         .serve(|| {
53///             |_request| async {
54///                 Ok::<_, hyper::Error>(Response::new("Hello, world.".to_string()))
55///             }
56///         })
57///         .await
58///         .expect("failed to serve a connection")
59/// };
60/// ```
61pub trait UnixListenerExt {
62    /// Indefinitely accept and respond to connections.
63    ///
64    /// Pass a function which will generate the function which responds to
65    /// all requests for an individual connection.
66    fn serve<MakeResponseFn, ResponseFn, ResponseFuture, B, E>(
67        self,
68        f: MakeResponseFn,
69    ) -> impl Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>>
70    where
71        MakeResponseFn: Fn() -> ResponseFn,
72        ResponseFn: Fn(Request<Incoming>) -> ResponseFuture,
73        ResponseFuture: Future<Output = Result<Response<B>, E>>,
74        B: Body + 'static,
75        <B as Body>::Error: std::error::Error + Send + Sync,
76        E: std::error::Error + Send + Sync + 'static;
77}
78
79impl UnixListenerExt for UnixListener {
80    fn serve<MakeServiceFn, ResponseFn, ResponseFuture, B, E>(
81        self,
82        f: MakeServiceFn,
83    ) -> impl Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>>
84    where
85        MakeServiceFn: Fn() -> ResponseFn,
86        ResponseFn: Fn(Request<Incoming>) -> ResponseFuture,
87        ResponseFuture: Future<Output = Result<Response<B>, E>>,
88        B: Body + 'static,
89        <B as Body>::Error: std::error::Error + Send + Sync,
90        E: std::error::Error + Send + Sync + 'static,
91    {
92        let (tx, mut rx) = tokio::sync::mpsc::channel(32);
93
94        // TODO We aren't fully handling closing the socket. Ideally when the
95        // SocketIncoming is dropped, we would abort the current accept() call
96        // and then close the socket. Currently we only close the socket once we
97        // receive a connection after the SocketIncoming was dropped.
98        std::thread::spawn(move || {
99            loop {
100                let result = self.accept();
101                let result_was_err = result.is_err();
102                if tx.blocking_send(result).is_err() {
103                    // End if the receiver closed.
104                    break;
105                }
106                if result_was_err {
107                    // If there was an error, we should stop trying to accept
108                    // connections.
109                    break;
110                }
111            }
112        });
113
114        async move {
115            while let Some(result) = rx.recv().await {
116                let (stream, _addr) = result?;
117                let stream =
118                    tokio::net::TcpStream::from_std(convert_unix_stream_to_nb_tcp_stream(stream))
119                        .unwrap();
120
121                let io = TokioIo::new(stream);
122
123                let svc_fn = service_fn(f());
124
125                hyper::server::conn::http1::Builder::new()
126                    // On OSX, disabling keep alive prevents serve_connection from
127                    // blocking and later returning an Err derived from E_NOTCONN.
128                    .keep_alive(false)
129                    .serve_connection(io, svc_fn)
130                    .await?;
131            }
132            Err("UnixListener closed".into())
133        }
134    }
135}
136
137impl UnixListenerExt for CommonUnixListener {
138    fn serve<MakeServiceFn, ResponseFn, ResponseFuture, B, E>(
139        self,
140        f: MakeServiceFn,
141    ) -> impl Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>>
142    where
143        MakeServiceFn: Fn() -> ResponseFn,
144        ResponseFn: Fn(Request<Incoming>) -> ResponseFuture,
145        ResponseFuture: Future<Output = Result<Response<B>, E>>,
146        B: Body + 'static,
147        <B as Body>::Error: std::error::Error + Send + Sync,
148        E: std::error::Error + Send + Sync + 'static,
149    {
150        self.0.serve(f)
151    }
152}