use std::{
error::Error,
path::PathBuf,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use axum::{
Router,
extract::{Path, State},
http::{StatusCode, header::CONTENT_TYPE},
response::{IntoResponse, Response},
routing::{get, post},
};
use clap::Parser;
use tailscale::{Config, Device};
use tracing::level_filters::LevelFilter;
static WWW: include_dir::Dir = include_dir::include_dir!("$CARGO_MANIFEST_DIR/examples/axum/www");
async fn assets(Path(path): Path<String>) -> Response {
let Some(result) = WWW
.get_file(&path)
.or_else(|| WWW.get_file(format!("{path}/index.html")))
else {
return (StatusCode::NOT_FOUND, "not found").into_response();
};
let mime = mime_guess::from_path(result.path());
(
[(CONTENT_TYPE, mime.first_or_octet_stream().as_ref())],
result.contents(),
)
.into_response()
}
async fn count(count: State<Arc<AtomicUsize>>) -> impl IntoResponse {
let new = count.0.fetch_add(1, Ordering::SeqCst);
format!(r#"{{"count": {new}}}"#)
}
#[derive(clap::Parser)]
#[command(version, about)]
struct Args {
#[arg(short = 'c', long, default_value = "tsrs_keys.json")]
key_file: PathBuf,
#[arg(short = 'k', long, env = "TS_AUTH_KEY")]
auth_key: Option<String>,
#[arg(short = 'H', long, default_value = "axum-example")]
hostname: Option<String>,
#[arg(long, env = "TS_CONTROL_URL")]
control_url: Option<url::Url>,
#[arg(short, long, default_value_t = 80)]
port: u16,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy(),
)
.init();
let args = Args::parse();
let mut config = Config::default_with_key_file(&args.key_file).await?;
config.requested_hostname = args.hostname;
if let Some(url) = args.control_url {
config.control_server_url = url;
}
let dev = Device::new(&config, args.auth_key).await?;
let listener = dev
.tcp_listen((dev.ipv4_addr().await?, args.port).into())
.await?;
let router = Router::new()
.route("/count", post(count))
.with_state(Arc::new(AtomicUsize::new(0)))
.route("/{*path}", get(assets));
let url = format!("http://{}/index.html", listener.local_addr());
tracing::info!(%url, "http server listening");
axum::serve(tailscale::axum::Listener::from(listener), router).await?;
Ok(())
}