aria2_rs_yet/
ws.rs

1use futures_util::{stream::SplitSink, SinkExt, StreamExt};
2use tokio::sync::{mpsc, oneshot};
3use tokio::time::{timeout, Duration};
4use tokio_tungstenite::tungstenite;
5use std::ops::Deref;
6use std::sync::Arc;
7
8use crate::call::Call;
9use crate::error::Error;
10use crate::jsonrpc;
11use crate::Result;
12
13type WSMessage = tokio_tungstenite::tungstenite::Message;
14type WSStream =
15    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
16
17#[derive(Debug, Clone)]
18pub enum Notification {
19    DownloadStart(String),
20    DownloadPause(String),
21    DownloadStop(String),
22    DownloadComplete(String),
23    DownloadError(String),
24    BtDownloadComplete(String),
25}
26
27impl Notification {
28    pub fn new(method: &str, gid: String) -> Self {
29        match method {
30            "aria2.onDownloadStart" => Self::DownloadStart(gid),
31            "aria2.onDownloadPause" => Self::DownloadPause(gid),
32            "aria2.onDownloadStop" => Self::DownloadStop(gid),
33            "aria2.onDownloadComplete" => Self::DownloadComplete(gid),
34            "aria2.onDownloadError" => Self::DownloadError(gid),
35            "aria2.onBtDownloadComplete" => Self::BtDownloadComplete(gid),
36            _ => unreachable!(),
37        }
38    }
39}
40
41#[derive(serde::Deserialize)]
42struct NotificationParam {
43    gid: String,
44}
45
46struct RPCRequest {
47    params: Option<serde_json::Value>,
48    method: &'static str,
49    handler: oneshot::Sender<RPCReponse>,
50}
51
52enum RPCReponse {
53    Success(serde_json::Value),
54    Error(jsonrpc::Error),
55}
56
57pub struct ConnectionMeta {
58    pub url: String,
59    pub token: Option<String>,
60}
61
62impl ConnectionMeta{
63    pub fn new(url: &str, token: Option<&str>) -> Self {
64        Self {
65            url: url.to_string(),
66            token: token.map(|s| format!("token:{}", s)),
67        }
68    }
69}
70
71impl tungstenite::client::IntoClientRequest for &ConnectionMeta{
72    fn into_client_request(self) -> tungstenite::Result<tungstenite::handshake::client::Request> {
73        // add header here if needed
74        self.url.as_str().into_client_request()
75    }
76}
77
78#[derive(Clone)]
79pub struct Client {
80    inner: Arc<ClientInner>,
81}
82
83impl Deref for Client {
84    type Target = ClientInner;
85
86    fn deref(&self) -> &Self::Target {
87        &self.inner
88    }
89}
90
91impl Client {
92    pub async fn connect(meta: ConnectionMeta) -> Result<(Self, mpsc::UnboundedReceiver<Notification>)>{
93        let (inner, notify_rx) = ClientInner::connect(meta).await?;
94        let client = Client {
95            inner: Arc::new(inner),
96        };
97        Ok((client, notify_rx))
98    }
99}
100
101pub struct ClientInner {
102    message_tx: mpsc::Sender<RPCRequest>,
103    token: Option<String>,
104    _drop_rx: oneshot::Receiver<()>,
105}
106
107impl ClientInner {
108    async fn connect(
109        meta: ConnectionMeta,
110    ) -> Result<(Self, mpsc::UnboundedReceiver<Notification>)> {
111        let (ws, _) = tokio_tungstenite::connect_async(&meta)
112            .await
113            .map_err(Error::Connect)?;
114        let (message_tx, message_rx) = mpsc::channel(32);
115        let (notification_tx, notification_rx) = mpsc::unbounded_channel();
116        let (drop_tx, _drop_rx) = oneshot::channel();
117        let token = meta.token.clone();
118        tokio::spawn(Self::background(
119            ws,
120            meta,
121            message_rx,
122            drop_tx,
123            notification_tx,
124        ));
125        Ok((
126            Self {
127                message_tx,
128                token,
129                _drop_rx,
130            },
131            notification_rx,
132        ))
133    }
134
135    pub async fn call<C: Call>(&self, call: C) -> Result<C::Response> {
136        let (tx, rx) = oneshot::channel();
137
138        let method = call.method();
139        let params = match call.to_params(self.token.as_ref().map(AsRef::as_ref)) {
140            Some(params) => Some(serde_json::to_value(params).map_err(Error::Encode)?),
141            None => None,
142        };
143        
144        tracing::debug!("call method: {}, params: {:?}", method, params);
145
146        let request = RPCRequest {
147            params,
148            method,
149            handler: tx,
150        };
151        self.message_tx
152            .send(request)
153            .await
154            .map_err(|_| Error::ChannelSend)?;
155        match rx.await.map_err(Error::ChannelRecv)? {
156            RPCReponse::Success(value) => {
157                serde_json::from_value(value).map_err(Error::Decode)
158            }
159            RPCReponse::Error(err) => Err(err.into()),
160        }
161    }
162
163    async fn background(
164        ws: WSStream,
165        meta: ConnectionMeta,
166        mut message_rx: mpsc::Receiver<RPCRequest>,
167        mut drop_tx: oneshot::Sender<()>,
168        notification_tx: mpsc::UnboundedSender<Notification>,
169    ) {
170        let (mut ws_tx, mut ws_rx) = ws.split();
171        let mut shutdown = tokio::spawn({
172            let notification_tx = notification_tx.clone();
173            async move {
174                tokio::join!(drop_tx.closed(), notification_tx.closed());
175            }
176        });
177
178        let mut request_id = 1i64;
179        let mut pending_requests = std::collections::HashMap::new();
180
181        loop {
182            loop {
183                if notification_tx.is_closed() && message_rx.is_closed() {
184                    tracing::info!("background task shutdown");
185                    return;
186                }
187                tokio::select! {
188                    _ = &mut shutdown => {
189                        tracing::info!("background task shutdown");
190                        return;
191                    }
192                    Some(msg) = message_rx.recv() => {
193                        request_id += 1;
194                        pending_requests.insert(request_id, msg.handler);
195
196                        if let Err(e) = timeout(
197                            Duration::from_secs(10),
198                           Self::send_request(&mut ws_tx, request_id, msg.method, msg.params,)
199                        ).await {
200                            tracing::error!("send request error: {e}");
201                            break;
202                        }
203                    }
204                    Some(msg) = ws_rx.next() => {
205                        let text = match msg {
206                            Ok(WSMessage::Text(text)) => text,
207                            Ok(WSMessage::Close(_)) => {
208                                tracing::info!("websocket closed");
209                                break;
210                            }
211                            Ok(_) => {
212                                continue;
213                            }
214                            Err(e) => {
215                                tracing::error!("websocket error: {e}");
216                                break;
217                            }
218                        };
219                        Self::handle_response(&text, &mut pending_requests, notification_tx.clone());
220                    }
221                }
222            }
223            pending_requests.clear();
224
225            // reconnect
226            loop {
227                if notification_tx.is_closed() && message_rx.is_closed() {
228                    tracing::info!("background task shutdown");
229                    return;
230                }
231                match timeout(
232                    Duration::from_secs(10),
233                    tokio_tungstenite::connect_async(&meta),
234                )
235                .await
236                {
237                    Err(e) => {
238                        tracing::error!("reconnect error: {e}, will retry in 10 seconds");
239                        tokio::time::sleep(Duration::from_secs(10)).await;
240                    }
241                    Ok(Err(e)) => {
242                        tracing::error!("reconnect timeout: {e}, will retry in 10 seconds");
243                        tokio::time::sleep(Duration::from_secs(10)).await;
244                    }
245                    Ok(Ok((new_ws, _))) => {
246                        let (tx, rx) = new_ws.split();
247                        ws_tx = tx;
248                        ws_rx = rx;
249                        break;
250                    }
251                }
252            }
253        }
254    }
255
256    async fn send_request(
257        sink: &mut SplitSink<WSStream, WSMessage>,
258        id: i64,
259        method: &str,
260        params: Option<serde_json::Value>,
261    ) -> Result<()> {
262        let rpc_req = jsonrpc::Request {
263            id: Some(id),
264            jsonrpc: "2.0",
265            method,
266            params,
267        };
268        sink.send(WSMessage::Text(
269            serde_json::to_string(&rpc_req)
270                .map_err(Error::Encode)?
271                .into(),
272        ))
273        .await
274        .map_err(Error::Websocket)
275    }
276
277    fn handle_response(
278        text: &str,
279        pending_requests: &mut std::collections::HashMap<i64, oneshot::Sender<RPCReponse>>,
280        notification_tx: mpsc::UnboundedSender<Notification>,
281    ) {
282        if let Ok(resp) = serde_json::from_str::<
283            jsonrpc::Response<i64, serde_json::Value, Vec<NotificationParam>>,
284        >(text)
285        {
286            match resp {
287                jsonrpc::Response::Err { id, error } => {
288                    if let Some(tx) = pending_requests.remove(&id) {
289                        let _ = tx.send(RPCReponse::Error(error));
290                    }
291                }
292                jsonrpc::Response::Resp { id, result } => {
293                    if let Some(tx) = pending_requests.remove(&id) {
294                        let _ = tx.send(RPCReponse::Success(result));
295                    }
296                }
297                jsonrpc::Response::Notification { method, params } => {
298                    tokio::spawn(async move {
299                        let method = method;
300                        for param in params {
301                            if notification_tx.send(Notification::new(&method, param.gid)).is_err()
302                            {
303                                break;
304                            }
305                        }
306                    });
307                }
308            }
309        }
310    }
311}