use std::{
collections::{hash_map::Entry, HashMap},
fmt,
sync::Weak,
time::Duration,
};
use futures::future::BoxFuture;
use log::{debug, info};
use serde::Deserialize;
use snafu::ResultExt;
use tokio::{
select, spawn,
sync::{broadcast, mpsc},
time::timeout,
};
use crate::{error, utils::print_error, Event, InnerClient, Notification, Result};
type Callback = Option<BoxFuture<'static, ()>>;
#[derive(Default)]
pub struct Callbacks {
pub on_download_complete: Callback,
pub on_error: Callback,
}
impl Callbacks {
pub(crate) fn is_empty(&self) -> bool {
self.on_download_complete.is_none() && self.on_error.is_none()
}
}
impl fmt::Debug for Callbacks {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Callbacks")
.field("on_download_complete", &self.on_download_complete.is_some())
.field("on_error", &self.on_error.is_some())
.finish()
}
}
async fn on_reconnect(
inner: &InnerClient,
callbacks_map: &mut HashMap<String, Callbacks>,
) -> Result<()> {
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct TaskStatus {
status: String,
total_length: String,
completed_length: String,
gid: String,
}
if callbacks_map.is_empty() {
return Ok(());
}
let mut tasks = HashMap::new();
let req = inner.custom_tell_stopped(
0,
1000,
Some(
["status", "totalLength", "completedLength", "gid"]
.into_iter()
.map(|x| x.to_string())
.collect(),
),
);
for map in timeout(Duration::from_secs(10), req)
.await
.context(error::ReconnectTaskTimeoutSnafu)??
{
let task: TaskStatus =
serde_json::from_value(serde_json::Value::Object(map)).context(error::JsonSnafu)?;
tasks.insert(task.gid.clone(), task);
}
for (gid, callbacks) in callbacks_map {
if let Some(status) = tasks.get(gid) {
debug!("checking callbacks for gid {} after reconnected", gid);
if status.total_length == status.completed_length {
if let Some(h) = callbacks.on_download_complete.take() {
spawn(h);
}
} else if status.status == "error" {
if let Some(h) = callbacks.on_error.take() {
spawn(h);
}
}
}
}
Ok(())
}
fn invoke_callbacks_on_event(event: Event, callbacks: &mut Callbacks) -> bool {
match event {
Event::Complete | Event::BtComplete => {
if let Some(callback) = callbacks.on_download_complete.take() {
spawn(callback);
}
}
Event::Error => {
if let Some(callback) = callbacks.on_error.take() {
spawn(callback);
}
}
_ => return false,
}
true
}
#[derive(Debug)]
pub(crate) struct TaskCallbacks {
pub gid: String,
pub callbacks: Callbacks,
}
pub(crate) async fn callback_worker(
weak: Weak<InnerClient>,
mut rx_notification: broadcast::Receiver<Notification>,
mut rx_callback: mpsc::UnboundedReceiver<TaskCallbacks>,
) {
use broadcast::error::RecvError;
let mut is_first_notification = true;
let mut callbacks_map = HashMap::new();
let mut yet_processed_notifications: HashMap<String, Vec<Event>> = HashMap::new();
loop {
select! {
r = rx_notification.recv() => {
match r {
Ok(notification) => {
match notification {
Notification::WebSocketConnected => {
if is_first_notification {
is_first_notification = false;
continue;
}
if let Some(inner) = weak.upgrade() {
print_error(on_reconnect(inner.as_ref(), &mut callbacks_map).await);
}
},
Notification::Aria2 { gid, event } => {
match callbacks_map.entry(gid.clone()) {
Entry::Occupied(mut e) => {
let invoked = invoke_callbacks_on_event(event, e.get_mut());
if invoked {
e.remove();
}
}
_ => {
yet_processed_notifications
.entry(gid.clone())
.or_insert_with(Vec::new)
.push(event);
}
}
},
_ => {}
}
}
Err(RecvError::Closed) => {
return;
}
Err(RecvError::Lagged(_)) => {
info!("unexpected lag in notifications");
}
}
},
r = rx_callback.recv() => {
match r {
Some(TaskCallbacks { gid, mut callbacks }) => {
if let Some(events) = yet_processed_notifications.remove(&gid) {
let mut invoked = false;
for event in events {
invoked = invoke_callbacks_on_event(event, &mut callbacks);
if invoked {
break;
}
}
if !invoked {
callbacks_map.insert(gid, callbacks);
}
} else {
callbacks_map.insert(gid, callbacks);
}
}
None => {
return;
}
}
},
}
}
}