kovi/bot/
connect.rs

1use super::Server;
2use super::{ApiReturn, Bot, Host};
3use crate::bot::handler::InternalInternalEvent;
4use crate::event::InternalEvent;
5use crate::types::ApiAndOneshot;
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, StreamExt};
8use http::HeaderValue;
9use log::{debug, error, warn};
10use parking_lot::{Mutex, RwLock};
11use std::error::Error;
12use std::fmt::Display;
13use std::{net::IpAddr, sync::Arc};
14use tokio::net::TcpStream;
15use tokio::sync::mpsc::Sender;
16use tokio::sync::{mpsc, oneshot};
17use tokio_tungstenite::tungstenite::Message;
18use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
19use tokio_tungstenite::{connect_async, tungstenite::client::IntoClientRequest};
20
21type ApiTxMap = Arc<Mutex<ahash::HashMap<String, ApiAndOneshot>>>;
22
23impl Bot {
24    pub(crate) async fn ws_connect(
25        server: Server,
26        api_rx: mpsc::Receiver<ApiAndOneshot>,
27        event_tx: mpsc::Sender<InternalInternalEvent>,
28        bot: Arc<RwLock<Bot>>,
29    ) -> Result<(), Box<dyn Error + Send + Sync>> {
30        #[allow(clippy::type_complexity)]
31        let (event_connected_tx, event_connected_rx): (
32            oneshot::Sender<Result<(), Box<dyn std::error::Error + Send + Sync>>>,
33            oneshot::Receiver<Result<(), Box<dyn std::error::Error + Send + Sync>>>,
34        ) = oneshot::channel();
35
36        #[allow(clippy::type_complexity)]
37        let (api_connected_tx, api_connected_rx): (
38            oneshot::Sender<Result<(), Box<dyn Error + Send + Sync>>>,
39            oneshot::Receiver<Result<(), Box<dyn Error + Send + Sync>>>,
40        ) = oneshot::channel();
41
42        {
43            let mut bot_write = bot.write();
44            bot_write.spawn(Self::ws_event_connect(
45                server.clone(),
46                event_tx.clone(),
47                event_connected_tx,
48                bot.clone(),
49            ));
50            bot_write.spawn(Self::ws_send_api(
51                server,
52                api_rx,
53                event_tx,
54                api_connected_tx,
55                bot.clone(),
56            ));
57        }
58
59        let (res1, res2) = tokio::join!(event_connected_rx, api_connected_rx);
60        let (res1, res2) = (res1.expect("unreachable"), res2.expect("unreachable"));
61        match (res1, res2) {
62            (Ok(_), Ok(_)) => Ok(()),
63            (Err(e), _) | (_, Err(e)) => Err(e),
64        }
65    }
66
67    pub(crate) async fn ws_event_connect(
68        server: Server,
69        event_tx: mpsc::Sender<InternalInternalEvent>,
70        connected_tx: oneshot::Sender<Result<(), Box<dyn Error + Send + Sync>>>,
71        bot: Arc<RwLock<Bot>>,
72    ) {
73        let (host, port, access_token, secure) =
74            (server.host, server.port, server.access_token, server.secure);
75
76        let protocol = if secure { "wss" } else { "ws" };
77        let mut request = match host {
78            Host::IpAddr(ip) => match ip {
79                IpAddr::V4(ip) => format!("{protocol}://{ip}:{port}/event")
80                    .into_client_request()
81                    .expect("The domain name is invalid"),
82                IpAddr::V6(ip) => format!("{protocol}://[{ip}]:{port}/event")
83                    .into_client_request()
84                    .expect("The domain name is invalid"),
85            },
86            Host::Domain(domain) => format!("{protocol}://{domain}:{port}/event")
87                .into_client_request()
88                .expect("The domain name is invalid"),
89        };
90
91        //增加Authorization头
92        if !access_token.is_empty() {
93            request.headers_mut().insert(
94                "Authorization",
95                HeaderValue::from_str(&format!("Bearer {access_token}")).expect("unreachable"),
96            );
97        }
98
99        let (ws_stream, _) = match connect_async(request).await {
100            Ok(v) => v,
101            Err(e) => {
102                connected_tx
103                    .send(Err(e.into()))
104                    .expect("The OneBot connect channel has been established");
105                return;
106            }
107        };
108
109        connected_tx
110            .send(Ok(()))
111            .expect("The OneBot connect channel has been established");
112
113        let (_, read) = ws_stream.split();
114
115        let mut bot_write = bot.write();
116        bot_write.spawn(ws_event_connect_read(read, event_tx));
117    }
118
119    pub(crate) async fn ws_send_api(
120        server: Server,
121        api_rx: mpsc::Receiver<ApiAndOneshot>,
122        event_tx: mpsc::Sender<InternalInternalEvent>,
123        connected_tx: oneshot::Sender<Result<(), Box<dyn std::error::Error + Send + Sync>>>,
124        bot: Arc<RwLock<Bot>>,
125    ) {
126        let (host, port, access_token, secure) =
127            (server.host, server.port, server.access_token, server.secure);
128
129        let protocol = if secure { "wss" } else { "ws" };
130        let mut request = match host {
131            Host::IpAddr(ip) => match ip {
132                IpAddr::V4(ip) => format!("{protocol}://{ip}:{port}/api")
133                    .into_client_request()
134                    .expect("The domain name is invalid"),
135                IpAddr::V6(ip) => format!("{protocol}://[{ip}]:{port}/api")
136                    .into_client_request()
137                    .expect("The domain name is invalid"),
138            },
139            Host::Domain(domain) => format!("{protocol}://{domain}:{port}/api")
140                .into_client_request()
141                .expect("The domain name is invalid"),
142        };
143
144        //增加Authorization头
145        if !access_token.is_empty() {
146            request.headers_mut().insert(
147                "Authorization",
148                HeaderValue::from_str(&format!("Bearer {access_token}")).expect("unreachable"),
149            );
150        }
151
152        let (ws_stream, _) = match connect_async(request).await {
153            Ok(v) => v,
154            Err(e) => {
155                connected_tx
156                    .send(Err(e.into()))
157                    .expect("The OneBot connect channel has been established");
158                return;
159            }
160        };
161
162        connected_tx
163            .send(Ok(()))
164            .expect("The OneBot connect channel has been established");
165
166        let (write, read) = ws_stream.split();
167        let api_tx_map: ApiTxMap = Arc::new(Mutex::new(ahash::HashMap::<_, _>::default()));
168
169        let mut bot_write = bot.write();
170
171        //读
172        bot_write.spawn(ws_send_api_read(
173            read,
174            event_tx.clone(),
175            Arc::clone(&api_tx_map),
176        ));
177
178        //写
179        bot_write.spawn(ws_send_api_write(
180            write,
181            api_rx,
182            event_tx,
183            api_tx_map.clone(),
184        ));
185    }
186}
187
188async fn ws_event_connect_read(
189    read: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
190    event_tx: Sender<InternalInternalEvent>,
191) {
192    read.for_each(|msg| {
193        let event_tx = event_tx.clone();
194        async {
195            match msg {
196                Ok(msg) => handle_msg(msg, event_tx).await,
197                Err(e) => connection_failed_eprintln(e, event_tx).await,
198            }
199        }
200    })
201    .await;
202
203    async fn handle_msg(
204        msg: tokio_tungstenite::tungstenite::Message,
205        event_tx: Sender<InternalInternalEvent>,
206    ) {
207        if !msg.is_text() {
208            return;
209        }
210
211        let text = msg.to_text().expect("unreachable");
212        if let Err(e) = event_tx
213            .send(InternalInternalEvent::OneBotEvent(
214                InternalEvent::OneBotEvent(text.to_string()),
215            ))
216            .await
217        {
218            debug!("通道关闭:{e}")
219        }
220    }
221}
222
223async fn ws_send_api_read(
224    read: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
225    event_tx: Sender<InternalInternalEvent>,
226    api_tx_map: ApiTxMap,
227) {
228    read.for_each(|msg| {
229        let event_tx = event_tx.clone();
230        async {
231            match msg {
232                Ok(msg) => handle_msg(msg, event_tx, api_tx_map.clone()).await,
233                Err(e) => connection_failed_eprintln(e, event_tx).await,
234            }
235        }
236    })
237    .await;
238
239    async fn handle_msg(
240        msg: tokio_tungstenite::tungstenite::Message,
241        event_tx: Sender<InternalInternalEvent>,
242        api_tx_map: ApiTxMap,
243    ) {
244        if msg.is_close() {
245            connection_failed_eprintln(format!("{msg}\nBot api connection failed"), event_tx).await;
246            return;
247        }
248        if !msg.is_text() {
249            return;
250        }
251
252        let text = msg.to_text().expect("unreachable");
253
254        debug!("{text}");
255
256        let return_value: ApiReturn = match serde_json::from_str(text) {
257            Ok(v) => v,
258            Err(_) => {
259                debug!("Unknow api return: {text}");
260                return;
261            }
262        };
263
264        if return_value.status != "ok" {
265            warn!("Api return error: {text}")
266        }
267
268        let api_tx_cache = {
269            let mut api_tx_map = api_tx_map.lock();
270            match api_tx_map.remove(&return_value.echo) {
271                Some(v) => v,
272                None => {
273                    log::error!("Api return echo not found from api_tx_map: {text}");
274                    return;
275                }
276            }
277        };
278
279        let return_value = if return_value.status.to_lowercase() == "ok" {
280            Ok(return_value)
281        } else {
282            Err(return_value)
283        };
284
285        if let Some(tx) = api_tx_cache.1
286            && tx.send(return_value.clone()).is_err() {
287                log::debug!("Return Api to plugin failed, the receiver has been closed")
288            };
289
290        event_tx
291            .send(InternalInternalEvent::OneBotEvent(
292                InternalEvent::OneBotApiEvent((api_tx_cache.0, return_value)),
293            ))
294            .await
295            .expect("The event_tx is closed");
296    }
297}
298
299async fn ws_send_api_write(
300    mut write: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
301    mut api_rx: mpsc::Receiver<ApiAndOneshot>,
302    event_tx: Sender<InternalInternalEvent>,
303    api_tx_map: ApiTxMap,
304) {
305    while let Some((api_msg, return_api_tx)) = api_rx.recv().await {
306        let event_tx = event_tx.clone();
307        debug!("{api_msg}");
308
309        api_tx_map
310            .lock()
311            .insert(api_msg.echo.clone(), (api_msg.clone(), return_api_tx));
312
313        let msg = tokio_tungstenite::tungstenite::Message::text(api_msg.to_string());
314
315        if let Err(e) = write.send(msg).await {
316            connection_failed_eprintln(e, event_tx).await;
317        }
318    }
319}
320
321async fn connection_failed_eprintln<E>(e: E, event_tx: Sender<InternalInternalEvent>)
322where
323    E: Display,
324{
325    log::error!("{e}\nBot connection failed, please check the configuration and restart.");
326    if let Err(e) = event_tx
327        .send(InternalInternalEvent::KoviEvent(
328            crate::bot::handler::KoviEvent::Drop,
329        ))
330        .await
331    {
332        error!("通道关闭,{e}")
333    };
334}