rhombus/internal/
local_upload_provider.rs

1use std::{fmt::Write, sync::Arc, task::Poll};
2
3use async_hash::{Digest, Sha256};
4use axum::{
5    body::Body,
6    extract::{Request, State},
7    response::IntoResponse,
8    Extension,
9};
10use pin_project_lite::pin_project;
11use reqwest::StatusCode;
12use tokio::io::AsyncRead;
13use tower_http::services::ServeFile;
14
15use crate::{
16    internal::{auth::MaybeUser, upload_provider::validate_simple_filename},
17    LocalUploadProvider,
18};
19
20pub type LocalUploadProviderState = Arc<LocalUploadProvider>;
21
22pub async fn route_local_download(
23    state: State<LocalUploadProviderState>,
24    Extension(maybe_user): Extension<MaybeUser>,
25    axum::extract::Path(path): axum::extract::Path<String>,
26    req: Request<Body>,
27) -> impl IntoResponse {
28    if !validate_simple_filename(&path) {
29        return (StatusCode::BAD_REQUEST, "Invalid path".to_owned()).into_response();
30    }
31
32    let (_, filename) = if let Some(parts) = path.split_once('-') {
33        parts
34    } else {
35        return (StatusCode::BAD_REQUEST, "Invalid path".to_owned()).into_response();
36    };
37
38    let filepath = std::path::Path::new(&state.base_path).join(&path);
39
40    let filepath = match filepath.canonicalize() {
41        Ok(path) => path,
42        Err(_) => {
43            return (StatusCode::NOT_FOUND, "Not Found").into_response();
44        }
45    };
46
47    tracing::info!(
48        path = filepath.to_str().unwrap(),
49        user_id = maybe_user.map(|u| u.id),
50        "Downloading from local file"
51    );
52
53    let mut response = ServeFile::new(&filepath).try_call(req).await.unwrap();
54    response.headers_mut().insert(
55        "Content-Disposition",
56        format!("attachment; filename={}", &filename)
57            .parse()
58            .unwrap(),
59    );
60    response.into_response()
61}
62
63pub fn slice_to_hex_string(slice: &[u8]) -> String {
64    slice.iter().fold(String::new(), |mut output, b| {
65        let _ = write!(output, "{b:02x}");
66        output
67    })
68}
69
70pin_project! {
71    pub struct HashRead<T> {
72        #[pin]
73        read: T,
74
75        hasher: Sha256,
76    }
77}
78
79impl<T> HashRead<T> {
80    pub fn new(read: T) -> Self {
81        Self {
82            read,
83            hasher: Sha256::new(),
84        }
85    }
86
87    pub fn hash(self) -> Vec<u8> {
88        self.hasher.finalize().as_slice().into()
89    }
90}
91
92impl<T> AsyncRead for HashRead<T>
93where
94    T: AsyncRead,
95{
96    fn poll_read(
97        self: std::pin::Pin<&mut Self>,
98        cx: &mut std::task::Context<'_>,
99        buf: &mut tokio::io::ReadBuf<'_>,
100    ) -> std::task::Poll<std::io::Result<()>> {
101        let this = self.project();
102        let before_len = buf.filled().len();
103
104        match this.read.poll_read(cx, buf) {
105            Poll::Pending => Poll::Pending,
106            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
107            Poll::Ready(Ok(())) => {
108                let filled = buf.filled();
109                let after_len = filled.len();
110
111                if after_len > before_len {
112                    let new = &filled[before_len..];
113                    this.hasher.update(new);
114                }
115
116                Poll::Ready(Ok(()))
117            }
118        }
119    }
120}