aria2_ws/
callback.rs

1use std::{
2    collections::{hash_map::Entry, HashMap},
3    fmt,
4    sync::Weak,
5    time::Duration,
6};
7
8use futures::future::BoxFuture;
9use log::{debug, info};
10use serde::Deserialize;
11use snafu::ResultExt;
12use tokio::{
13    select, spawn,
14    sync::{broadcast, mpsc},
15    time::timeout,
16};
17
18use crate::{error, utils::print_error, Event, InnerClient, Notification, Result};
19
20type Callback = Option<BoxFuture<'static, ()>>;
21
22/// Callbacks that will be executed on notifications.
23///
24/// If the connection lost, all callbacks will be checked whether they need to be executed once reconnected.
25///
26/// It executes at most once for each task. That means a task can either be completed or failed.
27///
28/// If you need to customize the behavior, you can use `Client::subscribe_notifications`
29/// to receive notifications and handle them yourself,
30/// or use `tell_status` to check the status of the task.
31#[derive(Default)]
32pub struct Callbacks {
33    /// Will trigger on `Event::Complete` or `Event::BtComplete`.
34    pub on_download_complete: Callback,
35    /// Will trigger on `Event::Error`.
36    pub on_error: Callback,
37}
38
39impl Callbacks {
40    pub(crate) fn is_empty(&self) -> bool {
41        self.on_download_complete.is_none() && self.on_error.is_none()
42    }
43}
44
45impl fmt::Debug for Callbacks {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        f.debug_struct("Callbacks")
48            .field("on_download_complete", &self.on_download_complete.is_some())
49            .field("on_error", &self.on_error.is_some())
50            .finish()
51    }
52}
53
54/// Check whether the callback is ready to be executed after reconnected.
55async fn on_reconnect(
56    inner: &InnerClient,
57    callbacks_map: &mut HashMap<String, Callbacks>,
58) -> Result<()> {
59    // Response from `custom_tell_stopped` call
60    #[derive(Debug, Clone, Deserialize)]
61    #[serde(rename_all = "camelCase")]
62    struct TaskStatus {
63        status: String,
64        total_length: String,
65        completed_length: String,
66        gid: String,
67    }
68
69    if callbacks_map.is_empty() {
70        return Ok(());
71    }
72    let mut tasks = HashMap::new();
73    let req = inner.custom_tell_stopped(
74        0,
75        1000,
76        Some(
77            ["status", "totalLength", "completedLength", "gid"]
78                .into_iter()
79                .map(|x| x.to_string())
80                .collect(),
81        ),
82    );
83    // Cancel if takes too long
84    for map in timeout(Duration::from_secs(10), req)
85        .await
86        .context(error::ReconnectTaskTimeoutSnafu)??
87    {
88        let task: TaskStatus =
89            serde_json::from_value(serde_json::Value::Object(map)).context(error::JsonSnafu)?;
90        tasks.insert(task.gid.clone(), task);
91    }
92
93    for (gid, callbacks) in callbacks_map {
94        if let Some(status) = tasks.get(gid) {
95            debug!("checking callbacks for gid {} after reconnected", gid);
96            // Check if the task is finished by checking the length.
97            if status.total_length == status.completed_length {
98                if let Some(h) = callbacks.on_download_complete.take() {
99                    spawn(h);
100                }
101            } else if status.status == "error" {
102                if let Some(h) = callbacks.on_error.take() {
103                    spawn(h);
104                }
105            }
106        }
107    }
108
109    Ok(())
110}
111
112fn invoke_callbacks_on_event(event: Event, callbacks: &mut Callbacks) -> bool {
113    match event {
114        Event::Complete | Event::BtComplete => {
115            if let Some(callback) = callbacks.on_download_complete.take() {
116                // Spawn a new task to avoid blocking the notification receiver.
117                spawn(callback);
118            }
119        }
120        Event::Error => {
121            if let Some(callback) = callbacks.on_error.take() {
122                spawn(callback);
123            }
124        }
125        _ => return false,
126    }
127    true
128}
129
130#[derive(Debug)]
131pub(crate) struct TaskCallbacks {
132    pub gid: String,
133    pub callbacks: Callbacks,
134}
135
136pub(crate) async fn callback_worker(
137    weak: Weak<InnerClient>,
138    mut rx_notification: broadcast::Receiver<Notification>,
139    mut rx_callback: mpsc::UnboundedReceiver<TaskCallbacks>,
140) {
141    use broadcast::error::RecvError;
142
143    let mut is_first_notification = true;
144    let mut callbacks_map = HashMap::new();
145    let mut yet_processed_notifications: HashMap<String, Vec<Event>> = HashMap::new();
146
147    loop {
148        select! {
149            r = rx_notification.recv() => {
150                match r {
151                    Ok(notification) => {
152                        match notification {
153                            Notification::WebSocketConnected => {
154                                if is_first_notification {
155                                    is_first_notification = false;
156                                    continue;
157                                    // Skip the first connected notification
158                                }
159                                // We might miss some notifications when the connection is lost.
160                                // So we need to check whether the callbacks need to be executed after reconnected.
161                                if let Some(inner) = weak.upgrade() {
162                                    print_error(on_reconnect(inner.as_ref(), &mut callbacks_map).await);
163                                }
164                            },
165                            Notification::Aria2 { gid, event } => {
166                                match callbacks_map.entry(gid.clone()) {
167                                    Entry::Occupied(mut e) => {
168                                        let invoked = invoke_callbacks_on_event(event, e.get_mut());
169                                        if invoked {
170                                            e.remove();
171                                        }
172                                    }
173                                    _ => {
174                                        // If the task is not in the map, we need to store it for possible later processing.
175                                        yet_processed_notifications
176                                            .entry(gid.clone())
177                                            .or_insert_with(Vec::new)
178                                            .push(event);
179                                    }
180                                }
181                            },
182                            _ => {}
183                        }
184                    }
185                    Err(RecvError::Closed) => {
186                        return;
187                    }
188                    Err(RecvError::Lagged(_)) => {
189                        info!("unexpected lag in notifications");
190                    }
191                }
192            },
193            r = rx_callback.recv() => {
194                match r {
195                    Some(TaskCallbacks { gid, mut callbacks }) => {
196                        if let Some(events) = yet_processed_notifications.remove(&gid) {
197                            let mut invoked = false;
198                            for event in events {
199                                invoked = invoke_callbacks_on_event(event, &mut callbacks);
200                                if invoked {
201                                    break;
202                                }
203                            }
204                            if !invoked {
205                                callbacks_map.insert(gid, callbacks);
206                            }
207                        } else {
208                            callbacks_map.insert(gid, callbacks);
209                        }
210                    }
211                    None => {
212                        return;
213                    }
214                }
215            },
216        }
217    }
218}