use axum::http::{HeaderValue, Method};
use eyre::{ensure, Context};
use std::net::SocketAddr;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
#[tokio::main]
async fn main() -> eyre::Result<()> {
let cors = std::env::args().nth(1).unwrap_or("*".to_string());
let router = make_router()
.into_axum("/")
.layer(make_cors(&cors)?);
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = tokio::net::TcpListener::bind(addr).await?;
println!("Listening on {}", listener.local_addr()?);
println!("CORS allowed for: {}", cors);
if cors == "*" {
println!("(specify cors domains as a comma-separated list to restrict origins)");
}
println!("use Ctrl-C to stop");
axum::serve(listener, router).await.map_err(Into::into)
}
fn get_allowed(cors: &str) -> eyre::Result<AllowOrigin> {
if cors == "*" {
return Ok(AllowOrigin::any());
}
ensure!(
!cors.split(',').any(|o| o == "*"),
"Wildcard '*' is not allowed in CORS domains"
);
cors.split(',')
.map(|domain| {
domain
.parse::<HeaderValue>()
.inspect_err(|e| eprintln!("Failed to parse domain {}: {}", domain, e))
.wrap_err_with(|| format!("Invalid CORS domain: {}", domain))
})
.collect::<Result<Vec<_>, _>>()
.map(Into::into)
}
fn make_cors(cors: &str) -> eyre::Result<CorsLayer> {
let origins = get_allowed(cors)?;
Ok(CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_origin(origins)
.allow_headers(Any))
}
fn make_router() -> ajj::Router<()> {
ajj::Router::<()>::new()
.route("helloWorld", || async {
tracing::info!("serving hello world");
Ok::<_, ()>("Hello, world!")
})
.route("addNumbers", |(a, b): (u32, u32)| async move {
tracing::info!("serving addNumbers");
Ok::<_, ()>(a + b)
})
}