endpoint_libs/libs/
toolbox.rs

1use dashmap::DashMap;
2use eyre::{Context, Result};
3use parking_lot::RwLock;
4use serde::*;
5use serde_json::Value;
6use std::fmt::{Debug, Display, Formatter};
7use std::net::{IpAddr, Ipv4Addr};
8use std::sync::Arc;
9use tokio_tungstenite::tungstenite::Message;
10use tracing::*;
11
12use super::error_code::ErrorCode;
13use super::log::LogLevel;
14use super::ws::{
15    internal_error_to_resp, request_error_to_resp, ConnectionId, WsConnection, WsLogResponse,
16    WsResponseValue, WsStreamState, WsSuccessResponse,
17};
18
19#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct NoResponseError;
21
22impl Display for NoResponseError {
23    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
24        f.write_str("NoResp")
25    }
26}
27
28impl std::error::Error for NoResponseError {}
29
30#[derive(Debug, Serialize, Deserialize, Clone)]
31pub struct CustomError {
32    pub code: ErrorCode,
33    pub params: Value,
34}
35
36impl CustomError {
37    pub fn new(code: impl Into<ErrorCode>, reason: impl Serialize) -> Self {
38        Self {
39            code: code.into(),
40            params: serde_json::to_value(reason)
41                .context("Failed to serialize error reason")
42                .unwrap(),
43        }
44    }
45    pub fn from_sql_error(err: &str, msg: impl Display) -> Result<Self> {
46        let code = u32::from_str_radix(err, 36)?;
47        let error_code = ErrorCode::new(code);
48        let this = Self {
49            code: error_code,
50            params: msg.to_string().into(),
51        };
52
53        Ok(this)
54    }
55}
56
57impl Display for CustomError {
58    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
59        f.write_str(&self.params.to_string())
60    }
61}
62
63impl std::error::Error for CustomError {}
64
65#[derive(Clone)]
66pub struct RequestContext {
67    pub connection_id: ConnectionId,
68    pub user_id: u64,
69    pub seq: u32,
70    pub method: u32,
71    pub log_id: u64,
72    pub roles: Arc<Vec<u32>>,
73    pub ip_addr: IpAddr,
74}
75
76impl RequestContext {
77    pub fn empty() -> Self {
78        Self {
79            connection_id: 0,
80            user_id: 0,
81            seq: 0,
82            method: 0,
83            log_id: 0,
84            roles: Arc::new(Vec::new()),
85            ip_addr: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
86        }
87    }
88    pub fn from_conn(conn: &WsConnection) -> Self {
89        let roles = conn.roles.read().clone();
90        Self {
91            connection_id: conn.connection_id,
92            user_id: conn.get_user_id(),
93            seq: 0,
94            method: 0,
95            log_id: conn.log_id,
96            roles,
97            ip_addr: conn.address.ip(),
98        }
99    }
100}
101
102pub struct Toolbox {
103    pub send_msg: RwLock<Arc<dyn Fn(ConnectionId, WsResponseValue) -> bool + Send + Sync>>,
104}
105pub type ArcToolbox = Arc<Toolbox>;
106impl Toolbox {
107    pub fn new() -> Arc<Self> {
108        Arc::new(Self {
109            send_msg: RwLock::new(Arc::new(|_conn_id, _msg| false)),
110        })
111    }
112
113    pub fn set_ws_states(
114        &self,
115        states: Arc<DashMap<ConnectionId, Arc<WsStreamState>>>,
116        oneshot: bool,
117    ) {
118        *self.send_msg.write() = Arc::new(move |conn_id, msg| {
119            let state = if let Some(state) = states.get(&conn_id) {
120                state
121            } else {
122                return false;
123            };
124            Self::send_ws_msg(&state.message_queue, msg, oneshot);
125            true
126        });
127    }
128
129    pub fn send_ws_msg(
130        sender: &tokio::sync::mpsc::Sender<Message>,
131        resp: WsResponseValue,
132        oneshot: bool,
133    ) {
134        let resp = serde_json::to_string(&resp).unwrap();
135        if let Err(err) = sender.try_send(resp.into()) {
136            warn!("Failed to send websocket message: {:?}", err)
137        }
138        if oneshot {
139            let _ = sender.try_send(Message::Close(None));
140        }
141    }
142    pub fn send(&self, conn_id: ConnectionId, resp: WsResponseValue) -> bool {
143        self.send_msg.read()(conn_id, resp)
144    }
145    pub fn send_response(&self, ctx: &RequestContext, resp: impl Serialize) {
146        self.send(
147            ctx.connection_id,
148            WsResponseValue::Immediate(WsSuccessResponse {
149                method: ctx.method,
150                seq: ctx.seq,
151                params: serde_json::to_value(&resp).unwrap(),
152            }),
153        );
154    }
155    pub fn send_internal_error(&self, ctx: &RequestContext, code: ErrorCode, err: eyre::Error) {
156        self.send(ctx.connection_id, internal_error_to_resp(ctx, code, err));
157    }
158    pub fn send_request_error(&self, ctx: &RequestContext, code: ErrorCode, err: impl Into<Value>) {
159        self.send(ctx.connection_id, request_error_to_resp(ctx, code, err));
160    }
161    pub fn send_log(&self, ctx: &RequestContext, level: LogLevel, msg: impl Into<String>) {
162        self.send(
163            ctx.connection_id,
164            WsResponseValue::Log(WsLogResponse {
165                seq: ctx.seq,
166                log_id: ctx.log_id,
167                level,
168                message: msg.into(),
169            }),
170        );
171    }
172    pub fn encode_ws_response<Resp: Serialize>(
173        ctx: RequestContext,
174        resp: Result<Resp>,
175    ) -> Option<WsResponseValue> {
176        #[allow(unused_variables)]
177        let RequestContext {
178            connection_id,
179            user_id,
180            seq,
181            method,
182            log_id,
183            ..
184        } = ctx;
185        let resp = match resp {
186            Ok(ok) => WsResponseValue::Immediate(WsSuccessResponse {
187                method,
188                seq,
189                params: serde_json::to_value(ok).expect("Failed to serialize response"),
190            }),
191            Err(err) if err.is::<NoResponseError>() => {
192                return None;
193            }
194
195            Err(err) if err.is::<CustomError>() => {
196                error!("CustomError: {:?}", err);
197                let err = err.downcast::<CustomError>().unwrap();
198                request_error_to_resp(&ctx, err.code, err.params)
199            }
200            Err(err) => internal_error_to_resp(
201                &ctx,
202                ErrorCode::new(100500), // Internal Error
203                err,
204            ),
205        };
206        Some(resp)
207    }
208}
209tokio::task_local! {
210    pub static TOOLBOX: ArcToolbox;
211}