1use chrono::Utc;
2use convert_case::Case;
3use convert_case::Casing;
4use eyre::{bail, Context, ContextCompat, Result};
5use futures::future::LocalBoxFuture;
6use futures::FutureExt;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio_tungstenite::tungstenite::handshake::server::{
11 Callback, ErrorResponse, Request, Response,
12};
13use tracing::*;
14
15use crate::libs::toolbox::ArcToolbox;
16use crate::libs::toolbox::RequestContext;
17use crate::libs::toolbox::Toolbox;
18use crate::model::EndpointSchema;
19use crate::model::Type;
20
21use super::WsConnection;
22
23pub struct VerifyProtocol<'a> {
24 pub addr: SocketAddr,
25 pub tx: tokio::sync::mpsc::Sender<String>,
26 pub allow_cors_domains: &'a Option<Vec<String>>,
27}
28
29impl Callback for VerifyProtocol<'_> {
30 fn on_request(
31 self,
32 request: &Request,
33 mut response: Response,
34 ) -> Result<Response, ErrorResponse> {
35 let addr = self.addr;
36 debug!(?addr, "handshake request: {:?}", request);
37
38 let protocol = request
39 .headers()
40 .get("Sec-WebSocket-Protocol")
41 .or_else(|| request.headers().get("sec-websocket-protocol"));
42
43 let protocol_str = match protocol {
44 Some(protocol) => protocol
45 .to_str()
46 .map_err(|_| {
47 ErrorResponse::new(Some("Sec-WebSocket-Protocol is not valid utf-8".to_owned()))
48 })?
49 .to_string(),
50 None => "".to_string(),
51 };
52
53 self.tx.try_send(protocol_str.clone()).unwrap();
54
55 response
56 .headers_mut()
57 .append("Date", Utc::now().to_rfc2822().parse().unwrap());
58 if !protocol_str.is_empty() {
59 response.headers_mut().insert(
60 "Sec-WebSocket-Protocol",
61 protocol_str
62 .split(',')
63 .next()
64 .unwrap_or("")
65 .parse()
66 .unwrap(),
67 );
68 }
69
70 response
71 .headers_mut()
72 .insert("Server", "RustWebsocketServer/1.0".parse().unwrap());
73
74 if let Some(allow_cors_domains) = self.allow_cors_domains {
75 if let Some(origin) = request.headers().get("Origin") {
76 let origin = origin.to_str().unwrap();
77 if allow_cors_domains.iter().any(|x| x == origin) {
78 response
79 .headers_mut()
80 .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
81 response
82 .headers_mut()
83 .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
84 }
85 }
86 } else {
87 if let Some(origin) = request.headers().get("Origin") {
89 let origin = origin.to_str().unwrap();
90 response
91 .headers_mut()
92 .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
93 response
94 .headers_mut()
95 .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
96 }
97 }
98
99 debug!(?addr, "Responding handshake with: {:?}", response);
100
101 Ok(response)
102 }
103}
104
105pub trait AuthController: Sync + Send {
106 fn auth(
107 self: Arc<Self>,
108 toolbox: &ArcToolbox,
109 header: String,
110 conn: Arc<WsConnection>,
111 ) -> LocalBoxFuture<'static, Result<()>>;
112}
113
114pub struct SimpleAuthController;
115
116impl AuthController for SimpleAuthController {
117 fn auth(
118 self: Arc<Self>,
119 _toolbox: &ArcToolbox,
120 _header: String,
121 _conn: Arc<WsConnection>,
122 ) -> LocalBoxFuture<'static, Result<()>> {
123 async move { Ok(()) }.boxed()
124 }
125}
126
127pub trait SubAuthController: Sync + Send {
128 fn auth(
129 self: Arc<Self>,
130 toolbox: &ArcToolbox,
131 param: serde_json::Value,
132 ctx: RequestContext,
133 conn: Arc<WsConnection>,
134 ) -> LocalBoxFuture<'static, Result<serde_json::Value>>;
135}
136pub struct EndpointAuthController {
137 pub auth_endpoints: HashMap<String, WsAuthController>,
138}
139pub struct WsAuthController {
140 pub schema: EndpointSchema,
141 pub handler: Arc<dyn SubAuthController>,
142}
143
144impl Default for EndpointAuthController {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150impl EndpointAuthController {
151 pub fn new() -> Self {
152 Self {
153 auth_endpoints: Default::default(),
154 }
155 }
156 pub fn add_auth_endpoint(
157 &mut self,
158 schema: EndpointSchema,
159 handler: impl SubAuthController + 'static,
160 ) {
161 self.auth_endpoints.insert(
162 schema.name.to_ascii_lowercase(),
163 WsAuthController {
164 schema,
165 handler: Arc::new(handler),
166 },
167 );
168 }
169}
170fn parse_ty(ty: &Type, value: &str) -> Result<serde_json::Value> {
171 Ok(match &ty {
172 Type::String => {
173 let decoded = urlencoding::decode(value)?;
174 serde_json::Value::String(decoded.to_string())
175 }
176 Type::Int => serde_json::Value::Number(
177 value
178 .parse::<i64>()
179 .with_context(|| format!("Failed to parse integer: {value}"))?
180 .into(),
181 ),
182 Type::Boolean => serde_json::Value::Bool(
183 value
184 .parse::<bool>()
185 .with_context(|| format!("Failed to parse boolean: {value}"))?,
186 ),
187 Type::Enum { .. } => serde_json::Value::String(value.to_string()),
188 Type::EnumRef(_) => serde_json::Value::String(value.to_string()),
189 Type::UUID => serde_json::Value::String(value.to_string()),
190 Type::Optional(ty) => parse_ty(ty, value)?,
191 Type::BlockchainAddress => serde_json::Value::String(value.to_string()),
192 ty => bail!("Not implemented {:?}", ty),
193 })
194}
195
196impl AuthController for EndpointAuthController {
197 fn auth(
198 self: Arc<Self>,
199 toolbox: &ArcToolbox,
200 header: String,
201 conn: Arc<WsConnection>,
202 ) -> LocalBoxFuture<'static, Result<()>> {
203 let toolbox = toolbox.clone();
204
205 async move {
206 let splits = header
207 .split(',')
208 .map(|x| x.trim())
209 .filter(|x| !x.is_empty())
210 .map(|x| (&x[..1], &x[1..]))
211 .collect::<HashMap<&str, &str>>();
212
213 let method = splits.get("0").context("Could not find method")?;
214 let endpoint = self
216 .auth_endpoints
217 .get(*method)
218 .with_context(|| format!("Could not find endpoint for method {method}"))?;
219 let mut params = serde_json::Map::new();
220 for (index, param) in endpoint.schema.parameters.iter().enumerate() {
221 let index = index + 1;
222 match splits.get(&index.to_string().as_str()) {
223 Some(value) => {
224 params.insert(param.name.to_case(Case::Camel), parse_ty(¶m.ty, value)?);
225 }
226 None if !matches!(¶m.ty, Type::Optional(_)) => {
227 bail!("Could not find param {} {}", param.name, index);
228 }
229 _ => {}
230 }
231 }
232 let roles = conn.roles.read().clone();
233 let ctx = RequestContext {
234 connection_id: conn.connection_id,
235 user_id: 0,
236 seq: 0,
237 method: endpoint.schema.code,
238 log_id: conn.log_id,
239 roles: roles.clone(),
240 ip_addr: conn.address.ip(),
241 };
242 let resp = endpoint
243 .handler
244 .clone()
245 .auth(
246 &toolbox,
247 serde_json::Value::Object(params),
248 ctx.clone(),
249 conn,
250 )
251 .await;
252 debug!("Auth response: {:?}", resp);
253 let conn_id = ctx.connection_id;
254 if let Some(resp) = Toolbox::encode_ws_response(ctx, resp) {
255 toolbox.send(conn_id, resp);
256 }
257 Ok(())
258 }
259 .boxed_local()
260 }
261}