1use dashmap::DashMap;
2use eyre::{Context, Result};
3use parking_lot::RwLock;
4use serde::*;
5use serde_json::Value;
6use std::fmt::{Debug, Display, Formatter};
7use std::net::{IpAddr, Ipv4Addr};
8use std::sync::Arc;
9use tokio_tungstenite::tungstenite::Message;
10use tracing::*;
11
12use super::error_code::ErrorCode;
13use super::log::LogLevel;
14use super::ws::{
15 internal_error_to_resp, request_error_to_resp, ConnectionId, WsConnection, WsLogResponse,
16 WsResponseValue, WsStreamState, WsSuccessResponse,
17};
18
19#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct NoResponseError;
21
22impl Display for NoResponseError {
23 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
24 f.write_str("NoResp")
25 }
26}
27
28impl std::error::Error for NoResponseError {}
29
30#[derive(Debug, Serialize, Deserialize, Clone)]
31pub struct CustomError {
32 pub code: ErrorCode,
33 pub params: Value,
34}
35
36impl CustomError {
37 pub fn new(code: impl Into<ErrorCode>, reason: impl Serialize) -> Self {
38 Self {
39 code: code.into(),
40 params: serde_json::to_value(reason)
41 .context("Failed to serialize error reason")
42 .unwrap(),
43 }
44 }
45 pub fn from_sql_error(err: &str, msg: impl Display) -> Result<Self> {
46 let code = u32::from_str_radix(err, 36)?;
47 let error_code = ErrorCode::new(code);
48 let this = Self {
49 code: error_code,
50 params: msg.to_string().into(),
51 };
52
53 Ok(this)
54 }
55}
56
57impl Display for CustomError {
58 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
59 f.write_str(&self.params.to_string())
60 }
61}
62
63impl std::error::Error for CustomError {}
64
65#[derive(Clone)]
66pub struct RequestContext {
67 pub connection_id: ConnectionId,
68 pub user_id: u64,
69 pub seq: u32,
70 pub method: u32,
71 pub log_id: u64,
72 pub roles: Arc<Vec<u32>>,
73 pub ip_addr: IpAddr,
74}
75
76impl RequestContext {
77 pub fn empty() -> Self {
78 Self {
79 connection_id: 0,
80 user_id: 0,
81 seq: 0,
82 method: 0,
83 log_id: 0,
84 roles: Arc::new(Vec::new()),
85 ip_addr: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
86 }
87 }
88 pub fn from_conn(conn: &WsConnection) -> Self {
89 let roles = conn.roles.read().clone();
90 Self {
91 connection_id: conn.connection_id,
92 user_id: conn.get_user_id(),
93 seq: 0,
94 method: 0,
95 log_id: conn.log_id,
96 roles,
97 ip_addr: conn.address.ip(),
98 }
99 }
100}
101
102pub struct Toolbox {
103 pub send_msg: RwLock<Arc<dyn Fn(ConnectionId, WsResponseValue) -> bool + Send + Sync>>,
104}
105pub type ArcToolbox = Arc<Toolbox>;
106impl Toolbox {
107 pub fn new() -> Arc<Self> {
108 Arc::new(Self {
109 send_msg: RwLock::new(Arc::new(|_conn_id, _msg| false)),
110 })
111 }
112
113 pub fn set_ws_states(
114 &self,
115 states: Arc<DashMap<ConnectionId, Arc<WsStreamState>>>,
116 oneshot: bool,
117 ) {
118 *self.send_msg.write() = Arc::new(move |conn_id, msg| {
119 let state = if let Some(state) = states.get(&conn_id) {
120 state
121 } else {
122 return false;
123 };
124 Self::send_ws_msg(&state.message_queue, msg, oneshot);
125 true
126 });
127 }
128
129 pub fn send_ws_msg(
130 sender: &tokio::sync::mpsc::Sender<Message>,
131 resp: WsResponseValue,
132 oneshot: bool,
133 ) {
134 let resp = serde_json::to_string(&resp).unwrap();
135 if let Err(err) = sender.try_send(resp.into()) {
136 warn!("Failed to send websocket message: {:?}", err)
137 }
138 if oneshot {
139 let _ = sender.try_send(Message::Close(None));
140 }
141 }
142 pub fn send(&self, conn_id: ConnectionId, resp: WsResponseValue) -> bool {
143 self.send_msg.read()(conn_id, resp)
144 }
145 pub fn send_response(&self, ctx: &RequestContext, resp: impl Serialize) {
146 self.send(
147 ctx.connection_id,
148 WsResponseValue::Immediate(WsSuccessResponse {
149 method: ctx.method,
150 seq: ctx.seq,
151 params: serde_json::to_value(&resp).unwrap(),
152 }),
153 );
154 }
155 pub fn send_internal_error(&self, ctx: &RequestContext, code: ErrorCode, err: eyre::Error) {
156 self.send(ctx.connection_id, internal_error_to_resp(ctx, code, err));
157 }
158 pub fn send_request_error(&self, ctx: &RequestContext, code: ErrorCode, err: impl Into<Value>) {
159 self.send(ctx.connection_id, request_error_to_resp(ctx, code, err));
160 }
161 pub fn send_log(&self, ctx: &RequestContext, level: LogLevel, msg: impl Into<String>) {
162 self.send(
163 ctx.connection_id,
164 WsResponseValue::Log(WsLogResponse {
165 seq: ctx.seq,
166 log_id: ctx.log_id,
167 level,
168 message: msg.into(),
169 }),
170 );
171 }
172 pub fn encode_ws_response<Resp: Serialize>(
173 ctx: RequestContext,
174 resp: Result<Resp>,
175 ) -> Option<WsResponseValue> {
176 #[allow(unused_variables)]
177 let RequestContext {
178 connection_id,
179 user_id,
180 seq,
181 method,
182 log_id,
183 ..
184 } = ctx;
185 let resp = match resp {
186 Ok(ok) => WsResponseValue::Immediate(WsSuccessResponse {
187 method,
188 seq,
189 params: serde_json::to_value(ok).expect("Failed to serialize response"),
190 }),
191 Err(err) if err.is::<NoResponseError>() => {
192 return None;
193 }
194
195 Err(err) if err.is::<CustomError>() => {
196 error!("CustomError: {:?}", err);
197 let err = err.downcast::<CustomError>().unwrap();
198 request_error_to_resp(&ctx, err.code, err.params)
199 }
200 Err(err) => internal_error_to_resp(
201 &ctx,
202 ErrorCode::new(100500), err,
204 ),
205 };
206 Some(resp)
207 }
208}
209tokio::task_local! {
210 pub static TOOLBOX: ArcToolbox;
211}