Skip to main content

onebot_api/communication/
ws.rs

1use super::utils::*;
2use async_trait::async_trait;
3use futures::stream::{SplitSink, SplitStream};
4use futures::{SinkExt, StreamExt};
5use reqwest::IntoUrl;
6use std::sync::Arc;
7use tokio::net::TcpStream;
8use tokio::select;
9use tokio::sync::broadcast;
10use tokio_tungstenite::tungstenite::Message;
11use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
12use url::Url;
13
14#[derive(Clone, Debug)]
15pub struct WsService {
16	url: Url,
17	access_token: Option<String>,
18	api_receiver: Option<APIReceiver>,
19	event_sender: Option<EventSender>,
20	close_signal_sender: broadcast::Sender<()>,
21}
22
23impl Drop for WsService {
24	fn drop(&mut self) {
25		let _ = self.close_signal_sender.send(());
26	}
27}
28
29impl WsService {
30	pub fn new(url: impl IntoUrl, access_token: Option<String>) -> reqwest::Result<Self> {
31		let (close_signal_sender, _) = broadcast::channel(1);
32		Ok(Self {
33			url: url.into_url()?,
34			access_token,
35			api_receiver: None,
36			event_sender: None,
37			close_signal_sender,
38		})
39	}
40}
41
42impl WsService {
43	pub fn get_url(&self) -> Url {
44		let mut url = self.url.clone();
45		if let Some(token) = &self.access_token {
46			let mut query_pairs = url.query_pairs_mut();
47			query_pairs.append_pair("access_token", token);
48		}
49		url
50	}
51
52	pub async fn connect(
53		url: impl ToString,
54	) -> anyhow::Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
55		let (stream, _) = tokio_tungstenite::connect_async(url.to_string()).await?;
56		Ok(stream)
57	}
58
59	pub async fn send_processor(
60		mut send_side: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
61		api_receiver: APIReceiver,
62		mut close_signal: broadcast::Receiver<()>,
63	) -> anyhow::Result<()> {
64		loop {
65			select! {
66				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
67				Ok(data) = api_receiver.recv_async() => {
68					let str = serde_json::to_string(&data);
69					if str.is_err() {
70						continue
71					}
72					let _ = send_side.send(Message::Text(str?.into())).await;
73				}
74			}
75		}
76	}
77
78	pub async fn read_processor(
79		mut read_side: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
80		event_sender: EventSender,
81		mut close_signal: broadcast::Receiver<()>,
82	) -> anyhow::Result<()> {
83		loop {
84			select! {
85				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
86				Some(Ok(Message::Text(data))) = read_side.next() => {
87					let str = data.as_str();
88					let event = serde_json::from_str::<Event>(str);
89					if event.is_err() {
90						continue
91					}
92					let event = Arc::new(event?);
93					let _ = event_sender.send(event);
94				}
95			}
96		}
97	}
98}
99
100#[async_trait]
101impl CommunicationService for WsService {
102	fn inject(&mut self, api_receiver: APIReceiver, event_sender: EventSender) {
103		self.api_receiver = Some(api_receiver);
104		self.event_sender = Some(event_sender);
105	}
106
107	async fn start_service(&self) -> anyhow::Result<()> {
108		if self.api_receiver.is_none() || self.event_sender.is_none() {
109			return Err(anyhow::anyhow!("api receiver or event sender is none"));
110		}
111
112		let api_receiver = self.api_receiver.clone().unwrap();
113		let event_sender = self.event_sender.clone().unwrap();
114
115		let (send_side, read_side) = Self::connect(self.get_url()).await?.split();
116
117		tokio::spawn(Self::read_processor(
118			read_side,
119			event_sender,
120			self.close_signal_sender.subscribe(),
121		));
122		tokio::spawn(Self::send_processor(
123			send_side,
124			api_receiver,
125			self.close_signal_sender.subscribe(),
126		));
127		Ok(())
128	}
129}