usdpl_back/
websockets.rs

1use ratchet_rs::deflate::DeflateExtProvider;
2use ratchet_rs::{Error as RatchetError, ProtocolRegistry, WebSocketConfig};
3use tokio::net::{TcpListener, TcpStream};
4
5use nrpc::_helpers::futures::StreamExt;
6
7use crate::rpc::StaticServiceRegistry;
8
9struct MethodDescriptor<'a> {
10    service: &'a str,
11    method: &'a str,
12}
13
14/// Handler for communication to and from the front-end
15pub struct WebsocketServer {
16    services: StaticServiceRegistry,
17    port: u16,
18}
19
20impl WebsocketServer {
21    /// Initialise an instance of the back-end websocket server
22    pub fn new(port_usdpl: u16) -> Self {
23        Self {
24            services: StaticServiceRegistry::with_builtins(),
25            port: port_usdpl,
26        }
27    }
28
29    /// Get the service registry that the server handles
30    pub fn registry(&mut self) -> &'_ mut StaticServiceRegistry {
31        &mut self.services
32    }
33
34    /// Register a nRPC service for this server to handle
35    pub fn register<S: nrpc::ServerService<'static> + Send + 'static>(mut self, service: S) -> Self {
36        self.services.register(service);
37        self
38    }
39
40    /// Run the web server forever, asynchronously
41    pub async fn run(&self) -> std::io::Result<()> {
42        #[cfg(debug_assertions)]
43        let addr = (std::net::Ipv4Addr::UNSPECIFIED, self.port);
44        #[cfg(not(debug_assertions))]
45        let addr = (std::net::Ipv4Addr::LOCALHOST, self.port);
46
47        let tcp = TcpListener::bind(addr).await?;
48
49        while let Ok((stream, _addr_do_not_use)) = tcp.accept().await {
50            tokio::spawn(error_logger("USDPL websocket server error", Self::connection_handler(self.services.clone(), stream)));
51        }
52
53        Ok(())
54    }
55
56    #[cfg(feature = "blocking")]
57    /// Run the server forever, blocking the current thread
58    pub fn run_blocking(self) -> std::io::Result<()> {
59        let runner = tokio::runtime::Builder::new_multi_thread()
60            .enable_all()
61            .build()?;
62        runner.block_on(self.run())
63    }
64
65    async fn connection_handler(
66        mut services: StaticServiceRegistry,
67        stream: TcpStream,
68    ) -> Result<(), RatchetError> {
69        log::debug!("connection_handler invoked!");
70        let upgraded = ratchet_rs::accept_with(
71            stream,
72            WebSocketConfig::default(),
73            DeflateExtProvider::default(),
74            ProtocolRegistry::new(["usdpl-nrpc"])?,
75        )
76        .await?
77        .upgrade()
78        .await?;
79
80        let request_path = upgraded.request.uri().path();
81
82        log::debug!("accepted new connection on uri {}", request_path);
83
84        let websocket = std::sync::Arc::new(tokio::sync::Mutex::new(upgraded.websocket));
85
86        let descriptor = Self::parse_uri_path(request_path)
87            .map_err(|e| RatchetError::with_cause(ratchet_rs::ErrorKind::Protocol, e))?;
88
89        let input_stream = Box::new(nrpc::_helpers::futures::stream::StreamExt::boxed(crate::rpc::ws_stream(websocket.clone())));
90        let output_stream = services
91            .call_descriptor(
92                descriptor.service,
93                descriptor.method,
94                input_stream,
95            )
96            .await
97            .map_err(|e| {
98                RatchetError::with_cause(ratchet_rs::ErrorKind::Protocol, e.to_string())
99            })?;
100
101        output_stream.for_each(|result| async {
102            match result {
103                Ok(msg) => {
104                    let mut ws_lock = websocket.lock().await;
105                    if let Err(e) = ws_lock.write_binary(msg).await {
106                        log::error!("websocket error while writing response on uri {}: {}", request_path, e);
107                    }
108                },
109                Err(e) => {
110                    log::error!("service error while writing response on uri {}: {}", request_path, e);
111                }
112            }
113        }).await;
114
115        websocket.lock().await.close(ratchet_rs::CloseReason {
116            code: ratchet_rs::CloseCode::Normal,
117            description: None,
118        }).await?;
119
120        /*let mut buf = BytesMut::new();
121        loop {
122            match websocket.read(&mut buf).await? {
123                Message::Text => {
124                    return Err(RatchetError::with_cause(
125                        ratchet_rs::ErrorKind::Protocol,
126                        "Websocket text messages are not accepted",
127                    ))
128                }
129                Message::Binary => {
130                    log::debug!("got binary ws message on uri {}", request_path);
131                    let response = services
132                        .call_descriptor(
133                            descriptor.service,
134                            descriptor.method,
135                            buf.clone().freeze(),
136                        )
137                        .await
138                        .map_err(|e| {
139                            RatchetError::with_cause(ratchet_rs::ErrorKind::Protocol, e.to_string())
140                        })?;
141                    log::debug!("service completed response on uri {}", request_path);
142                    websocket.write_binary(response).await?;
143                }
144                Message::Ping(x) => websocket.write_pong(x).await?,
145                Message::Pong(_) => {}
146                Message::Close(_) => break,
147            }
148        }*/
149        log::debug!("ws connection {} closed", request_path);
150        Ok(())
151    }
152
153    fn parse_uri_path<'a>(path: &'a str) -> Result<MethodDescriptor<'a>, &'static str> {
154        let mut iter = path.trim_matches('/').split('/');
155        if let Some(service) = iter.next() {
156            if let Some(method) = iter.next() {
157                if iter.next().is_none() {
158                    return Ok(MethodDescriptor { service, method });
159                } else {
160                    Err("URL path has too many separators")
161                }
162            } else {
163                Err("URL path has no method")
164            }
165        } else {
166            Err("URL path has no service")
167        }
168    }
169}
170
171async fn error_logger<E: std::error::Error>(msg: &'static str, f: impl core::future::Future<Output=Result<(), E>>) {
172    if let Err(e) = f.await {
173        log::error!("{}: {}", msg, e);
174    }
175}