Skip to main content

live_server/http_layer/
server.rs

1use axum::{
2    Router,
3    body::Body,
4    extract::{
5        Request, State, WebSocketUpgrade,
6        ws::{Message, Utf8Bytes, WebSocket},
7    },
8    http::{HeaderMap, HeaderValue, StatusCode, header},
9    routing::get,
10};
11use futures::{sink::SinkExt, stream::StreamExt};
12use mime_guess::mime;
13use std::{
14    fs,
15    io::ErrorKind,
16    path::{Path, PathBuf},
17    sync::Arc,
18};
19use tokio::{net::TcpListener, sync::broadcast};
20
21use crate::{
22    http_layer::template::{error_html, index_html},
23    utils::is_ignored,
24};
25
26/// JS script containing a function that takes in the address and connects to the websocket.
27const WEBSOCKET_FUNCTION: &str = include_str!("../templates/websocket.js");
28
29/// JS script to inject to the HTML on reload so the client
30/// knows it's a successful reload.
31const RELOAD_PAYLOAD: &str = include_str!("../templates/reload.js");
32
33pub(crate) async fn serve(tcp_listener: TcpListener, router: Router) {
34    axum::serve(tcp_listener, router).await.unwrap();
35}
36
37pub struct Options {
38    /// Always hard reload the page instead of hot-reload
39    pub hard_reload: bool,
40    /// Show page list of the current URL if `index.html` does not exist
41    pub index_listing: bool,
42    /// Ignore hidden and ignored files
43    pub auto_ignore: bool,
44}
45
46pub(crate) struct AppState {
47    /// Always hard reload the page instead of hot-reload
48    pub(crate) hard_reload: bool,
49    /// Show page list of the current URL if `index.html` does not exist
50    pub(crate) index_listing: bool,
51    /// Ignore hidden and ignored files
52    pub(crate) auto_ignore: bool,
53    pub(crate) tx: Arc<broadcast::Sender<()>>,
54    pub(crate) root: PathBuf,
55}
56
57impl Default for Options {
58    fn default() -> Self {
59        Self {
60            hard_reload: false,
61            index_listing: true,
62            auto_ignore: false,
63        }
64    }
65}
66
67pub(crate) fn create_server(state: AppState) -> Router {
68    let tx = state.tx.clone();
69    Router::new()
70        .route("/", get(static_assets))
71        .route("/{*path}", get(static_assets))
72        .route(
73            "/live-server-ws",
74            get(|ws: WebSocketUpgrade| async move {
75                ws.on_failed_upgrade(|error| {
76                    log::error!("Failed to upgrade websocket: {error}");
77                })
78                .on_upgrade(|socket: WebSocket| on_websocket_upgrade(socket, tx))
79            }),
80        )
81        .with_state(Arc::new(state))
82}
83
84async fn on_websocket_upgrade(socket: WebSocket, tx: Arc<broadcast::Sender<()>>) {
85    let (mut sender, mut receiver) = socket.split();
86    let mut rx = tx.subscribe();
87    let mut send_task = tokio::spawn(async move {
88        while rx.recv().await.is_ok() {
89            sender
90                .send(Message::Text(Utf8Bytes::default()))
91                .await
92                .unwrap();
93        }
94    });
95    let mut recv_task =
96        tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} });
97    tokio::select! {
98        _ = (&mut send_task) => recv_task.abort(),
99        _ = (&mut recv_task) => send_task.abort(),
100    };
101}
102
103fn get_index_listing(uri_path: &str, root: &Path, auto_ignore: bool) -> String {
104    let is_root = uri_path == "/";
105    let path = root.join(&uri_path[1..]);
106    let entries = fs::read_dir(path).unwrap();
107    let mut entry_names = entries
108        .into_iter()
109        .filter_map(|e| {
110            if let Ok(entry) = e {
111                if auto_ignore {
112                    match is_ignored(root, &entry.path()) {
113                        Ok(ignored) => {
114                            if ignored {
115                                return None;
116                            }
117                        }
118                        Err(err) => {
119                            log::error!("Failed to check ignore files: {err}");
120                            // Do nothing if we cannot know if it's an ignored entry
121                            return None;
122                        }
123                    }
124                }
125                let is_dir = entry.metadata().ok()?.is_dir();
126                let trailing = if is_dir { "/" } else { "" };
127                entry
128                    .file_name()
129                    .to_str()
130                    .map(|name| format!("{name}{trailing}"))
131            } else {
132                None
133            }
134        })
135        .collect::<Vec<String>>();
136    entry_names.sort();
137    if !is_root {
138        entry_names.insert(0, "..".to_string());
139    }
140    entry_names
141        .into_iter()
142        .map(|en| format!("<li><a href=\"{en}\">{en}</a></li>"))
143        .collect::<Vec<String>>()
144        .join("\n")
145}
146
147async fn static_assets(
148    state: State<Arc<AppState>>,
149    req: Request<Body>,
150) -> (StatusCode, HeaderMap, Body) {
151    let is_reload = req.uri().query().is_some_and(|x| x == "reload");
152
153    // Get the path and mime of the static file.
154    let uri_path = req.uri().path();
155    // Avoid [directory traversal attack](https://en.wikipedia.org/wiki/Directory_traversal_attack).
156    if uri_path.starts_with("//") {
157        let redirect = format!("/{}", uri_path.trim_start_matches("/"));
158        let mut headers = HeaderMap::new();
159        headers.append(header::LOCATION, HeaderValue::from_str(&redirect).unwrap());
160        return (StatusCode::TEMPORARY_REDIRECT, headers, Body::empty());
161    }
162    let mut path = state.root.join(&uri_path[1..]);
163    let is_accessing_dir = path.is_dir();
164    if is_accessing_dir {
165        if !uri_path.ends_with('/') {
166            // redirect so parent links work correctly
167            let redirect = format!("{uri_path}/");
168            let mut headers = HeaderMap::new();
169            headers.append(header::LOCATION, HeaderValue::from_str(&redirect).unwrap());
170            return (StatusCode::TEMPORARY_REDIRECT, headers, Body::empty());
171        }
172        path.push("index.html");
173    }
174    let mime = mime_guess::from_path(&path).first_or_text_plain();
175
176    let mut headers = HeaderMap::new();
177    let content_type = if mime.type_() == mime::TEXT {
178        format!("{}; charset=utf-8", mime.as_ref())
179    } else {
180        mime.as_ref().to_string()
181    };
182    headers.append(
183        header::CONTENT_TYPE,
184        HeaderValue::from_str(&content_type).unwrap(),
185    );
186
187    if state.auto_ignore {
188        match is_ignored(&state.root, &path) {
189            Ok(ignored) => {
190                if ignored {
191                    let err_msg =
192                        "Unable to access ignored or hidden file, because `--ignore` is enabled";
193                    let body = generate_error_body(err_msg, state.hard_reload, is_reload);
194
195                    return (StatusCode::FORBIDDEN, HeaderMap::new(), body);
196                }
197            }
198            Err(err) => {
199                let err_msg = format!("Failed to check ignore files: {err}");
200                let body = generate_error_body(&err_msg, state.hard_reload, is_reload);
201                log::error!("{err_msg}");
202
203                return (StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new(), body);
204            }
205        }
206    }
207
208    // Read the file.
209    let mut file = match fs::read(&path) {
210        Ok(file) => file,
211        Err(err) => {
212            let status_code = match err.kind() {
213                ErrorKind::NotFound => {
214                    if state.index_listing && is_accessing_dir {
215                        let script = format_script(state.hard_reload, is_reload, false);
216                        let html = index_html(
217                            uri_path,
218                            &script,
219                            &get_index_listing(uri_path, &state.root, state.auto_ignore),
220                        );
221                        return (StatusCode::OK, headers, html);
222                    }
223                    StatusCode::NOT_FOUND
224                }
225                _ => StatusCode::INTERNAL_SERVER_ERROR,
226            };
227            match path.to_str() {
228                Some(path) => log::warn!("Failed to read \"{path}\": {err}"),
229                None => log::warn!("Failed to read file with invalid path: {err}"),
230            }
231            return (
232                status_code,
233                headers,
234                if mime == "text/html" {
235                    generate_error_body(&err.to_string(), state.hard_reload, is_reload)
236                } else {
237                    Body::from(err.to_string())
238                },
239            );
240        }
241    };
242
243    // Construct the response.
244    if mime == "text/html" {
245        let text = match String::from_utf8(file) {
246            Ok(text) => text,
247            Err(err) => {
248                log::error!("Failed to read {path:?} as utf-8: {err}");
249                let html = generate_error_body(&err.to_string(), state.hard_reload, is_reload);
250                return (StatusCode::INTERNAL_SERVER_ERROR, headers, html);
251            }
252        };
253        let script = format_script(state.hard_reload, is_reload, false);
254        file = format!("{text}{script}").into_bytes();
255    } else if state.hard_reload {
256        // allow client to cache assets for a smoother reload.
257        // client handles preloading to refresh cache before reloading.
258        headers.append(
259            header::CACHE_CONTROL,
260            HeaderValue::from_str("max-age=30").unwrap(),
261        );
262    }
263
264    (StatusCode::OK, headers, Body::from(file))
265}
266
267/// Inject the address into the websocket script and wrap it in a script tag
268fn format_script(hard_reload: bool, is_reload: bool, is_error: bool) -> String {
269    match (is_reload, is_error) {
270        // successful reload, inject the reload payload
271        (true, false) => format!("<script>{RELOAD_PAYLOAD}</script>"),
272        // failed reload, don't inject anything so the client polls again
273        (true, true) => String::new(),
274        // normal connection, inject the websocket client
275        _ => {
276            let hard = if hard_reload { "true" } else { "false" };
277            format!(r#"<script>{WEBSOCKET_FUNCTION}({hard})</script>"#)
278        }
279    }
280}
281
282fn generate_error_body(err_msg: &str, hard_reload: bool, is_reload: bool) -> Body {
283    let script = format_script(hard_reload, is_reload, true);
284    error_html(&script, err_msg)
285}