rews 0.4.5

A binary client for Usenet.
Documentation
use std::fs::create_dir_all;
use std::io::{Error, ErrorKind};
use std::iter;
use std::path::{Path, PathBuf};
use std::result::Result;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use crossbeam::deque::{Injector, Stealer, Worker};
use flume::{unbounded, Receiver};
use once_cell::sync::OnceCell;
use serde::Deserialize;
use threadpool::ThreadPool;
use yenc::DecodeOptions;

use crate::article::{Body, MessageId};
use crate::connection::{self, Command, Connection, StatusCode};

pub struct Download<'a> {
  on_progress: OnceCell<Receiver<Result<(PathBuf, MessageId, Body), DownloadError>>>,
  download_dir: &'a Path,
  do_cancel: Arc<AtomicBool>,
  pool: ThreadPool,
  server: &'a Server,
}

impl<'a> Download<'a> {
  pub fn new(download_dir: &'a Path, server: &'a Server) -> Self {
    Download {
      on_progress: OnceCell::new(),
      download_dir,
      do_cancel: Arc::new(AtomicBool::new(false)),
      pool: ThreadPool::new(server.connections),
      server,
    }
  }

  pub fn cancel(&self) {
    self.do_cancel.store(true, Ordering::Relaxed);
  }

  pub fn is_cancelled(&self) -> bool {
    self.do_cancel.load(Ordering::Relaxed)
  }

  pub fn next(&self) -> Result<(PathBuf, MessageId), DownloadError> {
    self
      .on_progress
      .wait()
      .recv()
      .unwrap_or(Err(DownloadError::NoMoreElements))
      .and_then(|(output_dir, message_id, body)| {
        let mut final_output_dir = PathBuf::from(self.download_dir);
        final_output_dir.push(&output_dir);

        // Ensure that the output directory exists.
        create_dir_all(&final_output_dir).unwrap();

        // Set up decode options.
        let options = DecodeOptions::new(final_output_dir);

        // Decode the body.
        options
          .decode_stream(&body.bytes()[..])
          .map_err(DownloadError::DecodeFailed)?;

        Ok((output_dir, message_id))
      })
  }

  pub fn run(&self, queue: Arc<Injector<(PathBuf, MessageId)>>) -> Result<(), DownloadError> {
    self
      .on_progress
      .set(self.fetch(queue))
      .map_err(|_| DownloadError::AlreadyRunning)
  }

  fn fetch(
    &self,
    queue: Arc<Injector<(PathBuf, MessageId)>>,
  ) -> Receiver<Result<(PathBuf, MessageId, Body), DownloadError>> {
    // Construct the global task queue.
    let mut stealers: Vec<Stealer<(PathBuf, MessageId)>> = Vec::new();

    // We need a channel to return the downloaded data.
    let (segment_complete, on_segment_complete) = unbounded();

    // Spin up all the threads with a connection each.
    for _ in 0..self.pool.max_count() {
      // Clone objects which will be moved into threads.
      let server = self.server.clone();
      let global = queue.clone();
      let segment_complete = segment_complete.clone();

      // Create a local task queue and a stealer.
      let worker: Worker<_> = Worker::new_fifo();
      stealers.push(worker.stealer());

      // Clone stealers reference.
      let stealers = stealers.clone();

      // Clone cancel signal reference.
      let do_cancel = self.do_cancel.clone();

      // Execute a worker on a thread.
      self.pool.execute(move || {
        let addr = format!("{}:{}", &server.host, if server.secure { 563 } else { 119 });

        if server.secure {
          let hostname = &server.host;

          match connection::secure(&addr, hostname) {
            Ok(mut connection) => {
              let stream = connection.stream();

              Self::execute_tasks(
                stream,
                &server,
                &worker,
                &global,
                &stealers,
                do_cancel,
                |result| segment_complete.send(result).unwrap(),
              );
            }
            Err(e) => {
              segment_complete
                .send(Err(DownloadError::IoError(e)))
                .unwrap();
            }
          }
        } else {
          match connection::insecure(&addr) {
            Ok(connection) => {
              Self::execute_tasks(
                connection,
                &server,
                &worker,
                &global,
                &stealers,
                do_cancel,
                |result| segment_complete.send(result).unwrap(),
              );
            }
            Err(e) => {
              segment_complete
                .send(Err(DownloadError::IoError(e)))
                .unwrap();
            }
          }
        }
      });
    }

    on_segment_complete
  }

  fn execute_tasks<T: Connection, F>(
    mut connection: T,
    server: &Server,
    local: &Worker<(PathBuf, MessageId)>,
    global: &Injector<(PathBuf, MessageId)>,
    stealers: &[Stealer<(PathBuf, MessageId)>],
    do_cancel: Arc<AtomicBool>,
    on_complete: F,
  ) where
    F: Fn(Result<(PathBuf, MessageId, Body), DownloadError>),
  {
    if connection
      .authenticate(&server.username, &server.password)
      .is_ok()
    {
      while let Some((output_dir, message_id)) = Self::find_task(local, global, stealers) {
        // Break the loop if a cancellation signal has been received.
        if do_cancel.load(Ordering::Relaxed) {
          break;
        }

        // Get the body for the message_id specified in the task.
        let result = Body::by_id(&mut connection, &message_id).map_err(DownloadError::IoError);

        // Return the result.
        on_complete(result.map(|body| (output_dir, message_id, body)));
      }

      if let Err(e) = connection.quit() {
        println!("Failed to send QUIT command to the server: {e:?}");
      }
    }
  }

  fn find_task<T>(local: &Worker<T>, global: &Injector<T>, stealers: &[Stealer<T>]) -> Option<T> {
    // Pop a task from the local queue, if not empty.
    local.pop().or_else(|| {
      // Otherwise, we need to look for a task elsewhere.
      iter::repeat_with(|| {
        // Try stealing a batch of tasks from the global queue.
        global
          .steal_batch_and_pop(local)
          // Or try stealing a task from one of the other threads.
          .or_else(|| stealers.iter().map(|s| s.steal()).collect())
      })
      // Loop while no task was stolen and any steal operation needs to be retried.
      .find(|s| !s.is_retry())
      // Extract the stolen task, if there is one.
      .and_then(|s| s.success())
    })
  }
}

#[derive(Clone, Debug, Deserialize)]
pub struct Server {
  pub host: String,
  pub username: String,
  pub password: String,
  pub connections: usize,
  pub secure: bool,
}

impl Body {
  pub fn by_id<T: Connection>(connection: &mut T, message_id: &MessageId) -> Result<Body, Error> {
    connection
      .execute(&Command::new(&[
        "BODY".to_string(),
        format!("<{message_id}>"),
      ]))
      .and_then(|(status, text)| {
        if status.status_code == StatusCode::from(430) {
          Err(Error::new(
            ErrorKind::NotFound,
            format!("No article with message-id <{message_id}> exists."),
          ))
        } else {
          Ok(Body::new(text))
        }
      })
  }
}

#[derive(Debug)]
pub enum DownloadError {
  CannotOpenFile,
  InvalidFile,
  DecodeFailed(yenc::DecodeError),
  IoError(std::io::Error),
  AlreadyRunning,
  NoMoreElements,
}