endpoint_libs/libs/ws/
headers.rs

1use chrono::Utc;
2use convert_case::Case;
3use convert_case::Casing;
4use eyre::{bail, Context, ContextCompat, Result};
5use futures::future::LocalBoxFuture;
6use futures::FutureExt;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio_tungstenite::tungstenite::handshake::server::{
11    Callback, ErrorResponse, Request, Response,
12};
13use tracing::*;
14
15use crate::libs::toolbox::ArcToolbox;
16use crate::libs::toolbox::RequestContext;
17use crate::libs::toolbox::Toolbox;
18use crate::model::EndpointSchema;
19use crate::model::Type;
20
21use super::WsConnection;
22
23pub struct VerifyProtocol<'a> {
24    pub addr: SocketAddr,
25    pub tx: tokio::sync::mpsc::Sender<String>,
26    pub allow_cors_domains: &'a Option<Vec<String>>,
27}
28
29impl Callback for VerifyProtocol<'_> {
30    fn on_request(
31        self,
32        request: &Request,
33        mut response: Response,
34    ) -> Result<Response, ErrorResponse> {
35        let addr = self.addr;
36        debug!(?addr, "handshake request: {:?}", request);
37
38        let protocol = request
39            .headers()
40            .get("Sec-WebSocket-Protocol")
41            .or_else(|| request.headers().get("sec-websocket-protocol"));
42
43        let protocol_str = match protocol {
44            Some(protocol) => protocol
45                .to_str()
46                .map_err(|_| {
47                    ErrorResponse::new(Some("Sec-WebSocket-Protocol is not valid utf-8".to_owned()))
48                })?
49                .to_string(),
50            None => "".to_string(),
51        };
52
53        self.tx.try_send(protocol_str.clone()).unwrap();
54
55        response
56            .headers_mut()
57            .append("Date", Utc::now().to_rfc2822().parse().unwrap());
58        if !protocol_str.is_empty() {
59            response.headers_mut().insert(
60                "Sec-WebSocket-Protocol",
61                protocol_str
62                    .split(',')
63                    .next()
64                    .unwrap_or("")
65                    .parse()
66                    .unwrap(),
67            );
68        }
69
70        response
71            .headers_mut()
72            .insert("Server", "RustWebsocketServer/1.0".parse().unwrap());
73
74        if let Some(allow_cors_domains) = self.allow_cors_domains {
75            if let Some(origin) = request.headers().get("Origin") {
76                let origin = origin.to_str().unwrap();
77                if allow_cors_domains.iter().any(|x| x == origin) {
78                    response
79                        .headers_mut()
80                        .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
81                    response
82                        .headers_mut()
83                        .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
84                }
85            }
86        } else {
87            // Allow all domains
88            if let Some(origin) = request.headers().get("Origin") {
89                let origin = origin.to_str().unwrap();
90                response
91                    .headers_mut()
92                    .insert("Access-Control-Allow-Origin", origin.parse().unwrap());
93                response
94                    .headers_mut()
95                    .insert("Access-Control-Allow-Credentials", "true".parse().unwrap());
96            }
97        }
98
99        debug!(?addr, "Responding handshake with: {:?}", response);
100
101        Ok(response)
102    }
103}
104
105pub trait AuthController: Sync + Send {
106    fn auth(
107        self: Arc<Self>,
108        toolbox: &ArcToolbox,
109        header: String,
110        conn: Arc<WsConnection>,
111    ) -> LocalBoxFuture<'static, Result<()>>;
112}
113
114pub struct SimpleAuthController;
115
116impl AuthController for SimpleAuthController {
117    fn auth(
118        self: Arc<Self>,
119        _toolbox: &ArcToolbox,
120        _header: String,
121        _conn: Arc<WsConnection>,
122    ) -> LocalBoxFuture<'static, Result<()>> {
123        async move { Ok(()) }.boxed()
124    }
125}
126
127pub trait SubAuthController: Sync + Send {
128    fn auth(
129        self: Arc<Self>,
130        toolbox: &ArcToolbox,
131        param: serde_json::Value,
132        ctx: RequestContext,
133        conn: Arc<WsConnection>,
134    ) -> LocalBoxFuture<'static, Result<serde_json::Value>>;
135}
136pub struct EndpointAuthController {
137    pub auth_endpoints: HashMap<String, WsAuthController>,
138}
139pub struct WsAuthController {
140    pub schema: EndpointSchema,
141    pub handler: Arc<dyn SubAuthController>,
142}
143
144impl Default for EndpointAuthController {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150impl EndpointAuthController {
151    pub fn new() -> Self {
152        Self {
153            auth_endpoints: Default::default(),
154        }
155    }
156    pub fn add_auth_endpoint(
157        &mut self,
158        schema: EndpointSchema,
159        handler: impl SubAuthController + 'static,
160    ) {
161        self.auth_endpoints.insert(
162            schema.name.to_ascii_lowercase(),
163            WsAuthController {
164                schema,
165                handler: Arc::new(handler),
166            },
167        );
168    }
169}
170fn parse_ty(ty: &Type, value: &str) -> Result<serde_json::Value> {
171    Ok(match &ty {
172        Type::String => {
173            let decoded = urlencoding::decode(value)?;
174            serde_json::Value::String(decoded.to_string())
175        }
176        Type::Int => serde_json::Value::Number(
177            value
178                .parse::<i64>()
179                .with_context(|| format!("Failed to parse integer: {value}"))?
180                .into(),
181        ),
182        Type::Boolean => serde_json::Value::Bool(
183            value
184                .parse::<bool>()
185                .with_context(|| format!("Failed to parse boolean: {value}"))?,
186        ),
187        Type::Enum { .. } => serde_json::Value::String(value.to_string()),
188        Type::EnumRef(_) => serde_json::Value::String(value.to_string()),
189        Type::UUID => serde_json::Value::String(value.to_string()),
190        Type::Optional(ty) => parse_ty(ty, value)?,
191        Type::BlockchainAddress => serde_json::Value::String(value.to_string()),
192        ty => bail!("Not implemented {:?}", ty),
193    })
194}
195
196impl AuthController for EndpointAuthController {
197    fn auth(
198        self: Arc<Self>,
199        toolbox: &ArcToolbox,
200        header: String,
201        conn: Arc<WsConnection>,
202    ) -> LocalBoxFuture<'static, Result<()>> {
203        let toolbox = toolbox.clone();
204
205        async move {
206            let splits = header
207                .split(',')
208                .map(|x| x.trim())
209                .filter(|x| !x.is_empty())
210                .map(|x| (&x[..1], &x[1..]))
211                .collect::<HashMap<&str, &str>>();
212
213            let method = splits.get("0").context("Could not find method")?;
214            // info!("method: {:?}", method);
215            let endpoint = self
216                .auth_endpoints
217                .get(*method)
218                .with_context(|| format!("Could not find endpoint for method {method}"))?;
219            let mut params = serde_json::Map::new();
220            for (index, param) in endpoint.schema.parameters.iter().enumerate() {
221                let index = index + 1;
222                match splits.get(&index.to_string().as_str()) {
223                    Some(value) => {
224                        params.insert(param.name.to_case(Case::Camel), parse_ty(&param.ty, value)?);
225                    }
226                    None if !matches!(&param.ty, Type::Optional(_)) => {
227                        bail!("Could not find param {} {}", param.name, index);
228                    }
229                    _ => {}
230                }
231            }
232            let roles = conn.roles.read().clone();
233            let ctx = RequestContext {
234                connection_id: conn.connection_id,
235                user_id: 0,
236                seq: 0,
237                method: endpoint.schema.code,
238                log_id: conn.log_id,
239                roles: roles.clone(),
240                ip_addr: conn.address.ip(),
241            };
242            let resp = endpoint
243                .handler
244                .clone()
245                .auth(
246                    &toolbox,
247                    serde_json::Value::Object(params),
248                    ctx.clone(),
249                    conn,
250                )
251                .await;
252            debug!("Auth response: {:?}", resp);
253            let conn_id = ctx.connection_id;
254            if let Some(resp) = Toolbox::encode_ws_response(ctx, resp) {
255                toolbox.send(conn_id, resp);
256            }
257            Ok(())
258        }
259        .boxed_local()
260    }
261}