use super::config::Config;
use super::error::CamoError;
use crate::utils::crypto::verify_digest;
use crate::utils::encoding::decode_url;
#[cfg(feature = "server")]
use crate::server::http_client::ReqwestClient;
#[cfg(feature = "worker")]
use crate::server::http_client::WorkerFetchClient;
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::get,
Extension, Router,
};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
pub struct AppState {
pub config: Config,
}
impl AppState {
pub fn from_config(config: &Config) -> Self {
AppState {
config: config.clone(),
}
}
}
pub fn create_router(state: Arc<AppState>) -> Router {
let mut router = Router::new()
.route("/", get(health_check))
.route("/health", get(health_check))
.route("/favicon.ico", get(favicon))
.route("/{digest}", get(proxy_query))
.route("/{digest}/{*encoded_url}", get(proxy_path))
.with_state(state.clone());
#[cfg(feature = "worker")]
{
router = router.layer(Extension(WorkerFetchClient::new(&state.config)));
return router;
}
#[cfg(feature = "server")]
{
if state.config.metrics {
router = router.route("/metrics", get(metrics_handler));
}
router = router.layer(Extension(ReqwestClient::new(&state.config)));
return router.layer(tower_http::trace::TraceLayer::new_for_http());
}
}
async fn health_check() -> &'static str {
"OK"
}
async fn favicon() -> StatusCode {
StatusCode::NOT_FOUND
}
async fn proxy_query(
Path(digest): Path<String>,
Query(params): Query<HashMap<String, String>>,
State(state): State<Arc<AppState>>,
#[cfg(feature = "worker")] Extension(http_client): Extension<WorkerFetchClient>,
#[cfg(feature = "server")] Extension(http_client): Extension<ReqwestClient>,
) -> Response {
let url = match params.get("url") {
Some(u) => u.clone(),
None => return (StatusCode::BAD_REQUEST, "Missing url parameter").into_response(),
};
proxy_request(&state, &digest, &url, &http_client).await
}
async fn proxy_path(
Path((digest, encoded_url)): Path<(String, String)>,
State(state): State<Arc<AppState>>,
#[cfg(feature = "worker")] Extension(http_client): Extension<WorkerFetchClient>,
#[cfg(feature = "server")] Extension(http_client): Extension<ReqwestClient>,
) -> Response {
let url = match decode_url(&encoded_url) {
Some(u) => u,
None => return (StatusCode::BAD_REQUEST, "Invalid URL encoding").into_response(),
};
proxy_request(&state, &digest, &url, &http_client).await
}
async fn proxy_request(
state: &Arc<AppState>,
digest: &str,
url: &str,
#[cfg(feature = "worker")] http_client: &WorkerFetchClient,
#[cfg(feature = "server")] http_client: &ReqwestClient,
) -> Response {
let key = state.config.key.as_ref().expect("key must be set");
if !verify_digest(key, url, digest) {
return CamoError::DigestMismatch.into_response();
}
let url = match url::Url::parse(url) {
Ok(u) => u,
Err(_) => {
return CamoError::InvalidUrl("Malformed URL".into()).into_response();
}
};
if url.scheme() != "http" && url.scheme() != "https" {
return CamoError::InvalidUrl("Only http/https schemes allowed".into()).into_response();
}
match http_client.get(url).await {
Ok(response) => {
response.into_response()
}
Err(e) => {
if state.config.metrics {
let _error_type = match &e {
CamoError::ContentTypeNotAllowed(_) => "content_type",
CamoError::ContentTooLarge(_) => "content_size",
CamoError::Timeout => "timeout",
CamoError::PrivateNetworkNotAllowed => "private_network",
_ => "upstream",
};
}
e.into_response()
}
}
}
async fn metrics_handler() -> impl IntoResponse {
"# Metrics endpoint\n"
}