use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use serde_json::json;
use tokio::sync::Mutex;
use tokio::sync::broadcast::error::RecvError;
use tokio::task::AbortHandle;
use tokio::time::{Instant, sleep};
use crate::cdp::core::CdpCore;
use crate::protocol::Connection;
use crate::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DownloadState {
InProgress,
Completed,
Canceled,
}
#[derive(Debug, Clone)]
pub struct DownloadMission {
pub guid: String,
pub url: String,
pub suggested_filename: String,
pub path: PathBuf,
pub state: DownloadState,
pub received_bytes: u64,
pub total_bytes: u64,
}
impl DownloadMission {
pub fn is_finished(&self) -> bool {
matches!(
self.state,
DownloadState::Completed | DownloadState::Canceled
)
}
pub fn succeeded(&self) -> bool {
matches!(self.state, DownloadState::Completed)
}
pub fn downloaded_bytes(&self) -> u64 {
self.received_bytes
}
pub async fn save_as(&self, dest: impl AsRef<Path>) -> Result<PathBuf> {
let dest = dest.as_ref().to_path_buf();
if let Some(parent) = dest.parent().filter(|p| !p.as_os_str().is_empty()) {
tokio::fs::create_dir_all(parent).await?;
}
if tokio::fs::rename(&self.path, &dest).await.is_err() {
tokio::fs::copy(&self.path, &dest).await?;
let _ = tokio::fs::remove_file(&self.path).await;
}
Ok(dest)
}
}
pub(crate) struct DownloadShared {
pub(crate) missions: Arc<Mutex<Vec<DownloadMission>>>,
pub(crate) new_returned: Arc<Mutex<HashSet<String>>>,
pub(crate) done_returned: Arc<Mutex<HashSet<String>>>,
pub(crate) running: bool,
pub(crate) abort: Option<AbortHandle>,
}
impl Default for DownloadShared {
fn default() -> Self {
Self {
missions: Arc::new(Mutex::new(Vec::new())),
new_returned: Arc::new(Mutex::new(HashSet::new())),
done_returned: Arc::new(Mutex::new(HashSet::new())),
running: false,
abort: None,
}
}
}
pub struct ChromiumDownloads {
core: Arc<CdpCore>,
}
impl ChromiumDownloads {
pub(crate) fn new(core: Arc<CdpCore>) -> Self {
Self { core }
}
pub async fn start(&self) -> Result<()> {
let dir = self.core.download_dir().ok_or_else(|| {
Error::Other(
"downloads(): 需先用 ChromiumOptions::download_path 或 tab.set_download_path 设置下载目录"
.into(),
)
})?;
self.stop().await?;
let _ = std::fs::create_dir_all(&dir);
self.core
.send(
"Browser.setDownloadBehavior",
json!({ "behavior": "allow", "downloadPath": dir.display().to_string(), "eventsEnabled": true }),
)
.await?;
let missions = {
let g = self.core.downloads.lock().await;
g.missions.lock().await.clear();
g.new_returned.lock().await.clear();
g.done_returned.lock().await.clear();
g.missions.clone()
};
let task = tokio::spawn(download_pump(
self.core.conn.clone(),
self.core.session_id.clone(),
dir,
missions,
));
let mut g = self.core.downloads.lock().await;
g.running = true;
g.abort = Some(task.abort_handle());
Ok(())
}
pub async fn listening(&self) -> bool {
self.core.downloads.lock().await.running
}
pub async fn missions(&self) -> Vec<DownloadMission> {
let m = self.core.downloads.lock().await.missions.clone();
m.lock().await.clone()
}
pub async fn wait_new(&self, timeout: Duration) -> Result<Option<DownloadMission>> {
self.ensure_active().await?;
let deadline = Instant::now() + timeout;
loop {
for m in self.missions().await {
let seen = self.core.downloads.lock().await.new_returned.clone();
if seen.lock().await.insert(m.guid.clone()) {
return Ok(Some(m));
}
}
if self.expired(deadline).await {
return Ok(None);
}
}
}
pub async fn wait_done(&self, timeout: Duration) -> Result<Option<DownloadMission>> {
self.ensure_active().await?;
let deadline = Instant::now() + timeout;
loop {
for m in self.missions().await {
if !m.succeeded() {
continue;
}
let returned = self.core.downloads.lock().await.done_returned.clone();
if returned.lock().await.insert(m.guid.clone()) {
return Ok(Some(m));
}
}
if self.expired(deadline).await {
return Ok(None);
}
}
}
pub async fn wait_count_done(
&self,
count: usize,
timeout: Duration,
) -> Result<Vec<DownloadMission>> {
let deadline = Instant::now() + timeout;
let mut out = Vec::with_capacity(count);
while out.len() < count {
let remain = deadline.saturating_duration_since(Instant::now());
if remain.is_zero() {
break;
}
match self.wait_done(remain).await? {
Some(m) => out.push(m),
None => break,
}
}
Ok(out)
}
pub async fn stop(&self) -> Result<()> {
let (abort, missions, new_r, done_r) = {
let mut g = self.core.downloads.lock().await;
g.running = false;
(
g.abort.take(),
g.missions.clone(),
g.new_returned.clone(),
g.done_returned.clone(),
)
};
missions.lock().await.clear();
new_r.lock().await.clear();
done_r.lock().await.clear();
if let Some(a) = abort {
a.abort();
}
Ok(())
}
async fn ensure_active(&self) -> Result<()> {
if self.core.downloads.lock().await.running {
Ok(())
} else {
Err(Error::Other("尚未调用 downloads().start()".into()))
}
}
async fn expired(&self, deadline: Instant) -> bool {
if Instant::now() >= deadline {
return true;
}
sleep(Duration::from_millis(60)).await;
false
}
}
async fn download_pump(
conn: Connection,
session_id: String,
dir: PathBuf,
missions: Arc<Mutex<Vec<DownloadMission>>>,
) {
let mut events = conn.subscribe();
loop {
let ev = match events.recv().await {
Ok(ev) => ev,
Err(RecvError::Lagged(_)) => continue,
Err(RecvError::Closed) => break,
};
if ev.session_id.as_deref() != Some(session_id.as_str()) {
continue;
}
match ev.method.as_str() {
"Page.downloadWillBegin" => {
let guid = ev.params["guid"].as_str().unwrap_or_default().to_string();
if guid.is_empty() {
continue;
}
let name = ev.params["suggestedFilename"]
.as_str()
.unwrap_or_default()
.to_string();
let url = ev.params["url"].as_str().unwrap_or_default().to_string();
let mut g = missions.lock().await;
if !g.iter().any(|m| m.guid == guid) {
g.push(DownloadMission {
path: dir.join(&name),
guid,
url,
suggested_filename: name,
state: DownloadState::InProgress,
received_bytes: 0,
total_bytes: 0,
});
}
}
"Page.downloadProgress" => {
let guid = ev.params["guid"].as_str().unwrap_or_default();
if guid.is_empty() {
continue;
}
let state = map_state(ev.params["state"].as_str().unwrap_or(""));
let received = ev.params["receivedBytes"].as_f64().unwrap_or(0.0) as u64;
let total = ev.params["totalBytes"].as_f64().unwrap_or(0.0) as u64;
let mut g = missions.lock().await;
if let Some(m) = g.iter_mut().find(|m| m.guid == guid) {
m.received_bytes = received;
m.total_bytes = total;
m.state = state;
}
}
_ => {}
}
}
}
fn map_state(s: &str) -> DownloadState {
match s {
"completed" => DownloadState::Completed,
"canceled" => DownloadState::Canceled,
_ => DownloadState::InProgress,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn state_mapping() {
assert_eq!(map_state("completed"), DownloadState::Completed);
assert_eq!(map_state("canceled"), DownloadState::Canceled);
assert_eq!(map_state("inProgress"), DownloadState::InProgress);
assert_eq!(map_state(""), DownloadState::InProgress);
}
#[test]
fn mission_finished_and_succeeded() {
let mk = |state| DownloadMission {
guid: "g".into(),
url: "u".into(),
suggested_filename: "f.bin".into(),
path: PathBuf::from("/tmp/f.bin"),
state,
received_bytes: 10,
total_bytes: 10,
};
let done = mk(DownloadState::Completed);
assert!(done.is_finished() && done.succeeded());
assert_eq!(done.downloaded_bytes(), 10);
let canceled = mk(DownloadState::Canceled);
assert!(canceled.is_finished() && !canceled.succeeded());
let prog = mk(DownloadState::InProgress);
assert!(!prog.is_finished() && !prog.succeeded());
}
}