rhombus/internal/
local_upload_provider.rs1use std::{fmt::Write, sync::Arc, task::Poll};
2
3use async_hash::{Digest, Sha256};
4use axum::{
5 body::Body,
6 extract::{Request, State},
7 response::IntoResponse,
8 Extension,
9};
10use pin_project_lite::pin_project;
11use reqwest::StatusCode;
12use tokio::io::AsyncRead;
13use tower_http::services::ServeFile;
14
15use crate::{
16 internal::{auth::MaybeUser, upload_provider::validate_simple_filename},
17 LocalUploadProvider,
18};
19
20pub type LocalUploadProviderState = Arc<LocalUploadProvider>;
21
22pub async fn route_local_download(
23 state: State<LocalUploadProviderState>,
24 Extension(maybe_user): Extension<MaybeUser>,
25 axum::extract::Path(path): axum::extract::Path<String>,
26 req: Request<Body>,
27) -> impl IntoResponse {
28 if !validate_simple_filename(&path) {
29 return (StatusCode::BAD_REQUEST, "Invalid path".to_owned()).into_response();
30 }
31
32 let (_, filename) = if let Some(parts) = path.split_once('-') {
33 parts
34 } else {
35 return (StatusCode::BAD_REQUEST, "Invalid path".to_owned()).into_response();
36 };
37
38 let filepath = std::path::Path::new(&state.base_path).join(&path);
39
40 let filepath = match filepath.canonicalize() {
41 Ok(path) => path,
42 Err(_) => {
43 return (StatusCode::NOT_FOUND, "Not Found").into_response();
44 }
45 };
46
47 tracing::info!(
48 path = filepath.to_str().unwrap(),
49 user_id = maybe_user.map(|u| u.id),
50 "Downloading from local file"
51 );
52
53 let mut response = ServeFile::new(&filepath).try_call(req).await.unwrap();
54 response.headers_mut().insert(
55 "Content-Disposition",
56 format!("attachment; filename={}", &filename)
57 .parse()
58 .unwrap(),
59 );
60 response.into_response()
61}
62
63pub fn slice_to_hex_string(slice: &[u8]) -> String {
64 slice.iter().fold(String::new(), |mut output, b| {
65 let _ = write!(output, "{b:02x}");
66 output
67 })
68}
69
70pin_project! {
71 pub struct HashRead<T> {
72 #[pin]
73 read: T,
74
75 hasher: Sha256,
76 }
77}
78
79impl<T> HashRead<T> {
80 pub fn new(read: T) -> Self {
81 Self {
82 read,
83 hasher: Sha256::new(),
84 }
85 }
86
87 pub fn hash(self) -> Vec<u8> {
88 self.hasher.finalize().as_slice().into()
89 }
90}
91
92impl<T> AsyncRead for HashRead<T>
93where
94 T: AsyncRead,
95{
96 fn poll_read(
97 self: std::pin::Pin<&mut Self>,
98 cx: &mut std::task::Context<'_>,
99 buf: &mut tokio::io::ReadBuf<'_>,
100 ) -> std::task::Poll<std::io::Result<()>> {
101 let this = self.project();
102 let before_len = buf.filled().len();
103
104 match this.read.poll_read(cx, buf) {
105 Poll::Pending => Poll::Pending,
106 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
107 Poll::Ready(Ok(())) => {
108 let filled = buf.filled();
109 let after_len = filled.len();
110
111 if after_len > before_len {
112 let new = &filled[before_len..];
113 this.hasher.update(new);
114 }
115
116 Poll::Ready(Ok(()))
117 }
118 }
119 }
120}