onebot_api/communication/
ws.rs1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use futures::stream::{SplitSink, SplitStream};
5use futures::{SinkExt, StreamExt};
6use reqwest::IntoUrl;
7use std::sync::Arc;
8use tokio::net::TcpStream;
9use tokio::select;
10use tokio::sync::broadcast;
11use tokio_tungstenite::tungstenite::Message;
12use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
13use url::Url;
14
15#[derive(Clone, Debug)]
16pub struct WsService {
17 url: Url,
18 access_token: Option<String>,
19 api_receiver: Option<APIReceiver>,
20 event_sender: Option<EventSender>,
21 close_signal_sender: broadcast::Sender<()>,
22}
23
24impl Drop for WsService {
25 fn drop(&mut self) {
26 let _ = self.close_signal_sender.send(());
27 }
28}
29
30impl WsService {
31 pub fn new(url: impl IntoUrl, access_token: Option<String>) -> reqwest::Result<Self> {
32 let (close_signal_sender, _) = broadcast::channel(1);
33 Ok(Self {
34 url: url.into_url()?,
35 access_token,
36 api_receiver: None,
37 event_sender: None,
38 close_signal_sender,
39 })
40 }
41}
42
43impl WsService {
44 pub fn get_url(&self) -> Url {
45 let mut url = self.url.clone();
46 if let Some(token) = &self.access_token {
47 let mut query_pairs = url.query_pairs_mut();
48 query_pairs.append_pair("access_token", token);
49 }
50 url
51 }
52
53 pub async fn connect(
54 url: impl ToString,
55 ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, tokio_tungstenite::tungstenite::Error> {
56 let (stream, _) = tokio_tungstenite::connect_async(url.to_string()).await?;
57 Ok(stream)
58 }
59
60 pub async fn send_processor(
61 mut send_side: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
62 api_receiver: APIReceiver,
63 mut close_signal: broadcast::Receiver<()>,
64 ) -> anyhow::Result<()> {
65 loop {
66 select! {
67 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
68 Ok(data) = api_receiver.recv_async() => {
69 let str = serde_json::to_string(&data);
70 if str.is_err() {
71 continue
72 }
73 let _ = send_side.send(Message::Text(str?.into())).await;
74 }
75 }
76 }
77 }
78
79 pub async fn read_processor(
80 mut read_side: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
81 event_sender: EventSender,
82 mut close_signal: broadcast::Receiver<()>,
83 ) -> anyhow::Result<()> {
84 loop {
85 select! {
86 _ = close_signal.recv() => return Err(anyhow::anyhow!("close")),
87 Some(Ok(Message::Text(data))) = read_side.next() => {
88 let str = data.as_str();
89 let event = serde_json::from_str::<Event>(str);
90 if event.is_err() {
91 continue
92 }
93 let event = Arc::new(event?);
94 let _ = event_sender.send(event);
95 }
96 }
97 }
98 }
99}
100
101#[async_trait]
102impl CommunicationService for WsService {
103 fn inject(&mut self, api_receiver: APIReceiver, event_sender: EventSender) {
104 self.api_receiver = Some(api_receiver);
105 self.event_sender = Some(event_sender);
106 }
107
108 async fn start_service(&self) -> ServiceStartResult<()> {
109 if self.api_receiver.is_none() && self.event_sender.is_none() {
110 return Err(ServiceStartError::NotInjected);
111 } else if self.event_sender.is_none() {
112 return Err(ServiceStartError::NotInjectedEventSender);
113 } else if self.api_receiver.is_none() {
114 return Err(ServiceStartError::NotInjectedAPIReceiver);
115 }
116
117 let api_receiver = self.api_receiver.clone().unwrap();
118 let event_sender = self.event_sender.clone().unwrap();
119
120 let (send_side, read_side) = Self::connect(self.get_url()).await?.split();
121
122 tokio::spawn(Self::read_processor(
123 read_side,
124 event_sender,
125 self.close_signal_sender.subscribe(),
126 ));
127 tokio::spawn(Self::send_processor(
128 send_side,
129 api_receiver,
130 self.close_signal_sender.subscribe(),
131 ));
132 Ok(())
133 }
134}