endpoint_libs/libs/
toolbox.rs

1use dashmap::DashMap;
2use eyre::{Context, Result};
3use parking_lot::RwLock;
4use serde::*;
5use serde_json::Value;
6use std::fmt::{Debug, Display, Formatter};
7use std::net::{IpAddr, Ipv4Addr};
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10use tokio_tungstenite::tungstenite::Message;
11use tracing::*;
12
13use super::error_code::ErrorCode;
14use super::log::LogLevel;
15use super::ws::{
16    internal_error_to_resp, request_error_to_resp, ConnectionId, WsConnection, WsLogResponse,
17    WsResponseValue, WsStreamState, WsSuccessResponse,
18};
19
20#[derive(Debug, Serialize, Deserialize, Clone)]
21pub struct NoResponseError;
22
23impl Display for NoResponseError {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        f.write_str("NoResp")
26    }
27}
28
29impl std::error::Error for NoResponseError {}
30
31#[derive(Debug, Serialize, Deserialize, Clone)]
32pub struct CustomError {
33    pub code: ErrorCode,
34    pub params: Value,
35}
36
37impl CustomError {
38    pub fn new(code: impl Into<ErrorCode>, reason: impl Serialize) -> Self {
39        Self {
40            code: code.into(),
41            params: serde_json::to_value(reason)
42                .context("Failed to serialize error reason")
43                .unwrap(),
44        }
45    }
46    pub fn from_sql_error(err: &str, msg: impl Display) -> Result<Self> {
47        let code = u32::from_str_radix(err, 36)?;
48        let error_code = ErrorCode::new(code);
49        let this = Self {
50            code: error_code,
51            params: msg.to_string().into(),
52        };
53
54        Ok(this)
55    }
56}
57
58impl Display for CustomError {
59    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60        f.write_str(&self.params.to_string())
61    }
62}
63
64impl std::error::Error for CustomError {}
65
66#[derive(Copy, Clone)]
67pub struct RequestContext {
68    pub connection_id: ConnectionId,
69    pub user_id: u64,
70    pub seq: u32,
71    pub method: u32,
72    pub log_id: u64,
73    pub role: u32,
74    pub ip_addr: IpAddr,
75}
76impl RequestContext {
77    pub fn empty() -> Self {
78        Self {
79            connection_id: 0,
80            user_id: 0,
81            seq: 0,
82            method: 0,
83            log_id: 0,
84            role: 0,
85            ip_addr: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
86        }
87    }
88    pub fn from_conn(conn: &WsConnection) -> Self {
89        Self {
90            connection_id: conn.connection_id,
91            user_id: conn.get_user_id(),
92            seq: 0,
93            method: 0,
94            log_id: conn.log_id,
95            role: conn.role.load(Ordering::Relaxed),
96            ip_addr: conn.address.ip(),
97        }
98    }
99}
100
101pub struct Toolbox {
102    pub send_msg: RwLock<Arc<dyn Fn(ConnectionId, WsResponseValue) -> bool + Send + Sync>>,
103}
104pub type ArcToolbox = Arc<Toolbox>;
105impl Toolbox {
106    pub fn new() -> Arc<Self> {
107        Arc::new(Self {
108            send_msg: RwLock::new(Arc::new(|_conn_id, _msg| false)),
109        })
110    }
111
112    pub fn set_ws_states(
113        &self,
114        states: Arc<DashMap<ConnectionId, Arc<WsStreamState>>>,
115        oneshot: bool,
116    ) {
117        *self.send_msg.write() = Arc::new(move |conn_id, msg| {
118            let state = if let Some(state) = states.get(&conn_id) {
119                state
120            } else {
121                return false;
122            };
123            Self::send_ws_msg(&state.message_queue, msg, oneshot);
124            true
125        });
126    }
127
128    pub fn send_ws_msg(
129        sender: &tokio::sync::mpsc::Sender<Message>,
130        resp: WsResponseValue,
131        oneshot: bool,
132    ) {
133        let resp = serde_json::to_string(&resp).unwrap();
134        if let Err(err) = sender.try_send(resp.into()) {
135            warn!("Failed to send websocket message: {:?}", err)
136        }
137        if oneshot {
138            let _ = sender.try_send(Message::Close(None));
139        }
140    }
141    pub fn send(&self, conn_id: ConnectionId, resp: WsResponseValue) -> bool {
142        self.send_msg.read()(conn_id, resp)
143    }
144    pub fn send_response(&self, ctx: &RequestContext, resp: impl Serialize) {
145        self.send(
146            ctx.connection_id,
147            WsResponseValue::Immediate(WsSuccessResponse {
148                method: ctx.method,
149                seq: ctx.seq,
150                params: serde_json::to_value(&resp).unwrap(),
151            }),
152        );
153    }
154    pub fn send_internal_error(&self, ctx: &RequestContext, code: ErrorCode, err: eyre::Error) {
155        self.send(ctx.connection_id, internal_error_to_resp(ctx, code, err));
156    }
157    pub fn send_request_error(&self, ctx: &RequestContext, code: ErrorCode, err: impl Into<Value>) {
158        self.send(ctx.connection_id, request_error_to_resp(ctx, code, err));
159    }
160    pub fn send_log(&self, ctx: &RequestContext, level: LogLevel, msg: impl Into<String>) {
161        self.send(
162            ctx.connection_id,
163            WsResponseValue::Log(WsLogResponse {
164                seq: ctx.seq,
165                log_id: ctx.log_id,
166                level,
167                message: msg.into(),
168            }),
169        );
170    }
171    pub fn encode_ws_response<Resp: Serialize>(
172        ctx: RequestContext,
173        resp: Result<Resp>,
174    ) -> Option<WsResponseValue> {
175        #[allow(unused_variables)]
176        let RequestContext {
177            connection_id,
178            user_id,
179            seq,
180            method,
181            log_id,
182            ..
183        } = ctx;
184        let resp = match resp {
185            Ok(ok) => WsResponseValue::Immediate(WsSuccessResponse {
186                method,
187                seq,
188                params: serde_json::to_value(ok).expect("Failed to serialize response"),
189            }),
190            Err(err) if err.is::<NoResponseError>() => {
191                return None;
192            }
193
194            Err(err) if err.is::<CustomError>() => {
195                error!("CustomError: {:?}", err);
196                let err = err.downcast::<CustomError>().unwrap();
197                request_error_to_resp(&ctx, err.code, err.params)
198            }
199            Err(err) => internal_error_to_resp(
200                &ctx,
201                ErrorCode::new(100500), // Internal Error
202                err,
203            ),
204        };
205        Some(resp)
206    }
207}
208tokio::task_local! {
209    pub static TOOLBOX: ArcToolbox;
210}