1use std::{
2 collections::{hash_map::Entry, HashMap},
3 fmt,
4 sync::Weak,
5 time::Duration,
6};
7
8use futures::future::BoxFuture;
9use log::{debug, info};
10use serde::Deserialize;
11use snafu::ResultExt;
12use tokio::{
13 select, spawn,
14 sync::{broadcast, mpsc},
15 time::timeout,
16};
17
18use crate::{error, utils::print_error, Event, InnerClient, Notification, Result};
19
20type Callback = Option<BoxFuture<'static, ()>>;
21
22#[derive(Default)]
32pub struct Callbacks {
33 pub on_download_complete: Callback,
35 pub on_error: Callback,
37}
38
39impl Callbacks {
40 pub(crate) fn is_empty(&self) -> bool {
41 self.on_download_complete.is_none() && self.on_error.is_none()
42 }
43}
44
45impl fmt::Debug for Callbacks {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 f.debug_struct("Callbacks")
48 .field("on_download_complete", &self.on_download_complete.is_some())
49 .field("on_error", &self.on_error.is_some())
50 .finish()
51 }
52}
53
54async fn on_reconnect(
56 inner: &InnerClient,
57 callbacks_map: &mut HashMap<String, Callbacks>,
58) -> Result<()> {
59 #[derive(Debug, Clone, Deserialize)]
61 #[serde(rename_all = "camelCase")]
62 struct TaskStatus {
63 status: String,
64 total_length: String,
65 completed_length: String,
66 gid: String,
67 }
68
69 if callbacks_map.is_empty() {
70 return Ok(());
71 }
72 let mut tasks = HashMap::new();
73 let req = inner.custom_tell_stopped(
74 0,
75 1000,
76 Some(
77 ["status", "totalLength", "completedLength", "gid"]
78 .into_iter()
79 .map(|x| x.to_string())
80 .collect(),
81 ),
82 );
83 for map in timeout(Duration::from_secs(10), req)
85 .await
86 .context(error::ReconnectTaskTimeoutSnafu)??
87 {
88 let task: TaskStatus =
89 serde_json::from_value(serde_json::Value::Object(map)).context(error::JsonSnafu)?;
90 tasks.insert(task.gid.clone(), task);
91 }
92
93 for (gid, callbacks) in callbacks_map {
94 if let Some(status) = tasks.get(gid) {
95 debug!("checking callbacks for gid {} after reconnected", gid);
96 if status.total_length == status.completed_length {
98 if let Some(h) = callbacks.on_download_complete.take() {
99 spawn(h);
100 }
101 } else if status.status == "error" {
102 if let Some(h) = callbacks.on_error.take() {
103 spawn(h);
104 }
105 }
106 }
107 }
108
109 Ok(())
110}
111
112fn invoke_callbacks_on_event(event: Event, callbacks: &mut Callbacks) -> bool {
113 match event {
114 Event::Complete | Event::BtComplete => {
115 if let Some(callback) = callbacks.on_download_complete.take() {
116 spawn(callback);
118 }
119 }
120 Event::Error => {
121 if let Some(callback) = callbacks.on_error.take() {
122 spawn(callback);
123 }
124 }
125 _ => return false,
126 }
127 true
128}
129
130#[derive(Debug)]
131pub(crate) struct TaskCallbacks {
132 pub gid: String,
133 pub callbacks: Callbacks,
134}
135
136pub(crate) async fn callback_worker(
137 weak: Weak<InnerClient>,
138 mut rx_notification: broadcast::Receiver<Notification>,
139 mut rx_callback: mpsc::UnboundedReceiver<TaskCallbacks>,
140) {
141 use broadcast::error::RecvError;
142
143 let mut is_first_notification = true;
144 let mut callbacks_map = HashMap::new();
145 let mut yet_processed_notifications: HashMap<String, Vec<Event>> = HashMap::new();
146
147 loop {
148 select! {
149 r = rx_notification.recv() => {
150 match r {
151 Ok(notification) => {
152 match notification {
153 Notification::WebSocketConnected => {
154 if is_first_notification {
155 is_first_notification = false;
156 continue;
157 }
159 if let Some(inner) = weak.upgrade() {
162 print_error(on_reconnect(inner.as_ref(), &mut callbacks_map).await);
163 }
164 },
165 Notification::Aria2 { gid, event } => {
166 match callbacks_map.entry(gid.clone()) {
167 Entry::Occupied(mut e) => {
168 let invoked = invoke_callbacks_on_event(event, e.get_mut());
169 if invoked {
170 e.remove();
171 }
172 }
173 _ => {
174 yet_processed_notifications
176 .entry(gid.clone())
177 .or_insert_with(Vec::new)
178 .push(event);
179 }
180 }
181 },
182 _ => {}
183 }
184 }
185 Err(RecvError::Closed) => {
186 return;
187 }
188 Err(RecvError::Lagged(_)) => {
189 info!("unexpected lag in notifications");
190 }
191 }
192 },
193 r = rx_callback.recv() => {
194 match r {
195 Some(TaskCallbacks { gid, mut callbacks }) => {
196 if let Some(events) = yet_processed_notifications.remove(&gid) {
197 let mut invoked = false;
198 for event in events {
199 invoked = invoke_callbacks_on_event(event, &mut callbacks);
200 if invoked {
201 break;
202 }
203 }
204 if !invoked {
205 callbacks_map.insert(gid, callbacks);
206 }
207 } else {
208 callbacks_map.insert(gid, callbacks);
209 }
210 }
211 None => {
212 return;
213 }
214 }
215 },
216 }
217 }
218}