1use dashmap::DashMap;
2use eyre::{Context, Result};
3use parking_lot::RwLock;
4use serde::*;
5use serde_json::Value;
6use std::fmt::{Debug, Display, Formatter};
7use std::net::{IpAddr, Ipv4Addr};
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10use tokio_tungstenite::tungstenite::Message;
11use tracing::*;
12
13use super::error_code::ErrorCode;
14use super::log::LogLevel;
15use super::ws::{
16 internal_error_to_resp, request_error_to_resp, ConnectionId, WsConnection, WsLogResponse,
17 WsResponseValue, WsStreamState, WsSuccessResponse,
18};
19
20#[derive(Debug, Serialize, Deserialize, Clone)]
21pub struct NoResponseError;
22
23impl Display for NoResponseError {
24 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25 f.write_str("NoResp")
26 }
27}
28
29impl std::error::Error for NoResponseError {}
30
31#[derive(Debug, Serialize, Deserialize, Clone)]
32pub struct CustomError {
33 pub code: ErrorCode,
34 pub params: Value,
35}
36
37impl CustomError {
38 pub fn new(code: impl Into<ErrorCode>, reason: impl Serialize) -> Self {
39 Self {
40 code: code.into(),
41 params: serde_json::to_value(reason)
42 .context("Failed to serialize error reason")
43 .unwrap(),
44 }
45 }
46 pub fn from_sql_error(err: &str, msg: impl Display) -> Result<Self> {
47 let code = u32::from_str_radix(err, 36)?;
48 let error_code = ErrorCode::new(code);
49 let this = Self {
50 code: error_code,
51 params: msg.to_string().into(),
52 };
53
54 Ok(this)
55 }
56}
57
58impl Display for CustomError {
59 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60 f.write_str(&self.params.to_string())
61 }
62}
63
64impl std::error::Error for CustomError {}
65
66#[derive(Copy, Clone)]
67pub struct RequestContext {
68 pub connection_id: ConnectionId,
69 pub user_id: u64,
70 pub seq: u32,
71 pub method: u32,
72 pub log_id: u64,
73 pub role: u32,
74 pub ip_addr: IpAddr,
75}
76impl RequestContext {
77 pub fn empty() -> Self {
78 Self {
79 connection_id: 0,
80 user_id: 0,
81 seq: 0,
82 method: 0,
83 log_id: 0,
84 role: 0,
85 ip_addr: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
86 }
87 }
88 pub fn from_conn(conn: &WsConnection) -> Self {
89 Self {
90 connection_id: conn.connection_id,
91 user_id: conn.get_user_id(),
92 seq: 0,
93 method: 0,
94 log_id: conn.log_id,
95 role: conn.role.load(Ordering::Relaxed),
96 ip_addr: conn.address.ip(),
97 }
98 }
99}
100
101pub struct Toolbox {
102 pub send_msg: RwLock<Arc<dyn Fn(ConnectionId, WsResponseValue) -> bool + Send + Sync>>,
103}
104pub type ArcToolbox = Arc<Toolbox>;
105impl Toolbox {
106 pub fn new() -> Arc<Self> {
107 Arc::new(Self {
108 send_msg: RwLock::new(Arc::new(|_conn_id, _msg| false)),
109 })
110 }
111
112 pub fn set_ws_states(
113 &self,
114 states: Arc<DashMap<ConnectionId, Arc<WsStreamState>>>,
115 oneshot: bool,
116 ) {
117 *self.send_msg.write() = Arc::new(move |conn_id, msg| {
118 let state = if let Some(state) = states.get(&conn_id) {
119 state
120 } else {
121 return false;
122 };
123 Self::send_ws_msg(&state.message_queue, msg, oneshot);
124 true
125 });
126 }
127
128 pub fn send_ws_msg(
129 sender: &tokio::sync::mpsc::Sender<Message>,
130 resp: WsResponseValue,
131 oneshot: bool,
132 ) {
133 let resp = serde_json::to_string(&resp).unwrap();
134 if let Err(err) = sender.try_send(resp.into()) {
135 warn!("Failed to send websocket message: {:?}", err)
136 }
137 if oneshot {
138 let _ = sender.try_send(Message::Close(None));
139 }
140 }
141 pub fn send(&self, conn_id: ConnectionId, resp: WsResponseValue) -> bool {
142 self.send_msg.read()(conn_id, resp)
143 }
144 pub fn send_response(&self, ctx: &RequestContext, resp: impl Serialize) {
145 self.send(
146 ctx.connection_id,
147 WsResponseValue::Immediate(WsSuccessResponse {
148 method: ctx.method,
149 seq: ctx.seq,
150 params: serde_json::to_value(&resp).unwrap(),
151 }),
152 );
153 }
154 pub fn send_internal_error(&self, ctx: &RequestContext, code: ErrorCode, err: eyre::Error) {
155 self.send(ctx.connection_id, internal_error_to_resp(ctx, code, err));
156 }
157 pub fn send_request_error(&self, ctx: &RequestContext, code: ErrorCode, err: impl Into<Value>) {
158 self.send(ctx.connection_id, request_error_to_resp(ctx, code, err));
159 }
160 pub fn send_log(&self, ctx: &RequestContext, level: LogLevel, msg: impl Into<String>) {
161 self.send(
162 ctx.connection_id,
163 WsResponseValue::Log(WsLogResponse {
164 seq: ctx.seq,
165 log_id: ctx.log_id,
166 level,
167 message: msg.into(),
168 }),
169 );
170 }
171 pub fn encode_ws_response<Resp: Serialize>(
172 ctx: RequestContext,
173 resp: Result<Resp>,
174 ) -> Option<WsResponseValue> {
175 #[allow(unused_variables)]
176 let RequestContext {
177 connection_id,
178 user_id,
179 seq,
180 method,
181 log_id,
182 ..
183 } = ctx;
184 let resp = match resp {
185 Ok(ok) => WsResponseValue::Immediate(WsSuccessResponse {
186 method,
187 seq,
188 params: serde_json::to_value(ok).expect("Failed to serialize response"),
189 }),
190 Err(err) if err.is::<NoResponseError>() => {
191 return None;
192 }
193
194 Err(err) if err.is::<CustomError>() => {
195 error!("CustomError: {:?}", err);
196 let err = err.downcast::<CustomError>().unwrap();
197 request_error_to_resp(&ctx, err.code, err.params)
198 }
199 Err(err) => internal_error_to_resp(
200 &ctx,
201 ErrorCode::new(100500), err,
203 ),
204 };
205 Some(resp)
206 }
207}
208tokio::task_local! {
209 pub static TOOLBOX: ArcToolbox;
210}