rhombus 0.2.21

Next generation extendable CTF framework with batteries included
Documentation
use std::{io, sync::Arc};

use async_hash::{Digest, Sha256};
use axum::{
    body::{Body, Bytes},
    extract::State,
    http::Response,
    response::IntoResponse,
    routing::{get, post},
    Extension, Router,
};
use futures::{Stream, TryStreamExt};

use reqwest::StatusCode;
use tokio::io::AsyncReadExt;
use tokio_util::{bytes::BytesMut, io::StreamReader};

use crate::{
    internal::{
        auth::MaybeUser,
        database::provider::Connection,
        local_upload_provider::slice_to_hex_string,
        upload_provider::{route_upload_file, validate_simple_filename},
    },
    upload_provider::UploadProvider,
    Result,
};

#[derive(Clone)]
pub struct DatabaseUploadProvider {
    pub db: Connection,
}

impl DatabaseUploadProvider {
    pub async fn new(db: Connection) -> DatabaseUploadProvider {
        DatabaseUploadProvider { db }
    }
}

impl UploadProvider for DatabaseUploadProvider {
    fn routes(&self) -> Result<Router> {
        let provider_state = Arc::new(self.clone());
        let router = Router::new()
            .route("/uploads/:hash_filename", get(route_database_download))
            .route("/upload/:path", post(route_upload_file::<Self>))
            .with_state(provider_state);
        Ok(router)
    }

    async fn upload<S, E>(&self, filename: &str, stream: S) -> Result<String>
    where
        S: Stream<Item = std::result::Result<Bytes, E>> + Send,
        E: Into<axum::BoxError>,
    {
        let body_with_io_error = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err));
        let body_reader = StreamReader::new(body_with_io_error);

        futures::pin_mut!(body_reader);

        let mut buffer = BytesMut::with_capacity(1024);
        _ = body_reader.read_buf(&mut buffer).await?;

        let mut hasher = Sha256::new();
        hasher.update(&buffer);
        let hash = slice_to_hex_string(hasher.finalize().as_slice());

        self.db.upload_file(&hash, filename, &buffer).await?;

        let url = format!("/uploads/{}-{}", hash, filename);

        tracing::info!(url, "uploaded to database");

        Ok(url)
    }
}

pub async fn route_database_download(
    state: State<Arc<DatabaseUploadProvider>>,
    Extension(maybe_user): Extension<MaybeUser>,
    axum::extract::Path(path): axum::extract::Path<String>,
) -> impl IntoResponse {
    if !validate_simple_filename(&path) {
        return (StatusCode::BAD_REQUEST, "Invalid path".to_owned()).into_response();
    }

    let (hash, filename) = if let Some(parts) = path.split_once('-') {
        parts
    } else {
        return (StatusCode::BAD_REQUEST, "Invalid path".to_owned()).into_response();
    };

    let (contents, db_filename) = state.db.download_file(hash).await.unwrap();

    if filename != db_filename {
        return (StatusCode::NOT_FOUND, "Not Found").into_response();
    }

    tracing::info!(
        path = filename,
        user_id = maybe_user.map(|u| u.id),
        "Downloading from database"
    );

    let body = Body::from(contents);
    let mut response = Response::new(body);
    response.headers_mut().insert(
        "Content-Disposition",
        format!("attachment; filename={}", &filename)
            .parse()
            .unwrap(),
    );
    response.into_response()
}