free_storage/
lib.rs

1#![warn(clippy::nursery, clippy::pedantic)]
2#![allow(clippy::missing_panics_doc, clippy::must_use_candidate)]
3
4use serde_json::json;
5use std::io::{ErrorKind, Read};
6use uuid::Uuid;
7
8use reqwest::{header, Client, Url};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13    #[error("Network Error: {0}")]
14    Reqwest(#[from] reqwest::Error),
15    #[error("Error Parsing JSON: {0}")]
16    JSON(#[from] serde_json::Error),
17    #[error("I/O Error: {0}")]
18    Io(#[from] std::io::Error),
19    #[error("Error Parsing URL: {0}")]
20    Url(#[from] url::ParseError),
21    #[error("Invalid Repository OR The Token is Invalid")]
22    InvalidRepoOrInvalidToken,
23    #[error("Unauthorized")]
24    Unauthorized,
25}
26
27type Result<T> = std::result::Result<T, Error>;
28
29/// A struct that holds the data for a single file.
30#[allow(clippy::unsafe_derive_deserialize)]
31#[derive(Clone, Debug, Serialize, Deserialize)]
32pub struct FileId {
33    asset_ids: Vec<u32>,
34    repo: String,
35}
36
37#[derive(Deserialize)]
38struct AssetsResponse {
39    id: u32,
40}
41
42impl FileId {
43    /// Creates a new [`FileId`] from raw `asset_ids` and a `repo`.
44    ///
45    /// This usually isn't used, instead use [`Self::upload`].
46    pub fn from_raw(asset_ids: Vec<u32>, repo: String) -> Self {
47        Self { asset_ids, repo }
48    }
49
50    /// Uploads a file to the GitHub repository's releases.
51    ///
52    /// The token must have read and write access to the repository.
53    /// `repo` must be in the format `owner/repo`.
54    ///
55    /// # Errors
56    /// Returns an [`Error::InvalidRepo`] if `repo` is not in the correct format, it doesn't exist,
57    /// or if the token does not have `read`/`write` access to the repository.
58    pub async fn upload<S: Into<String> + Send + Sync>(
59        file_name: S,
60        mut file_data: impl Read + Send + Sync,
61        repo: impl Into<String> + Send + Sync,
62        token: impl AsRef<str> + Send + Sync,
63    ) -> Result<Self> {
64        let file_name = <S as Into<String>>::into(file_name)
65            .chars()
66            .filter(|&c| c != '?' && c != '!')
67            .collect::<String>();
68        let repo = repo.into();
69
70        if repo.split('/').count() != 2 {
71            return Err(Error::InvalidRepoOrInvalidToken);
72        }
73
74        tracing::debug!("Uploading file {file_name} to GitHub repo {repo}");
75
76        let client = client(Some(token));
77
78        let (_, uploads_url) = create_or_get_release(&repo, "files", client.clone()).await?;
79
80        let uuid = Uuid::new_v4();
81
82        let mut threads = Vec::new();
83
84        let mut chunks = 0;
85
86        loop {
87            let mut url = uploads_url.clone();
88            url.set_query(Some(&format!("name={uuid}-chunk{chunks}")));
89
90            let client = client.clone();
91
92            tracing::trace!("Reading chunk {chunks}");
93
94            let mut chunk = {
95                // We're only using 100 megabytes because of the time it takes to upload to GitHub
96                let mut chunk = vec![0; 100_000_000];
97
98                let read = loop {
99                    match file_data.read(&mut chunk) {
100                        Ok(a) => break a,
101                        Err(e) => {
102                            if e.kind() == ErrorKind::WouldBlock {
103                                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
104                            } else {
105                                return Err(e.into());
106                            }
107                        }
108                    };
109                };
110
111                if read == 0 {
112                    break;
113                }
114                if read < 100_000_000 {
115                    tracing::trace!("Resizing chunk {chunks} from 100,000,000 to {read}");
116                    // Don't keep all the trailing NULL bytes
117                    chunk.splice(..read, []).collect()
118                } else {
119                    chunk
120                }
121            };
122
123            if chunks == 0 {
124                unsafe { prepend_slice(&mut chunk, format!("{file_name}?").as_bytes()) }
125            }
126
127            threads.push(tokio::spawn(async move {
128                tracing::debug!(
129                    "Uploading chunk {chunks} with {} bytes to {url}",
130                    chunk.len()
131                );
132                client
133                    .post(url)
134                    .header(reqwest::header::CONTENT_TYPE, "application/octet-stream")
135                    .body(chunk)
136                    .send()
137                    .await
138            }));
139
140            chunks += 1;
141        }
142
143        let mut asset_ids = Vec::with_capacity(chunks);
144        for thread in threads {
145            let json = thread.await.unwrap()?.json::<AssetsResponse>().await?;
146
147            asset_ids.push(json.id);
148        }
149
150        Ok(Self { asset_ids, repo })
151    }
152
153    /// Downloads the file from the GitHub repository's releases.
154    ///
155    /// The token must have read access to the repository.
156    ///
157    /// # Errors
158    /// Returns an [`Error::Unauthorized`] if the token does not have read access to the repository
159    /// or if the file doesn't exist.
160    ///
161    /// Returns an [`Error::Reqwest`] if there was a network error.
162    pub async fn get<T: Into<String> + Sync + Send>(
163        &self,
164        token: Option<T>,
165    ) -> Result<(Vec<u8>, String)> {
166        let chunks = self.asset_ids.len();
167
168        tracing::debug!("Downloading {chunks} chunks");
169
170        let mut file = Vec::<u8>::new();
171        let mut threads = Vec::with_capacity(chunks);
172
173        let client = client(token.map(Into::into));
174
175        for asset_id in &self.asset_ids {
176            let url = format!(
177                "https://api.github.com/repos/{}/releases/assets/{asset_id}",
178                self.repo
179            );
180
181            let client = client.clone();
182
183            threads.push(tokio::spawn(async move {
184                client
185                    .get(url)
186                    .header(header::ACCEPT, "application/octet-stream")
187                    .send()
188                    .await
189            }));
190        }
191
192        for thread in threads {
193            let res = thread.await.unwrap()?;
194
195            if res.status().as_u16() == 404 {
196                return Err(Error::Unauthorized);
197            }
198
199            let chunk = res.bytes().await?;
200            file.extend(&chunk);
201        }
202
203        let file = file.into_iter();
204
205        let file_name = file
206            .clone()
207            .map(|b| b as char)
208            .take_while(|&c| c != '?')
209            .collect::<String>();
210
211        let file = file.skip(file_name.len() + 1).collect::<Vec<_>>();
212
213        Ok((file, file_name))
214    }
215}
216
217#[derive(Clone, Debug, Serialize, Deserialize)]
218struct ReleaseResponse {
219    upload_url: Option<String>,
220    assets_url: Option<String>,
221}
222
223/// Creates a new release on GitHub and returns the `assets_url`.
224/// If the release exists, it will only return the `assets_url`.
225async fn create_or_get_release(repo: &str, tag: &str, client: Client) -> Result<(Url, Url)> {
226    let get_release = || async {
227        let url = format!("https://api.github.com/repos/{repo}/releases/tags/{tag}");
228
229        tracing::trace!("Getting release at {url}");
230
231        let release = client
232            .get(url)
233            .send()
234            .await?
235            .json::<ReleaseResponse>()
236            .await?;
237
238        Result::Ok(
239            release
240                .assets_url
241                .and_then(|a| release.upload_url.map(|u| (a, u)))
242                .map(|(a, u)| {
243                    let url = parse_url(&a).unwrap();
244                    let upload_url = parse_url(&u).unwrap();
245                    (url, upload_url)
246                }),
247        )
248    };
249    let create_release = || async {
250        let url = format!("https://api.github.com/repos/{repo}/releases");
251
252        tracing::trace!("Creating release at {url} with tag {tag}");
253
254        let release = client
255            .post(url)
256            .json(&json!({
257                "tag_name": tag,
258            }))
259            .send()
260            .await?
261            .json::<ReleaseResponse>()
262            .await?;
263
264        Result::Ok(
265            release
266                .assets_url
267                .and_then(|a| release.upload_url.map(|u| (a, u)))
268                .map(|(a, u)| {
269                    let url = parse_url(&a).unwrap();
270                    let upload_url = parse_url(&u).unwrap();
271                    (url, upload_url)
272                }),
273        )
274    };
275
276    if let Ok(Some(urls)) = get_release().await {
277        Ok(urls)
278    } else if let Ok(Some(urls)) = create_release().await {
279        Ok(urls)
280    } else {
281        // at this point, the repo probably has no commits.
282        let url = format!("https://api.github.com/repos/{repo}/contents/__no_empty_repo__",);
283        client
284            .put(url)
285            .json(&json!({
286                "message": "add a commit to allow creation of a release.",
287                "content": "",
288                "sha254": "",
289            }))
290            .send()
291            .await?
292            .text()
293            .await?;
294
295        if let Ok(Some(urls)) = create_release().await {
296            Ok(urls)
297        } else {
298            tracing::debug!(
299                "Could not create release. This could be because:
300                                                            *   The repo doesn't exist
301                                                            *   The token is invalid"
302            );
303            Err(Error::InvalidRepoOrInvalidToken)
304        }
305    }
306}
307
308unsafe fn prepend_slice<T: Copy>(vec: &mut Vec<T>, slice: &[T]) {
309    let len = vec.len();
310    let amt = slice.len();
311    vec.reserve(amt);
312
313    std::ptr::copy(vec.as_ptr(), vec.as_mut_ptr().add(amt), len);
314    std::ptr::copy(slice.as_ptr(), vec.as_mut_ptr(), amt);
315    vec.set_len(len + amt);
316}
317
318fn client(token: Option<impl AsRef<str>>) -> Client {
319    let client = Client::builder().user_agent("Rust").default_headers({
320        let mut map = header::HeaderMap::new();
321        if let Some(token) = token {
322            map.insert(header::AUTHORIZATION, {
323                let mut header =
324                    header::HeaderValue::from_str(&format!("token {}", token.as_ref())).unwrap();
325                header.set_sensitive(true);
326                header
327            });
328        }
329        map
330    });
331    client.build().unwrap()
332}
333
334fn parse_url(url: &str) -> Result<Url> {
335    let mut url = url.parse::<Url>()?;
336    url.set_query(None);
337
338    if let Some(path) = url.clone().path().strip_suffix("%7B") {
339        url.set_path(path);
340    }
341
342    Ok(url)
343}