1use futures_util::{stream::SplitSink, SinkExt, StreamExt};
2use tokio::sync::{mpsc, oneshot};
3use tokio::time::{timeout, Duration};
4use tokio_tungstenite::tungstenite;
5use std::ops::Deref;
6use std::sync::Arc;
7
8use crate::call::Call;
9use crate::error::Error;
10use crate::jsonrpc;
11use crate::Result;
12
13type WSMessage = tokio_tungstenite::tungstenite::Message;
14type WSStream =
15 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
16
17#[derive(Debug, Clone)]
18pub enum Notification {
19 DownloadStart(String),
20 DownloadPause(String),
21 DownloadStop(String),
22 DownloadComplete(String),
23 DownloadError(String),
24 BtDownloadComplete(String),
25}
26
27impl Notification {
28 pub fn new(method: &str, gid: String) -> Self {
29 match method {
30 "aria2.onDownloadStart" => Self::DownloadStart(gid),
31 "aria2.onDownloadPause" => Self::DownloadPause(gid),
32 "aria2.onDownloadStop" => Self::DownloadStop(gid),
33 "aria2.onDownloadComplete" => Self::DownloadComplete(gid),
34 "aria2.onDownloadError" => Self::DownloadError(gid),
35 "aria2.onBtDownloadComplete" => Self::BtDownloadComplete(gid),
36 _ => unreachable!(),
37 }
38 }
39}
40
41#[derive(serde::Deserialize)]
42struct NotificationParam {
43 gid: String,
44}
45
46struct RPCRequest {
47 params: Option<serde_json::Value>,
48 method: &'static str,
49 handler: oneshot::Sender<RPCReponse>,
50}
51
52enum RPCReponse {
53 Success(serde_json::Value),
54 Error(jsonrpc::Error),
55}
56
57pub struct ConnectionMeta {
58 pub url: String,
59 pub token: Option<String>,
60}
61
62impl ConnectionMeta{
63 pub fn new(url: &str, token: Option<&str>) -> Self {
64 Self {
65 url: url.to_string(),
66 token: token.map(|s| format!("token:{}", s)),
67 }
68 }
69}
70
71impl tungstenite::client::IntoClientRequest for &ConnectionMeta{
72 fn into_client_request(self) -> tungstenite::Result<tungstenite::handshake::client::Request> {
73 self.url.as_str().into_client_request()
75 }
76}
77
78#[derive(Clone)]
79pub struct Client {
80 inner: Arc<ClientInner>,
81}
82
83impl Deref for Client {
84 type Target = ClientInner;
85
86 fn deref(&self) -> &Self::Target {
87 &self.inner
88 }
89}
90
91impl Client {
92 pub async fn connect(meta: ConnectionMeta) -> Result<(Self, mpsc::UnboundedReceiver<Notification>)>{
93 let (inner, notify_rx) = ClientInner::connect(meta).await?;
94 let client = Client {
95 inner: Arc::new(inner),
96 };
97 Ok((client, notify_rx))
98 }
99}
100
101pub struct ClientInner {
102 message_tx: mpsc::Sender<RPCRequest>,
103 token: Option<String>,
104 _drop_rx: oneshot::Receiver<()>,
105}
106
107impl ClientInner {
108 async fn connect(
109 meta: ConnectionMeta,
110 ) -> Result<(Self, mpsc::UnboundedReceiver<Notification>)> {
111 let (ws, _) = tokio_tungstenite::connect_async(&meta)
112 .await
113 .map_err(Error::Connect)?;
114 let (message_tx, message_rx) = mpsc::channel(32);
115 let (notification_tx, notification_rx) = mpsc::unbounded_channel();
116 let (drop_tx, _drop_rx) = oneshot::channel();
117 let token = meta.token.clone();
118 tokio::spawn(Self::background(
119 ws,
120 meta,
121 message_rx,
122 drop_tx,
123 notification_tx,
124 ));
125 Ok((
126 Self {
127 message_tx,
128 token,
129 _drop_rx,
130 },
131 notification_rx,
132 ))
133 }
134
135 pub async fn call<C: Call>(&self, call: C) -> Result<C::Response> {
136 let (tx, rx) = oneshot::channel();
137
138 let method = call.method();
139 let params = match call.to_params(self.token.as_ref().map(AsRef::as_ref)) {
140 Some(params) => Some(serde_json::to_value(params).map_err(Error::Encode)?),
141 None => None,
142 };
143
144 tracing::debug!("call method: {}, params: {:?}", method, params);
145
146 let request = RPCRequest {
147 params,
148 method,
149 handler: tx,
150 };
151 self.message_tx
152 .send(request)
153 .await
154 .map_err(|_| Error::ChannelSend)?;
155 match rx.await.map_err(Error::ChannelRecv)? {
156 RPCReponse::Success(value) => {
157 serde_json::from_value(value).map_err(Error::Decode)
158 }
159 RPCReponse::Error(err) => Err(err.into()),
160 }
161 }
162
163 async fn background(
164 ws: WSStream,
165 meta: ConnectionMeta,
166 mut message_rx: mpsc::Receiver<RPCRequest>,
167 mut drop_tx: oneshot::Sender<()>,
168 notification_tx: mpsc::UnboundedSender<Notification>,
169 ) {
170 let (mut ws_tx, mut ws_rx) = ws.split();
171 let mut shutdown = tokio::spawn({
172 let notification_tx = notification_tx.clone();
173 async move {
174 tokio::join!(drop_tx.closed(), notification_tx.closed());
175 }
176 });
177
178 let mut request_id = 1i64;
179 let mut pending_requests = std::collections::HashMap::new();
180
181 loop {
182 loop {
183 if notification_tx.is_closed() && message_rx.is_closed() {
184 tracing::info!("background task shutdown");
185 return;
186 }
187 tokio::select! {
188 _ = &mut shutdown => {
189 tracing::info!("background task shutdown");
190 return;
191 }
192 Some(msg) = message_rx.recv() => {
193 request_id += 1;
194 pending_requests.insert(request_id, msg.handler);
195
196 if let Err(e) = timeout(
197 Duration::from_secs(10),
198 Self::send_request(&mut ws_tx, request_id, msg.method, msg.params,)
199 ).await {
200 tracing::error!("send request error: {e}");
201 break;
202 }
203 }
204 Some(msg) = ws_rx.next() => {
205 let text = match msg {
206 Ok(WSMessage::Text(text)) => text,
207 Ok(WSMessage::Close(_)) => {
208 tracing::info!("websocket closed");
209 break;
210 }
211 Ok(_) => {
212 continue;
213 }
214 Err(e) => {
215 tracing::error!("websocket error: {e}");
216 break;
217 }
218 };
219 Self::handle_response(&text, &mut pending_requests, notification_tx.clone());
220 }
221 }
222 }
223 pending_requests.clear();
224
225 loop {
227 if notification_tx.is_closed() && message_rx.is_closed() {
228 tracing::info!("background task shutdown");
229 return;
230 }
231 match timeout(
232 Duration::from_secs(10),
233 tokio_tungstenite::connect_async(&meta),
234 )
235 .await
236 {
237 Err(e) => {
238 tracing::error!("reconnect error: {e}, will retry in 10 seconds");
239 tokio::time::sleep(Duration::from_secs(10)).await;
240 }
241 Ok(Err(e)) => {
242 tracing::error!("reconnect timeout: {e}, will retry in 10 seconds");
243 tokio::time::sleep(Duration::from_secs(10)).await;
244 }
245 Ok(Ok((new_ws, _))) => {
246 let (tx, rx) = new_ws.split();
247 ws_tx = tx;
248 ws_rx = rx;
249 break;
250 }
251 }
252 }
253 }
254 }
255
256 async fn send_request(
257 sink: &mut SplitSink<WSStream, WSMessage>,
258 id: i64,
259 method: &str,
260 params: Option<serde_json::Value>,
261 ) -> Result<()> {
262 let rpc_req = jsonrpc::Request {
263 id: Some(id),
264 jsonrpc: "2.0",
265 method,
266 params,
267 };
268 sink.send(WSMessage::Text(
269 serde_json::to_string(&rpc_req)
270 .map_err(Error::Encode)?
271 .into(),
272 ))
273 .await
274 .map_err(Error::Websocket)
275 }
276
277 fn handle_response(
278 text: &str,
279 pending_requests: &mut std::collections::HashMap<i64, oneshot::Sender<RPCReponse>>,
280 notification_tx: mpsc::UnboundedSender<Notification>,
281 ) {
282 if let Ok(resp) = serde_json::from_str::<
283 jsonrpc::Response<i64, serde_json::Value, Vec<NotificationParam>>,
284 >(text)
285 {
286 match resp {
287 jsonrpc::Response::Err { id, error } => {
288 if let Some(tx) = pending_requests.remove(&id) {
289 let _ = tx.send(RPCReponse::Error(error));
290 }
291 }
292 jsonrpc::Response::Resp { id, result } => {
293 if let Some(tx) = pending_requests.remove(&id) {
294 let _ = tx.send(RPCReponse::Success(result));
295 }
296 }
297 jsonrpc::Response::Notification { method, params } => {
298 tokio::spawn(async move {
299 let method = method;
300 for param in params {
301 if notification_tx.send(Notification::new(&method, param.gid)).is_err()
302 {
303 break;
304 }
305 }
306 });
307 }
308 }
309 }
310 }
311}