endpoint_libs/libs/
toolbox.rs

1use dashmap::DashMap;
2use eyre::{Context, Result};
3use parking_lot::RwLock;
4use serde::*;
5use serde_json::Value;
6use std::fmt::{Debug, Display, Formatter};
7use std::net::{IpAddr, Ipv4Addr};
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10use tokio_tungstenite::tungstenite::Message;
11use tracing::*;
12
13use super::error_code::ErrorCode;
14use super::log::LogLevel;
15use super::ws::{internal_error_to_resp, request_error_to_resp, ConnectionId, WsConnection, WsLogResponse, WsResponseValue, WsStreamState, WsSuccessResponse};
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct NoResponseError;
19
20impl Display for NoResponseError {
21    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22        f.write_str("NoResp")
23    }
24}
25
26impl std::error::Error for NoResponseError {}
27
28#[derive(Debug, Serialize, Deserialize, Clone)]
29pub struct CustomError {
30    pub code: ErrorCode,
31    pub params: Value,
32}
33
34impl CustomError {
35    pub fn new(code: impl Into<ErrorCode>, reason: impl Serialize) -> Self {
36        Self {
37            code: code.into(),
38            params: serde_json::to_value(reason)
39                .context("Failed to serialize error reason")
40                .unwrap(),
41        }
42    }
43    pub fn from_sql_error(err: &str, msg: impl Display) -> Result<Self> {
44        let code = u32::from_str_radix(err, 36)?;
45        let error_code = ErrorCode::new(code);
46        let this = Self {
47            code: error_code,
48            params: msg.to_string().into(),
49        };
50
51        Ok(this)
52    }
53}
54
55impl Display for CustomError {
56    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
57        f.write_str(&self.params.to_string())
58    }
59}
60
61impl std::error::Error for CustomError {}
62
63#[derive(Copy, Clone)]
64pub struct RequestContext {
65    pub connection_id: ConnectionId,
66    pub user_id: i64,
67    pub seq: u32,
68    pub method: u32,
69    pub log_id: u64,
70    pub role: u32,
71    pub ip_addr: IpAddr,
72}
73impl RequestContext {
74    pub fn empty() -> Self {
75        Self {
76            connection_id: 0,
77            user_id: 0,
78            seq: 0,
79            method: 0,
80            log_id: 0,
81            role: 0,
82            ip_addr: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
83        }
84    }
85    pub fn from_conn(conn: &WsConnection) -> Self {
86        Self {
87            connection_id: conn.connection_id,
88            user_id: conn.get_user_id(),
89            seq: 0,
90            method: 0,
91            log_id: conn.log_id,
92            role: conn.role.load(Ordering::Relaxed),
93            ip_addr: conn.address.ip(),
94        }
95    }
96}
97
98pub struct Toolbox {
99    pub send_msg: RwLock<Arc<dyn Fn(ConnectionId, WsResponseValue) -> bool + Send + Sync>>,
100}
101pub type ArcToolbox = Arc<Toolbox>;
102impl Toolbox {
103    pub fn new() -> Arc<Self> {
104        Arc::new(Self {
105            send_msg: RwLock::new(Arc::new(|_conn_id, _msg| false)),
106        })
107    }
108
109    pub fn set_ws_states(&self, states: Arc<DashMap<ConnectionId, Arc<WsStreamState>>>, oneshot: bool) {
110        *self.send_msg.write() = Arc::new(move |conn_id, msg| {
111            let state = if let Some(state) = states.get(&conn_id) {
112                state
113            } else {
114                return false;
115            };
116            Self::send_ws_msg(&state.message_queue, msg, oneshot);
117            true
118        });
119    }
120
121    pub fn send_ws_msg(sender: &tokio::sync::mpsc::Sender<Message>, resp: WsResponseValue, oneshot: bool) {
122        let resp = serde_json::to_string(&resp).unwrap();
123        if let Err(err) = sender.try_send(resp.into()) {
124            warn!("Failed to send websocket message: {:?}", err)
125        }
126        if oneshot {
127            let _ = sender.try_send(Message::Close(None));
128        }
129    }
130    pub fn send(&self, conn_id: ConnectionId, resp: WsResponseValue) -> bool {
131        self.send_msg.read()(conn_id, resp)
132    }
133    pub fn send_response(&self, ctx: &RequestContext, resp: impl Serialize) {
134        self.send(
135            ctx.connection_id,
136            WsResponseValue::Immediate(WsSuccessResponse {
137                method: ctx.method,
138                seq: ctx.seq,
139                params: serde_json::to_value(&resp).unwrap(),
140            }),
141        );
142    }
143    pub fn send_internal_error(&self, ctx: &RequestContext, code: ErrorCode, err: eyre::Error) {
144        self.send(ctx.connection_id, internal_error_to_resp(ctx, code, err));
145    }
146    pub fn send_request_error(&self, ctx: &RequestContext, code: ErrorCode, err: impl Into<Value>) {
147        self.send(ctx.connection_id, request_error_to_resp(ctx, code, err));
148    }
149    pub fn send_log(&self, ctx: &RequestContext, level: LogLevel, msg: impl Into<String>) {
150        self.send(
151            ctx.connection_id,
152            WsResponseValue::Log(WsLogResponse {
153                seq: ctx.seq,
154                log_id: ctx.log_id,
155                level,
156                message: msg.into(),
157            }),
158        );
159    }
160    pub fn encode_ws_response<Resp: Serialize>(ctx: RequestContext, resp: Result<Resp>) -> Option<WsResponseValue> {
161        #[allow(unused_variables)]
162        let RequestContext {
163            connection_id,
164            user_id,
165            seq,
166            method,
167            log_id,
168            ..
169        } = ctx;
170        let resp = match resp {
171            Ok(ok) => WsResponseValue::Immediate(WsSuccessResponse {
172                method,
173                seq,
174                params: serde_json::to_value(ok).expect("Failed to serialize response"),
175            }),
176            Err(err) if err.is::<NoResponseError>() => {
177                return None;
178            }
179
180            Err(err) if err.is::<CustomError>() => {
181                error!("CustomError: {:?}", err);
182                let err = err.downcast::<CustomError>().unwrap();
183                request_error_to_resp(&ctx, err.code, err.params)
184            }
185            Err(err) => internal_error_to_resp(
186                &ctx,
187                ErrorCode::new(100500), // Internal Error
188                err,
189            ),
190        };
191        Some(resp)
192    }
193}
194tokio::task_local! {
195    pub static TOOLBOX: ArcToolbox;
196}