Skip to main content

harddrive_party_shared/client/
mod.rs

1#[cfg(feature = "native")]
2mod events;
3
4#[cfg(feature = "native")]
5pub use events::EventStream;
6
7use crate::{
8    ui_messages::{FilesQuery, Info, PeerPath, UiDownloadRequest, UiRequestedFile, UiServerError},
9    wire_messages::{IndexQuery, LsResponse, ReadQuery},
10};
11use bincode::{deserialize, serialize};
12use bytes::{Buf, Bytes, BytesMut};
13use futures::Stream;
14use reqwest::{Response, Url};
15use serde::de::DeserializeOwned;
16use std::task::{Context, Poll};
17use std::{num::ParseIntError, pin::Pin};
18use thiserror::Error;
19
20/// A result for which the error is UiServerError
21type UiResult<T> = Result<T, UiServerError>;
22
23#[derive(Clone)]
24pub struct Client {
25    http_client: reqwest::Client,
26    ui_url: Url,
27}
28
29// TODO error handle for all these methods
30impl Client {
31    pub fn new(ui_url: Url) -> Self {
32        Self {
33            http_client: reqwest::Client::new(),
34            ui_url,
35        }
36    }
37
38    #[cfg(feature = "native")]
39    pub async fn event_stream(&self) -> Result<EventStream, ClientError> {
40        EventStream::new(self.ui_url.clone()).await
41    }
42
43    pub async fn shares(
44        &self,
45        query: IndexQuery,
46    ) -> Result<impl Stream<Item = Result<LsResponse, UiServerError>>, ClientError> {
47        let res = self
48            .http_client
49            .post(
50                self.ui_url
51                    .join("api/shares")
52                    .map_err(|_| ClientError::InvalidUrl)?,
53            )
54            .body(serialize(&query)?)
55            .send()
56            .await?;
57
58        if !res.status().is_success() {
59            return Err(ClientError::from_response(res).await);
60        }
61
62        Ok(LengthPrefixedStream::new(res))
63    }
64
65    pub async fn files(
66        &self,
67        query: FilesQuery,
68    ) -> Result<impl Stream<Item = UiResult<(LsResponse, String)>>, ClientError> {
69        let res = self
70            .http_client
71            .post(
72                self.ui_url
73                    .join("api/files")
74                    .map_err(|_| ClientError::InvalidUrl)?,
75            )
76            .body(serialize(&query)?)
77            .send()
78            .await?;
79
80        if !res.status().is_success() {
81            return Err(ClientError::from_response(res).await);
82        }
83
84        Ok(LengthPrefixedStream::new(res))
85    }
86
87    pub async fn download(&self, peer_path: &PeerPath) -> Result<u32, ClientError> {
88        let res = self
89            .http_client
90            .post(
91                self.ui_url
92                    .join("api/download")
93                    .map_err(|_| ClientError::InvalidUrl)?,
94            )
95            .body(serialize(peer_path)?)
96            .send()
97            .await?;
98
99        if !res.status().is_success() {
100            return Err(ClientError::from_response(res).await);
101        }
102
103        Ok(res.text().await?.parse()?)
104    }
105
106    pub async fn connect(&self, announce_address: String) -> Result<(), ClientError> {
107        let res = self
108            .http_client
109            .post(
110                self.ui_url
111                    .join("api/connect")
112                    .map_err(|_| ClientError::InvalidUrl)?,
113            )
114            .body(announce_address)
115            .send()
116            .await?;
117
118        if !res.status().is_success() {
119            return Err(ClientError::from_response(res).await);
120        }
121        Ok(())
122    }
123
124    pub async fn read(
125        &self,
126        peer_name: String,
127        read_query: ReadQuery,
128    ) -> Result<impl Stream<Item = Result<Bytes, reqwest::Error>>, ClientError> {
129        // payload is (read_query, peer_name)
130        let res = self
131            .http_client
132            .post(
133                self.ui_url
134                    .join("api/read")
135                    .map_err(|_| ClientError::InvalidUrl)?,
136            )
137            .body(serialize(&(read_query, peer_name))?)
138            .send()
139            .await?;
140
141        if !res.status().is_success() {
142            return Err(ClientError::from_response(res).await);
143        }
144
145        let stream = res.bytes_stream();
146        Ok(stream)
147    }
148
149    pub async fn info(&self) -> Result<Info, ClientError> {
150        let res = self
151            .http_client
152            .get(
153                self.ui_url
154                    .join("api/info")
155                    .map_err(|_| ClientError::InvalidUrl)?,
156            )
157            .send()
158            .await?;
159
160        if !res.status().is_success() {
161            return Err(ClientError::from_response(res).await);
162        }
163
164        Ok(bincode::deserialize(&res.bytes().await?)?)
165    }
166
167    pub async fn requested_files(
168        &self,
169        id: u32,
170    ) -> Result<impl Stream<Item = Result<Vec<UiRequestedFile>, UiServerError>>, ClientError> {
171        let res = self
172            .http_client
173            .get(
174                self.ui_url
175                    .join("api/request")
176                    .map_err(|_| ClientError::InvalidUrl)?,
177            )
178            .query(&[("id", id.to_string())])
179            .send()
180            .await?;
181
182        if !res.status().is_success() {
183            return Err(ClientError::from_response(res).await);
184        }
185
186        Ok(LengthPrefixedStream::new(res))
187    }
188
189    pub async fn requests(
190        &self,
191    ) -> Result<impl Stream<Item = Result<Vec<UiDownloadRequest>, UiServerError>>, ClientError>
192    {
193        let res = self
194            .http_client
195            .get(
196                self.ui_url
197                    .join("api/requests")
198                    .map_err(|_| ClientError::InvalidUrl)?,
199            )
200            .send()
201            .await?;
202
203        if !res.status().is_success() {
204            return Err(ClientError::from_response(res).await);
205        }
206
207        Ok(LengthPrefixedStream::new(res))
208    }
209
210    pub async fn add_share(&self, share_dir: String) -> Result<u32, ClientError> {
211        let res = self
212            .http_client
213            .put(
214                self.ui_url
215                    .join("api/shares")
216                    .map_err(|_| ClientError::InvalidUrl)?,
217            )
218            .body(share_dir)
219            .send()
220            .await?;
221
222        if !res.status().is_success() {
223            return Err(ClientError::from_response(res).await);
224        }
225
226        Ok(res.text().await?.parse()?)
227    }
228
229    pub async fn remove_share(&self, share_dir: String) -> Result<(), ClientError> {
230        let res = self
231            .http_client
232            .delete(
233                self.ui_url
234                    .join("api/shares")
235                    .map_err(|_| ClientError::InvalidUrl)?,
236            )
237            .body(share_dir)
238            .send()
239            .await?;
240
241        if !res.status().is_success() {
242            return Err(ClientError::from_response(res).await);
243        }
244
245        Ok(())
246    }
247
248    pub async fn shut_down(&self) -> Result<(), ClientError> {
249        let res = self
250            .http_client
251            .post(
252                self.ui_url
253                    .join("api/close")
254                    .map_err(|_| ClientError::InvalidUrl)?,
255            )
256            .send()
257            .await?;
258
259        if !res.status().is_success() {
260            return Err(ClientError::from_response(res).await);
261        }
262
263        Ok(())
264    }
265}
266
267/// For deserializing chunked byte HTTP responses
268// If using this with tokio spawned tasks, we also need `+ Send` - but adding send wont compile on
269// wasm, so we need conditional comilation
270struct LengthPrefixedStream<T>
271where
272    T: DeserializeOwned + 'static + Send,
273{
274    inner: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>>>>,
275    buffer: BytesMut,
276    _marker: std::marker::PhantomData<T>,
277}
278
279impl<T> LengthPrefixedStream<T>
280where
281    T: DeserializeOwned + 'static + Send,
282{
283    pub fn new(response: Response) -> Self {
284        let stream = response.bytes_stream();
285        LengthPrefixedStream {
286            inner: Box::pin(stream),
287            buffer: BytesMut::with_capacity(64 * 1024),
288            _marker: std::marker::PhantomData,
289        }
290    }
291}
292
293impl<T> Stream for LengthPrefixedStream<T>
294where
295    T: DeserializeOwned + 'static + std::marker::Unpin + Send,
296{
297    type Item = UiResult<T>;
298
299    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
300        let this = self.get_mut();
301
302        loop {
303            // Try to parse a complete message
304            if this.buffer.len() >= 4 {
305                let mut len_buf = &this.buffer[..4];
306                let msg_len = len_buf.get_u32() as usize;
307
308                if this.buffer.len() >= 4 + msg_len {
309                    this.buffer.advance(4); // discard the length prefix
310                    let msg = this.buffer.split_to(msg_len);
311                    let res: UiResult<T> = match deserialize(&msg) {
312                        Ok(res) => res,
313                        Err(err) => Err(UiServerError::Serialization(err.to_string())),
314                    };
315                    return Poll::Ready(Some(res));
316                }
317            }
318
319            // Not enough data - try to pull the next chunk
320            match this.inner.as_mut().poll_next(cx) {
321                Poll::Ready(Some(Ok(chunk))) => {
322                    this.buffer.extend_from_slice(&chunk);
323                }
324                Poll::Ready(Some(Err(e))) => {
325                    return Poll::Ready(Some(Err(UiServerError::RequestError(e.to_string()))));
326                }
327                Poll::Ready(None) => {
328                    // End of stream
329                    if this.buffer.is_empty() {
330                        return Poll::Ready(None);
331                    } else {
332                        // Incomplete trailing data
333                        return Poll::Ready(Some(Err(UiServerError::RequestError(
334                            "Incomlete message at end of stream".to_string(),
335                        ))));
336                    }
337                }
338                Poll::Pending => return Poll::Pending,
339            }
340        }
341    }
342}
343
344/// An error from the client
345#[derive(PartialEq, Debug, Error)]
346pub enum ClientError {
347    #[error("Cannot connect: {0}")]
348    ConnectionError(String),
349    #[error("Invalid Url")]
350    InvalidUrl,
351    #[error("Unexpected message type on websocket")]
352    UnexpectedMessageType,
353    #[error("Serialization or deserialization: {0}")]
354    Serialization(String),
355    #[error("HTTP client: {0}")]
356    HttpRequest(String),
357    #[error("Cannot parse integer: {0}")]
358    ParseInt(#[from] ParseIntError),
359    #[error("Server: {0}")]
360    ServerError(#[from] UiServerError),
361}
362
363impl From<bincode::Error> for ClientError {
364    fn from(value: bincode::Error) -> Self {
365        ClientError::Serialization(value.to_string())
366    }
367}
368
369impl From<reqwest::Error> for ClientError {
370    fn from(value: reqwest::Error) -> Self {
371        ClientError::Serialization(value.to_string())
372    }
373}
374
375impl ClientError {
376    pub async fn from_response(response: reqwest::Response) -> Self {
377        match response.status() {
378            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
379                let err: UiServerError =
380                    serde_json::from_str(&response.text().await.unwrap_or_default()).unwrap();
381                err.into()
382            }
383            _ => ClientError::HttpRequest(format!(
384                "Request failed: {} {}",
385                response.status(),
386                response.text().await.unwrap_or_default()
387            )),
388        }
389    }
390}