pub mod connections;
pub mod errors;
pub mod peer;
pub mod shares;
pub mod ui_server;
pub mod wishlist;
pub use connections::Hdp;
pub use harddrive_party_shared::ui_messages;
pub use harddrive_party_shared::wire_messages;
use crate::{
connections::{
discovery::{DiscoveryMethod, PeerConnect},
known_peers::KnownPeers,
},
errors::UiServerErrorWrapper,
peer::Peer,
shares::Shares,
ui_messages::{PeerPath, UiEvent, UiServerError},
wire_messages::{AnnounceAddress, Request},
wishlist::{DownloadRequest, RequestedFile, WishList},
};
use async_stream::try_stream;
use bincode::serialize;
use futures::{pin_mut, StreamExt};
use harddrive_party_shared::wire_messages::{IndexQuery, LsResponse};
use log::{debug, error, warn};
use quinn::RecvStream;
use rand::{rngs::OsRng, Rng};
use std::{collections::HashMap, path::PathBuf, sync::Arc};
use thiserror::Error;
use tokio::sync::{broadcast, mpsc::Sender, oneshot, Mutex};
pub mod subtree_names {
pub const CONFIG: &[u8; 1] = b"c";
pub const FILES: &[u8; 1] = b"f";
pub const DIRS: &[u8; 1] = b"d";
pub const SHARE_NAMES: &[u8; 1] = b"s";
pub const REQUESTS: &[u8; 1] = b"r";
pub const REQUESTS_BY_TIMESTAMP: &[u8; 1] = b"R";
pub const REQUESTS_PROGRESS: &[u8; 1] = b"P";
pub const REQUESTED_FILES_BY_PEER: &[u8; 1] = b"p";
pub const REQUESTED_FILES_BY_REQUEST_ID: &[u8; 1] = b"C";
pub const KNOWN_PEERS: &[u8; 1] = b"k";
}
#[derive(Clone)]
pub struct SharedState {
pub peers: Arc<Mutex<HashMap<String, Peer>>>,
pub known_peers: KnownPeers,
pub shares: Shares,
pub wishlist: WishList,
pub event_broadcaster: broadcast::Sender<UiEvent>,
peer_announce_tx: Sender<PeerConnect>,
pub download_dir: PathBuf,
pub name: String,
pub announce_address: AnnounceAddress,
pub os_home_dir: Option<String>,
graceful_shutdown_tx: tokio::sync::mpsc::Sender<()>,
}
impl SharedState {
#[allow(clippy::too_many_arguments)]
pub async fn new(
db: sled::Db,
share_dirs: Vec<String>,
download_dir: PathBuf,
name: String,
peer_announce_tx: Sender<PeerConnect>,
peers: Arc<Mutex<HashMap<String, Peer>>>,
announce_address: AnnounceAddress,
graceful_shutdown_tx: tokio::sync::mpsc::Sender<()>,
known_peers: KnownPeers,
) -> anyhow::Result<Self> {
let shares = Shares::new(db.clone(), share_dirs).await?;
let os_home_dir = match std::env::var_os("HOME") {
Some(o) => o.to_str().map(|s| s.to_string()),
None => None,
};
let (event_broadcaster, _rx) = broadcast::channel(65536);
Ok(Self {
peers,
known_peers,
shares,
wishlist: WishList::new(&db)?,
event_broadcaster,
peer_announce_tx,
download_dir,
name,
announce_address,
os_home_dir,
graceful_shutdown_tx,
})
}
pub async fn send_event(&self, event: UiEvent) {
if self.event_broadcaster.send(event).is_err() {
warn!("UI response channel closed");
}
}
pub async fn request(&self, request: Request, name: &str) -> Result<RecvStream, RequestError> {
let peers = self.peers.lock().await;
let peer = peers.get(name).ok_or(RequestError::PeerNotFound)?;
Self::request_peer(request, peer).await
}
pub async fn request_peer(request: Request, peer: &Peer) -> Result<RecvStream, RequestError> {
let (mut send, recv) = peer.connection.open_bi().await?;
let buf = serialize(&request).map_err(|_| RequestError::SerializationError)?;
debug!("Message serialized, writing...");
send.write_all(&buf).await?;
send.finish()?;
debug!("Message sent");
Ok(recv)
}
pub fn get_ui_announce_address(&self) -> String {
self.announce_address.to_string()
}
pub async fn connect_to_peer(
&self,
announce_address: AnnounceAddress,
) -> Result<(), UiServerErrorWrapper> {
let discovery_method = DiscoveryMethod::Direct;
let (response_tx, response_rx) = oneshot::channel();
let peer_connect = PeerConnect {
discovery_method,
announce_address,
response_tx: Some(response_tx),
};
self.peer_announce_tx
.send(peer_connect)
.await
.map_err(|_| {
UiServerError::PeerDiscovery("Peer announce channel closed".to_string())
})?;
response_rx.await?
}
pub async fn download(&self, peer_path: PeerPath) -> Result<u32, UiServerErrorWrapper> {
let ls_request = Request::Ls(IndexQuery {
path: Some(peer_path.path.clone()),
searchterm: None,
recursive: true,
});
let recv = self.request(ls_request, &peer_path.peer_name).await?;
let peer_public_key = {
let peers = self.peers.lock().await;
match peers.get(&peer_path.peer_name) {
Some(peer) => peer.public_key,
None => {
warn!("Handling request to download a file from a peer who is not connected");
return Err(
UiServerError::ConnectionError("Peer not connected".to_string()).into(),
);
}
}
};
let mut rng = OsRng;
let id: u32 = rng.gen();
let ls_response_stream = process_length_prefix(recv).await?;
pin_mut!(ls_response_stream);
while let Some(Ok(ls_response)) = ls_response_stream.next().await {
if let LsResponse::Success(entries) = ls_response {
for entry in entries.iter() {
if entry.name == peer_path.path {
if let Err(err) = self.wishlist.add_request(&DownloadRequest::new(
entry.name.clone(),
entry.size,
id,
peer_public_key,
)) {
error!("Cannot add download request {err:?}");
}
}
if !entry.is_dir {
debug!("Adding {} to wishlist", entry.name);
if let Err(err) = self.wishlist.add_requested_file(&RequestedFile {
path: entry.name.clone(),
size: entry.size,
request_id: id,
downloaded: false,
}) {
error!("Cannot make download request {err:?}");
};
}
}
}
}
Ok(id)
}
pub async fn shut_down(&self) {
self.shares.flush().await;
self.wishlist.flush().await;
if self.graceful_shutdown_tx.send(()).await.is_err() {
std::process::exit(0);
};
}
}
#[derive(Error, Debug, PartialEq)]
pub enum RequestError {
#[error("Peer not found")]
PeerNotFound,
#[error(transparent)]
ConnectionError(#[from] quinn::ConnectionError),
#[error("Cannot serialize message")]
SerializationError,
#[error(transparent)]
WriteError(#[from] quinn::WriteError),
#[error("Attempted to close an already closed stream")]
ClosedStream(#[from] quinn::ClosedStream),
}
pub type LsResponseStream = futures::stream::BoxStream<'static, anyhow::Result<LsResponse>>;
pub async fn process_length_prefix(
mut recv: quinn::RecvStream,
) -> Result<LsResponseStream, UiServerErrorWrapper> {
let mut length_buf: [u8; 4] = [0; 4];
let stream = try_stream! {
while let Ok(()) = recv.read_exact(&mut length_buf).await {
let length: u32 = u32::from_be_bytes(length_buf);
debug!("Read prefix {length}");
let length_usize: usize = length.try_into()?;
let mut msg_buf = vec![Default::default(); length_usize];
match recv.read_exact(&mut msg_buf).await {
Ok(()) => {
let ls_response: LsResponse = bincode::deserialize(&msg_buf)?;
yield ls_response;
}
Err(_) => {
warn!("Bad prefix / read error");
break;
}
}
}
};
Ok(stream.boxed())
}