onebot_api/communication/
ws_utils.rs

1use crate::api::APISender;
2use crate::api::arg_type::MessageType;
3use crate::api::return_type::{
4	CanSendResponse, GetCookiesResponse, GetCredentialsResponse, GetCsrfTokenResponse,
5	GetDataResponse, GetFriendListResponse, GetGroupInfoResponse, GetGroupMemberInfoResponse,
6	GetLoginInfoResponse, GetMsgResponse, GetStatusResponse, GetStrangerInfoResponse,
7	GetVersionInfoResponse, SendMsgResponse,
8};
9use crate::event::message::GroupMessageAnonymous;
10use crate::event::{Event, EventReceiver, EventTrait};
11use crate::message::receive_segment::ReceiveSegment;
12use crate::message::send_segment::SendSegment;
13use anyhow::anyhow;
14use async_trait::async_trait;
15use flume::{Receiver, Sender};
16use onebot_api_macro::generate_json;
17use serde::Deserialize;
18use serde::de::DeserializeOwned;
19use serde_json::{Value, json};
20use std::collections::HashMap;
21use std::sync::Arc;
22use tokio::sync::broadcast;
23use uuid::Uuid;
24
25#[derive(Deserialize, Debug, Clone)]
26pub struct APICallResponse {
27	pub status: String,
28	pub retcode: u32,
29	#[serde(default)]
30	pub data: Value,
31	pub echo: Option<String>,
32}
33
34#[derive(Deserialize, Debug, Clone)]
35#[serde(untagged)]
36pub enum WsEvent {
37	Event(Event),
38	Response(APICallResponse),
39}
40
41impl EventTrait for WsEvent {}
42
43#[async_trait]
44pub trait WebSocketService: Send + Sync {
45	fn register_api_receiver(&mut self, api_receiver: Receiver<String>);
46	fn register_msg_sender(&mut self, msg_sender: Sender<String>);
47	async fn start(&self) -> anyhow::Result<()>;
48}
49
50pub struct WsClient {
51	service: Box<dyn WebSocketService>, // 使用DI解耦
52	msg_sender: Sender<String>,
53	// msg_receiver: Receiver<String>,
54	api_sender: Sender<String>,
55	api_receiver: Receiver<String>,
56	broadcast_sender: Arc<broadcast::Sender<WsEvent>>,
57	close_sender: Sender<()>,
58	// close_receiver: Receiver<()>,
59	max_waiting_times: Option<i32>,
60}
61
62impl WsClient {
63	pub fn with_service(mut service: Box<dyn WebSocketService>) -> Self {
64		let (msg_sender, msg_receiver) = flume::unbounded();
65		let (api_sender, api_receiver) = flume::unbounded();
66		let (close_sender, close_receiver) = flume::unbounded();
67		service.register_api_receiver(api_receiver.clone());
68		service.register_msg_sender(msg_sender.clone());
69		let (broadcast_sender, _) = broadcast::channel(16);
70		let broadcast_sender = Arc::new(broadcast_sender);
71		Self::spawn_event_listener(
72			msg_receiver.clone(),
73			Arc::clone(&broadcast_sender),
74			close_receiver.clone(),
75		);
76		Self {
77			service,
78			msg_sender,
79			// msg_receiver,
80			api_sender,
81			api_receiver,
82			broadcast_sender,
83			close_sender,
84			// close_receiver,
85			max_waiting_times: Some(10),
86		}
87	}
88
89	pub async fn start_service(&self) -> anyhow::Result<()> {
90		self.service.start().await
91	}
92
93	pub fn change_service(&mut self, mut service: Box<dyn WebSocketService>) {
94		service.register_api_receiver(self.api_receiver.clone());
95		service.register_msg_sender(self.msg_sender.clone());
96		self.service = service;
97	}
98
99	pub fn get_service(&self) -> &dyn WebSocketService {
100		&*self.service
101	}
102
103	pub fn get_service_mut(&mut self) -> &mut dyn WebSocketService {
104		&mut *self.service
105	}
106}
107
108impl WsClient {
109	fn spawn_event_listener(
110		msg_receiver: Receiver<String>,
111		broadcast_sender: Arc<broadcast::Sender<WsEvent>>,
112		close_receiver: Receiver<()>,
113	) {
114		tokio::spawn(async move {
115			Self::ws_stream_handler(&msg_receiver, &broadcast_sender, &close_receiver).await;
116		});
117	}
118
119	async fn ws_stream_handler(
120		msg_receiver: &Receiver<String>,
121		broadcast_sender: &broadcast::Sender<WsEvent>,
122		close_receiver: &Receiver<()>,
123	) {
124		loop {
125			tokio::select! {
126				msg = msg_receiver.recv_async() => {
127					if let Ok(msg) = msg {
128						Self::msg_handler(msg, broadcast_sender);
129					}
130				}
131				_ = close_receiver.recv_async() => {
132					return
133				}
134			}
135		}
136	}
137
138	fn msg_handler(msg: String, broadcast_sender: &broadcast::Sender<WsEvent>) {
139		if let Ok(event) = serde_json::from_str(&msg) {
140			let _ = broadcast_sender.send(event);
141		}
142	}
143
144	fn generate_api_call_json(
145		action: String,
146		params: HashMap<String, Value>,
147		echo: String,
148	) -> String {
149		serde_json::to_string(&json!({
150			"action": action,
151			"params": params,
152			"echo": echo
153		}))
154		.unwrap()
155	}
156
157	async fn wait_for_echo(
158		rx: &mut broadcast::Receiver<WsEvent>,
159		echo: String,
160		max: Option<i32>,
161	) -> Option<APICallResponse> {
162		let mut count = 0;
163		let target_echo = Some(echo.clone());
164		while let Ok(event) = rx.recv().await {
165			if let WsEvent::Response(res) = event
166				&& res.echo == target_echo
167			{
168				return Some(res);
169			}
170			count += 1;
171			if let Some(max) = max
172				&& count >= max
173			{
174				return None;
175			}
176		}
177		None
178	}
179
180	async fn send_json(&self, json: String, echo: String) -> anyhow::Result<APICallResponse> {
181		let receiver = self.get_receiver();
182		let max_waiting_times = self.max_waiting_times;
183		let task = tokio::spawn(async move {
184			let mut receiver = receiver;
185			Self::wait_for_echo(&mut receiver, echo, max_waiting_times).await
186		});
187
188		self.api_sender.send_async(json).await?;
189		let res = task.await?;
190		if let Some(data) = res {
191			Ok(data)
192		} else {
193			Err(anyhow!("No response!"))
194		}
195	}
196
197	fn verify_response(res: APICallResponse) -> anyhow::Result<Value> {
198		if res.status == "ok" {
199			Ok(res.data)
200		} else {
201			Err(anyhow!("the request failed with code: {}", res.retcode))
202		}
203	}
204
205	async fn get_request_value<T: DeserializeOwned>(
206		&self,
207		json: String,
208		echo: String,
209	) -> anyhow::Result<T> {
210		let res = self.send_json(json, echo).await?;
211		let value = Self::verify_response(res)?;
212		let data = serde_json::from_value::<T>(value)?;
213		Ok(data)
214	}
215}
216
217impl Drop for WsClient {
218	fn drop(&mut self) {
219		let _ = self.close_sender.send(());
220	}
221}
222
223#[async_trait]
224impl EventReceiver<WsEvent> for WsClient {
225	fn get_receiver(&self) -> broadcast::Receiver<WsEvent> {
226		self.broadcast_sender.subscribe()
227	}
228}
229
230#[allow(unused_variables)]
231#[async_trait]
232impl APISender for WsClient {
233	#[generate_json]
234	async fn send_private_msg(
235		&self,
236		user_id: i64,
237		message: Vec<SendSegment>,
238		auto_escape: Option<bool>,
239	) -> anyhow::Result<i32> {
240		let res: SendMsgResponse = self.get_request_value(__json, __echo).await?;
241		Ok(res.message_id)
242	}
243
244	#[generate_json]
245	async fn send_group_msg(
246		&self,
247		group_id: i64,
248		message: Vec<SendSegment>,
249		auto_escape: Option<bool>,
250	) -> anyhow::Result<i32> {
251		let res: SendMsgResponse = self.get_request_value(__json, __echo).await?;
252		Ok(res.message_id)
253	}
254
255	#[generate_json]
256	async fn send_msg(
257		&self,
258		message_type: Option<MessageType>,
259		user_id: i64,
260		group_id: i64,
261		message: Vec<SendSegment>,
262		auto_escape: Option<bool>,
263	) -> anyhow::Result<i32> {
264		let res: SendMsgResponse = self.get_request_value(__json, __echo).await?;
265		Ok(res.message_id)
266	}
267
268	#[generate_json]
269	async fn delete_msg(&self, message_id: i32) -> anyhow::Result<()> {
270		self.get_request_value::<Value>(__json, __echo).await?;
271		Ok(())
272	}
273
274	#[generate_json]
275	async fn get_msg(&self, message_id: i32) -> anyhow::Result<GetMsgResponse> {
276		self.get_request_value(__json, __echo).await
277	}
278
279	#[generate_json]
280	async fn get_forward_msg(&self, id: String) -> anyhow::Result<Vec<ReceiveSegment>> {
281		self.get_request_value(__json, __echo).await
282	}
283
284	#[generate_json]
285	async fn send_like(&self, user_id: i64, times: Option<i32>) -> anyhow::Result<()> {
286		self.get_request_value::<Value>(__json, __echo).await?;
287		Ok(())
288	}
289
290	#[generate_json]
291	async fn set_group_kick(
292		&self,
293		group_id: i32,
294		user_id: i32,
295		reject_add_request: Option<bool>,
296	) -> anyhow::Result<()> {
297		self.get_request_value::<Value>(__json, __echo).await?;
298		Ok(())
299	}
300
301	#[generate_json]
302	async fn set_group_ban(
303		&self,
304		group_id: i32,
305		user_id: i32,
306		duration: Option<i32>,
307	) -> anyhow::Result<()> {
308		self.get_request_value::<Value>(__json, __echo).await?;
309		Ok(())
310	}
311
312	#[generate_json]
313	async fn set_group_anonymous_ban(
314		&self,
315		group_id: i32,
316		anonymous: Option<GroupMessageAnonymous>,
317		flag: Option<String>,
318		duration: Option<i32>,
319	) -> anyhow::Result<()> {
320		self.get_request_value::<Value>(__json, __echo).await?;
321		Ok(())
322	}
323
324	#[generate_json]
325	async fn set_group_whole_ban(&self, group_id: i32, enable: Option<bool>) -> anyhow::Result<()> {
326		self.get_request_value::<Value>(__json, __echo).await?;
327		Ok(())
328	}
329
330	#[generate_json]
331	async fn set_group_admin(
332		&self,
333		group_id: i32,
334		user_id: i32,
335		enable: Option<bool>,
336	) -> anyhow::Result<()> {
337		self.get_request_value::<Value>(__json, __echo).await?;
338		Ok(())
339	}
340
341	#[generate_json]
342	async fn set_group_anonymous(&self, group_id: i32, enable: Option<bool>) -> anyhow::Result<()> {
343		self.get_request_value::<Value>(__json, __echo).await?;
344		Ok(())
345	}
346
347	#[generate_json]
348	async fn set_group_card(
349		&self,
350		group_id: i32,
351		user_id: i32,
352		card: Option<String>,
353	) -> anyhow::Result<()> {
354		self.get_request_value::<Value>(__json, __echo).await?;
355		Ok(())
356	}
357
358	#[generate_json]
359	async fn set_group_name(&self, group_id: i32, group_name: String) -> anyhow::Result<()> {
360		self.get_request_value::<Value>(__json, __echo).await?;
361		Ok(())
362	}
363
364	#[generate_json]
365	async fn set_group_leave(&self, group_id: i32, is_dismiss: Option<bool>) -> anyhow::Result<()> {
366		self.get_request_value::<Value>(__json, __echo).await?;
367		Ok(())
368	}
369
370	#[generate_json]
371	async fn set_group_special_title(
372		&self,
373		group_id: i32,
374		user_id: i32,
375		special_title: Option<String>,
376		duration: Option<i32>,
377	) -> anyhow::Result<()> {
378		self.get_request_value::<Value>(__json, __echo).await?;
379		Ok(())
380	}
381
382	#[generate_json]
383	async fn set_friend_add_request(
384		&self,
385		flag: String,
386		approve: Option<bool>,
387		remark: Option<String>,
388	) -> anyhow::Result<()> {
389		self.get_request_value::<Value>(__json, __echo).await?;
390		Ok(())
391	}
392
393	#[generate_json]
394	async fn set_group_add_request(
395		&self,
396		flag: String,
397		sub_type: String,
398		approve: Option<bool>,
399		reason: Option<String>,
400	) -> anyhow::Result<()> {
401		self.get_request_value::<Value>(__json, __echo).await?;
402		Ok(())
403	}
404
405	#[generate_json]
406	async fn get_login_info(&self) -> anyhow::Result<GetLoginInfoResponse> {
407		self.get_request_value(__json, __echo).await
408	}
409
410	#[generate_json]
411	async fn get_stranger_info(
412		&self,
413		user_id: i32,
414		no_cache: Option<bool>,
415	) -> anyhow::Result<GetStrangerInfoResponse> {
416		self.get_request_value(__json, __echo).await
417	}
418
419	#[generate_json]
420	async fn get_friend_list(&self) -> anyhow::Result<Vec<GetFriendListResponse>> {
421		self.get_request_value(__json, __echo).await
422	}
423
424	#[generate_json]
425	async fn get_group_info(
426		&self,
427		group_id: i32,
428		no_cache: Option<bool>,
429	) -> anyhow::Result<GetGroupInfoResponse> {
430		self.get_request_value(__json, __echo).await
431	}
432
433	#[generate_json]
434	async fn get_group_list(&self) -> anyhow::Result<Vec<GetGroupInfoResponse>> {
435		self.get_request_value(__json, __echo).await
436	}
437
438	#[generate_json]
439	async fn get_group_member_info(
440		&self,
441		group_id: i32,
442		user_id: i32,
443		no_cache: Option<bool>,
444	) -> anyhow::Result<GetGroupMemberInfoResponse> {
445		self.get_request_value(__json, __echo).await
446	}
447
448	#[generate_json]
449	async fn get_group_member_list(
450		&self,
451		group_id: i32,
452	) -> anyhow::Result<Vec<GetGroupMemberInfoResponse>> {
453		self.get_request_value(__json, __echo).await
454	}
455
456	async fn get_group_honor_info(
457		&self,
458		group_id: i64,
459		honor_type: String,
460	) -> anyhow::Result<GetGroupMemberInfoResponse> {
461		let echo = Uuid::new_v4().to_string();
462		let mut map = HashMap::new();
463		map.insert("group_id".to_string(), serde_json::to_value(group_id)?);
464		map.insert("type".to_string(), serde_json::to_value(honor_type)?);
465		let json = Self::generate_api_call_json("get_group_honor_info".to_string(), map, echo.clone());
466		self.get_request_value(json, echo).await
467	}
468
469	#[generate_json]
470	async fn get_cookies(&self, domain: Option<String>) -> anyhow::Result<String> {
471		let res: GetCookiesResponse = self.get_request_value(__json, __echo).await?;
472		Ok(res.cookies)
473	}
474
475	#[generate_json]
476	async fn get_csrf_token(&self) -> anyhow::Result<i32> {
477		let res: GetCsrfTokenResponse = self.get_request_value(__json, __echo).await?;
478		Ok(res.token)
479	}
480
481	#[generate_json]
482	async fn get_credentials(
483		&self,
484		domain: Option<String>,
485	) -> anyhow::Result<GetCredentialsResponse> {
486		self.get_request_value(__json, __echo).await
487	}
488
489	#[generate_json]
490	async fn get_record(&self, file: String, out_format: String) -> anyhow::Result<String> {
491		let res: GetDataResponse = self.get_request_value(__json, __echo).await?;
492		Ok(res.file)
493	}
494
495	#[generate_json]
496	async fn get_image(&self, file: String) -> anyhow::Result<String> {
497		let res: GetDataResponse = self.get_request_value(__json, __echo).await?;
498		Ok(res.file)
499	}
500
501	#[generate_json]
502	async fn can_send_image(&self) -> anyhow::Result<bool> {
503		let res: CanSendResponse = self.get_request_value(__json, __echo).await?;
504		Ok(res.yes)
505	}
506
507	#[generate_json]
508	async fn can_send_record(&self) -> anyhow::Result<bool> {
509		let res: CanSendResponse = self.get_request_value(__json, __echo).await?;
510		Ok(res.yes)
511	}
512
513	#[generate_json]
514	async fn get_status(&self) -> anyhow::Result<GetStatusResponse> {
515		self.get_request_value(__json, __echo).await
516	}
517
518	#[generate_json]
519	async fn get_version_info(&self) -> anyhow::Result<GetVersionInfoResponse> {
520		self.get_request_value(__json, __echo).await
521	}
522
523	#[generate_json]
524	async fn set_restart(&self, delay: Option<i32>) -> anyhow::Result<()> {
525		self.get_request_value::<Value>(__json, __echo).await?;
526		Ok(())
527	}
528
529	#[generate_json]
530	async fn clean_cache(&self) -> anyhow::Result<()> {
531		self.get_request_value::<Value>(__json, __echo).await?;
532		Ok(())
533	}
534}