1use chrono::Utc;
2use convert_case::Case;
3use convert_case::Casing;
4use eyre::{bail, Context, ContextCompat, Result};
5use futures::future::LocalBoxFuture;
6use futures::FutureExt;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::atomic::Ordering;
10use std::sync::Arc;
11use tokio_tungstenite::tungstenite::handshake::server::{Callback, ErrorResponse, Request, Response};
12use tracing::*;
13
14use crate::libs::toolbox::ArcToolbox;
15use crate::libs::toolbox::RequestContext;
16use crate::libs::toolbox::Toolbox;
17use crate::model::EndpointSchema;
18use crate::model::Type;
19
20use super::WsConnection;
21
22pub struct VerifyProtocol<'a> {
23 pub addr: SocketAddr,
24 pub tx: tokio::sync::mpsc::Sender<String>,
25 pub allow_cors_domains: &'a Option<Vec<String>>,
26}
27
28impl<'a> Callback for VerifyProtocol<'a> {
29 fn on_request(self, request: &Request, mut response: Response) -> Result<Response, ErrorResponse> {
30 let addr = self.addr;
31 debug!(?addr, "handshake request: {:?}", request);
32
33 let protocol = request
34 .headers()
35 .get("Sec-WebSocket-Protocol")
36 .or_else(|| request.headers().get("sec-websocket-protocol"));
37
38 let protocol_str = match protocol {
39 Some(protocol) => protocol
40 .to_str()
41 .map_err(|_| ErrorResponse::new(Some("Sec-WebSocket-Protocol is not valid utf-8".to_owned())))?
42 .to_string(),
43 None => "".to_string(),
44 };
45
46 self.tx.try_send(protocol_str.clone()).unwrap();
47
48 response
49 .headers_mut()
50 .append("Date", Utc::now().to_rfc2822().parse().unwrap());
51 if !protocol_str.is_empty() {
52 response.headers_mut().insert(
53 "Sec-WebSocket-Protocol",
54 protocol_str.split(',').next().unwrap_or("").parse().unwrap(),
55 );
56 }
57
58 response
59 .headers_mut()
60 .insert("Server", "RustWebsocketServer/1.0".parse().unwrap());
61
62 if let Some(allow_cors_domains) = self.allow_cors_domains {
63 if let Some(origin) = request.headers().get("Origin") {
64 let origin = origin.to_str().unwrap();
65 if allow_cors_domains.iter().any(|x| x == origin) {
66 response
67 .headers_mut()
68 .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
69 response
70 .headers_mut()
71 .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
72 }
73 }
74 } else {
75 if let Some(origin) = request.headers().get("Origin") {
77 let origin = origin.to_str().unwrap();
78 response
79 .headers_mut()
80 .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
81 response
82 .headers_mut()
83 .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
84 }
85 }
86
87 debug!(?addr, "Responding handshake with: {:?}", response);
88
89 Ok(response)
90 }
91}
92
93pub trait AuthController: Sync + Send {
94 fn auth(
95 self: Arc<Self>,
96 toolbox: &ArcToolbox,
97 header: String,
98 conn: Arc<WsConnection>,
99 ) -> LocalBoxFuture<'static, Result<()>>;
100}
101
102pub struct SimpleAuthController;
103
104impl AuthController for SimpleAuthController {
105 fn auth(
106 self: Arc<Self>,
107 _toolbox: &ArcToolbox,
108 _header: String,
109 _conn: Arc<WsConnection>,
110 ) -> LocalBoxFuture<'static, Result<()>> {
111 async move { Ok(()) }.boxed()
112 }
113}
114
115pub trait SubAuthController: Sync + Send {
116 fn auth(
117 self: Arc<Self>,
118 toolbox: &ArcToolbox,
119 param: serde_json::Value,
120 ctx: RequestContext,
121 conn: Arc<WsConnection>,
122 ) -> LocalBoxFuture<'static, Result<serde_json::Value>>;
123}
124pub struct EndpointAuthController {
125 pub auth_endpoints: HashMap<String, WsAuthController>,
126}
127pub struct WsAuthController {
128 pub schema: EndpointSchema,
129 pub handler: Arc<dyn SubAuthController>,
130}
131
132impl Default for EndpointAuthController {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138impl EndpointAuthController {
139 pub fn new() -> Self {
140 Self {
141 auth_endpoints: Default::default(),
142 }
143 }
144 pub fn add_auth_endpoint(&mut self, schema: EndpointSchema, handler: impl SubAuthController + 'static) {
145 self.auth_endpoints.insert(
146 schema.name.to_ascii_lowercase(),
147 WsAuthController {
148 schema,
149 handler: Arc::new(handler),
150 },
151 );
152 }
153}
154fn parse_ty(ty: &Type, value: &str) -> Result<serde_json::Value> {
155 Ok(match &ty {
156 Type::String => {
157 let decoded = urlencoding::decode(value)?;
158 serde_json::Value::String(decoded.to_string())
159 }
160 Type::Int => serde_json::Value::Number(
161 value
162 .parse::<i64>()
163 .with_context(|| format!("Failed to parse integer: {}", value))?
164 .into(),
165 ),
166 Type::Boolean => serde_json::Value::Bool(
167 value
168 .parse::<bool>()
169 .with_context(|| format!("Failed to parse boolean: {}", value))?,
170 ),
171 Type::Enum { .. } => serde_json::Value::String(value.to_string()),
172 Type::EnumRef(_) => serde_json::Value::String(value.to_string()),
173 Type::UUID => serde_json::Value::String(value.to_string()),
174 Type::Optional(ty) => parse_ty(ty, value)?,
175 Type::BlockchainAddress => serde_json::Value::String(value.to_string()),
176 ty => bail!("Not implemented {:?}", ty),
177 })
178}
179
180impl AuthController for EndpointAuthController {
181 fn auth(
182 self: Arc<Self>,
183 toolbox: &ArcToolbox,
184 header: String,
185 conn: Arc<WsConnection>,
186 ) -> LocalBoxFuture<'static, Result<()>> {
187 let toolbox = toolbox.clone();
188
189 async move {
190 let splits = header
191 .split(',')
192 .map(|x| x.trim())
193 .filter(|x| !x.is_empty())
194 .map(|x| (&x[..1], &x[1..]))
195 .collect::<HashMap<&str, &str>>();
196
197 let method = splits.get("0").context("Could not find method")?;
198 let endpoint = self
200 .auth_endpoints
201 .get(*method)
202 .with_context(|| format!("Could not find endpoint for method {}", method))?;
203 let mut params = serde_json::Map::new();
204 for (index, param) in endpoint.schema.parameters.iter().enumerate() {
205 let index = index + 1;
206 match splits.get(&index.to_string().as_str()) {
207 Some(value) => {
208 params.insert(param.name.to_case(Case::Camel), parse_ty(¶m.ty, value)?);
209 }
210 None if !matches!(¶m.ty, Type::Optional(_)) => {
211 bail!("Could not find param {} {}", param.name, index);
212 }
213 _ => {}
214 }
215 }
216 let ctx = RequestContext {
217 connection_id: conn.connection_id,
218 user_id: 0,
219 seq: 0,
220 method: endpoint.schema.code,
221 log_id: conn.log_id,
222 role: conn.role.load(Ordering::Relaxed),
223 ip_addr: conn.address.ip(),
224 };
225 let resp = endpoint
226 .handler
227 .clone()
228 .auth(&toolbox, serde_json::Value::Object(params), ctx, conn)
229 .await;
230 debug!("Auth response: {:?}", resp);
231 if let Some(resp) = Toolbox::encode_ws_response(ctx, resp) {
232 toolbox.send(ctx.connection_id, resp);
233 }
234 Ok(())
235 }
236 .boxed_local()
237 }
238}