use axum::Router;
use axum::extract::State;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use axum::routing::{any, get, post};
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::broadcast;
mod assets;
mod build;
mod config;
mod dashboard;
mod html;
mod onboarding;
mod pages;
mod preview;
mod sections;
const DEFAULT_WEBAPP_PORT: u16 = 1112;
const PREVIEW_MOUNT: &str = "/preview";
const RELOAD_CHANNEL_CAPACITY: usize = 16;
pub(crate) const LIVERELOAD_JS: &str = r#"
<script>
(function() {
var reconnectInterval = 1000;
var maxReconnect = 30000;
function connect() {
var proto = location.protocol === 'https:' ? 'wss:' : 'ws:';
var ws = new WebSocket(proto + '//' + location.host + '/__livereload');
ws.onmessage = function(event) {
if (event.data === 'reload') { location.reload(); }
};
ws.onclose = function() {
setTimeout(connect, reconnectInterval);
reconnectInterval = Math.min(reconnectInterval * 1.5, maxReconnect);
};
ws.onopen = function() { reconnectInterval = 1000; };
}
connect();
})();
</script>
"#;
pub(crate) struct AppState {
pub root: PathBuf,
pub output_dir: PathBuf,
pub sandbox: Option<PathBuf>,
pub reload_tx: broadcast::Sender<()>,
pub preview_base_url: String,
}
impl AppState {
fn site_title(&self) -> String {
let config_path = self.root.join("config.toml");
if let Ok(content) = std::fs::read_to_string(&config_path) {
if let Ok(config) = toml::from_str::<toml::Value>(&content) {
if let Some(title) = config.get("title").and_then(|v| v.as_str()) {
return title.to_string();
}
}
}
"Zorto Site".to_string()
}
fn site_exists(&self) -> bool {
self.root.join("config.toml").exists()
}
fn site_base_url(&self) -> String {
PREVIEW_MOUNT.to_string()
}
}
pub(crate) fn app(state: Arc<AppState>) -> Router {
Router::new()
.route("/", get(dashboard::index))
.route("/pages", get(pages::list))
.route("/pages/new", get(pages::new_form).post(pages::create))
.route("/pages/{*path}", get(pages::edit).post(pages::save))
.route("/pages/delete/{*path}", post(pages::delete))
.route("/sections", get(sections::list))
.route(
"/sections/new",
get(sections::new_form).post(sections::create),
)
.route("/sections/delete/{*path}", post(sections::delete))
.route(
"/sections/{*path}",
get(sections::edit).post(sections::save),
)
.route("/config", get(config::edit).post(config::save))
.route("/assets", get(assets::list))
.route("/assets/upload", post(assets::upload))
.route("/assets/delete", post(assets::delete))
.route("/build", post(build::trigger))
.route("/_render-markdown", post(build::render_preview))
.route("/preview", get(preview::serve))
.route("/preview/", get(preview::serve))
.route("/preview/{*path}", get(preview::serve))
.route("/static/htmx.min.js", get(serve_htmx))
.route("/__livereload", any(livereload_ws))
.route("/setup", get(onboarding::welcome))
.route(
"/setup/template",
get(onboarding::template).post(onboarding::template_submit),
)
.route(
"/setup/theme",
get(onboarding::theme).post(onboarding::theme_submit),
)
.route("/setup/configure", get(onboarding::configure))
.route("/setup/create", post(onboarding::create))
.with_state(state)
}
pub(crate) fn webapp_bind_addr(port: u16) -> SocketAddr {
SocketAddr::from(([127, 0, 0, 1], port))
}
pub fn run_webapp(root: &Path, output_dir: &Path, sandbox: Option<&Path>) -> anyhow::Result<()> {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(async {
let port: u16 = DEFAULT_WEBAPP_PORT;
let addr = webapp_bind_addr(port);
let listener = match tokio::net::TcpListener::bind(addr).await {
Ok(l) => l,
Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
eprintln!("Port {port} is in use, using a random available port...");
let fallback = webapp_bind_addr(0);
tokio::net::TcpListener::bind(fallback).await?
}
Err(e) => return Err(e.into()),
};
let actual_addr = listener.local_addr()?;
let actual_port = actual_addr.port();
let preview_base_url = format!("http://127.0.0.1:{actual_port}{PREVIEW_MOUNT}");
let (reload_tx, _) = broadcast::channel::<()>(RELOAD_CHANNEL_CAPACITY);
let state = Arc::new(AppState {
root: root.to_path_buf(),
output_dir: output_dir.to_path_buf(),
sandbox: sandbox.map(|p| p.to_path_buf()),
reload_tx,
preview_base_url,
});
if state.site_exists() {
if let Err(e) = rebuild_site(&state) {
eprintln!("initial rebuild failed: {e}");
}
}
let start_path = if state.site_exists() { "/" } else { "/setup" };
let app = app(state);
println!("zorto webapp: http://localhost:{actual_port}");
let _ = open::that(format!("http://localhost:{actual_port}{start_path}"));
axum::serve(listener, app)
.with_graceful_shutdown(async {
tokio::signal::ctrl_c().await.ok();
println!("\nshutting down...");
})
.await?;
Ok(())
})
}
async fn serve_htmx() -> impl axum::response::IntoResponse {
(
[("content-type", "application/javascript")],
include_str!("htmx.min.js"),
)
}
async fn livereload_ws(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> Response {
ws.on_upgrade(move |socket| handle_livereload(socket, state))
}
async fn handle_livereload(mut socket: WebSocket, state: Arc<AppState>) {
let mut rx = state.reload_tx.subscribe();
while rx.recv().await.is_ok() {
if socket
.send(Message::Text(String::from("reload").into()))
.await
.is_err()
{
break;
}
}
}
pub(crate) fn rebuild_site(state: &AppState) -> Result<(), String> {
match zorto_core::site::Site::load(&state.root, &state.output_dir, true) {
Ok(mut site) => {
site.sandbox = state.sandbox.clone();
site.set_base_url(state.preview_base_url.clone());
site.build().map_err(|e| e.to_string())?;
let _ = state.reload_tx.send(());
Ok(())
}
Err(e) => Err(e.to_string()),
}
}
pub(crate) fn escape(s: &str) -> String {
zorto_core::content::escape_html(s)
}
pub(crate) fn validate_path(base: &Path, user_path: &str) -> Result<PathBuf, String> {
let joined = base.join(user_path);
let canonical_base = base
.canonicalize()
.map_err(|e| format!("Base directory does not exist: {e}"))?;
let canonical = if joined.exists() {
joined
.canonicalize()
.map_err(|e| format!("Cannot resolve path: {e}"))?
} else {
let parent = joined.parent().ok_or("Invalid path")?;
let canonical_parent = parent
.canonicalize()
.map_err(|e| format!("Parent directory does not exist: {e}"))?;
canonical_parent.join(joined.file_name().ok_or("Invalid filename")?)
};
if !canonical.starts_with(&canonical_base) {
return Err("Path traversal detected".to_string());
}
Ok(canonical)
}
#[cfg(test)]
mod integration_tests;
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_webapp_bind_addr_defaults_to_localhost() {
let addr = webapp_bind_addr(DEFAULT_WEBAPP_PORT);
assert_eq!(addr.ip().to_string(), "127.0.0.1");
assert_eq!(addr.port(), DEFAULT_WEBAPP_PORT);
let fallback = webapp_bind_addr(0);
assert_eq!(fallback.ip().to_string(), "127.0.0.1");
assert_eq!(fallback.port(), 0);
}
#[test]
fn test_validate_path_normal() {
let tmp = TempDir::new().unwrap();
let base = tmp.path();
std::fs::write(base.join("file.txt"), "hello").unwrap();
let result = validate_path(base, "file.txt");
assert!(result.is_ok());
}
#[test]
fn test_validate_path_traversal_blocked() {
let tmp = TempDir::new().unwrap();
let base = tmp.path().join("subdir");
std::fs::create_dir_all(&base).unwrap();
let result = validate_path(&base, "../../../etc/passwd");
let err = result.unwrap_err();
assert!(err.contains("Path traversal detected") || err.contains("does not exist"));
}
#[test]
fn test_validate_path_dotdot_traversal() {
let tmp = TempDir::new().unwrap();
let parent = tmp.path();
let base = parent.join("site");
let outside = parent.join("secret");
std::fs::create_dir_all(&base).unwrap();
std::fs::create_dir_all(&outside).unwrap();
std::fs::write(outside.join("data.txt"), "secret").unwrap();
let result = validate_path(&base, "../secret/data.txt");
assert!(result.is_err());
}
#[test]
fn test_validate_path_new_file_in_base() {
let tmp = TempDir::new().unwrap();
let base = tmp.path();
let result = validate_path(base, "new_file.txt");
assert!(result.is_ok());
}
#[test]
fn test_validate_path_subdirectory() {
let tmp = TempDir::new().unwrap();
let base = tmp.path();
let sub = base.join("sub");
std::fs::create_dir_all(&sub).unwrap();
std::fs::write(sub.join("file.txt"), "data").unwrap();
let result = validate_path(base, "sub/file.txt");
assert!(result.is_ok());
}
#[test]
fn test_validate_path_nonexistent_base() {
let result = validate_path(Path::new("/nonexistent/base/dir"), "file.txt");
assert!(result.is_err());
assert!(result.unwrap_err().contains("does not exist"));
}
#[test]
fn test_validate_path_symlink_escape() {
let tmp = TempDir::new().unwrap();
let base = tmp.path().join("site");
let outside = tmp.path().join("outside");
std::fs::create_dir_all(&base).unwrap();
std::fs::create_dir_all(&outside).unwrap();
std::fs::write(outside.join("secret.txt"), "secret data").unwrap();
#[cfg(unix)]
{
std::os::unix::fs::symlink(&outside, base.join("escape")).unwrap();
let result = validate_path(&base, "escape/secret.txt");
assert!(result.is_err(), "symlink escape should be blocked");
}
}
}