Skip to main content

rustpress_dev/
lib.rs

1use std::fs::File;
2use std::io::Read;
3use std::path::{Path, PathBuf};
4use std::sync::mpsc;
5use std::sync::{
6    atomic::{AtomicU64, Ordering},
7    Arc,
8};
9use std::thread;
10use std::time::{Duration, Instant};
11
12use anyhow::Result;
13use notify::{Config as NotifyConfig, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
14use rustpress_core::{build_site, BuildOptions, Config};
15use tiny_http::{Header, Response, Server, StatusCode};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct ServeOptions {
19    pub config_path: PathBuf,
20    pub host: String,
21    pub port: u16,
22}
23
24pub fn serve_preview(options: ServeOptions) -> Result<()> {
25    let config = Config::load(&options.config_path)?;
26    let root = config_root(&options.config_path).join(config.out_dir);
27    serve_dir(&root, &options.host, options.port, None)
28}
29
30pub fn serve_dev(options: ServeOptions) -> Result<()> {
31    let build_options = BuildOptions::new(options.config_path.clone());
32    let result = build_site(build_options.clone())?;
33    let root = result.out_dir;
34    let config_path = options.config_path.clone();
35    let refresh_version = Arc::new(AtomicU64::new(1));
36
37    let (tx, rx) = mpsc::channel();
38    let mut watcher = RecommendedWatcher::new(tx, NotifyConfig::default())?;
39    let config = Config::load(&config_path)?;
40    let project_root = config_root(&config_path);
41    watcher.watch(
42        &project_root.join(&config.src_dir),
43        RecursiveMode::Recursive,
44    )?;
45    watcher.watch(&config_path, RecursiveMode::NonRecursive)?;
46
47    let rebuild_config_path = config_path.clone();
48    let rebuild_refresh_version = Arc::clone(&refresh_version);
49    thread::spawn(move || {
50        let mut last = Instant::now() - Duration::from_secs(2);
51        while let Ok(event) = rx.recv() {
52            let Ok(event) = event else { continue };
53            if !matches!(
54                event.kind,
55                EventKind::Create(_) | EventKind::Modify(_) | EventKind::Remove(_)
56            ) {
57                continue;
58            }
59            if last.elapsed() < Duration::from_millis(250) {
60                continue;
61            }
62            last = Instant::now();
63            match build_site(BuildOptions::new(rebuild_config_path.clone())) {
64                Ok(result) => {
65                    rebuild_refresh_version.fetch_add(1, Ordering::SeqCst);
66                    eprintln!("rebuilt {} page(s)", result.page_count);
67                }
68                Err(err) => eprintln!("rebuild failed: {err:?}"),
69            }
70        }
71    });
72
73    println!(
74        "RustPress dev server: http://{}:{}/",
75        options.host, options.port
76    );
77    serve_dir(&root, &options.host, options.port, Some(refresh_version))
78}
79
80fn serve_dir(
81    root: &Path,
82    host: &str,
83    port: u16,
84    refresh_version: Option<Arc<AtomicU64>>,
85) -> Result<()> {
86    let address = format!("{host}:{port}");
87    let server = Server::http(&address).map_err(|err| anyhow::anyhow!("{err}"))?;
88    println!("Serving {} at http://{address}/", root.display());
89
90    for request in server.incoming_requests() {
91        let url = request.url().split('?').next().unwrap_or("/");
92        if url == "/__rustpress/version" {
93            let version = refresh_version
94                .as_ref()
95                .map(|version| version.load(Ordering::SeqCst))
96                .unwrap_or(0);
97            let mut response = Response::from_string(version.to_string());
98            if let Some(header) = header("Cache-Control", "no-store") {
99                response = response.with_header(header);
100            }
101            if let Some(header) = header("Content-Type", "text/plain; charset=utf-8") {
102                response = response.with_header(header);
103            }
104            request.respond(response)?;
105            continue;
106        }
107
108        let path = resolve_path(root, url);
109        let response = match File::open(&path) {
110            Ok(mut file) => {
111                let mut bytes = Vec::new();
112                file.read_to_end(&mut bytes)?;
113                if refresh_version.is_some()
114                    && path.extension().and_then(|value| value.to_str()) == Some("html")
115                {
116                    bytes = inject_live_reload(bytes);
117                }
118                let mut response = Response::from_data(bytes);
119                if let Some(header) = content_type_header(&path) {
120                    response = response.with_header(header);
121                }
122                response
123            }
124            Err(_) => Response::from_string("Not Found").with_status_code(StatusCode(404)),
125        };
126        request.respond(response)?;
127    }
128    Ok(())
129}
130
131const LIVE_RELOAD_SCRIPT: &str = r#"<script>
132(() => {
133  let current = null;
134  async function check() {
135    try {
136      const response = await fetch("/__rustpress/version", { cache: "no-store" });
137      const next = await response.text();
138      if (current === null) current = next;
139      else if (next !== current) {
140        window.dispatchEvent(new CustomEvent("rustpress:refresh"));
141        location.reload();
142      }
143    } catch (_) {}
144  }
145  setInterval(check, 700);
146  check();
147})();
148</script>"#;
149
150fn inject_live_reload(bytes: Vec<u8>) -> Vec<u8> {
151    let mut html = match String::from_utf8(bytes) {
152        Ok(html) => html,
153        Err(err) => return err.into_bytes(),
154    };
155
156    if let Some(index) = html.rfind("</body>") {
157        html.insert_str(index, LIVE_RELOAD_SCRIPT);
158    } else {
159        html.push_str(LIVE_RELOAD_SCRIPT);
160    }
161    html.into_bytes()
162}
163
164fn resolve_path(root: &Path, url: &str) -> PathBuf {
165    let clean = url.trim_start_matches('/');
166    let candidate = root.join(clean);
167    if url.ends_with('/') || clean.is_empty() {
168        return candidate.join("index.html");
169    }
170    if candidate.is_dir() {
171        candidate.join("index.html")
172    } else {
173        candidate
174    }
175}
176
177fn content_type_header(path: &Path) -> Option<Header> {
178    let content_type = match path.extension().and_then(|value| value.to_str()) {
179        Some("html") => "text/html; charset=utf-8",
180        Some("css") => "text/css; charset=utf-8",
181        Some("js") => "text/javascript; charset=utf-8",
182        Some("json") => "application/json; charset=utf-8",
183        Some("wasm") => "application/wasm",
184        Some("br") => "application/octet-stream",
185        Some("svg") => "image/svg+xml",
186        Some("png") => "image/png",
187        Some("jpg") | Some("jpeg") => "image/jpeg",
188        Some("webp") => "image/webp",
189        _ => return None,
190    };
191    header("Content-Type", content_type)
192}
193
194fn header(name: &str, value: &str) -> Option<Header> {
195    Header::from_bytes(name, value).ok()
196}
197
198fn config_root(config_path: &Path) -> PathBuf {
199    config_path
200        .parent()
201        .map(Path::to_path_buf)
202        .unwrap_or_else(|| PathBuf::from("."))
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn resolves_directory_routes() {
211        let root = Path::new("/tmp/site");
212
213        assert_eq!(resolve_path(root, "/"), Path::new("/tmp/site/index.html"));
214        assert_eq!(
215            resolve_path(root, "/guide/"),
216            Path::new("/tmp/site/guide/index.html")
217        );
218        assert_eq!(
219            resolve_path(root, "/assets/rustpress.css"),
220            Path::new("/tmp/site/assets/rustpress.css")
221        );
222    }
223
224    #[test]
225    fn injects_live_reload_before_body_close() {
226        let html = inject_live_reload(b"<html><body>Docs</body></html>".to_vec());
227        let html = String::from_utf8(html).unwrap();
228
229        assert!(html.contains("/__rustpress/version"));
230        assert!(html.contains("rustpress:refresh"));
231        assert!(html.contains("</script></body>"));
232    }
233}