Skip to main content

onebot_api/communication/
ws.rs

1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use futures::stream::{SplitSink, SplitStream};
5use futures::{SinkExt, StreamExt};
6use reqwest::IntoUrl;
7use std::sync::Arc;
8use tokio::net::TcpStream;
9use tokio::select;
10use tokio::sync::broadcast;
11use tokio_tungstenite::tungstenite::Message;
12use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
13use url::Url;
14
15#[derive(Clone, Debug)]
16pub struct WsService {
17	url: Url,
18	access_token: Option<String>,
19	api_receiver: Option<APIReceiver>,
20	event_sender: Option<EventSender>,
21	close_signal_sender: broadcast::Sender<()>,
22}
23
24impl Drop for WsService {
25	fn drop(&mut self) {
26		let _ = self.close_signal_sender.send(());
27	}
28}
29
30impl WsService {
31	pub fn new(url: impl IntoUrl, access_token: Option<String>) -> reqwest::Result<Self> {
32		let (close_signal_sender, _) = broadcast::channel(1);
33		Ok(Self {
34			url: url.into_url()?,
35			access_token,
36			api_receiver: None,
37			event_sender: None,
38			close_signal_sender,
39		})
40	}
41}
42
43impl WsService {
44	pub fn get_url(&self) -> Url {
45		let mut url = self.url.clone();
46		if let Some(token) = &self.access_token {
47			let mut query_pairs = url.query_pairs_mut();
48			query_pairs.append_pair("access_token", token);
49		}
50		url
51	}
52
53	pub async fn connect(
54		url: impl ToString,
55	) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Error> {
56		let (stream, _) = tokio_tungstenite::connect_async(url.to_string()).await?;
57		Ok(stream)
58	}
59
60	pub async fn send_processor(
61		mut send_side: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
62		api_receiver: APIReceiver,
63		mut close_signal: broadcast::Receiver<()>,
64	) -> anyhow::Result<()> {
65		loop {
66			select! {
67				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
68				Ok(data) = api_receiver.recv_async() => {
69					let str = serde_json::to_string(&data);
70					if str.is_err() {
71						continue
72					}
73					let _ = send_side.send(Message::Text(str?.into())).await;
74				}
75			}
76		}
77	}
78
79	pub async fn read_processor(
80		mut read_side: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
81		event_sender: EventSender,
82		mut close_signal: broadcast::Receiver<()>,
83	) -> anyhow::Result<()> {
84		loop {
85			select! {
86				_ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
87				Some(Ok(Message::Text(data))) = read_side.next() => {
88					let str = data.as_str();
89					let event = serde_json::from_str::<Event>(str);
90					if event.is_err() {
91						continue
92					}
93					let event = Arc::new(event?);
94					let _ = event_sender.send(event);
95				}
96			}
97		}
98	}
99}
100
101#[async_trait]
102impl CommunicationService for WsService {
103	fn inject(&mut self, api_receiver: APIReceiver, event_sender: EventSender) {
104		self.api_receiver = Some(api_receiver);
105		self.event_sender = Some(event_sender);
106	}
107
108	async fn start_service(&self) -> ServiceStartResult<()> {
109		if self.api_receiver.is_none() && self.event_sender.is_none() {
110			return Err(ServiceStartError::NotInjected);
111		} else if self.event_sender.is_none() {
112			return Err(ServiceStartError::NotInjectedEventSender);
113		} else if self.api_receiver.is_none() {
114			return Err(ServiceStartError::NotInjectedAPIReceiver);
115		}
116
117		let api_receiver = self.api_receiver.clone().unwrap();
118		let event_sender = self.event_sender.clone().unwrap();
119
120		let (send_side, read_side) = Self::connect(self.get_url()).await?.split();
121
122		tokio::spawn(Self::read_processor(
123			read_side,
124			event_sender,
125			self.close_signal_sender.subscribe(),
126		));
127		tokio::spawn(Self::send_processor(
128			send_side,
129			api_receiver,
130			self.close_signal_sender.subscribe(),
131		));
132		Ok(())
133	}
134}