1use crate::{
2 callback::{callback_worker, TaskCallbacks},
3 error,
4 utils::print_error,
5 Callbacks, Notification, Result, RpcRequest, RpcResponse,
6};
7use futures::prelude::*;
8use log::{debug, info};
9use serde::de::DeserializeOwned;
10use serde_json::Value;
11use snafu::prelude::*;
12use std::{
13 collections::HashMap,
14 ops::Deref,
15 sync::{
16 atomic::{AtomicI32, Ordering},
17 Arc,
18 },
19 time::Duration,
20};
21use tokio::{
22 select, spawn,
23 sync::{broadcast, mpsc, oneshot, Notify},
24 time::sleep,
25};
26use tokio_tungstenite::tungstenite::Message;
27type WebSocket =
28 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
29
30#[derive(Debug)]
31pub(crate) struct Subscription {
32 pub id: i32,
33 pub tx: oneshot::Sender<RpcResponse>,
34}
35pub struct InnerClient {
36 token: Option<String>,
37 id: AtomicI32,
38 tx_ws_sink: mpsc::Sender<Message>,
40 tx_notification: broadcast::Sender<Notification>,
41 tx_subscription: mpsc::Sender<Subscription>,
42 shutdown: Arc<Notify>,
44}
45
46#[derive(Clone)]
63pub struct Client {
64 inner: Arc<InnerClient>,
65 tx_callback: mpsc::UnboundedSender<TaskCallbacks>,
67}
68
69impl Drop for InnerClient {
70 fn drop(&mut self) {
71 debug!("InnerClient dropped, notify shutdown");
73 self.shutdown.notify_waiters();
74 }
75}
76
77async fn process_ws(
78 ws: WebSocket,
79 rx_ws_sink: &mut mpsc::Receiver<Message>,
80 tx_notification: broadcast::Sender<Notification>,
81 rx_subscription: &mut mpsc::Receiver<Subscription>,
82) {
83 let (mut sink, mut stream) = ws.split();
84 let mut subscriptions = HashMap::<i32, oneshot::Sender<RpcResponse>>::new();
85
86 let on_stream = |msg: String,
87 subscriptions: &mut HashMap<i32, oneshot::Sender<RpcResponse>>|
88 -> Result<()> {
89 let v: Value = serde_json::from_str(&msg).context(error::JsonSnafu)?;
90 if let Value::Object(obj) = &v {
91 if obj.contains_key("method") {
92 let req: RpcRequest = serde_json::from_value(v).context(error::JsonSnafu)?;
95 let notification = (&req).try_into()?;
96 let _ = tx_notification.send(notification);
97 return Ok(());
98 }
99 }
100
101 let res: RpcResponse = serde_json::from_value(v).context(error::JsonSnafu)?;
103 if let Some(ref id) = res.id {
104 let tx = subscriptions.remove(id);
105 if let Some(tx) = tx {
106 let _ = tx.send(res);
107 }
108 }
109 Ok(())
110 };
111
112 loop {
113 select! {
114 msg = stream.try_next() => {
115 debug!("websocket received message: {:?}", msg);
116 let Ok(msg) = msg else {
117 break;
118 };
119 if let Some(Message::Text(s)) = msg {
120 print_error(on_stream(s.to_string(), &mut subscriptions));
121 }
122 },
123 msg = rx_ws_sink.recv() => {
124 debug!("writing message to websocket: {:?}", msg);
125 let Some(msg) = msg else {
126 break;
127 };
128 print_error(sink.send(msg).await);
129 },
130 subscription = rx_subscription.recv() => {
131 if let Some(subscription) = subscription {
132 subscriptions.insert(subscription.id, subscription.tx);
133 }
134 }
135 }
136 }
137}
138
139impl InnerClient {
140 pub(crate) async fn connect(url: &str, token: Option<&str>) -> Result<Self> {
141 let (tx_ws_sink, mut rx_ws_sink) = mpsc::channel(1);
142 let (tx_subscription, mut rx_subscription) = mpsc::channel(1);
143 let shutdown = Arc::new(Notify::new());
144 let (tx_notification, _) = broadcast::channel(1);
147
148 let inner = InnerClient {
149 tx_ws_sink,
150 id: AtomicI32::new(0),
151 token: token.map(|t| "token:".to_string() + t),
152 tx_subscription,
153 tx_notification: tx_notification.clone(),
154 shutdown: shutdown.clone(),
155 };
156
157 async fn connect_ws(url: &str) -> Result<WebSocket> {
158 debug!("connecting to {}", url);
159 let (ws, res) = tokio_tungstenite::connect_async(url)
160 .await
161 .context(error::WebsocketIoSnafu)?;
162 debug!("connected to {}, {:?}", url, res);
163 Ok(ws)
164 }
165
166 let ws = connect_ws(url).await?;
167 let url = url.to_string();
168 spawn(async move {
170 let mut ws = Some(ws);
171 loop {
172 if let Some(ws) = ws.take() {
173 let _ = tx_notification.send(Notification::WebSocketConnected);
174
175 let fut = process_ws(
176 ws,
177 &mut rx_ws_sink,
178 tx_notification.clone(),
179 &mut rx_subscription,
180 );
181
182 select! {
183 _ = fut => {},
184 _ = shutdown.notified() => {
185 return;
186 },
187 }
188
189 let _ = tx_notification.send(Notification::WebsocketClosed);
190 } else {
191 let r = select! {
192 r = connect_ws(&url) => r,
193 _ = shutdown.notified() => return,
194 };
195 match r {
196 Ok(ws_) => {
197 ws.replace(ws_);
198 }
199 Err(err) => {
200 info!("{}", err);
201 sleep(Duration::from_secs(3)).await;
202 }
203 }
204 }
205 }
206 });
207
208 Ok(inner)
209 }
210
211 fn id(&self) -> i32 {
212 self.id.fetch_add(1, Ordering::Relaxed)
213 }
214
215 async fn wait_for_id<T>(&self, id: i32, rx: oneshot::Receiver<RpcResponse>) -> Result<T>
216 where
217 T: DeserializeOwned + Send,
218 {
219 let res = rx.await.map_err(|err| {
220 error::WebsocketClosedSnafu {
221 message: format!("receiving response for id {}: {}", id, err),
222 }
223 .build()
224 })?;
225
226 if let Some(err) = res.error {
227 return Err(err).context(error::Aria2Snafu);
228 }
229
230 if let Some(v) = res.result {
231 Ok(serde_json::from_value::<T>(v).context(error::JsonSnafu)?)
232 } else {
233 error::ParseSnafu {
234 value: format!("{:?}", res),
235 to: "RpcResponse with result",
236 }
237 .fail()
238 }
239 }
240
241 pub async fn call(&self, id: i32, method: &str, mut params: Vec<Value>) -> Result<()> {
243 if let Some(ref token) = self.token {
244 params.insert(0, Value::String(token.clone()))
245 }
246 let req = RpcRequest {
247 id: Some(id),
248 jsonrpc: "2.0".to_string(),
249 method: "aria2.".to_string() + method,
250 params,
251 };
252 self.tx_ws_sink
253 .send(Message::Text(
254 serde_json::to_string(&req)
255 .context(error::JsonSnafu)?
256 .into(),
257 ))
258 .await
259 .expect("tx_ws_sink: receiver has been dropped");
260 Ok(())
261 }
262
263 pub async fn call_and_wait<T>(&self, method: &str, params: Vec<Value>) -> Result<T>
265 where
266 T: DeserializeOwned + Send,
267 {
268 let id = self.id();
269 let (tx, rx) = oneshot::channel();
270 self.tx_subscription
271 .send(Subscription { id, tx })
272 .await
273 .expect("tx_subscription: receiver has been closed");
274
275 self.call(id, method, params).await?;
276 self.wait_for_id::<T>(id, rx).await
277 }
278
279 pub fn subscribe_notifications(&self) -> broadcast::Receiver<Notification> {
283 self.tx_notification.subscribe()
284 }
285}
286
287impl Client {
288 pub async fn connect(url: &str, token: Option<&str>) -> Result<Self> {
313 let inner = Arc::new(InnerClient::connect(url, token).await?);
314
315 let weak = Arc::downgrade(&inner);
316 let rx_notification = inner.subscribe_notifications();
317 let (tx_callback, rx_callback) = mpsc::unbounded_channel();
318 spawn(callback_worker(weak, rx_notification, rx_callback));
320
321 Ok(Self { inner, tx_callback })
322 }
323
324 pub(crate) fn add_callbacks(&self, gid: String, callbacks: Callbacks) {
325 if callbacks.is_empty() {
326 return;
327 }
328 self.tx_callback
329 .send(TaskCallbacks { gid, callbacks })
330 .expect("tx_callback: receiver has been dropped");
331 }
332}
333
334impl Deref for Client {
335 type Target = InnerClient;
336
337 fn deref(&self) -> &Self::Target {
338 &self.inner
339 }
340}