use std::{
io,
path::{Path, PathBuf},
result::Result,
};
use axum::{body::Bytes, extract::Request, http::StatusCode, response::IntoResponse, BoxError};
use axum_extra::{headers::Range, TypedHeader};
use axum_range::{KnownSize, Ranged};
use futures::{Stream, TryStreamExt};
use futures_util::pin_mut;
use tokio::io::AsyncWrite;
use tokio_util::io::StreamReader;
use crate::{
acl::AccessType,
auth::BasicAuthFromRequest,
error::{ApiErrorKind, ApiResult},
handlers::{access_check::check_auth_and_acl, file_helpers::Finalizer},
storage::STORAGE,
typed_path::{PathParts, TpeKind},
};
pub async fn add_file<P: PathParts>(
path: P,
auth: BasicAuthFromRequest,
request: Request,
) -> ApiResult<impl IntoResponse> {
let (path, tpe, name) = path.parts();
tracing::debug!("[get_file] path: {path:?}, tpe: {tpe:?}, name: {name:?}");
let path_str = path.unwrap_or_default();
let path = PathBuf::from(&path_str);
let file = get_save_file(auth.user, path, tpe, name).await?;
let stream = request.into_body().into_data_stream();
let _ = save_body(file, stream).await?;
Ok(())
}
pub async fn delete_file<P: PathParts>(
path: P,
auth: BasicAuthFromRequest,
) -> ApiResult<impl IntoResponse> {
let (path, tpe, name) = path.parts();
tracing::debug!("[delete_file] path: {path:?}, tpe: {tpe:?}, name: {name:?}");
let path_str = path.unwrap_or_default();
let path = Path::new(&path_str);
let _ = check_name(tpe, name.as_deref())?;
let _ = check_auth_and_acl(auth.user, tpe, path, AccessType::Append)?;
let tpe = if let Some(tpe) = tpe {
tpe.into_str()
} else {
return Err(ApiErrorKind::InternalError("tpe is not valid".to_string()));
};
let storage = STORAGE.get().unwrap();
storage.remove_file(path, tpe, name.as_deref()).await?;
Ok(())
}
pub async fn get_file<P: PathParts>(
path: P,
auth: BasicAuthFromRequest,
range: Option<TypedHeader<Range>>,
) -> ApiResult<impl IntoResponse> {
let (path, tpe, name) = path.parts();
tracing::debug!(?path, "type" = ?tpe, ?name, "[get_file]");
let _ = check_name(tpe, name.as_deref())?;
let path_str = path.unwrap_or_default();
let path = Path::new(&path_str);
let _ = check_auth_and_acl(auth.user, tpe, path, AccessType::Read)?;
let tpe = if let Some(tpe) = tpe {
tpe.into_str()
} else {
return Err(ApiErrorKind::InternalError("tpe is not valid".to_string()));
};
let storage = STORAGE.get().unwrap();
let file = storage.open_file(path, tpe, name.as_deref()).await?;
let body = KnownSize::file(file)
.await
.map_err(|err| ApiErrorKind::GettingFileMetadataFailed(format!("{err:?}")))?;
let range = range.map(|TypedHeader(range)| range);
let status_code = if range.is_some() {
StatusCode::PARTIAL_CONTENT
} else {
StatusCode::OK
};
Ok((status_code, Ranged::new(range, body)).into_response())
}
pub async fn get_save_file(
user: String,
path: PathBuf,
tpe: Option<TpeKind>,
name: Option<String>,
) -> ApiResult<impl AsyncWrite + Unpin + Finalizer> {
tracing::debug!("[get_save_file] path: {path:?}, tpe: {tpe:?}, name: {name:?}");
let _ = check_name(tpe, name.as_deref())?;
let _ = check_auth_and_acl(user, tpe, path.as_path(), AccessType::Append)?;
let tpe = if let Some(tpe) = tpe {
tpe.into_str()
} else {
return Err(ApiErrorKind::InternalError("tpe is not valid".to_string()));
};
let storage = STORAGE.get().unwrap();
storage.create_file(&path, tpe, name.as_deref()).await
}
pub async fn save_body<S, E>(
mut write_stream: impl AsyncWrite + Unpin + Finalizer + Send,
stream: S,
) -> ApiResult<impl IntoResponse>
where
S: Stream<Item = Result<Bytes, E>> + Send,
E: Into<BoxError>,
{
let body_with_io_error = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err));
let body_reader = StreamReader::new(body_with_io_error);
pin_mut!(body_reader);
let byte_count = match tokio::io::copy(&mut body_reader, &mut write_stream).await {
Ok(b) => b,
Err(err) => return Err(ApiErrorKind::FinalizingFileFailed(format!("{:?}", err))),
};
tracing::debug!("[file written] bytes: {byte_count}");
write_stream.finalize().await.map_err(|err| {
ApiErrorKind::FinalizingFileFailed(format!("Could not finalize file: {}", err))
})
}
#[cfg(test)]
const fn check_string_sha256(_name: &str) -> bool {
true
}
#[cfg(not(test))]
fn check_string_sha256(name: &str) -> bool {
if name.len() != 64 {
return false;
}
for c in name.chars() {
if !c.is_ascii_digit() && !('a'..='f').contains(&c) {
return false;
}
}
true
}
pub fn check_name(
tpe: impl Into<Option<TpeKind>>,
name: Option<&str>,
) -> ApiResult<impl IntoResponse> {
let tpe = tpe.into();
match (tpe, name) {
(Some(TpeKind::Config), _) => Ok(()),
(_, Some(name)) if check_string_sha256(name) => Ok(()),
_ => Err(ApiErrorKind::FilenameNotAllowed(
name.unwrap_or_default().to_string(),
)),
}
}
#[cfg(test)]
mod test {
use crate::{
handlers::file_exchange::{add_file, delete_file, get_file},
log::print_request_response,
testing::{
basic_auth_header_value, init_test_environment, request_uri_for_test, server_config,
},
typed_path::RepositoryTpeNamePath,
};
use std::{fs, path::PathBuf};
use axum::{
body::Body,
http::{header, Method, Request, StatusCode},
middleware, Router,
};
use axum_extra::routing::RouterExt; use http_body_util::BodyExt;
use tower::ServiceExt;
#[tokio::test]
async fn test_add_delete_file_passes() {
init_test_environment(server_config());
let file_name = "__add_file_test_adds_this_one__";
let path = PathBuf::new()
.join("tests")
.join("generated")
.join("test_storage")
.join("test_repo")
.join("keys")
.join(file_name);
if path.exists() {
fs::remove_file(&path).unwrap();
assert!(!path.exists());
}
let app = Router::new()
.typed_post(add_file::<RepositoryTpeNamePath>)
.layer(middleware::from_fn(print_request_response));
let test_vec = "Hello World".to_string();
let body = Body::new(test_vec.clone());
let uri = ["/test_repo/keys/", file_name].concat();
let request = Request::builder()
.uri(uri)
.method(Method::POST)
.header(
"Authorization",
basic_auth_header_value("rustic", Some("rustic")),
)
.body(body)
.unwrap();
let resp = app.oneshot(request).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(path.exists());
let body = fs::read_to_string(&path).unwrap();
assert_eq!(body, test_vec);
let app = Router::new()
.typed_delete(delete_file::<RepositoryTpeNamePath>)
.layer(middleware::from_fn(print_request_response));
let uri = ["/test_repo/keys/", file_name].concat();
let request = Request::builder()
.uri(uri)
.method(Method::DELETE)
.header(
"Authorization",
basic_auth_header_value("rustic", Some("rustic")),
)
.body(body)
.unwrap();
let resp = app.oneshot(request).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(!path.exists());
}
#[tokio::test]
async fn test_get_file_passes() {
init_test_environment(server_config());
let file_name = "__get_file_test_adds_this_two__";
let path = PathBuf::new()
.join("tests")
.join("generated")
.join("test_storage")
.join("test_repo")
.join("keys")
.join(file_name);
if path.exists() {
tracing::debug!("[server_get_file_tester] test file found and removed");
fs::remove_file(&path).unwrap();
assert!(!path.exists());
}
let app = Router::new()
.typed_post(add_file::<RepositoryTpeNamePath>)
.layer(middleware::from_fn(print_request_response));
let test_vec = "Hello Sweet World".to_string();
let body = Body::new(test_vec.clone());
let uri = ["/test_repo/keys/", file_name].concat();
let request = Request::builder()
.uri(uri)
.method(Method::POST)
.header(
"Authorization",
basic_auth_header_value("rustic", Some("rustic")),
)
.body(body)
.unwrap();
let resp = app.oneshot(request).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(path.exists());
let body = fs::read_to_string(&path).unwrap();
assert_eq!(body, test_vec);
let app = Router::new()
.typed_get(get_file::<RepositoryTpeNamePath>)
.layer(middleware::from_fn(print_request_response));
let uri = ["/test_repo/keys/", file_name].concat();
let request = request_uri_for_test(&uri, Method::GET);
let resp = app.clone().oneshot(request).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let (_parts, body) = resp.into_parts();
let byte_vec = body.collect().await.unwrap().to_bytes();
let body_str = String::from_utf8(byte_vec.to_vec()).unwrap();
assert_eq!(body_str, test_vec);
let uri = ["/test_repo/keys/", file_name].concat();
let request = Request::builder()
.uri(uri)
.method(Method::GET)
.header(header::RANGE, "bytes=6-12")
.header(
"Authorization",
basic_auth_header_value("rustic", Some("rustic")),
)
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(request).await.unwrap();
let test_vec = "Sweet W".to_string();
assert_eq!(resp.status(), StatusCode::PARTIAL_CONTENT);
let (_parts, body) = resp.into_parts();
let byte_vec = body.collect().await.unwrap().to_bytes();
let body_str = String::from_utf8(byte_vec.to_vec()).unwrap();
assert_eq!(body_str, test_vec);
let app = Router::new()
.typed_delete(delete_file::<RepositoryTpeNamePath>)
.layer(middleware::from_fn(print_request_response));
let uri = ["/test_repo/keys/", file_name].concat();
let request = request_uri_for_test(&uri, Method::DELETE);
let resp = app.oneshot(request).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert!(!path.exists());
}
}