Skip to main content

karbon_framework/hmr/
hmr_handler.rs

1use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
2use axum::response::Response;
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::broadcast;
7
8/// Watches files and notifies connected browsers via WebSocket.
9///
10/// ```ignore
11/// use framework::hmr::HmrServer;
12///
13/// let hmr = HmrServer::new()
14///     .watch("templates/")
15///     .watch("static/");
16///
17/// // Mount the HMR endpoints
18/// let app = Router::new()
19///     .route("/_hmr/ws", get({
20///         let hmr = hmr.clone();
21///         move |ws: WebSocketUpgrade| hmr.ws_handler(ws)
22///     }))
23///     .layer(hmr.inject_script_layer());
24///
25/// // Start watching (spawns a background task)
26/// hmr.start();
27/// ```
28#[derive(Clone)]
29pub struct HmrServer {
30    tx: broadcast::Sender<HmrEvent>,
31    watch_paths: Arc<Vec<PathBuf>>,
32}
33
34#[derive(Debug, Clone)]
35enum HmrEvent {
36    /// A CSS file changed — hot-swap without reload
37    CssChanged { path: String },
38    /// Any other file changed — trigger full reload
39    FullReload { path: String },
40}
41
42impl HmrServer {
43    pub fn new() -> Self {
44        let (tx, _) = broadcast::channel(64);
45        Self {
46            tx,
47            watch_paths: Arc::new(Vec::new()),
48        }
49    }
50
51    /// Add a directory to watch for changes
52    pub fn watch(mut self, path: &str) -> Self {
53        Arc::get_mut(&mut self.watch_paths)
54            .expect("watch() must be called before start()")
55            .push(PathBuf::from(path));
56        self
57    }
58
59    /// Start watching files in a background task.
60    /// Uses polling (checks every 500ms) for maximum cross-platform compatibility.
61    pub fn start(&self) {
62        let tx = self.tx.clone();
63        let paths = self.watch_paths.clone();
64
65        tokio::spawn(async move {
66            let mut last_modified = std::collections::HashMap::new();
67            let mut scan_count = 0u64;
68
69            loop {
70                let mut current_paths = std::collections::HashSet::new();
71
72                for watch_path in paths.iter() {
73                    if let Ok(entries) = walk_dir(watch_path) {
74                        for entry in entries {
75                            let modified = entry
76                                .metadata()
77                                .ok()
78                                .and_then(|m| m.modified().ok());
79
80                            let Some(modified) = modified else { continue };
81                            let path_str = entry.path().display().to_string();
82                            current_paths.insert(path_str.clone());
83
84                            let changed = match last_modified.get(&path_str) {
85                                Some(prev) => *prev != modified,
86                                None => false, // first scan, don't trigger
87                            };
88
89                            last_modified.insert(path_str.clone(), modified);
90
91                            if changed {
92                                let event = if path_str.ends_with(".css") {
93                                    HmrEvent::CssChanged { path: path_str }
94                                } else {
95                                    HmrEvent::FullReload { path: path_str }
96                                };
97                                let _ = tx.send(event);
98                            }
99                        }
100                    }
101                }
102
103                // Purge deleted files every 20 scans (~10s) to prevent unbounded growth
104                scan_count += 1;
105                if scan_count % 20 == 0 {
106                    last_modified.retain(|k, _| current_paths.contains(k));
107                }
108
109                tokio::time::sleep(Duration::from_millis(500)).await;
110            }
111        });
112    }
113
114    /// WebSocket handler for HMR clients
115    pub async fn ws_handler(self, ws: WebSocketUpgrade) -> Response {
116        ws.on_upgrade(move |socket| self.handle_ws(socket))
117    }
118
119    async fn handle_ws(self, mut socket: WebSocket) {
120        let mut rx = self.tx.subscribe();
121
122        loop {
123            tokio::select! {
124                msg = rx.recv() => {
125                    match msg {
126                        Ok(HmrEvent::CssChanged { path }) => {
127                            let json = serde_json::json!({ "type": "css", "path": path });
128                            if socket.send(Message::Text(json.to_string().into())).await.is_err() {
129                                break;
130                            }
131                        }
132                        Ok(HmrEvent::FullReload { path }) => {
133                            let json = serde_json::json!({ "type": "reload", "path": path });
134                            if socket.send(Message::Text(json.to_string().into())).await.is_err() {
135                                break;
136                            }
137                        }
138                        Err(broadcast::error::RecvError::Lagged(_)) => continue,
139                        Err(_) => break,
140                    }
141                }
142                msg = socket.recv() => {
143                    match msg {
144                        Some(Ok(Message::Close(_))) | None => break,
145                        _ => {}
146                    }
147                }
148            }
149        }
150    }
151
152    /// Returns the client-side script tag to inject into HTML pages in dev mode.
153    ///
154    /// ```ignore
155    /// // In your HTML template or layout:
156    /// if cfg!(debug_assertions) {
157    ///     format!("{}\n{}", body_html, hmr.client_script())
158    /// }
159    /// ```
160    pub fn client_script(&self) -> String {
161        format!("<script>{}</script>", CLIENT_JS)
162    }
163}
164
165impl Default for HmrServer {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171const MAX_WALK_DEPTH: usize = 10;
172const MAX_WALK_FILES: usize = 10_000;
173
174fn walk_dir(path: &std::path::Path) -> Result<Vec<std::fs::DirEntry>, std::io::Error> {
175    let mut entries = Vec::new();
176    walk_dir_inner(path, 0, &mut entries)?;
177    Ok(entries)
178}
179
180fn walk_dir_inner(
181    path: &std::path::Path,
182    depth: usize,
183    entries: &mut Vec<std::fs::DirEntry>,
184) -> Result<(), std::io::Error> {
185    if depth > MAX_WALK_DEPTH || entries.len() >= MAX_WALK_FILES || !path.is_dir() {
186        return Ok(());
187    }
188    for entry in std::fs::read_dir(path)? {
189        let entry = entry?;
190        let file_type = entry.file_type()?;
191        if file_type.is_dir() {
192            let name = entry.file_name();
193            let name = name.to_string_lossy();
194            if name.starts_with('.') || name == "node_modules" || name == "target" {
195                continue;
196            }
197            walk_dir_inner(&entry.path(), depth + 1, entries)?;
198        } else if file_type.is_file() {
199            entries.push(entry);
200            if entries.len() >= MAX_WALK_FILES {
201                return Ok(());
202            }
203        }
204    }
205    Ok(())
206}
207
208/// Client-side HMR runtime — connects to the dev server, hot-swaps CSS,
209/// full-reloads for everything else.
210const CLIENT_JS: &str = r#"
211(function() {
212    if (typeof window === 'undefined') return;
213    const url = (location.protocol === 'https:' ? 'wss' : 'ws') + '://' + location.host + '/_hmr/ws';
214    let ws;
215    function connect() {
216        ws = new WebSocket(url);
217        ws.onmessage = function(e) {
218            try {
219                const msg = JSON.parse(e.data);
220                if (msg.type === 'css') {
221                    // Hot-swap CSS: reload all stylesheets
222                    document.querySelectorAll('link[rel="stylesheet"]').forEach(function(link) {
223                        const href = link.href.split('?')[0];
224                        link.href = href + '?_hmr=' + Date.now();
225                    });
226                    console.log('[HMR] CSS updated:', msg.path);
227                } else if (msg.type === 'reload') {
228                    console.log('[HMR] Reloading:', msg.path);
229                    location.reload();
230                }
231            } catch {}
232        };
233        ws.onclose = function() { setTimeout(connect, 2000); };
234    }
235    connect();
236})();
237"#;