1use dashmap::DashMap;
2use eyre::{Context, Result};
3use parking_lot::RwLock;
4use serde::*;
5use serde_json::Value;
6use std::fmt::{Debug, Display, Formatter};
7use std::net::{IpAddr, Ipv4Addr};
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10use tokio_tungstenite::tungstenite::Message;
11use tracing::*;
12
13use super::error_code::ErrorCode;
14use super::log::LogLevel;
15use super::ws::{internal_error_to_resp, request_error_to_resp, ConnectionId, WsConnection, WsLogResponse, WsResponseValue, WsStreamState, WsSuccessResponse};
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct NoResponseError;
19
20impl Display for NoResponseError {
21 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22 f.write_str("NoResp")
23 }
24}
25
26impl std::error::Error for NoResponseError {}
27
28#[derive(Debug, Serialize, Deserialize, Clone)]
29pub struct CustomError {
30 pub code: ErrorCode,
31 pub params: Value,
32}
33
34impl CustomError {
35 pub fn new(code: impl Into<ErrorCode>, reason: impl Serialize) -> Self {
36 Self {
37 code: code.into(),
38 params: serde_json::to_value(reason)
39 .context("Failed to serialize error reason")
40 .unwrap(),
41 }
42 }
43 pub fn from_sql_error(err: &str, msg: impl Display) -> Result<Self> {
44 let code = u32::from_str_radix(err, 36)?;
45 let error_code = ErrorCode::new(code);
46 let this = Self {
47 code: error_code,
48 params: msg.to_string().into(),
49 };
50
51 Ok(this)
52 }
53}
54
55impl Display for CustomError {
56 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
57 f.write_str(&self.params.to_string())
58 }
59}
60
61impl std::error::Error for CustomError {}
62
63#[derive(Copy, Clone)]
64pub struct RequestContext {
65 pub connection_id: ConnectionId,
66 pub user_id: i64,
67 pub seq: u32,
68 pub method: u32,
69 pub log_id: u64,
70 pub role: u32,
71 pub ip_addr: IpAddr,
72}
73impl RequestContext {
74 pub fn empty() -> Self {
75 Self {
76 connection_id: 0,
77 user_id: 0,
78 seq: 0,
79 method: 0,
80 log_id: 0,
81 role: 0,
82 ip_addr: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
83 }
84 }
85 pub fn from_conn(conn: &WsConnection) -> Self {
86 Self {
87 connection_id: conn.connection_id,
88 user_id: conn.get_user_id(),
89 seq: 0,
90 method: 0,
91 log_id: conn.log_id,
92 role: conn.role.load(Ordering::Relaxed),
93 ip_addr: conn.address.ip(),
94 }
95 }
96}
97
98pub struct Toolbox {
99 pub send_msg: RwLock<Arc<dyn Fn(ConnectionId, WsResponseValue) -> bool + Send + Sync>>,
100}
101pub type ArcToolbox = Arc<Toolbox>;
102impl Toolbox {
103 pub fn new() -> Arc<Self> {
104 Arc::new(Self {
105 send_msg: RwLock::new(Arc::new(|_conn_id, _msg| false)),
106 })
107 }
108
109 pub fn set_ws_states(&self, states: Arc<DashMap<ConnectionId, Arc<WsStreamState>>>, oneshot: bool) {
110 *self.send_msg.write() = Arc::new(move |conn_id, msg| {
111 let state = if let Some(state) = states.get(&conn_id) {
112 state
113 } else {
114 return false;
115 };
116 Self::send_ws_msg(&state.message_queue, msg, oneshot);
117 true
118 });
119 }
120
121 pub fn send_ws_msg(sender: &tokio::sync::mpsc::Sender<Message>, resp: WsResponseValue, oneshot: bool) {
122 let resp = serde_json::to_string(&resp).unwrap();
123 if let Err(err) = sender.try_send(resp.into()) {
124 warn!("Failed to send websocket message: {:?}", err)
125 }
126 if oneshot {
127 let _ = sender.try_send(Message::Close(None));
128 }
129 }
130 pub fn send(&self, conn_id: ConnectionId, resp: WsResponseValue) -> bool {
131 self.send_msg.read()(conn_id, resp)
132 }
133 pub fn send_response(&self, ctx: &RequestContext, resp: impl Serialize) {
134 self.send(
135 ctx.connection_id,
136 WsResponseValue::Immediate(WsSuccessResponse {
137 method: ctx.method,
138 seq: ctx.seq,
139 params: serde_json::to_value(&resp).unwrap(),
140 }),
141 );
142 }
143 pub fn send_internal_error(&self, ctx: &RequestContext, code: ErrorCode, err: eyre::Error) {
144 self.send(ctx.connection_id, internal_error_to_resp(ctx, code, err));
145 }
146 pub fn send_request_error(&self, ctx: &RequestContext, code: ErrorCode, err: impl Into<Value>) {
147 self.send(ctx.connection_id, request_error_to_resp(ctx, code, err));
148 }
149 pub fn send_log(&self, ctx: &RequestContext, level: LogLevel, msg: impl Into<String>) {
150 self.send(
151 ctx.connection_id,
152 WsResponseValue::Log(WsLogResponse {
153 seq: ctx.seq,
154 log_id: ctx.log_id,
155 level,
156 message: msg.into(),
157 }),
158 );
159 }
160 pub fn encode_ws_response<Resp: Serialize>(ctx: RequestContext, resp: Result<Resp>) -> Option<WsResponseValue> {
161 #[allow(unused_variables)]
162 let RequestContext {
163 connection_id,
164 user_id,
165 seq,
166 method,
167 log_id,
168 ..
169 } = ctx;
170 let resp = match resp {
171 Ok(ok) => WsResponseValue::Immediate(WsSuccessResponse {
172 method,
173 seq,
174 params: serde_json::to_value(ok).expect("Failed to serialize response"),
175 }),
176 Err(err) if err.is::<NoResponseError>() => {
177 return None;
178 }
179
180 Err(err) if err.is::<CustomError>() => {
181 error!("CustomError: {:?}", err);
182 let err = err.downcast::<CustomError>().unwrap();
183 request_error_to_resp(&ctx, err.code, err.params)
184 }
185 Err(err) => internal_error_to_resp(
186 &ctx,
187 ErrorCode::new(100500), err,
189 ),
190 };
191 Some(resp)
192 }
193}
194tokio::task_local! {
195 pub static TOOLBOX: ArcToolbox;
196}