endpoint_libs/libs/ws/
headers.rs

1use chrono::Utc;
2use convert_case::Case;
3use convert_case::Casing;
4use eyre::{bail, Context, ContextCompat, Result};
5use futures::future::LocalBoxFuture;
6use futures::FutureExt;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::atomic::Ordering;
10use std::sync::Arc;
11use tokio_tungstenite::tungstenite::handshake::server::{Callback, ErrorResponse, Request, Response};
12use tracing::*;
13
14use crate::libs::toolbox::ArcToolbox;
15use crate::libs::toolbox::RequestContext;
16use crate::libs::toolbox::Toolbox;
17use crate::model::EndpointSchema;
18use crate::model::Type;
19
20use super::WsConnection;
21
22pub struct VerifyProtocol<'a> {
23    pub addr: SocketAddr,
24    pub tx: tokio::sync::mpsc::Sender<String>,
25    pub allow_cors_domains: &'a Option<Vec<String>>,
26}
27
28impl<'a> Callback for VerifyProtocol<'a> {
29    fn on_request(self, request: &Request, mut response: Response) -> Result<Response, ErrorResponse> {
30        let addr = self.addr;
31        debug!(?addr, "handshake request: {:?}", request);
32
33        let protocol = request
34            .headers()
35            .get("Sec-WebSocket-Protocol")
36            .or_else(|| request.headers().get("sec-websocket-protocol"));
37
38        let protocol_str = match protocol {
39            Some(protocol) => protocol
40                .to_str()
41                .map_err(|_| ErrorResponse::new(Some("Sec-WebSocket-Protocol is not valid utf-8".to_owned())))?
42                .to_string(),
43            None => "".to_string(),
44        };
45
46        self.tx.try_send(protocol_str.clone()).unwrap();
47
48        response
49            .headers_mut()
50            .append("Date", Utc::now().to_rfc2822().parse().unwrap());
51        if !protocol_str.is_empty() {
52            response.headers_mut().insert(
53                "Sec-WebSocket-Protocol",
54                protocol_str.split(',').next().unwrap_or("").parse().unwrap(),
55            );
56        }
57
58        response
59            .headers_mut()
60            .insert("Server", "RustWebsocketServer/1.0".parse().unwrap());
61
62        if let Some(allow_cors_domains) = self.allow_cors_domains {
63            if let Some(origin) = request.headers().get("Origin") {
64                let origin = origin.to_str().unwrap();
65                if allow_cors_domains.iter().any(|x| x == origin) {
66                    response
67                        .headers_mut()
68                        .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
69                    response
70                        .headers_mut()
71                        .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
72                }
73            }
74        } else {
75            // Allow all domains
76            if let Some(origin) = request.headers().get("Origin") {
77                let origin = origin.to_str().unwrap();
78                response
79                    .headers_mut()
80                    .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
81                response
82                    .headers_mut()
83                    .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
84            }
85        }
86
87        debug!(?addr, "Responding handshake with: {:?}", response);
88
89        Ok(response)
90    }
91}
92
93pub trait AuthController: Sync + Send {
94    fn auth(
95        self: Arc<Self>,
96        toolbox: &ArcToolbox,
97        header: String,
98        conn: Arc<WsConnection>,
99    ) -> LocalBoxFuture<'static, Result<()>>;
100}
101
102pub struct SimpleAuthController;
103
104impl AuthController for SimpleAuthController {
105    fn auth(
106        self: Arc<Self>,
107        _toolbox: &ArcToolbox,
108        _header: String,
109        _conn: Arc<WsConnection>,
110    ) -> LocalBoxFuture<'static, Result<()>> {
111        async move { Ok(()) }.boxed()
112    }
113}
114
115pub trait SubAuthController: Sync + Send {
116    fn auth(
117        self: Arc<Self>,
118        toolbox: &ArcToolbox,
119        param: serde_json::Value,
120        ctx: RequestContext,
121        conn: Arc<WsConnection>,
122    ) -> LocalBoxFuture<'static, Result<serde_json::Value>>;
123}
124pub struct EndpointAuthController {
125    pub auth_endpoints: HashMap<String, WsAuthController>,
126}
127pub struct WsAuthController {
128    pub schema: EndpointSchema,
129    pub handler: Arc<dyn SubAuthController>,
130}
131
132impl Default for EndpointAuthController {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138impl EndpointAuthController {
139    pub fn new() -> Self {
140        Self {
141            auth_endpoints: Default::default(),
142        }
143    }
144    pub fn add_auth_endpoint(&mut self, schema: EndpointSchema, handler: impl SubAuthController + 'static) {
145        self.auth_endpoints.insert(
146            schema.name.to_ascii_lowercase(),
147            WsAuthController {
148                schema,
149                handler: Arc::new(handler),
150            },
151        );
152    }
153}
154fn parse_ty(ty: &Type, value: &str) -> Result<serde_json::Value> {
155    Ok(match &ty {
156        Type::String => {
157            let decoded = urlencoding::decode(value)?;
158            serde_json::Value::String(decoded.to_string())
159        }
160        Type::Int => serde_json::Value::Number(
161            value
162                .parse::<i64>()
163                .with_context(|| format!("Failed to parse integer: {}", value))?
164                .into(),
165        ),
166        Type::Boolean => serde_json::Value::Bool(
167            value
168                .parse::<bool>()
169                .with_context(|| format!("Failed to parse boolean: {}", value))?,
170        ),
171        Type::Enum { .. } => serde_json::Value::String(value.to_string()),
172        Type::EnumRef(_) => serde_json::Value::String(value.to_string()),
173        Type::UUID => serde_json::Value::String(value.to_string()),
174        Type::Optional(ty) => parse_ty(ty, value)?,
175        Type::BlockchainAddress => serde_json::Value::String(value.to_string()),
176        ty => bail!("Not implemented {:?}", ty),
177    })
178}
179
180impl AuthController for EndpointAuthController {
181    fn auth(
182        self: Arc<Self>,
183        toolbox: &ArcToolbox,
184        header: String,
185        conn: Arc<WsConnection>,
186    ) -> LocalBoxFuture<'static, Result<()>> {
187        let toolbox = toolbox.clone();
188
189        async move {
190            let splits = header
191                .split(',')
192                .map(|x| x.trim())
193                .filter(|x| !x.is_empty())
194                .map(|x| (&x[..1], &x[1..]))
195                .collect::<HashMap<&str, &str>>();
196
197            let method = splits.get("0").context("Could not find method")?;
198            // info!("method: {:?}", method);
199            let endpoint = self
200                .auth_endpoints
201                .get(*method)
202                .with_context(|| format!("Could not find endpoint for method {}", method))?;
203            let mut params = serde_json::Map::new();
204            for (index, param) in endpoint.schema.parameters.iter().enumerate() {
205                let index = index + 1;
206                match splits.get(&index.to_string().as_str()) {
207                    Some(value) => {
208                        params.insert(param.name.to_case(Case::Camel), parse_ty(&param.ty, value)?);
209                    }
210                    None if !matches!(&param.ty, Type::Optional(_)) => {
211                        bail!("Could not find param {} {}", param.name, index);
212                    }
213                    _ => {}
214                }
215            }
216            let ctx = RequestContext {
217                connection_id: conn.connection_id,
218                user_id: 0,
219                seq: 0,
220                method: endpoint.schema.code,
221                log_id: conn.log_id,
222                role: conn.role.load(Ordering::Relaxed),
223                ip_addr: conn.address.ip(),
224            };
225            let resp = endpoint
226                .handler
227                .clone()
228                .auth(&toolbox, serde_json::Value::Object(params), ctx, conn)
229                .await;
230            debug!("Auth response: {:?}", resp);
231            if let Some(resp) = Toolbox::encode_ws_response(ctx, resp) {
232                toolbox.send(ctx.connection_id, resp);
233            }
234            Ok(())
235        }
236        .boxed_local()
237    }
238}