use std::collections::{BinaryHeap, HashMap, HashSet};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore, broadcast};
use tokio::task::JoinHandle;
use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::{Stream, StreamExt};
use tokio_util::sync::CancellationToken;
use crate::download::config::speed_profile::SpeedProfile;
pub use crate::download::types::{DownloadPriority, DownloadStatus, ManagerConfig, ProgressUpdate};
use crate::download::types::{DownloadTask, ProgressCallback, ProgressCounters};
use crate::download::worker::{
WorkerContext, build_progress_callback, emit_bus_event, prepare_task_fetcher, run_download_task,
};
use crate::error::Result;
use crate::model::format::HttpHeaders;
pub struct DownloadManager {
config: ManagerConfig,
client: Arc<reqwest::Client>,
queue: Arc<Mutex<BinaryHeap<DownloadTask>>>,
semaphore: Arc<Semaphore>,
next_id: Arc<Mutex<u64>>,
statuses: Arc<Mutex<HashMap<u64, DownloadStatus>>>,
tasks: Arc<Mutex<HashMap<u64, JoinHandle<Result<()>>>>>,
cancelled: Arc<Mutex<HashSet<u64>>>,
completion_tx: broadcast::Sender<(u64, DownloadStatus)>,
progress_tx: broadcast::Sender<ProgressUpdate>,
event_bus: Option<crate::events::EventBus>,
progress_counters: ProgressCounters,
worker_notify: Arc<tokio::sync::Notify>,
worker_started: Arc<AtomicBool>,
shutdown_token: CancellationToken,
}
impl DownloadManager {
pub fn parallel_segments(&self) -> usize {
self.config.parallel_segments
}
pub fn segment_size(&self) -> usize {
self.config.segment_size
}
pub fn retry_attempts(&self) -> usize {
self.config.retry_attempts
}
pub fn speed_profile(&self) -> SpeedProfile {
self.config.speed_profile
}
pub fn new() -> Self {
tracing::debug!("⚙️ Creating download manager with default config");
Self::with_config(ManagerConfig::default())
}
pub fn with_config(config: ManagerConfig) -> Self {
tracing::debug!(config = %config, "⚙️ Creating download manager with config");
Self::with_config_and_event_bus(config, None)
}
pub fn with_config_and_event_bus(config: ManagerConfig, event_bus: Option<crate::events::EventBus>) -> Self {
tracing::debug!(config = %config, has_event_bus = event_bus.is_some(), "⚙️ Initializing download manager");
let (completion_tx, _) = broadcast::channel(100);
let (progress_tx, _) = broadcast::channel(1000);
let client = Self::build_shared_client(&config);
Self {
config: config.clone(),
client,
queue: Arc::new(Mutex::new(BinaryHeap::new())),
semaphore: Arc::new(Semaphore::new(config.max_concurrent_downloads)),
next_id: Arc::new(Mutex::new(0)),
statuses: Arc::new(Mutex::new(HashMap::new())),
tasks: Arc::new(Mutex::new(HashMap::new())),
cancelled: Arc::new(Mutex::new(HashSet::new())),
completion_tx,
progress_tx,
event_bus,
progress_counters: Arc::new(std::sync::Mutex::new(HashMap::new())),
worker_notify: Arc::new(tokio::sync::Notify::new()),
worker_started: Arc::new(AtomicBool::new(false)),
shutdown_token: CancellationToken::new(),
}
}
pub fn client(&self) -> &Arc<reqwest::Client> {
&self.client
}
fn build_shared_client(config: &ManagerConfig) -> Arc<reqwest::Client> {
let default_headers = config
.user_agent
.as_ref()
.map(|ua| crate::model::format::HttpHeaders::browser_defaults(ua.clone()).to_header_map());
let http_config = crate::utils::http::HttpClientConfig {
proxy: config.proxy.as_ref(),
user_agent: config.user_agent.clone(),
default_headers,
http2_adaptive_window: true,
..Default::default()
};
crate::utils::http::build_http_client(http_config).unwrap_or_else(|e| {
tracing::warn!(error = %e, "Failed to build configured HTTP client, falling back to default");
Arc::new(reqwest::Client::new())
})
}
pub async fn enqueue(
&self,
url: impl AsRef<str>,
destination: impl Into<PathBuf>,
priority: Option<DownloadPriority>,
) -> u64 {
self.enqueue_internal(
url.as_ref().to_string(),
destination.into(),
priority.unwrap_or(DownloadPriority::Normal),
None,
None,
None,
)
.await
}
pub async fn enqueue_with_headers(
&self,
url: impl AsRef<str>,
destination: impl Into<PathBuf>,
priority: Option<DownloadPriority>,
http_headers: Option<crate::model::format::HttpHeaders>,
) -> u64 {
self.enqueue_internal(
url.as_ref().to_string(),
destination.into(),
priority.unwrap_or(DownloadPriority::Normal),
None,
http_headers,
None,
)
.await
}
pub async fn enqueue_with_progress<F>(
&self,
url: impl AsRef<str>,
destination: impl Into<PathBuf>,
priority: Option<DownloadPriority>,
progress_callback: F,
) -> u64
where
F: Fn(u64, u64) + Send + Sync + 'static,
{
self.enqueue_internal(
url.as_ref().to_string(),
destination.into(),
priority.unwrap_or(DownloadPriority::Normal),
Some(Arc::new(progress_callback)),
None,
None,
)
.await
}
pub async fn enqueue_with_progress_and_headers<F>(
&self,
url: impl AsRef<str>,
destination: impl Into<PathBuf>,
priority: Option<DownloadPriority>,
progress_callback: F,
http_headers: Option<HttpHeaders>,
) -> u64
where
F: Fn(u64, u64) + Send + Sync + 'static,
{
self.enqueue_internal(
url.as_ref().to_string(),
destination.into(),
priority.unwrap_or(DownloadPriority::Normal),
Some(Arc::new(progress_callback)),
http_headers,
None,
)
.await
}
pub async fn enqueue_range(
&self,
url: impl AsRef<str>,
destination: impl Into<PathBuf>,
byte_start: u64,
byte_end: u64,
priority: Option<DownloadPriority>,
http_headers: Option<HttpHeaders>,
) -> u64 {
self.enqueue_internal(
url.as_ref().to_string(),
destination.into(),
priority.unwrap_or(DownloadPriority::Normal),
None,
http_headers,
Some((byte_start, byte_end)),
)
.await
}
pub async fn get_status(&self, id: u64) -> Option<DownloadStatus> {
tracing::debug!(download_id = id, "⚙️ Getting download status");
let statuses = self.statuses.lock().await;
let status = statuses.get(&id)?;
if matches!(status, DownloadStatus::Downloading { .. }) {
let counters = self.progress_counters.lock().unwrap();
if let Some((dl, total)) = counters.get(&id) {
return Some(DownloadStatus::Downloading {
downloaded_bytes: dl.load(AtomicOrdering::Relaxed),
total_bytes: total.load(AtomicOrdering::Relaxed),
});
}
}
Some(status.clone())
}
pub async fn cleanup_finished(&self) {
tracing::debug!("⚙️ Cleaning up finished downloads");
let mut statuses = self.statuses.lock().await;
let mut cancelled = self.cancelled.lock().await;
let ids_to_remove: Vec<u64> = statuses
.iter()
.filter_map(|(id, status)| match status {
DownloadStatus::Completed | DownloadStatus::Failed { .. } | DownloadStatus::Canceled => Some(*id),
_ => None,
})
.collect();
for id in &ids_to_remove {
statuses.remove(id);
cancelled.remove(id);
}
drop(statuses);
drop(cancelled);
let mut tasks = self.tasks.lock().await;
let mut counters = self.progress_counters.lock().unwrap_or_else(|e| e.into_inner());
for id in &ids_to_remove {
tasks.remove(id);
counters.remove(id);
}
}
pub async fn cancel(&self, id: u64) -> bool {
tracing::debug!(download_id = id, "📥 Cancelling download");
{
self.cancelled.lock().await.insert(id);
}
{
self.progress_counters.lock().unwrap().remove(&id);
}
let task_handle = { self.tasks.lock().await.remove(&id) };
if let Some(handle) = task_handle {
handle.abort();
self.mark_cancelled_and_emit(id, "Cancelled by user").await;
return true;
}
if self.remove_from_queue(id).await {
self.mark_cancelled_and_emit(id, "Cancelled before download started")
.await;
return true;
}
false
}
pub async fn wait_for_completion(&self, id: u64) -> Option<DownloadStatus> {
tracing::debug!(download_id = id, "📥 Waiting for download completion");
let mut rx = self.completion_tx.subscribe();
if let Some(status) = self.get_status(id).await
&& is_terminal_status(&status)
{
return Some(status);
}
loop {
match rx.recv().await {
Ok((download_id, status)) if download_id == id => {
if is_terminal_status(&status) {
return Some(status);
}
continue;
}
Ok(_) => continue, Err(broadcast::error::RecvError::Lagged(_)) => {
if let Some(status) = self.get_status(id).await {
if is_terminal_status(&status) {
return Some(status);
}
continue;
} else {
return None;
}
}
Err(_) => return None, }
}
}
pub fn progress_stream(&self, id: u64) -> impl Stream<Item = ProgressUpdate> + Send + 'static {
tracing::debug!(download_id = id, "📥 Subscribing to progress stream");
let rx = self.progress_tx.subscribe();
BroadcastStream::new(rx).filter_map(move |result| match result {
Ok(update) if update.download_id == id => Some(update),
_ => None,
})
}
pub fn progress_stream_all(&self) -> impl Stream<Item = ProgressUpdate> + Send + 'static {
tracing::debug!("📥 Subscribing to all progress streams");
let rx = self.progress_tx.subscribe();
BroadcastStream::new(rx).filter_map(|result| result.ok())
}
fn emit_event(&self, event: crate::events::DownloadEvent) {
tracing::trace!(event = ?event, "🔔 Emitting download event");
if let Some(ref bus) = self.event_bus {
bus.emit(event);
}
}
async fn mark_cancelled_and_emit(&self, id: u64, reason: &str) {
let mut statuses = self.statuses.lock().await;
statuses.insert(id, DownloadStatus::Canceled);
drop(statuses);
self.emit_event(crate::events::DownloadEvent::DownloadCanceled {
download_id: id,
reason: reason.to_string(),
});
}
async fn remove_from_queue(&self, id: u64) -> bool {
let mut queue = self.queue.lock().await;
let len_before = queue.len();
let mut new_queue = BinaryHeap::new();
for task in queue.drain() {
if task.id != id {
new_queue.push(task);
}
}
*queue = new_queue;
len_before > queue.len()
}
async fn enqueue_internal(
&self,
url: String,
destination: PathBuf,
priority: DownloadPriority,
progress_callback: Option<ProgressCallback>,
http_headers: Option<crate::model::format::HttpHeaders>,
range_constraint: Option<(u64, u64)>,
) -> u64 {
let mut id_guard = self.next_id.lock().await;
let id = *id_guard;
*id_guard += 1;
drop(id_guard);
let task = DownloadTask {
url: url.clone(),
destination: destination.clone(),
priority,
id,
progress_callback,
http_headers,
range_constraint,
};
tracing::debug!(id = id, url = url, destination = ?destination, priority = ?priority, "📥 Enqueuing download");
{
let mut queue = self.queue.lock().await;
queue.push(task);
}
{
let mut statuses = self.statuses.lock().await;
statuses.insert(id, DownloadStatus::Queued);
}
self.emit_event(crate::events::DownloadEvent::DownloadQueued {
download_id: id,
url,
priority,
output_path: destination,
});
self.worker_notify.notify_one();
self.ensure_worker();
if id % 100 == 0 {
let status_count = {
let statuses = self.statuses.lock().await;
statuses.len()
};
if status_count > self.config.cleanup_threshold {
self.cleanup_finished().await;
}
}
id
}
fn ensure_worker(&self) {
if self
.worker_started
.compare_exchange(false, true, AtomicOrdering::AcqRel, AtomicOrdering::Acquire)
.is_err()
{
return;
}
tracing::debug!(
max_concurrent = self.config.max_concurrent_downloads,
"⚙️ Starting download queue worker"
);
let ctx = WorkerLoopCtx {
queue: self.queue.clone(),
semaphore: self.semaphore.clone(),
statuses: self.statuses.clone(),
tasks: self.tasks.clone(),
config: self.config.clone(),
cancelled: self.cancelled.clone(),
completion_tx: self.completion_tx.clone(),
progress_tx: self.progress_tx.clone(),
event_bus: self.event_bus.clone(),
notify: self.worker_notify.clone(),
progress_counters: self.progress_counters.clone(),
shared_client: Arc::clone(&self.client),
shutdown: self.shutdown_token.clone(),
};
tokio::spawn(run_worker_loop(ctx));
}
pub fn shutdown(&self) {
self.shutdown_token.cancel();
}
}
impl std::fmt::Debug for DownloadManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DownloadManager")
.field("config", &self.config)
.field("max_concurrent_downloads", &self.config.max_concurrent_downloads)
.finish_non_exhaustive()
}
}
impl Default for DownloadManager {
fn default() -> Self {
Self::new()
}
}
impl Drop for DownloadManager {
fn drop(&mut self) {
self.shutdown_token.cancel();
if let Ok(tasks) = self.tasks.try_lock() {
for (_, handle) in tasks.iter() {
handle.abort();
}
}
}
}
fn is_terminal_status(status: &DownloadStatus) -> bool {
matches!(
status,
DownloadStatus::Completed | DownloadStatus::Failed { .. } | DownloadStatus::Canceled
)
}
#[derive(Debug)]
struct WorkerLoopCtx {
queue: Arc<Mutex<BinaryHeap<DownloadTask>>>,
semaphore: Arc<Semaphore>,
statuses: Arc<Mutex<HashMap<u64, DownloadStatus>>>,
tasks: Arc<Mutex<HashMap<u64, JoinHandle<Result<()>>>>>,
config: ManagerConfig,
cancelled: Arc<Mutex<HashSet<u64>>>,
completion_tx: broadcast::Sender<(u64, DownloadStatus)>,
progress_tx: broadcast::Sender<ProgressUpdate>,
event_bus: Option<crate::events::EventBus>,
notify: Arc<tokio::sync::Notify>,
progress_counters: ProgressCounters,
shared_client: Arc<reqwest::Client>,
shutdown: CancellationToken,
}
async fn run_worker_loop(ctx: WorkerLoopCtx) {
loop {
if ctx.shutdown.is_cancelled() {
tracing::debug!("🛑 Worker shutting down");
return;
}
loop {
let permit = match ctx.semaphore.clone().acquire_owned().await {
Ok(p) => p,
Err(_) => return, };
let Some(task) = ctx.queue.lock().await.pop() else {
drop(permit);
break; };
tracing::debug!(
task_id = task.id,
url = %task.url,
destination = ?task.destination,
priority = ?task.priority,
"⚙️ Popped task from download queue"
);
let worker_ctx = WorkerContext {
statuses: ctx.statuses.clone(),
tasks: ctx.tasks.clone(),
cancelled: ctx.cancelled.clone(),
completion_tx: ctx.completion_tx.clone(),
event_bus: ctx.event_bus.clone(),
progress_counters: ctx.progress_counters.clone(),
};
process_queued_task(
task,
&ctx.config,
&ctx.shared_client,
permit,
ctx.progress_tx.clone(),
worker_ctx,
)
.await;
}
tokio::select! {
_ = ctx.notify.notified() => {}
_ = ctx.shutdown.cancelled() => {
tracing::debug!("🛑 Worker shutting down");
return;
}
}
}
}
async fn process_queued_task(
task: DownloadTask,
config: &ManagerConfig,
client: &Arc<reqwest::Client>,
permit: OwnedSemaphorePermit,
progress_tx: broadcast::Sender<ProgressUpdate>,
ctx: WorkerContext,
) {
if ctx.cancelled.lock().await.contains(&task.id) {
return; }
ctx.statuses.lock().await.insert(
task.id,
DownloadStatus::Downloading {
downloaded_bytes: 0,
total_bytes: 0,
},
);
emit_bus_event(
&ctx.event_bus,
crate::events::DownloadEvent::DownloadStarted {
download_id: task.id,
url: task.url.clone(),
total_bytes: 0,
format_id: None,
},
);
let fetcher = match prepare_task_fetcher(&task, config, client) {
Ok(f) => f,
Err(e) => {
let reason = e.to_string();
ctx.statuses
.lock()
.await
.insert(task.id, DownloadStatus::Failed { reason: reason.clone() });
emit_bus_event(
&ctx.event_bus,
crate::events::DownloadEvent::DownloadFailed {
download_id: task.id,
url: task.url.clone(),
error: reason.clone(),
retry_count: 0,
},
);
let _ = ctx.completion_tx.send((task.id, DownloadStatus::Failed { reason }));
return;
}
};
let fetcher = fetcher.with_progress_callback(build_progress_callback(
task.id,
&ctx.progress_counters,
progress_tx,
ctx.event_bus.clone(),
task.progress_callback.clone(),
));
let task_id = task.id;
let tasks = ctx.tasks.clone(); let handle = tokio::spawn(run_download_task(
task_id,
task.url.clone(),
task.destination.clone(),
fetcher,
permit,
ctx,
));
tasks.lock().await.insert(task_id, handle);
}