pub mod discovery;
pub mod known_peers;
pub mod quic;
pub mod rpc;
pub mod speedometer;
use crate::{
connections::{
discovery::{DiscoveredPeer, DiscoveryMethod, PeerDiscovery},
quic::{
generate_certificate, get_certificate_from_connection, make_server_endpoint,
make_server_endpoint_basic_socket,
},
rpc::Rpc,
},
errors::UiServerErrorWrapper,
peer::Peer,
subtree_names::{CONFIG, KNOWN_PEERS},
ui_messages::{UiEvent, UiServerError},
wire_messages::{AnnouncePeer, Request},
SharedState,
};
use harddrive_party_shared::wire_messages::{AnnounceAddress, PeerConnectionDetails};
use log::{debug, error, info, warn};
use quinn::Endpoint;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, SystemTime},
};
use tokio::{
net::UdpSocket,
select,
sync::{mpsc, Mutex},
};
use x509_parser::prelude::{FromDer, X509Certificate};
const MAX_REQUEST_SIZE: usize = 1024;
const PUBLIC_KEY_LENGTH: usize = 32;
type PublicKey = [u8; PUBLIC_KEY_LENGTH];
pub struct Hdp {
pub shared_state: SharedState,
rpc: Rpc,
pub server_connection: ServerConnection,
peer_discovery: PeerDiscovery,
graceful_shutdown_rx: mpsc::Receiver<()>,
}
impl Hdp {
pub async fn new(
storage: impl AsRef<Path>,
share_dirs: Vec<String>,
download_dir: PathBuf,
use_mdns: bool,
) -> anyhow::Result<Self> {
let mut db_dir = storage.as_ref().to_owned();
db_dir.push("db");
let db = sled::open(db_dir)?;
let config_db = db.open_tree(CONFIG)?;
let (cert_der, priv_key_der) = {
let existing_cert = config_db.get(b"cert");
let existing_priv = config_db.get(b"priv");
match (existing_cert, existing_priv) {
(Ok(Some(cert_der)), Ok(Some(priv_key_der))) => (
CertificateDer::from(cert_der.to_vec()),
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(priv_key_der.to_vec())),
),
_ => {
let (cert_der, priv_key_der) = generate_certificate()?;
config_db.insert(b"cert", cert_der.as_ref())?;
config_db.insert(b"priv", priv_key_der.secret_der())?;
(cert_der, priv_key_der)
}
}
};
let (name, pk_hash) =
certificate_to_name(CertificateDer::from_slice(&cert_der.clone()).into_owned())?;
let peers: Arc<Mutex<HashMap<String, Peer>>> = Default::default();
let port = config_db
.get(b"port")
.ok()
.flatten()
.and_then(|bytes| bytes.to_vec().try_into().ok())
.map(u16::from_be_bytes);
let known_peers_db = db.open_tree(KNOWN_PEERS)?;
let (socket_option, peer_discovery) =
PeerDiscovery::new(use_mdns, pk_hash, peers.clone(), port, known_peers_db).await?;
let (graceful_shutdown_tx, graceful_shutdown_rx) = mpsc::channel(1);
let shared_state = SharedState::new(
db,
share_dirs,
download_dir,
name,
peer_discovery.peer_announce_tx.clone(),
peers,
peer_discovery.announce_address.clone(),
graceful_shutdown_tx,
peer_discovery.known_peers.clone(),
)
.await?;
let server_connection = match socket_option {
Some(socket) => {
if let Ok(port) = socket.get_port() {
let port_bytes = port.to_be_bytes();
config_db.insert(b"port", &port_bytes)?;
}
ServerConnection::WithEndpoint(
make_server_endpoint(
socket,
cert_der,
priv_key_der,
shared_state.known_peers.clone(),
peer_discovery.use_client_verification(),
)
.await?,
)
}
None => {
ServerConnection::Symmetric(cert_der, priv_key_der)
}
};
Ok(Self {
shared_state: shared_state.clone(),
rpc: Rpc::new(
shared_state.shares,
shared_state.event_broadcaster,
peer_discovery.peer_announce_tx.clone(),
),
server_connection,
peer_discovery,
graceful_shutdown_rx,
})
}
pub async fn run(&mut self) {
let (incoming_connection_tx, mut incoming_connection_rx) = mpsc::channel(1024);
if let ServerConnection::WithEndpoint(endpoint) = self.server_connection.clone() {
tokio::spawn(async move {
loop {
if let Some(incoming_conn) = endpoint.accept().await {
if incoming_connection_tx.send(incoming_conn).await.is_err() {
warn!("Cannot handle incoming connections - channel closed");
}
}
}
});
}
for announce_address in self.shared_state.known_peers.iter() {
if let PeerConnectionDetails::NoNat(socket_address) =
announce_address.connection_details
{
let peer = DiscoveredPeer {
socket_address,
socket_option: None,
discovery_method: DiscoveryMethod::Direct,
announce_address,
};
info!("Connecting to known peer... {}", peer.announce_address.name);
if let Err(err) = self.connect_to_peer(peer).await {
error!("Cannot connect to peer from known_peers {err:?}");
};
}
}
loop {
select! {
Some(incoming_conn) = incoming_connection_rx.recv() => {
let maybe_peer_details = self.peer_discovery.get_pending_peer(&incoming_conn.remote_address());
if let Err(err) = self.handle_incoming_connection(maybe_peer_details.clone(), incoming_conn).await {
error!("Error when handling incoming peer connection {err:?}");
if let Some((_, announce_address)) = maybe_peer_details {
let name = announce_address.name;
self.shared_state.send_event(UiEvent::PeerConnectionFailed { name, error: err.to_string() }).await;
}
}
}
Some(peer) = self.peer_discovery.peers_rx.recv() => {
debug!("Discovered peer {peer:?}");
let name = peer.announce_address.name.clone();
if let Err(err) = self.connect_to_peer(peer).await {
error!("Cannot connect to discovered peer {err:?}");
self.shared_state.send_event(UiEvent::PeerConnectionFailed { name, error: err.to_string() }).await;
};
}
Some(()) = self.graceful_shutdown_rx.recv() => {
debug!("Shutting down");
if let ServerConnection::WithEndpoint(endpoint) = self.server_connection.clone() {
endpoint.wait_idle().await;
}
std::process::exit(0);
}
}
}
}
async fn handle_connection(
&mut self,
conn: quinn::Connection,
incoming: bool,
maybe_peer_details: Option<(DiscoveryMethod, AnnounceAddress)>,
remote_cert: CertificateDer<'static>,
) -> Result<(), UiServerError> {
let (peer_name, peer_public_key) = certificate_to_name(remote_cert)
.map_err(|err| UiServerError::PeerDiscovery(err.to_string()))?;
debug!(
"[{}] Connected to peer {}",
self.shared_state.name, peer_name
);
let announce_address = if let Some(peer_details) = maybe_peer_details {
Some(peer_details.1)
} else {
None
};
let rpc = self.rpc.clone();
let shared_state = self.shared_state.clone();
tokio::spawn(async move {
{
let peer = Peer::new(
conn.clone(),
shared_state.event_broadcaster.clone(),
shared_state.download_dir.clone(),
peer_public_key,
shared_state.wishlist.clone(),
announce_address.clone(),
);
let mut peers = shared_state.peers.lock().await;
if let Some(ref announce_address) = announce_address {
let announce_peer = AnnouncePeer {
announce_address: announce_address.clone(),
};
for other_peer in peers.values() {
let request = Request::AnnouncePeer(announce_peer.clone());
if let Err(err) = SharedState::request_peer(request, other_peer).await {
error!("Failed to send announce message to {other_peer:?} - {err:?}");
}
if let Some(ref announce_address_other) = peer.announce_address {
let announce_other_peer = AnnouncePeer {
announce_address: announce_address_other.clone(),
};
let request = Request::AnnouncePeer(announce_other_peer);
if let Err(err) = SharedState::request_peer(request, &peer).await {
error!("Failed to send announce message to {peer:?} - {err:?}");
}
}
}
}
if let Some(_existing_peer) = peers.insert(peer_name.clone(), peer) {
warn!("Adding connection for already connected peer!");
};
let direction = if incoming { "incoming" } else { "outgoing" };
info!("[{}] connected to {} peers", direction, peers.len());
}
shared_state
.send_event(UiEvent::PeerConnected {
name: peer_name.clone(),
})
.await;
let err = loop {
match accept_incoming_request(&conn).await {
Ok((send, buf)) => {
rpc.request(buf, send, peer_name.clone()).await;
}
Err(err) => {
warn!("Failed to handle request: {err:?}");
break err;
}
}
};
{
let mut peers = shared_state.peers.lock().await;
if peers.remove(&peer_name).is_none() {
warn!("Connection closed but peer not present in map");
}
}
debug!("Connection closed - removed peer");
shared_state
.send_event(UiEvent::PeerDisconnected {
name: peer_name.clone(),
error: err.to_string(),
})
.await;
if let Some(announce_address) = announce_address {
if let Err(err) = shared_state.connect_to_peer(announce_address).await {
warn!("Could not reconnect to peer following disconnect: {err}");
}
}
});
Ok(())
}
async fn handle_incoming_connection(
&mut self,
maybe_peer_details: Option<(DiscoveryMethod, AnnounceAddress)>,
incoming_conn: quinn::Incoming,
) -> Result<(), UiServerErrorWrapper> {
let conn = incoming_conn.await?;
debug!(
"Incoming QUIC connection accepted {}",
conn.remote_address()
);
if let Some(i) = conn.handshake_data() {
if let Ok(handshake_data) = i.downcast::<quinn::crypto::rustls::HandshakeData>() {
debug!(
"Server name of connecting peer {:?}",
handshake_data.server_name
);
}
}
let c = conn.clone();
let remote_cert = get_certificate_from_connection(&c)?;
self.handle_connection(conn, true, maybe_peer_details, remote_cert)
.await?;
Ok(())
}
async fn connect_to_peer(&mut self, peer: DiscoveredPeer) -> Result<(), UiServerError> {
let endpoint = match self.server_connection.clone() {
ServerConnection::WithEndpoint(endpoint) => endpoint,
ServerConnection::Symmetric(cert_der, priv_key_der) => {
let socket = match peer.socket_option {
Some(socket) => socket,
None => UdpSocket::bind("0.0.0.0:0")
.await
.map_err(|e| UiServerError::PeerDiscovery(e.to_string()))?,
};
make_server_endpoint_basic_socket(
socket,
cert_der,
priv_key_der,
self.shared_state.known_peers.clone(),
)
.await
.map_err(|err| {
UiServerError::ConnectionError(format!("When creating endpoint: {err:?}"))
})?
}
};
let connection = endpoint
.connect(peer.socket_address, "peer")
.map_err(|err| UiServerError::ConnectionError(format!("When connecting: {err:?}")))?
.await
.map_err(|err| UiServerError::ConnectionError(format!("After connecting: {err:?}")))?;
let remote_cert = get_certificate_from_connection(&connection).map_err(|err| {
UiServerError::ConnectionError(format!("When getting certificate: {err:?}"))
})?;
self.handle_connection(
connection,
false,
Some((peer.discovery_method, peer.announce_address)),
remote_cert,
)
.await?;
Ok(())
}
}
pub fn certificate_to_name(
cert: CertificateDer<'static>,
) -> Result<(String, PublicKey), rustls::Error> {
let (_, cert) = X509Certificate::from_der(&cert)
.map_err(|_| rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding))?;
cert.verify_signature(None)
.map_err(|_| rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature))?;
let public_key = cert.public_key();
if public_key.algorithm.algorithm.to_string() != "1.3.101.112" {
return Err(rustls::Error::InvalidCertificate(
rustls::CertificateError::BadEncoding,
));
}
let public_key: [u8; 32] = public_key
.subject_public_key
.data
.as_ref()
.try_into()
.map_err(|_| rustls::Error::InvalidCertificate(rustls::CertificateError::BadEncoding))?;
Ok((key_to_animal::key_to_name(&public_key), public_key))
}
async fn accept_incoming_request(
conn: &quinn::Connection,
) -> anyhow::Result<(quinn::SendStream, Vec<u8>)> {
let (send, mut recv) = conn.accept_bi().await?;
let buf = recv.read_to_end(MAX_REQUEST_SIZE).await?;
Ok((send, buf))
}
#[derive(Debug)]
pub enum ServerConnection {
WithEndpoint(Endpoint),
Symmetric(CertificateDer<'static>, PrivateKeyDer<'static>),
}
impl std::fmt::Display for ServerConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ServerConnection::WithEndpoint(endpoint) => {
write!(
f,
"{}",
match endpoint.local_addr() {
Ok(local_addr) => local_addr.to_string(),
_ => "No local adddress".to_string(),
}
)?;
}
ServerConnection::Symmetric(_, _) => {
f.write_str("Behind symmetric NAT")?;
}
}
Ok(())
}
}
impl Clone for ServerConnection {
fn clone(&self) -> Self {
match self {
ServerConnection::WithEndpoint(endpoint) => {
ServerConnection::WithEndpoint(endpoint.clone())
}
ServerConnection::Symmetric(cert, key) => {
ServerConnection::Symmetric(cert.clone(), key.clone_key())
}
}
}
}
pub fn get_timestamp() -> Duration {
let system_time = SystemTime::now();
system_time
.duration_since(SystemTime::UNIX_EPOCH)
.expect("Time went backwards")
}