harddrive_party/
lib.rs

1pub mod connections;
2pub mod errors;
3pub mod peer;
4pub mod shares;
5pub mod ui_server;
6pub mod wishlist;
7
8pub use connections::Hdp;
9pub use harddrive_party_shared::ui_messages;
10pub use harddrive_party_shared::wire_messages;
11
12use crate::{
13    connections::{
14        discovery::{DiscoveryMethod, PeerConnect},
15        known_peers::KnownPeers,
16    },
17    errors::UiServerErrorWrapper,
18    peer::Peer,
19    shares::Shares,
20    ui_messages::{PeerPath, UiEvent, UiServerError},
21    wire_messages::{AnnounceAddress, Request},
22    wishlist::{DownloadRequest, RequestedFile, WishList},
23};
24use async_stream::try_stream;
25use bincode::serialize;
26use futures::{pin_mut, StreamExt};
27use harddrive_party_shared::wire_messages::{IndexQuery, LsResponse};
28use log::{debug, error, warn};
29use quinn::RecvStream;
30use rand::{rngs::OsRng, Rng};
31use std::{collections::HashMap, path::PathBuf, sync::Arc};
32use thiserror::Error;
33use tokio::sync::{broadcast, mpsc::Sender, oneshot, Mutex};
34
35/// Key-value store sub-tree names
36pub mod subtree_names {
37    pub const CONFIG: &[u8; 1] = b"c";
38    pub const FILES: &[u8; 1] = b"f";
39    pub const DIRS: &[u8; 1] = b"d";
40    pub const SHARE_NAMES: &[u8; 1] = b"s";
41    pub const REQUESTS: &[u8; 1] = b"r";
42    pub const REQUESTS_BY_TIMESTAMP: &[u8; 1] = b"R";
43    pub const REQUESTS_PROGRESS: &[u8; 1] = b"P";
44    pub const REQUESTED_FILES_BY_PEER: &[u8; 1] = b"p";
45    pub const REQUESTED_FILES_BY_REQUEST_ID: &[u8; 1] = b"C";
46    pub const KNOWN_PEERS: &[u8; 1] = b"k";
47}
48
49/// Shared state used by both the peer connections and user interface server
50#[derive(Clone)]
51pub struct SharedState {
52    /// A map of peer names to active peer connections
53    pub peers: Arc<Mutex<HashMap<String, Peer>>>,
54    /// A list of known peer names
55    pub known_peers: KnownPeers,
56    /// The index of shared files
57    pub shares: Shares,
58    /// Maintains lists of requested/downloaded files
59    pub wishlist: WishList,
60    /// Channel for sending events to the UI
61    pub event_broadcaster: broadcast::Sender<UiEvent>,
62    /// Channel for announcing peers to connect to
63    peer_announce_tx: Sender<PeerConnect>,
64    /// Download directory
65    pub download_dir: PathBuf,
66    /// A name derived from our public key
67    pub name: String,
68    /// Our own connection details
69    pub announce_address: AnnounceAddress,
70    /// Our OS home directory path
71    pub os_home_dir: Option<String>,
72    /// Channel for graceful shutdown signal
73    graceful_shutdown_tx: tokio::sync::mpsc::Sender<()>,
74}
75
76impl SharedState {
77    #[allow(clippy::too_many_arguments)]
78    pub async fn new(
79        db: sled::Db,
80        share_dirs: Vec<String>,
81        download_dir: PathBuf,
82        name: String,
83        peer_announce_tx: Sender<PeerConnect>,
84        peers: Arc<Mutex<HashMap<String, Peer>>>,
85        announce_address: AnnounceAddress,
86        graceful_shutdown_tx: tokio::sync::mpsc::Sender<()>,
87        known_peers: KnownPeers,
88    ) -> anyhow::Result<Self> {
89        let shares = Shares::new(db.clone(), share_dirs).await?;
90
91        // Set home dir - this is used in the UI as a placeholder when choosing a directory to
92        // share
93        // TODO for cross platform support we should use the `home` crate
94        let os_home_dir = match std::env::var_os("HOME") {
95            Some(o) => o.to_str().map(|s| s.to_string()),
96            None => None,
97        };
98
99        // For sending events to UI clients over websocket
100        let (event_broadcaster, _rx) = broadcast::channel(65536);
101
102        Ok(Self {
103            peers,
104            known_peers,
105            shares,
106            wishlist: WishList::new(&db)?,
107            event_broadcaster,
108            peer_announce_tx,
109            download_dir,
110            name,
111            announce_address,
112            os_home_dir,
113            graceful_shutdown_tx,
114        })
115    }
116
117    /// Send an event to the UI
118    pub async fn send_event(&self, event: UiEvent) {
119        if self.event_broadcaster.send(event).is_err() {
120            warn!("UI response channel closed");
121        }
122    }
123
124    /// Open a request stream and write a request to the peer with the given name
125    pub async fn request(&self, request: Request, name: &str) -> Result<RecvStream, RequestError> {
126        let peers = self.peers.lock().await;
127        let peer = peers.get(name).ok_or(RequestError::PeerNotFound)?;
128        Self::request_peer(request, peer).await
129    }
130
131    /// Static method to open a request stream and write a request to the given peer
132    pub async fn request_peer(request: Request, peer: &Peer) -> Result<RecvStream, RequestError> {
133        let (mut send, recv) = peer.connection.open_bi().await?;
134        let buf = serialize(&request).map_err(|_| RequestError::SerializationError)?;
135        debug!("Message serialized, writing...");
136        send.write_all(&buf).await?;
137        send.finish()?;
138        debug!("Message sent");
139        Ok(recv)
140    }
141
142    pub fn get_ui_announce_address(&self) -> String {
143        self.announce_address.to_string()
144    }
145
146    pub async fn connect_to_peer(
147        &self,
148        announce_address: AnnounceAddress,
149    ) -> Result<(), UiServerErrorWrapper> {
150        let discovery_method = DiscoveryMethod::Direct;
151
152        let (response_tx, response_rx) = oneshot::channel();
153        let peer_connect = PeerConnect {
154            discovery_method,
155            announce_address,
156            response_tx: Some(response_tx),
157        };
158        self.peer_announce_tx
159            .send(peer_connect)
160            .await
161            .map_err(|_| {
162                UiServerError::PeerDiscovery("Peer announce channel closed".to_string())
163            })?;
164
165        // TODO this could take a very long time as the other peer may not show up
166        // add a timeout here
167        response_rx.await?
168    }
169
170    pub async fn download(&self, peer_path: PeerPath) -> Result<u32, UiServerErrorWrapper> {
171        // Get details of the file / dir
172        let ls_request = Request::Ls(IndexQuery {
173            path: Some(peer_path.path.clone()),
174            searchterm: None,
175            recursive: true,
176        });
177        //             // let mut cache = self.ls_cache.lock().await;
178        //             //
179        //             // if let hash_map::Entry::Occupied(mut peer_cache_entry) =
180        //             //     cache.entry(peer_name.clone())
181        //             // {
182        //             //     let peer_cache = peer_cache_entry.get_mut();
183        //             //     if let Some(responses) = peer_cache.get(&ls_request) {
184        //             //         debug!("Found existing responses in cache");
185        //             //         for entries in responses.iter() {
186        //             //             for entry in entries.iter() {
187        //             //                 debug!("Adding {} to wishlist dir: {}", entry.name, entry.is_dir);
188        //             //             }
189        //             //         }
190        //             //     } else {
191        //             //         debug!("Found nothing in cache");
192        //             //     }
193        //             // }
194        //
195        let recv = self.request(ls_request, &peer_path.peer_name).await?;
196
197        let peer_public_key = {
198            let peers = self.peers.lock().await;
199            match peers.get(&peer_path.peer_name) {
200                Some(peer) => peer.public_key,
201                None => {
202                    warn!("Handling request to download a file from a peer who is not connected");
203                    // TODO return an error
204                    return Err(
205                        UiServerError::ConnectionError("Peer not connected".to_string()).into(),
206                    );
207                }
208            }
209        };
210        let mut rng = OsRng;
211        let id: u32 = rng.gen();
212
213        let ls_response_stream = process_length_prefix(recv).await?;
214        pin_mut!(ls_response_stream);
215        while let Some(Ok(ls_response)) = ls_response_stream.next().await {
216            if let LsResponse::Success(entries) = ls_response {
217                for entry in entries.iter() {
218                    if entry.name == peer_path.path {
219                        if let Err(err) = self.wishlist.add_request(&DownloadRequest::new(
220                            entry.name.clone(),
221                            entry.size,
222                            id,
223                            peer_public_key,
224                        )) {
225                            error!("Cannot add download request {err:?}");
226                        }
227                    }
228                    if !entry.is_dir {
229                        debug!("Adding {} to wishlist", entry.name);
230
231                        if let Err(err) = self.wishlist.add_requested_file(&RequestedFile {
232                            path: entry.name.clone(),
233                            size: entry.size,
234                            request_id: id,
235                            downloaded: false,
236                        }) {
237                            error!("Cannot make download request {err:?}");
238                        };
239                    }
240                }
241            }
242        }
243        Ok(id)
244    }
245
246    /// Gracefully shut down the process
247    pub async fn shut_down(&self) {
248        // TODO tidy up peer discovery / active transfers
249        self.shares.flush().await;
250        self.wishlist.flush().await;
251        // This sends a signal to shutdown the Quic endpoint
252        if self.graceful_shutdown_tx.send(()).await.is_err() {
253            std::process::exit(0);
254        };
255    }
256}
257
258/// Error on making a request to a given remote peer
259#[derive(Error, Debug, PartialEq)]
260pub enum RequestError {
261    #[error("Peer not found")]
262    PeerNotFound,
263    #[error(transparent)]
264    ConnectionError(#[from] quinn::ConnectionError),
265    #[error("Cannot serialize message")]
266    SerializationError,
267    #[error(transparent)]
268    WriteError(#[from] quinn::WriteError),
269    #[error("Attempted to close an already closed stream")]
270    ClosedStream(#[from] quinn::ClosedStream),
271}
272
273/// A stream of Ls responses
274pub type LsResponseStream = futures::stream::BoxStream<'static, anyhow::Result<LsResponse>>;
275
276/// Process responses from a remote peer that are prefixed with their length in bytes
277pub async fn process_length_prefix(
278    mut recv: quinn::RecvStream,
279) -> Result<LsResponseStream, UiServerErrorWrapper> {
280    // Read the length prefix
281    let mut length_buf: [u8; 4] = [0; 4];
282    let stream = try_stream! {
283        while let Ok(()) = recv.read_exact(&mut length_buf).await {
284            let length: u32 = u32::from_be_bytes(length_buf);
285            debug!("Read prefix {length}");
286
287            // Read a message
288            let length_usize: usize = length.try_into()?;
289            let mut msg_buf = vec![Default::default(); length_usize];
290            match recv.read_exact(&mut msg_buf).await {
291                Ok(()) => {
292                    let ls_response: LsResponse = bincode::deserialize(&msg_buf)?;
293                    yield ls_response;
294                }
295                Err(_) => {
296                    warn!("Bad prefix / read error");
297                    break;
298                }
299            }
300        }
301    };
302    Ok(stream.boxed())
303}