calepin 0.0.12

A Rust CLI for preprocessing Typst documents with executable code chunks
use std::fs;
use std::path::{Component, Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::Duration;

use anyhow::{anyhow, Context, Result};
use tiny_http::{Header, Method, Request, Response, Server, StatusCode};

use crate::cli::ServeArgs;

const DEFAULT_PORT: u16 = 8000;
const AUTO_PORT_ATTEMPTS: u16 = 50;
const STATUS_ENDPOINT: &str = "/__calepin/status";

/// Shared build status polled by the injected reload script: a version that
/// bumps on successful rebuilds, and the latest rebuild error, if any.
pub(crate) struct LiveReload {
    version: AtomicU64,
    error: Mutex<Option<String>>,
}

impl LiveReload {
    pub(crate) fn new() -> Arc<Self> {
        Arc::new(Self {
            version: AtomicU64::new(1),
            error: Mutex::new(None),
        })
    }

    pub(crate) fn rebuilt(&self) {
        *self.error.lock().unwrap() = None;
        self.version.fetch_add(1, Ordering::Relaxed);
    }

    pub(crate) fn set_error(&self, message: String) {
        *self.error.lock().unwrap() = Some(message);
    }

    fn status_json(&self) -> String {
        serde_json::json!({
            "version": self.version.load(Ordering::Relaxed),
            "error": *self.error.lock().unwrap(),
        })
        .to_string()
    }
}

pub(crate) struct ServeHandle {
    stop: Arc<AtomicBool>,
    join: Option<JoinHandle<()>>,
}

impl ServeHandle {
    pub(crate) fn stop(mut self) {
        self.stop.store(true, Ordering::Relaxed);
        if let Some(join) = self.join.take() {
            let _ = join.join();
        }
    }
}

pub(crate) fn serve(args: ServeArgs) -> Result<()> {
    let root = validate_root(&args.dir)?;
    let (server, bind) = bind_server(&args.host, args.port)?;
    eprintln!("Serving {} at http://{bind}/", root.display());
    eprintln!("Press Ctrl+C to stop.");
    run_server(server, root, None, Arc::new(AtomicBool::new(false)))
}

pub(crate) fn start(
    dir: &Path,
    host: &str,
    port: Option<u16>,
    live: Arc<LiveReload>,
) -> Result<ServeHandle> {
    let root = validate_root(dir)?;
    let (server, bind) = bind_server(host, port)?;
    eprintln!("Serving {} at http://{bind}/", root.display());
    let stop = Arc::new(AtomicBool::new(false));
    let thread_stop = Arc::clone(&stop);
    let join = thread::spawn(move || {
        if let Err(error) = run_server(server, root, Some(live), thread_stop) {
            cwarn!("serve failed: {}", error);
        }
    });
    Ok(ServeHandle {
        stop,
        join: Some(join),
    })
}

fn validate_root(dir: &Path) -> Result<PathBuf> {
    let root =
        fs::canonicalize(dir).with_context(|| format!("failed to resolve {}", dir.display()))?;
    if !root.is_dir() {
        return Err(anyhow!("serve path is not a directory: {}", root.display()));
    }
    Ok(root)
}

/// Binds the requested port, or scans for a free one starting at
/// `DEFAULT_PORT` when no port was requested. Returns the bound address.
fn bind_server(host: &str, port: Option<u16>) -> Result<(Server, String)> {
    match port {
        Some(port) => {
            let bind = format!("{host}:{port}");
            let server = Server::http(&bind).map_err(|error| {
                let message = error.to_string();
                if is_port_in_use(&message) {
                    anyhow!(
                        "failed to bind {bind}: address already in use; choose another port with `--port`"
                    )
                } else {
                    anyhow!("failed to bind {bind}: {message}")
                }
            })?;
            Ok((server, bind))
        }
        None => auto_bind(host, DEFAULT_PORT, AUTO_PORT_ATTEMPTS),
    }
}

fn auto_bind(host: &str, start: u16, attempts: u16) -> Result<(Server, String)> {
    let mut last_message = String::new();
    for port in start..start.saturating_add(attempts) {
        let bind = format!("{host}:{port}");
        match Server::http(&bind) {
            Ok(server) => return Ok((server, bind)),
            Err(error) => {
                last_message = error.to_string();
                if !is_port_in_use(&last_message) {
                    return Err(anyhow!("failed to bind {bind}: {last_message}"));
                }
            }
        }
    }
    Err(anyhow!(
        "failed to bind {host}: no free port between {start} and {}: {last_message}",
        start.saturating_add(attempts).saturating_sub(1)
    ))
}

fn is_port_in_use(message: &str) -> bool {
    // EADDRINUSE: macOS (48), Linux (98), Windows (10048).
    message.contains("in use")
        || message.contains("os error 48")
        || message.contains("os error 98")
        || message.contains("os error 10048")
}

fn run_server(
    server: Server,
    root: PathBuf,
    live: Option<Arc<LiveReload>>,
    stop: Arc<AtomicBool>,
) -> Result<()> {
    while !stop.load(Ordering::Relaxed) {
        match server.recv_timeout(Duration::from_millis(200)) {
            Ok(Some(request)) => {
                if let Err(error) = respond(request, &root, live.as_deref()) {
                    cwarn!("serve request failed: {}", error);
                }
            }
            Ok(None) => {}
            Err(error) => {
                cwarn!("serve connection failed: {}", error);
            }
        }
    }
    Ok(())
}

fn respond(request: Request, root: &Path, live: Option<&LiveReload>) -> Result<()> {
    if request.url() == STATUS_ENDPOINT {
        let status = live
            .map(LiveReload::status_json)
            .unwrap_or_else(|| "{\"version\":0,\"error\":null}".to_string());
        return send_text(request, 200, "application/json; charset=utf-8", &status);
    }

    if request.method() != &Method::Get && request.method() != &Method::Head {
        return send_text(
            request,
            405,
            "text/plain; charset=utf-8",
            "Method Not Allowed",
        );
    }

    let Some(path) = resolve_request_path(root, request.url()) else {
        return send_text(request, 403, "text/plain; charset=utf-8", "Forbidden");
    };
    let path = if path.is_dir() {
        path.join("index.html")
    } else {
        path
    };
    if !path.is_file() {
        return send_text(request, 404, "text/plain; charset=utf-8", "Not Found");
    }

    let mime = mime_guess::from_path(&path)
        .first_or_octet_stream()
        .essence_str()
        .to_string();
    let is_html = mime == "text/html";
    let mut body = fs::read(&path).with_context(|| format!("failed to read {}", path.display()))?;
    if is_html && live.is_some() {
        body = inject_reload_script(body);
    }

    if request.method() == &Method::Head {
        let response = Response::empty(StatusCode(200))
            .with_header(content_type_header(&mime)?)
            .with_header(no_store_header()?);
        request.respond(response)?;
    } else {
        let response = Response::from_data(body)
            .with_status_code(StatusCode(200))
            .with_header(content_type_header(&mime)?)
            .with_header(no_store_header()?);
        request.respond(response)?;
    }
    Ok(())
}

fn send_text(request: Request, status: u16, content_type: &str, body: &str) -> Result<()> {
    let response = Response::from_string(body)
        .with_status_code(StatusCode(status))
        .with_header(content_type_header(content_type)?)
        .with_header(no_store_header()?);
    request.respond(response)?;
    Ok(())
}

fn content_type_header(value: &str) -> Result<Header> {
    Header::from_bytes("Content-Type", value)
        .map_err(|_| anyhow!("failed to create content-type header"))
}

fn no_store_header() -> Result<Header> {
    Header::from_bytes("Cache-Control", "no-store")
        .map_err(|_| anyhow!("failed to create cache-control header"))
}

fn inject_reload_script(mut body: Vec<u8>) -> Vec<u8> {
    const SCRIPT: &[u8] = br#"<script>
(() => {
  let version = null;
  const overlayId = "calepin-error-overlay";
  function setOverlay(message) {
    let overlay = document.getElementById(overlayId);
    if (!message) {
      if (overlay) overlay.remove();
      return;
    }
    if (!overlay) {
      overlay = document.createElement("pre");
      overlay.id = overlayId;
      overlay.style.cssText =
        "position:fixed;inset:0;z-index:2147483647;margin:0;padding:2rem;" +
        "overflow:auto;background:rgba(15,15,15,0.95);color:#ff8a8a;" +
        "font:14px/1.5 ui-monospace,monospace;white-space:pre-wrap;";
      document.body.appendChild(overlay);
    }
    overlay.textContent = "calepin build failed\n\n" + message;
  }
  async function poll() {
    try {
      const response = await fetch("/__calepin/status", { cache: "no-store" });
      if (!response.ok) return;
      const status = await response.json();
      setOverlay(status.error);
      if (version === null) version = status.version;
      else if (!status.error && status.version !== version) window.location.reload();
    } catch {}
  }
  window.setInterval(poll, 1000);
  poll();
})();
</script>"#;

    if let Some(index) = find_case_insensitive(&body, b"</body>") {
        let mut out = Vec::with_capacity(body.len() + SCRIPT.len());
        out.extend_from_slice(&body[..index]);
        out.extend_from_slice(SCRIPT);
        out.extend_from_slice(&body[index..]);
        out
    } else {
        body.extend_from_slice(SCRIPT);
        body
    }
}

fn find_case_insensitive(haystack: &[u8], needle: &[u8]) -> Option<usize> {
    haystack
        .windows(needle.len())
        .position(|window| window.eq_ignore_ascii_case(needle))
}

fn resolve_request_path(root: &Path, target: &str) -> Option<PathBuf> {
    let path = target
        .split_once('?')
        .map(|(path, _)| path)
        .unwrap_or(target);
    let decoded = percent_decode(path)?;
    let mut out = root.to_path_buf();
    for component in Path::new(decoded.trim_start_matches('/')).components() {
        match component {
            Component::Normal(part) => out.push(part),
            Component::CurDir => {}
            Component::ParentDir | Component::RootDir | Component::Prefix(_) => return None,
        }
    }
    Some(out)
}

fn percent_decode(value: &str) -> Option<String> {
    let bytes = value.as_bytes();
    let mut out = Vec::with_capacity(bytes.len());
    let mut index = 0;
    while index < bytes.len() {
        if bytes[index] == b'%' {
            if index + 2 >= bytes.len() {
                return None;
            }
            let high = hex_value(bytes[index + 1])?;
            let low = hex_value(bytes[index + 2])?;
            out.push(high * 16 + low);
            index += 3;
        } else {
            out.push(bytes[index]);
            index += 1;
        }
    }
    String::from_utf8(out).ok()
}

fn hex_value(byte: u8) -> Option<u8> {
    match byte {
        b'0'..=b'9' => Some(byte - b'0'),
        b'a'..=b'f' => Some(byte - b'a' + 10),
        b'A'..=b'F' => Some(byte - b'A' + 10),
        _ => None,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn request_path_rejects_traversal() {
        let root = Path::new("/tmp/site");
        assert!(resolve_request_path(root, "/index.html").is_some());
        assert!(resolve_request_path(root, "/../secret").is_none());
        assert!(resolve_request_path(root, "/%2e%2e/secret").is_none());
    }

    #[test]
    fn injects_reload_script_before_body_close() {
        let html = inject_reload_script(b"<html><body>Hello</body></html>".to_vec());
        let text = String::from_utf8(html).unwrap();
        assert!(text.contains(STATUS_ENDPOINT));
        assert!(text.contains("calepin-error-overlay"));
        assert!(text.contains("</script></body>"));
    }

    #[test]
    fn live_reload_reports_error_then_clears_on_rebuild() {
        let live = LiveReload::new();
        let status: serde_json::Value = serde_json::from_str(&live.status_json()).unwrap();
        assert!(status["error"].is_null());
        let initial_version = status["version"].as_u64().unwrap();

        live.set_error("boom".to_string());
        let status: serde_json::Value = serde_json::from_str(&live.status_json()).unwrap();
        assert_eq!(status["error"], "boom");

        live.rebuilt();
        let status: serde_json::Value = serde_json::from_str(&live.status_json()).unwrap();
        assert!(status["error"].is_null());
        assert!(status["version"].as_u64().unwrap() > initial_version);
    }

    #[test]
    fn auto_bind_skips_busy_port() {
        let busy = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
        let busy_port = busy.local_addr().unwrap().port();

        let (_server, bind) = auto_bind("127.0.0.1", busy_port, 10).unwrap();

        assert_ne!(bind, format!("127.0.0.1:{busy_port}"));
    }
}