1use chrono::Utc;
2use convert_case::Case;
3use convert_case::Casing;
4use eyre::{bail, Context, ContextCompat, Result};
5use futures::future::LocalBoxFuture;
6use futures::FutureExt;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::atomic::Ordering;
10use std::sync::Arc;
11use tokio_tungstenite::tungstenite::handshake::server::{
12 Callback, ErrorResponse, Request, Response,
13};
14use tracing::*;
15
16use crate::libs::toolbox::ArcToolbox;
17use crate::libs::toolbox::RequestContext;
18use crate::libs::toolbox::Toolbox;
19use crate::model::EndpointSchema;
20use crate::model::Type;
21
22use super::WsConnection;
23
24pub struct VerifyProtocol<'a> {
25 pub addr: SocketAddr,
26 pub tx: tokio::sync::mpsc::Sender<String>,
27 pub allow_cors_domains: &'a Option<Vec<String>>,
28}
29
30impl Callback for VerifyProtocol<'_> {
31 fn on_request(
32 self,
33 request: &Request,
34 mut response: Response,
35 ) -> Result<Response, ErrorResponse> {
36 let addr = self.addr;
37 debug!(?addr, "handshake request: {:?}", request);
38
39 let protocol = request
40 .headers()
41 .get("Sec-WebSocket-Protocol")
42 .or_else(|| request.headers().get("sec-websocket-protocol"));
43
44 let protocol_str = match protocol {
45 Some(protocol) => protocol
46 .to_str()
47 .map_err(|_| {
48 ErrorResponse::new(Some("Sec-WebSocket-Protocol is not valid utf-8".to_owned()))
49 })?
50 .to_string(),
51 None => "".to_string(),
52 };
53
54 self.tx.try_send(protocol_str.clone()).unwrap();
55
56 response
57 .headers_mut()
58 .append("Date", Utc::now().to_rfc2822().parse().unwrap());
59 if !protocol_str.is_empty() {
60 response.headers_mut().insert(
61 "Sec-WebSocket-Protocol",
62 protocol_str
63 .split(',')
64 .next()
65 .unwrap_or("")
66 .parse()
67 .unwrap(),
68 );
69 }
70
71 response
72 .headers_mut()
73 .insert("Server", "RustWebsocketServer/1.0".parse().unwrap());
74
75 if let Some(allow_cors_domains) = self.allow_cors_domains {
76 if let Some(origin) = request.headers().get("Origin") {
77 let origin = origin.to_str().unwrap();
78 if allow_cors_domains.iter().any(|x| x == origin) {
79 response
80 .headers_mut()
81 .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
82 response
83 .headers_mut()
84 .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
85 }
86 }
87 } else {
88 if let Some(origin) = request.headers().get("Origin") {
90 let origin = origin.to_str().unwrap();
91 response
92 .headers_mut()
93 .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
94 response
95 .headers_mut()
96 .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
97 }
98 }
99
100 debug!(?addr, "Responding handshake with: {:?}", response);
101
102 Ok(response)
103 }
104}
105
106pub trait AuthController: Sync + Send {
107 fn auth(
108 self: Arc<Self>,
109 toolbox: &ArcToolbox,
110 header: String,
111 conn: Arc<WsConnection>,
112 ) -> LocalBoxFuture<'static, Result<()>>;
113}
114
115pub struct SimpleAuthController;
116
117impl AuthController for SimpleAuthController {
118 fn auth(
119 self: Arc<Self>,
120 _toolbox: &ArcToolbox,
121 _header: String,
122 _conn: Arc<WsConnection>,
123 ) -> LocalBoxFuture<'static, Result<()>> {
124 async move { Ok(()) }.boxed()
125 }
126}
127
128pub trait SubAuthController: Sync + Send {
129 fn auth(
130 self: Arc<Self>,
131 toolbox: &ArcToolbox,
132 param: serde_json::Value,
133 ctx: RequestContext,
134 conn: Arc<WsConnection>,
135 ) -> LocalBoxFuture<'static, Result<serde_json::Value>>;
136}
137pub struct EndpointAuthController {
138 pub auth_endpoints: HashMap<String, WsAuthController>,
139}
140pub struct WsAuthController {
141 pub schema: EndpointSchema,
142 pub handler: Arc<dyn SubAuthController>,
143}
144
145impl Default for EndpointAuthController {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151impl EndpointAuthController {
152 pub fn new() -> Self {
153 Self {
154 auth_endpoints: Default::default(),
155 }
156 }
157 pub fn add_auth_endpoint(
158 &mut self,
159 schema: EndpointSchema,
160 handler: impl SubAuthController + 'static,
161 ) {
162 self.auth_endpoints.insert(
163 schema.name.to_ascii_lowercase(),
164 WsAuthController {
165 schema,
166 handler: Arc::new(handler),
167 },
168 );
169 }
170}
171fn parse_ty(ty: &Type, value: &str) -> Result<serde_json::Value> {
172 Ok(match &ty {
173 Type::String => {
174 let decoded = urlencoding::decode(value)?;
175 serde_json::Value::String(decoded.to_string())
176 }
177 Type::Int => serde_json::Value::Number(
178 value
179 .parse::<i64>()
180 .with_context(|| format!("Failed to parse integer: {value}"))?
181 .into(),
182 ),
183 Type::Boolean => serde_json::Value::Bool(
184 value
185 .parse::<bool>()
186 .with_context(|| format!("Failed to parse boolean: {value}"))?,
187 ),
188 Type::Enum { .. } => serde_json::Value::String(value.to_string()),
189 Type::EnumRef(_) => serde_json::Value::String(value.to_string()),
190 Type::UUID => serde_json::Value::String(value.to_string()),
191 Type::Optional(ty) => parse_ty(ty, value)?,
192 Type::BlockchainAddress => serde_json::Value::String(value.to_string()),
193 ty => bail!("Not implemented {:?}", ty),
194 })
195}
196
197impl AuthController for EndpointAuthController {
198 fn auth(
199 self: Arc<Self>,
200 toolbox: &ArcToolbox,
201 header: String,
202 conn: Arc<WsConnection>,
203 ) -> LocalBoxFuture<'static, Result<()>> {
204 let toolbox = toolbox.clone();
205
206 async move {
207 let splits = header
208 .split(',')
209 .map(|x| x.trim())
210 .filter(|x| !x.is_empty())
211 .map(|x| (&x[..1], &x[1..]))
212 .collect::<HashMap<&str, &str>>();
213
214 let method = splits.get("0").context("Could not find method")?;
215 let endpoint = self
217 .auth_endpoints
218 .get(*method)
219 .with_context(|| format!("Could not find endpoint for method {method}"))?;
220 let mut params = serde_json::Map::new();
221 for (index, param) in endpoint.schema.parameters.iter().enumerate() {
222 let index = index + 1;
223 match splits.get(&index.to_string().as_str()) {
224 Some(value) => {
225 params.insert(param.name.to_case(Case::Camel), parse_ty(¶m.ty, value)?);
226 }
227 None if !matches!(¶m.ty, Type::Optional(_)) => {
228 bail!("Could not find param {} {}", param.name, index);
229 }
230 _ => {}
231 }
232 }
233 let ctx = RequestContext {
234 connection_id: conn.connection_id,
235 user_id: 0,
236 seq: 0,
237 method: endpoint.schema.code,
238 log_id: conn.log_id,
239 role: conn.role.load(Ordering::Relaxed),
240 ip_addr: conn.address.ip(),
241 };
242 let resp = endpoint
243 .handler
244 .clone()
245 .auth(&toolbox, serde_json::Value::Object(params), ctx, conn)
246 .await;
247 debug!("Auth response: {:?}", resp);
248 if let Some(resp) = Toolbox::encode_ws_response(ctx, resp) {
249 toolbox.send(ctx.connection_id, resp);
250 }
251 Ok(())
252 }
253 .boxed_local()
254 }
255}