endpoint_libs/libs/ws/
headers.rs

1use chrono::Utc;
2use convert_case::Case;
3use convert_case::Casing;
4use eyre::{bail, Context, ContextCompat, Result};
5use futures::future::LocalBoxFuture;
6use futures::FutureExt;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::atomic::Ordering;
10use std::sync::Arc;
11use tokio_tungstenite::tungstenite::handshake::server::{
12    Callback, ErrorResponse, Request, Response,
13};
14use tracing::*;
15
16use crate::libs::toolbox::ArcToolbox;
17use crate::libs::toolbox::RequestContext;
18use crate::libs::toolbox::Toolbox;
19use crate::model::EndpointSchema;
20use crate::model::Type;
21
22use super::WsConnection;
23
24pub struct VerifyProtocol<'a> {
25    pub addr: SocketAddr,
26    pub tx: tokio::sync::mpsc::Sender<String>,
27    pub allow_cors_domains: &'a Option<Vec<String>>,
28}
29
30impl Callback for VerifyProtocol<'_> {
31    fn on_request(
32        self,
33        request: &Request,
34        mut response: Response,
35    ) -> Result<Response, ErrorResponse> {
36        let addr = self.addr;
37        debug!(?addr, "handshake request: {:?}", request);
38
39        let protocol = request
40            .headers()
41            .get("Sec-WebSocket-Protocol")
42            .or_else(|| request.headers().get("sec-websocket-protocol"));
43
44        let protocol_str = match protocol {
45            Some(protocol) => protocol
46                .to_str()
47                .map_err(|_| {
48                    ErrorResponse::new(Some("Sec-WebSocket-Protocol is not valid utf-8".to_owned()))
49                })?
50                .to_string(),
51            None => "".to_string(),
52        };
53
54        self.tx.try_send(protocol_str.clone()).unwrap();
55
56        response
57            .headers_mut()
58            .append("Date", Utc::now().to_rfc2822().parse().unwrap());
59        if !protocol_str.is_empty() {
60            response.headers_mut().insert(
61                "Sec-WebSocket-Protocol",
62                protocol_str
63                    .split(',')
64                    .next()
65                    .unwrap_or("")
66                    .parse()
67                    .unwrap(),
68            );
69        }
70
71        response
72            .headers_mut()
73            .insert("Server", "RustWebsocketServer/1.0".parse().unwrap());
74
75        if let Some(allow_cors_domains) = self.allow_cors_domains {
76            if let Some(origin) = request.headers().get("Origin") {
77                let origin = origin.to_str().unwrap();
78                if allow_cors_domains.iter().any(|x| x == origin) {
79                    response
80                        .headers_mut()
81                        .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
82                    response
83                        .headers_mut()
84                        .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
85                }
86            }
87        } else {
88            // Allow all domains
89            if let Some(origin) = request.headers().get("Origin") {
90                let origin = origin.to_str().unwrap();
91                response
92                    .headers_mut()
93                    .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
94                response
95                    .headers_mut()
96                    .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
97            }
98        }
99
100        debug!(?addr, "Responding handshake with: {:?}", response);
101
102        Ok(response)
103    }
104}
105
106pub trait AuthController: Sync + Send {
107    fn auth(
108        self: Arc<Self>,
109        toolbox: &ArcToolbox,
110        header: String,
111        conn: Arc<WsConnection>,
112    ) -> LocalBoxFuture<'static, Result<()>>;
113}
114
115pub struct SimpleAuthController;
116
117impl AuthController for SimpleAuthController {
118    fn auth(
119        self: Arc<Self>,
120        _toolbox: &ArcToolbox,
121        _header: String,
122        _conn: Arc<WsConnection>,
123    ) -> LocalBoxFuture<'static, Result<()>> {
124        async move { Ok(()) }.boxed()
125    }
126}
127
128pub trait SubAuthController: Sync + Send {
129    fn auth(
130        self: Arc<Self>,
131        toolbox: &ArcToolbox,
132        param: serde_json::Value,
133        ctx: RequestContext,
134        conn: Arc<WsConnection>,
135    ) -> LocalBoxFuture<'static, Result<serde_json::Value>>;
136}
137pub struct EndpointAuthController {
138    pub auth_endpoints: HashMap<String, WsAuthController>,
139}
140pub struct WsAuthController {
141    pub schema: EndpointSchema,
142    pub handler: Arc<dyn SubAuthController>,
143}
144
145impl Default for EndpointAuthController {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151impl EndpointAuthController {
152    pub fn new() -> Self {
153        Self {
154            auth_endpoints: Default::default(),
155        }
156    }
157    pub fn add_auth_endpoint(
158        &mut self,
159        schema: EndpointSchema,
160        handler: impl SubAuthController + 'static,
161    ) {
162        self.auth_endpoints.insert(
163            schema.name.to_ascii_lowercase(),
164            WsAuthController {
165                schema,
166                handler: Arc::new(handler),
167            },
168        );
169    }
170}
171fn parse_ty(ty: &Type, value: &str) -> Result<serde_json::Value> {
172    Ok(match &ty {
173        Type::String => {
174            let decoded = urlencoding::decode(value)?;
175            serde_json::Value::String(decoded.to_string())
176        }
177        Type::Int => serde_json::Value::Number(
178            value
179                .parse::<i64>()
180                .with_context(|| format!("Failed to parse integer: {value}"))?
181                .into(),
182        ),
183        Type::Boolean => serde_json::Value::Bool(
184            value
185                .parse::<bool>()
186                .with_context(|| format!("Failed to parse boolean: {value}"))?,
187        ),
188        Type::Enum { .. } => serde_json::Value::String(value.to_string()),
189        Type::EnumRef(_) => serde_json::Value::String(value.to_string()),
190        Type::UUID => serde_json::Value::String(value.to_string()),
191        Type::Optional(ty) => parse_ty(ty, value)?,
192        Type::BlockchainAddress => serde_json::Value::String(value.to_string()),
193        ty => bail!("Not implemented {:?}", ty),
194    })
195}
196
197impl AuthController for EndpointAuthController {
198    fn auth(
199        self: Arc<Self>,
200        toolbox: &ArcToolbox,
201        header: String,
202        conn: Arc<WsConnection>,
203    ) -> LocalBoxFuture<'static, Result<()>> {
204        let toolbox = toolbox.clone();
205
206        async move {
207            let splits = header
208                .split(',')
209                .map(|x| x.trim())
210                .filter(|x| !x.is_empty())
211                .map(|x| (&x[..1], &x[1..]))
212                .collect::<HashMap<&str, &str>>();
213
214            let method = splits.get("0").context("Could not find method")?;
215            // info!("method: {:?}", method);
216            let endpoint = self
217                .auth_endpoints
218                .get(*method)
219                .with_context(|| format!("Could not find endpoint for method {method}"))?;
220            let mut params = serde_json::Map::new();
221            for (index, param) in endpoint.schema.parameters.iter().enumerate() {
222                let index = index + 1;
223                match splits.get(&index.to_string().as_str()) {
224                    Some(value) => {
225                        params.insert(param.name.to_case(Case::Camel), parse_ty(&param.ty, value)?);
226                    }
227                    None if !matches!(&param.ty, Type::Optional(_)) => {
228                        bail!("Could not find param {} {}", param.name, index);
229                    }
230                    _ => {}
231                }
232            }
233            let ctx = RequestContext {
234                connection_id: conn.connection_id,
235                user_id: 0,
236                seq: 0,
237                method: endpoint.schema.code,
238                log_id: conn.log_id,
239                role: conn.role.load(Ordering::Relaxed),
240                ip_addr: conn.address.ip(),
241            };
242            let resp = endpoint
243                .handler
244                .clone()
245                .auth(&toolbox, serde_json::Value::Object(params), ctx, conn)
246                .await;
247            debug!("Auth response: {:?}", resp);
248            if let Some(resp) = Toolbox::encode_ws_response(ctx, resp) {
249                toolbox.send(ctx.connection_id, resp);
250            }
251            Ok(())
252        }
253        .boxed_local()
254    }
255}