aria2_ws/
client.rs

1use crate::{
2    callback::{callback_worker, TaskCallbacks},
3    error,
4    utils::print_error,
5    Callbacks, Notification, Result, RpcRequest, RpcResponse,
6};
7use futures::prelude::*;
8use log::{debug, info};
9use serde::de::DeserializeOwned;
10use serde_json::Value;
11use snafu::prelude::*;
12use std::{
13    collections::HashMap,
14    ops::Deref,
15    sync::{
16        atomic::{AtomicI32, Ordering},
17        Arc,
18    },
19    time::Duration,
20};
21use tokio::{
22    select, spawn,
23    sync::{broadcast, mpsc, oneshot, Notify},
24    time::sleep,
25};
26use tokio_tungstenite::tungstenite::Message;
27type WebSocket =
28    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
29
30#[derive(Debug)]
31pub(crate) struct Subscription {
32    pub id: i32,
33    pub tx: oneshot::Sender<RpcResponse>,
34}
35pub struct InnerClient {
36    token: Option<String>,
37    id: AtomicI32,
38    /// Channel for sending messages to the websocket.
39    tx_ws_sink: mpsc::Sender<Message>,
40    tx_notification: broadcast::Sender<Notification>,
41    tx_subscription: mpsc::Sender<Subscription>,
42    /// On notified, all spawned tasks shut down.
43    shutdown: Arc<Notify>,
44}
45
46/// An aria2 websocket rpc client.
47///
48/// # Example
49///
50/// ```
51/// use aria2_ws::Client;
52///
53/// #[tokio::main]
54/// async fn main() {
55///     let client = Client::connect("ws://127.0.0.1:6800/jsonrpc", None)
56///         .await
57///         .unwrap();
58///     let version = client.get_version().await.unwrap();
59///     println!("{:?}", version);
60/// }
61/// ```
62#[derive(Clone)]
63pub struct Client {
64    inner: Arc<InnerClient>,
65    // The sender can be cloned like `Arc`.
66    tx_callback: mpsc::UnboundedSender<TaskCallbacks>,
67}
68
69impl Drop for InnerClient {
70    fn drop(&mut self) {
71        // notify all spawned tasks to shutdown
72        debug!("InnerClient dropped, notify shutdown");
73        self.shutdown.notify_waiters();
74    }
75}
76
77async fn process_ws(
78    ws: WebSocket,
79    rx_ws_sink: &mut mpsc::Receiver<Message>,
80    tx_notification: broadcast::Sender<Notification>,
81    rx_subscription: &mut mpsc::Receiver<Subscription>,
82) {
83    let (mut sink, mut stream) = ws.split();
84    let mut subscriptions = HashMap::<i32, oneshot::Sender<RpcResponse>>::new();
85
86    let on_stream = |msg: String,
87                     subscriptions: &mut HashMap<i32, oneshot::Sender<RpcResponse>>|
88     -> Result<()> {
89        let v: Value = serde_json::from_str(&msg).context(error::JsonSnafu)?;
90        if let Value::Object(obj) = &v {
91            if obj.contains_key("method") {
92                // The message should be a notification.
93                // https://aria2.github.io/manual/en/html/aria2c.html#notifications
94                let req: RpcRequest = serde_json::from_value(v).context(error::JsonSnafu)?;
95                let notification = (&req).try_into()?;
96                let _ = tx_notification.send(notification);
97                return Ok(());
98            }
99        }
100
101        // The message should be a response.
102        let res: RpcResponse = serde_json::from_value(v).context(error::JsonSnafu)?;
103        if let Some(ref id) = res.id {
104            let tx = subscriptions.remove(id);
105            if let Some(tx) = tx {
106                let _ = tx.send(res);
107            }
108        }
109        Ok(())
110    };
111
112    loop {
113        select! {
114            msg = stream.try_next() => {
115                debug!("websocket received message: {:?}", msg);
116                let Ok(msg) = msg else {
117                    break;
118                };
119                if let Some(Message::Text(s)) = msg {
120                    print_error(on_stream(s.to_string(), &mut subscriptions));
121                }
122            },
123            msg = rx_ws_sink.recv() => {
124                debug!("writing message to websocket: {:?}", msg);
125                let Some(msg) = msg else {
126                    break;
127                };
128                print_error(sink.send(msg).await);
129            },
130            subscription = rx_subscription.recv() => {
131                if let Some(subscription) = subscription {
132                    subscriptions.insert(subscription.id, subscription.tx);
133                }
134            }
135        }
136    }
137}
138
139impl InnerClient {
140    pub(crate) async fn connect(url: &str, token: Option<&str>) -> Result<Self> {
141        let (tx_ws_sink, mut rx_ws_sink) = mpsc::channel(1);
142        let (tx_subscription, mut rx_subscription) = mpsc::channel(1);
143        let shutdown = Arc::new(Notify::new());
144        // Broadcast notifications to all subscribers.
145        // The receiver is dropped cause there is no subscriber for now.
146        let (tx_notification, _) = broadcast::channel(1);
147
148        let inner = InnerClient {
149            tx_ws_sink,
150            id: AtomicI32::new(0),
151            token: token.map(|t| "token:".to_string() + t),
152            tx_subscription,
153            tx_notification: tx_notification.clone(),
154            shutdown: shutdown.clone(),
155        };
156
157        async fn connect_ws(url: &str) -> Result<WebSocket> {
158            debug!("connecting to {}", url);
159            let (ws, res) = tokio_tungstenite::connect_async(url)
160                .await
161                .context(error::WebsocketIoSnafu)?;
162            debug!("connected to {}, {:?}", url, res);
163            Ok(ws)
164        }
165
166        let ws = connect_ws(url).await?;
167        let url = url.to_string();
168        // spawn a task to process websocket messages
169        spawn(async move {
170            let mut ws = Some(ws);
171            loop {
172                if let Some(ws) = ws.take() {
173                    let _ = tx_notification.send(Notification::WebSocketConnected);
174
175                    let fut = process_ws(
176                        ws,
177                        &mut rx_ws_sink,
178                        tx_notification.clone(),
179                        &mut rx_subscription,
180                    );
181
182                    select! {
183                        _ = fut => {},
184                        _ = shutdown.notified() => {
185                            return;
186                        },
187                    }
188
189                    let _ = tx_notification.send(Notification::WebsocketClosed);
190                } else {
191                    let r = select! {
192                        r = connect_ws(&url) => r,
193                        _ = shutdown.notified() => return,
194                    };
195                    match r {
196                        Ok(ws_) => {
197                            ws.replace(ws_);
198                        }
199                        Err(err) => {
200                            info!("{}", err);
201                            sleep(Duration::from_secs(3)).await;
202                        }
203                    }
204                }
205            }
206        });
207
208        Ok(inner)
209    }
210
211    fn id(&self) -> i32 {
212        self.id.fetch_add(1, Ordering::Relaxed)
213    }
214
215    async fn wait_for_id<T>(&self, id: i32, rx: oneshot::Receiver<RpcResponse>) -> Result<T>
216    where
217        T: DeserializeOwned + Send,
218    {
219        let res = rx.await.map_err(|err| {
220            error::WebsocketClosedSnafu {
221                message: format!("receiving response for id {}: {}", id, err),
222            }
223            .build()
224        })?;
225
226        if let Some(err) = res.error {
227            return Err(err).context(error::Aria2Snafu);
228        }
229
230        if let Some(v) = res.result {
231            Ok(serde_json::from_value::<T>(v).context(error::JsonSnafu)?)
232        } else {
233            error::ParseSnafu {
234                value: format!("{:?}", res),
235                to: "RpcResponse with result",
236            }
237            .fail()
238        }
239    }
240
241    /// Send a rpc request to websocket without waiting for response.
242    pub async fn call(&self, id: i32, method: &str, mut params: Vec<Value>) -> Result<()> {
243        if let Some(ref token) = self.token {
244            params.insert(0, Value::String(token.clone()))
245        }
246        let req = RpcRequest {
247            id: Some(id),
248            jsonrpc: "2.0".to_string(),
249            method: "aria2.".to_string() + method,
250            params,
251        };
252        self.tx_ws_sink
253            .send(Message::Text(
254                serde_json::to_string(&req)
255                    .context(error::JsonSnafu)?
256                    .into(),
257            ))
258            .await
259            .expect("tx_ws_sink: receiver has been dropped");
260        Ok(())
261    }
262
263    /// Send a rpc request to websocket and wait for corresponding response.
264    pub async fn call_and_wait<T>(&self, method: &str, params: Vec<Value>) -> Result<T>
265    where
266        T: DeserializeOwned + Send,
267    {
268        let id = self.id();
269        let (tx, rx) = oneshot::channel();
270        self.tx_subscription
271            .send(Subscription { id, tx })
272            .await
273            .expect("tx_subscription: receiver has been closed");
274
275        self.call(id, method, params).await?;
276        self.wait_for_id::<T>(id, rx).await
277    }
278
279    /// Subscribe to notifications.
280    ///
281    /// Returns a instance of `broadcast::Receiver` which can be used to receive notifications.
282    pub fn subscribe_notifications(&self) -> broadcast::Receiver<Notification> {
283        self.tx_notification.subscribe()
284    }
285}
286
287impl Client {
288    /// Create a new `Client` that connects to the given url.
289    ///
290    /// # Example
291    ///
292    /// ```
293    /// use aria2_ws::Client;
294    ///
295    /// #[tokio::main]
296    /// async fn main() {
297    ///     let client = Client::connect("ws://127.0.0.1:6800/jsonrpc", None)
298    ///         .await
299    ///         .unwrap();
300    ///     let gid = client
301    ///         .add_uri(
302    ///             vec!["https://go.dev/dl/go1.17.6.windows-amd64.msi".to_string()],
303    ///             None,
304    ///             None,
305    ///             None,
306    ///         )
307    ///         .await
308    ///         .unwrap();
309    ///     client.force_remove(&gid).await.unwrap();
310    /// }
311    /// ```
312    pub async fn connect(url: &str, token: Option<&str>) -> Result<Self> {
313        let inner = Arc::new(InnerClient::connect(url, token).await?);
314
315        let weak = Arc::downgrade(&inner);
316        let rx_notification = inner.subscribe_notifications();
317        let (tx_callback, rx_callback) = mpsc::unbounded_channel();
318        // hold a weak reference to `inner` to prevent not shutting down when `Client` is dropped
319        spawn(callback_worker(weak, rx_notification, rx_callback));
320
321        Ok(Self { inner, tx_callback })
322    }
323
324    pub(crate) fn add_callbacks(&self, gid: String, callbacks: Callbacks) {
325        if callbacks.is_empty() {
326            return;
327        }
328        self.tx_callback
329            .send(TaskCallbacks { gid, callbacks })
330            .expect("tx_callback: receiver has been dropped");
331    }
332}
333
334impl Deref for Client {
335    type Target = InnerClient;
336
337    fn deref(&self) -> &Self::Target {
338        &self.inner
339    }
340}