1use eyre::{bail, eyre, ContextCompat, Result};
2use itertools::Itertools;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fs::File;
6use std::net::SocketAddr;
7use std::path::PathBuf;
8use std::sync::atomic::AtomicU32;
9use std::sync::Arc;
10use tokio::io::{AsyncRead, AsyncWrite};
11use tokio::sync::mpsc;
12use tokio::task::LocalSet;
13use tokio_tungstenite::tungstenite::Error as WsError;
14use tokio_tungstenite::tungstenite::Message;
15use tokio_tungstenite::WebSocketStream;
16use tracing::*;
17
18use crate::libs::error_code::ErrorCode;
19use crate::libs::handler::{RequestHandler, RequestHandlerErased};
20use crate::libs::listener::{ConnectionListener, TcpListener, TlsListener};
21use crate::libs::toolbox::{ArcToolbox, RequestContext, Toolbox, TOOLBOX};
22use crate::libs::utils::{get_conn_id, get_log_id};
23use crate::libs::ws::{VerifyProtocol, WsClientSession, WsConnection};
24use crate::model::EndpointSchema;
25use crate::libs::ws::client::WsRequest;
26
27use super::{AuthController, ConnectionId, SimpleAuthController, WebsocketStates, WsEndpoint};
28
29pub struct WebsocketServer {
30 pub auth_controller: Arc<dyn AuthController>,
31 pub handlers: HashMap<u32, WsEndpoint>,
32 pub message_receiver: Option<mpsc::Receiver<ConnectionId>>,
33 pub toolbox: ArcToolbox,
34 pub config: WsServerConfig,
35}
36
37impl WebsocketServer {
38 pub fn new(config: WsServerConfig) -> Self {
39 Self {
40 auth_controller: Arc::new(SimpleAuthController),
41 handlers: Default::default(),
42 message_receiver: None,
43 toolbox: Toolbox::new(),
44 config,
45 }
46 }
47 pub fn set_auth_controller(&mut self, controller: impl AuthController + 'static) {
48 self.auth_controller = Arc::new(controller);
49 }
50 pub fn add_handler<T: RequestHandler + 'static>(&mut self, handler: T) {
51 let schema = serde_json::from_str(T::Request::SCHEMA).expect("Invalid schema");
52 check_handler::<T>(&schema).expect("Invalid handler");
53 self.add_handler_erased(schema, Arc::new(handler))
54 }
55 pub fn add_handler_erased(&mut self, schema: EndpointSchema, handler: Arc<dyn RequestHandlerErased>) {
56 let old = self.handlers.insert(schema.code, WsEndpoint { schema, handler });
57 if let Some(old) = old {
58 panic!(
59 "Overwriting handler for endpoint {} {}",
60 old.schema.code, old.schema.name
61 );
62 }
63 }
64 async fn handle_ws_handshake_and_connection<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
65 self: Arc<Self>,
66 addr: SocketAddr,
67 states: Arc<WebsocketStates>,
68 stream: S,
69 ) -> Result<()> {
70 let (tx, mut rx) = mpsc::channel(1);
71 let hs = tokio_tungstenite::accept_hdr_async(
72 stream,
73 VerifyProtocol {
74 addr,
75 tx,
76 allow_cors_domains: &self.config.allow_cors_urls,
77 },
78 )
79 .await;
80
81 tracing::warn!("handle new WS connection");
83
84 let stream = wrap_ws_error(hs)?;
85 let conn = Arc::new(WsConnection {
86 connection_id: get_conn_id(),
87 user_id: Default::default(),
88 role: AtomicU32::new(0),
89 address: addr,
90 log_id: get_log_id(),
91 });
92 debug!(?addr, "New connection handshaken {:?}", conn);
93 let headers = rx.recv().await.ok_or_else(|| eyre!("Failed to receive ws headers"))?;
94
95 let (tx, rx) = mpsc::channel(100);
96 let conn = Arc::clone(&conn);
97 states.insert(conn.connection_id, tx, conn.clone());
98
99 let auth_result = Arc::clone(&self.auth_controller)
100 .auth(&self.toolbox, headers, Arc::clone(&conn))
101 .await;
102 let raw_ctx = RequestContext::from_conn(&conn);
103 if let Err(err) = auth_result {
104 self.toolbox.send_request_error(
105 &raw_ctx,
106 ErrorCode::new(100400), err.to_string(),
108 );
109 return Err(err);
110 }
111 self.handle_session_connection(conn, states, stream, rx).await;
112
113 Ok(())
114 }
115
116 pub async fn handle_session_connection<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
117 self: Arc<Self>,
118 conn: Arc<WsConnection>,
119 states: Arc<WebsocketStates>,
120 stream: WebSocketStream<S>,
121 rx: mpsc::Receiver<Message>,
122 ) {
123 let addr = conn.address;
124 let context = RequestContext::from_conn(&conn);
125
126 let session = WsClientSession::new(conn, stream, rx, self);
127 session.run().await;
128
129 states.remove(context.connection_id);
130 debug!(?addr, "Connection closed");
131 }
132
133 pub async fn listen(self) -> Result<()> {
134 info!("Listening on {}", self.config.address);
135
136 let addr = tokio::net::lookup_host(&self.config.address)
138 .await?
139 .next()
140 .with_context(|| format!("Failed to lookup host to bind: {}", self.config.address))?;
141
142 let listener = TcpListener::bind(addr).await?;
143 if self.config.insecure {
144 self.listen_impl(Arc::new(listener)).await
145 } else if self.config.pub_certs.is_some() && self.config.priv_key.is_some() {
146 let listener = TlsListener::bind(
148 listener,
149 self.config.pub_certs.clone().unwrap(),
150 self.config.priv_key.clone().unwrap(),
151 )
152 .await?;
153 self.listen_impl(Arc::new(listener)).await
154 } else {
155 bail!("pub_certs and priv_key should be set")
156 }
157 }
158
159 async fn listen_impl<T: ConnectionListener + 'static>(self, listener: Arc<T>) -> Result<()> {
160 let states = Arc::new(WebsocketStates::new());
161 self.toolbox
162 .set_ws_states(states.clone_states(), self.config.header_only);
163 let this = Arc::new(self);
164 let local_set = LocalSet::new();
165 let (mut sigterm, mut sigint) = crate::libs::signal::init_signals()?;
166 local_set
167 .run_until(async {
168 loop {
169 tokio::select! {
170 _ = crate::libs::signal::wait_for_signals(&mut sigterm, &mut sigint) => break,
171 accepted = listener.accept() => {
172 let (stream, addr) = match accepted {
173 Ok(x) => x,
174 Err(err) => {
175 error!("Error while accepting stream: {:?}", err);
176 continue;
177 }
178 };
179 let listener = Arc::clone(&listener);
180 let this = Arc::clone(&this);
181 let states = Arc::clone(&states);
182 local_set.spawn_local(async move {
183 let stream = match listener.handshake(stream).await {
184 Ok(channel) => {
185 info!("Accepted stream from {}", addr);
186 channel
187 }
188 Err(err) => {
189 error!("Error while handshaking stream: {:?}", err);
190 return;
191 }
192 };
193
194 let future = TOOLBOX.scope(this.toolbox.clone(), this.handle_ws_handshake_and_connection(addr, states, stream));
195 if let Err(err) = future.await {
196 error!("Error while handling connection: {:?}", err);
197 }
198 });
199 }
200 }
201 }
202 Ok(())
203 })
204 .await
205 }
206
207 pub fn dump_schemas(&self) -> Result<()> {
208 let _ = std::fs::create_dir_all("docs");
209 let file = format!("docs/{}_alive_endpoints.json", self.config.name);
210 let available_schemas: Vec<String> = self.handlers.values().map(|x| x.schema.name.clone()).sorted().collect();
211 info!("Dumping {} endpoint names to {}", available_schemas.len(), file);
212 serde_json::to_writer_pretty(File::create(file)?, &available_schemas)?;
213 Ok(())
214 }
215}
216
217pub fn wrap_ws_error<T>(err: Result<T, WsError>) -> Result<T> {
218 err.map_err(|x| eyre!(x))
219}
220
221pub fn check_name(cat: &str, be_name: &str, should_name: &str) -> Result<()> {
222 if !be_name.contains(should_name) {
223 bail!("{} name should be {} but got {}", cat, should_name, be_name);
224 } else {
225 Ok(())
226 }
227}
228
229pub fn check_handler<T: RequestHandler + 'static>(schema: &EndpointSchema) -> Result<()> {
230 let handler_name = std::any::type_name::<T>();
231 let should_handler_name = format!("Method{}", schema.name);
232 check_name("Method", handler_name, &should_handler_name)?;
233 let request_name = std::any::type_name::<T::Request>();
234 let should_req_name = format!("{}Request", schema.name);
235 check_name("Request", request_name, &should_req_name)?;
236
237 Ok(())
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize, Default)]
241pub struct WsServerConfig {
242 #[serde(default)]
243 pub name: String,
244 pub address: String,
245 #[serde(default)]
246 pub pub_certs: Option<Vec<PathBuf>>,
247 #[serde(default)]
248 pub priv_key: Option<PathBuf>,
249 #[serde(default)]
250 pub insecure: bool,
251 #[serde(default)]
252 pub debug: bool,
253 #[serde(skip)]
254 pub header_only: bool,
255 #[serde(skip)]
256 pub allow_cors_urls: Arc<Option<Vec<String>>>,
257}