use std::fs::create_dir_all;
use std::io::{Error, ErrorKind};
use std::iter;
use std::path::{Path, PathBuf};
use std::result::Result;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use crossbeam::deque::{Injector, Stealer, Worker};
use flume::{unbounded, Receiver};
use once_cell::sync::OnceCell;
use serde::Deserialize;
use threadpool::ThreadPool;
use yenc::DecodeOptions;
use crate::article::{Body, MessageId};
use crate::connection::{self, Command, Connection, StatusCode};
pub struct Download<'a> {
on_progress: OnceCell<Receiver<Result<(PathBuf, MessageId, Body), DownloadError>>>,
download_dir: &'a Path,
do_cancel: Arc<AtomicBool>,
pool: ThreadPool,
server: &'a Server,
}
impl<'a> Download<'a> {
pub fn new(download_dir: &'a Path, server: &'a Server) -> Self {
Download {
on_progress: OnceCell::new(),
download_dir,
do_cancel: Arc::new(AtomicBool::new(false)),
pool: ThreadPool::new(server.connections),
server,
}
}
pub fn cancel(&self) {
self.do_cancel.store(true, Ordering::Relaxed);
}
pub fn is_cancelled(&self) -> bool {
self.do_cancel.load(Ordering::Relaxed)
}
pub fn next(&self) -> Result<(PathBuf, MessageId), DownloadError> {
self
.on_progress
.wait()
.recv()
.unwrap_or(Err(DownloadError::NoMoreElements))
.and_then(|(output_dir, message_id, body)| {
let mut final_output_dir = PathBuf::from(self.download_dir);
final_output_dir.push(&output_dir);
create_dir_all(&final_output_dir).unwrap();
let options = DecodeOptions::new(final_output_dir);
options
.decode_stream(&body.bytes()[..])
.map_err(DownloadError::DecodeFailed)?;
Ok((output_dir, message_id))
})
}
pub fn run(&self, queue: Arc<Injector<(PathBuf, MessageId)>>) -> Result<(), DownloadError> {
self
.on_progress
.set(self.fetch(queue))
.map_err(|_| DownloadError::AlreadyRunning)
}
fn fetch(
&self,
queue: Arc<Injector<(PathBuf, MessageId)>>,
) -> Receiver<Result<(PathBuf, MessageId, Body), DownloadError>> {
let mut stealers: Vec<Stealer<(PathBuf, MessageId)>> = Vec::new();
let (segment_complete, on_segment_complete) = unbounded();
for _ in 0..self.pool.max_count() {
let server = self.server.clone();
let global = queue.clone();
let segment_complete = segment_complete.clone();
let worker: Worker<_> = Worker::new_fifo();
stealers.push(worker.stealer());
let stealers = stealers.clone();
let do_cancel = self.do_cancel.clone();
self.pool.execute(move || {
let addr = format!("{}:{}", &server.host, if server.secure { 563 } else { 119 });
if server.secure {
let hostname = &server.host;
match connection::secure(&addr, hostname) {
Ok(mut connection) => {
let stream = connection.stream();
Self::execute_tasks(
stream,
&server,
&worker,
&global,
&stealers,
do_cancel,
|result| segment_complete.send(result).unwrap(),
);
}
Err(e) => {
segment_complete
.send(Err(DownloadError::IoError(e)))
.unwrap();
}
}
} else {
match connection::insecure(&addr) {
Ok(connection) => {
Self::execute_tasks(
connection,
&server,
&worker,
&global,
&stealers,
do_cancel,
|result| segment_complete.send(result).unwrap(),
);
}
Err(e) => {
segment_complete
.send(Err(DownloadError::IoError(e)))
.unwrap();
}
}
}
});
}
on_segment_complete
}
fn execute_tasks<T: Connection, F>(
mut connection: T,
server: &Server,
local: &Worker<(PathBuf, MessageId)>,
global: &Injector<(PathBuf, MessageId)>,
stealers: &[Stealer<(PathBuf, MessageId)>],
do_cancel: Arc<AtomicBool>,
on_complete: F,
) where
F: Fn(Result<(PathBuf, MessageId, Body), DownloadError>),
{
if connection
.authenticate(&server.username, &server.password)
.is_ok()
{
while let Some((output_dir, message_id)) = Self::find_task(local, global, stealers) {
if do_cancel.load(Ordering::Relaxed) {
break;
}
let result = Body::by_id(&mut connection, &message_id).map_err(DownloadError::IoError);
on_complete(result.map(|body| (output_dir, message_id, body)));
}
if let Err(e) = connection.quit() {
println!("Failed to send QUIT command to the server: {e:?}");
}
}
}
fn find_task<T>(local: &Worker<T>, global: &Injector<T>, stealers: &[Stealer<T>]) -> Option<T> {
local.pop().or_else(|| {
iter::repeat_with(|| {
global
.steal_batch_and_pop(local)
.or_else(|| stealers.iter().map(|s| s.steal()).collect())
})
.find(|s| !s.is_retry())
.and_then(|s| s.success())
})
}
}
#[derive(Clone, Debug, Deserialize)]
pub struct Server {
pub host: String,
pub username: String,
pub password: String,
pub connections: usize,
pub secure: bool,
}
impl Body {
pub fn by_id<T: Connection>(connection: &mut T, message_id: &MessageId) -> Result<Body, Error> {
connection
.execute(&Command::new(&[
"BODY".to_string(),
format!("<{message_id}>"),
]))
.and_then(|(status, text)| {
if status.status_code == StatusCode::from(430) {
Err(Error::new(
ErrorKind::NotFound,
format!("No article with message-id <{message_id}> exists."),
))
} else {
Ok(Body::new(text))
}
})
}
}
#[derive(Debug)]
pub enum DownloadError {
CannotOpenFile,
InvalidFile,
DecodeFailed(yenc::DecodeError),
IoError(std::io::Error),
AlreadyRunning,
NoMoreElements,
}