1use crate::api::APISender;
2use crate::api::arg_type::MessageType;
3use crate::api::return_type::{
4 CanSendResponse, GetCookiesResponse, GetCredentialsResponse, GetCsrfTokenResponse,
5 GetDataResponse, GetFriendListResponse, GetGroupInfoResponse, GetGroupMemberInfoResponse,
6 GetLoginInfoResponse, GetMsgResponse, GetStatusResponse, GetStrangerInfoResponse,
7 GetVersionInfoResponse, SendMsgResponse,
8};
9use crate::event::message::GroupMessageAnonymous;
10use crate::event::{Event, EventReceiver, EventTrait};
11use crate::message::receive_segment::ReceiveSegment;
12use crate::message::send_segment::SendSegment;
13use anyhow::anyhow;
14use async_trait::async_trait;
15use flume::{Receiver, Sender};
16use onebot_api_macro::generate_json;
17use serde::Deserialize;
18use serde::de::DeserializeOwned;
19use serde_json::{Value, json};
20use std::collections::HashMap;
21use std::sync::Arc;
22use tokio::sync::broadcast;
23use uuid::Uuid;
24
25#[derive(Deserialize, Debug, Clone)]
26pub struct APICallResponse {
27 pub status: String,
28 pub retcode: u32,
29 #[serde(default)]
30 pub data: Value,
31 pub echo: Option<String>,
32}
33
34#[derive(Deserialize, Debug, Clone)]
35#[serde(untagged)]
36pub enum WsEvent {
37 Event(Event),
38 Response(APICallResponse),
39}
40
41impl EventTrait for WsEvent {}
42
43#[async_trait]
44pub trait WebSocketService: Send + Sync {
45 fn register_api_receiver(&mut self, api_receiver: Receiver<String>);
46 fn register_msg_sender(&mut self, msg_sender: Sender<String>);
47 async fn start(&self) -> anyhow::Result<()>;
48}
49
50pub struct WsClient {
51 service: Box<dyn WebSocketService>, msg_sender: Sender<String>,
53 api_sender: Sender<String>,
55 api_receiver: Receiver<String>,
56 broadcast_sender: Arc<broadcast::Sender<WsEvent>>,
57 close_sender: Sender<()>,
58 max_waiting_times: Option<i32>,
60}
61
62impl WsClient {
63 pub fn with_service(mut service: Box<dyn WebSocketService>) -> Self {
64 let (msg_sender, msg_receiver) = flume::unbounded();
65 let (api_sender, api_receiver) = flume::unbounded();
66 let (close_sender, close_receiver) = flume::unbounded();
67 service.register_api_receiver(api_receiver.clone());
68 service.register_msg_sender(msg_sender.clone());
69 let (broadcast_sender, _) = broadcast::channel(16);
70 let broadcast_sender = Arc::new(broadcast_sender);
71 Self::spawn_event_listener(
72 msg_receiver.clone(),
73 Arc::clone(&broadcast_sender),
74 close_receiver.clone(),
75 );
76 Self {
77 service,
78 msg_sender,
79 api_sender,
81 api_receiver,
82 broadcast_sender,
83 close_sender,
84 max_waiting_times: Some(10),
86 }
87 }
88
89 pub async fn start_service(&self) -> anyhow::Result<()> {
90 self.service.start().await
91 }
92
93 pub fn change_service(&mut self, mut service: Box<dyn WebSocketService>) {
94 service.register_api_receiver(self.api_receiver.clone());
95 service.register_msg_sender(self.msg_sender.clone());
96 self.service = service;
97 }
98
99 pub fn get_service(&self) -> &dyn WebSocketService {
100 &*self.service
101 }
102
103 pub fn get_service_mut(&mut self) -> &mut dyn WebSocketService {
104 &mut *self.service
105 }
106}
107
108impl WsClient {
109 fn spawn_event_listener(
110 msg_receiver: Receiver<String>,
111 broadcast_sender: Arc<broadcast::Sender<WsEvent>>,
112 close_receiver: Receiver<()>,
113 ) {
114 tokio::spawn(async move {
115 Self::ws_stream_handler(&msg_receiver, &broadcast_sender, &close_receiver).await;
116 });
117 }
118
119 async fn ws_stream_handler(
120 msg_receiver: &Receiver<String>,
121 broadcast_sender: &broadcast::Sender<WsEvent>,
122 close_receiver: &Receiver<()>,
123 ) {
124 loop {
125 tokio::select! {
126 msg = msg_receiver.recv_async() => {
127 if let Ok(msg) = msg {
128 Self::msg_handler(msg, broadcast_sender);
129 }
130 }
131 _ = close_receiver.recv_async() => {
132 return
133 }
134 }
135 }
136 }
137
138 fn msg_handler(msg: String, broadcast_sender: &broadcast::Sender<WsEvent>) {
139 if let Ok(event) = serde_json::from_str(&msg) {
140 let _ = broadcast_sender.send(event);
141 }
142 }
143
144 fn generate_api_call_json(
145 action: String,
146 params: HashMap<String, Value>,
147 echo: String,
148 ) -> String {
149 serde_json::to_string(&json!({
150 "action": action,
151 "params": params,
152 "echo": echo
153 }))
154 .unwrap()
155 }
156
157 async fn wait_for_echo(
158 rx: &mut broadcast::Receiver<WsEvent>,
159 echo: String,
160 max: Option<i32>,
161 ) -> Option<APICallResponse> {
162 let mut count = 0;
163 let target_echo = Some(echo.clone());
164 while let Ok(event) = rx.recv().await {
165 if let WsEvent::Response(res) = event
166 && res.echo == target_echo
167 {
168 return Some(res);
169 }
170 count += 1;
171 if let Some(max) = max
172 && count >= max
173 {
174 return None;
175 }
176 }
177 None
178 }
179
180 async fn send_json(&self, json: String, echo: String) -> anyhow::Result<APICallResponse> {
181 let receiver = self.get_receiver();
182 let max_waiting_times = self.max_waiting_times;
183 let task = tokio::spawn(async move {
184 let mut receiver = receiver;
185 Self::wait_for_echo(&mut receiver, echo, max_waiting_times).await
186 });
187
188 self.api_sender.send_async(json).await?;
189 let res = task.await?;
190 if let Some(data) = res {
191 Ok(data)
192 } else {
193 Err(anyhow!("No response!"))
194 }
195 }
196
197 fn verify_response(res: APICallResponse) -> anyhow::Result<Value> {
198 if res.status == "ok" {
199 Ok(res.data)
200 } else {
201 Err(anyhow!("the request failed with code: {}", res.retcode))
202 }
203 }
204
205 async fn get_request_value<T: DeserializeOwned>(
206 &self,
207 json: String,
208 echo: String,
209 ) -> anyhow::Result<T> {
210 let res = self.send_json(json, echo).await?;
211 let value = Self::verify_response(res)?;
212 let data = serde_json::from_value::<T>(value)?;
213 Ok(data)
214 }
215}
216
217impl Drop for WsClient {
218 fn drop(&mut self) {
219 let _ = self.close_sender.send(());
220 }
221}
222
223#[async_trait]
224impl EventReceiver<WsEvent> for WsClient {
225 fn get_receiver(&self) -> broadcast::Receiver<WsEvent> {
226 self.broadcast_sender.subscribe()
227 }
228}
229
230#[allow(unused_variables)]
231#[async_trait]
232impl APISender for WsClient {
233 #[generate_json]
234 async fn send_private_msg(
235 &self,
236 user_id: i64,
237 message: Vec<SendSegment>,
238 auto_escape: Option<bool>,
239 ) -> anyhow::Result<i32> {
240 let res: SendMsgResponse = self.get_request_value(__json, __echo).await?;
241 Ok(res.message_id)
242 }
243
244 #[generate_json]
245 async fn send_group_msg(
246 &self,
247 group_id: i64,
248 message: Vec<SendSegment>,
249 auto_escape: Option<bool>,
250 ) -> anyhow::Result<i32> {
251 let res: SendMsgResponse = self.get_request_value(__json, __echo).await?;
252 Ok(res.message_id)
253 }
254
255 #[generate_json]
256 async fn send_msg(
257 &self,
258 message_type: Option<MessageType>,
259 user_id: i64,
260 group_id: i64,
261 message: Vec<SendSegment>,
262 auto_escape: Option<bool>,
263 ) -> anyhow::Result<i32> {
264 let res: SendMsgResponse = self.get_request_value(__json, __echo).await?;
265 Ok(res.message_id)
266 }
267
268 #[generate_json]
269 async fn delete_msg(&self, message_id: i32) -> anyhow::Result<()> {
270 self.get_request_value::<Value>(__json, __echo).await?;
271 Ok(())
272 }
273
274 #[generate_json]
275 async fn get_msg(&self, message_id: i32) -> anyhow::Result<GetMsgResponse> {
276 self.get_request_value(__json, __echo).await
277 }
278
279 #[generate_json]
280 async fn get_forward_msg(&self, id: String) -> anyhow::Result<Vec<ReceiveSegment>> {
281 self.get_request_value(__json, __echo).await
282 }
283
284 #[generate_json]
285 async fn send_like(&self, user_id: i64, times: Option<i32>) -> anyhow::Result<()> {
286 self.get_request_value::<Value>(__json, __echo).await?;
287 Ok(())
288 }
289
290 #[generate_json]
291 async fn set_group_kick(
292 &self,
293 group_id: i32,
294 user_id: i32,
295 reject_add_request: Option<bool>,
296 ) -> anyhow::Result<()> {
297 self.get_request_value::<Value>(__json, __echo).await?;
298 Ok(())
299 }
300
301 #[generate_json]
302 async fn set_group_ban(
303 &self,
304 group_id: i32,
305 user_id: i32,
306 duration: Option<i32>,
307 ) -> anyhow::Result<()> {
308 self.get_request_value::<Value>(__json, __echo).await?;
309 Ok(())
310 }
311
312 #[generate_json]
313 async fn set_group_anonymous_ban(
314 &self,
315 group_id: i32,
316 anonymous: Option<GroupMessageAnonymous>,
317 flag: Option<String>,
318 duration: Option<i32>,
319 ) -> anyhow::Result<()> {
320 self.get_request_value::<Value>(__json, __echo).await?;
321 Ok(())
322 }
323
324 #[generate_json]
325 async fn set_group_whole_ban(&self, group_id: i32, enable: Option<bool>) -> anyhow::Result<()> {
326 self.get_request_value::<Value>(__json, __echo).await?;
327 Ok(())
328 }
329
330 #[generate_json]
331 async fn set_group_admin(
332 &self,
333 group_id: i32,
334 user_id: i32,
335 enable: Option<bool>,
336 ) -> anyhow::Result<()> {
337 self.get_request_value::<Value>(__json, __echo).await?;
338 Ok(())
339 }
340
341 #[generate_json]
342 async fn set_group_anonymous(&self, group_id: i32, enable: Option<bool>) -> anyhow::Result<()> {
343 self.get_request_value::<Value>(__json, __echo).await?;
344 Ok(())
345 }
346
347 #[generate_json]
348 async fn set_group_card(
349 &self,
350 group_id: i32,
351 user_id: i32,
352 card: Option<String>,
353 ) -> anyhow::Result<()> {
354 self.get_request_value::<Value>(__json, __echo).await?;
355 Ok(())
356 }
357
358 #[generate_json]
359 async fn set_group_name(&self, group_id: i32, group_name: String) -> anyhow::Result<()> {
360 self.get_request_value::<Value>(__json, __echo).await?;
361 Ok(())
362 }
363
364 #[generate_json]
365 async fn set_group_leave(&self, group_id: i32, is_dismiss: Option<bool>) -> anyhow::Result<()> {
366 self.get_request_value::<Value>(__json, __echo).await?;
367 Ok(())
368 }
369
370 #[generate_json]
371 async fn set_group_special_title(
372 &self,
373 group_id: i32,
374 user_id: i32,
375 special_title: Option<String>,
376 duration: Option<i32>,
377 ) -> anyhow::Result<()> {
378 self.get_request_value::<Value>(__json, __echo).await?;
379 Ok(())
380 }
381
382 #[generate_json]
383 async fn set_friend_add_request(
384 &self,
385 flag: String,
386 approve: Option<bool>,
387 remark: Option<String>,
388 ) -> anyhow::Result<()> {
389 self.get_request_value::<Value>(__json, __echo).await?;
390 Ok(())
391 }
392
393 #[generate_json]
394 async fn set_group_add_request(
395 &self,
396 flag: String,
397 sub_type: String,
398 approve: Option<bool>,
399 reason: Option<String>,
400 ) -> anyhow::Result<()> {
401 self.get_request_value::<Value>(__json, __echo).await?;
402 Ok(())
403 }
404
405 #[generate_json]
406 async fn get_login_info(&self) -> anyhow::Result<GetLoginInfoResponse> {
407 self.get_request_value(__json, __echo).await
408 }
409
410 #[generate_json]
411 async fn get_stranger_info(
412 &self,
413 user_id: i32,
414 no_cache: Option<bool>,
415 ) -> anyhow::Result<GetStrangerInfoResponse> {
416 self.get_request_value(__json, __echo).await
417 }
418
419 #[generate_json]
420 async fn get_friend_list(&self) -> anyhow::Result<Vec<GetFriendListResponse>> {
421 self.get_request_value(__json, __echo).await
422 }
423
424 #[generate_json]
425 async fn get_group_info(
426 &self,
427 group_id: i32,
428 no_cache: Option<bool>,
429 ) -> anyhow::Result<GetGroupInfoResponse> {
430 self.get_request_value(__json, __echo).await
431 }
432
433 #[generate_json]
434 async fn get_group_list(&self) -> anyhow::Result<Vec<GetGroupInfoResponse>> {
435 self.get_request_value(__json, __echo).await
436 }
437
438 #[generate_json]
439 async fn get_group_member_info(
440 &self,
441 group_id: i32,
442 user_id: i32,
443 no_cache: Option<bool>,
444 ) -> anyhow::Result<GetGroupMemberInfoResponse> {
445 self.get_request_value(__json, __echo).await
446 }
447
448 #[generate_json]
449 async fn get_group_member_list(
450 &self,
451 group_id: i32,
452 ) -> anyhow::Result<Vec<GetGroupMemberInfoResponse>> {
453 self.get_request_value(__json, __echo).await
454 }
455
456 async fn get_group_honor_info(
457 &self,
458 group_id: i64,
459 honor_type: String,
460 ) -> anyhow::Result<GetGroupMemberInfoResponse> {
461 let echo = Uuid::new_v4().to_string();
462 let mut map = HashMap::new();
463 map.insert("group_id".to_string(), serde_json::to_value(group_id)?);
464 map.insert("type".to_string(), serde_json::to_value(honor_type)?);
465 let json = Self::generate_api_call_json("get_group_honor_info".to_string(), map, echo.clone());
466 self.get_request_value(json, echo).await
467 }
468
469 #[generate_json]
470 async fn get_cookies(&self, domain: Option<String>) -> anyhow::Result<String> {
471 let res: GetCookiesResponse = self.get_request_value(__json, __echo).await?;
472 Ok(res.cookies)
473 }
474
475 #[generate_json]
476 async fn get_csrf_token(&self) -> anyhow::Result<i32> {
477 let res: GetCsrfTokenResponse = self.get_request_value(__json, __echo).await?;
478 Ok(res.token)
479 }
480
481 #[generate_json]
482 async fn get_credentials(
483 &self,
484 domain: Option<String>,
485 ) -> anyhow::Result<GetCredentialsResponse> {
486 self.get_request_value(__json, __echo).await
487 }
488
489 #[generate_json]
490 async fn get_record(&self, file: String, out_format: String) -> anyhow::Result<String> {
491 let res: GetDataResponse = self.get_request_value(__json, __echo).await?;
492 Ok(res.file)
493 }
494
495 #[generate_json]
496 async fn get_image(&self, file: String) -> anyhow::Result<String> {
497 let res: GetDataResponse = self.get_request_value(__json, __echo).await?;
498 Ok(res.file)
499 }
500
501 #[generate_json]
502 async fn can_send_image(&self) -> anyhow::Result<bool> {
503 let res: CanSendResponse = self.get_request_value(__json, __echo).await?;
504 Ok(res.yes)
505 }
506
507 #[generate_json]
508 async fn can_send_record(&self) -> anyhow::Result<bool> {
509 let res: CanSendResponse = self.get_request_value(__json, __echo).await?;
510 Ok(res.yes)
511 }
512
513 #[generate_json]
514 async fn get_status(&self) -> anyhow::Result<GetStatusResponse> {
515 self.get_request_value(__json, __echo).await
516 }
517
518 #[generate_json]
519 async fn get_version_info(&self) -> anyhow::Result<GetVersionInfoResponse> {
520 self.get_request_value(__json, __echo).await
521 }
522
523 #[generate_json]
524 async fn set_restart(&self, delay: Option<i32>) -> anyhow::Result<()> {
525 self.get_request_value::<Value>(__json, __echo).await?;
526 Ok(())
527 }
528
529 #[generate_json]
530 async fn clean_cache(&self) -> anyhow::Result<()> {
531 self.get_request_value::<Value>(__json, __echo).await?;
532 Ok(())
533 }
534}