workers-rsx 0.1.0

A JSX-like templating engine for Cloudflare Workers
Documentation
use serde::Serialize;
use std::collections::HashMap;

use crate::state::state_update_script;

/// Extract all top-level elements with `id="..."` from an HTML string.
/// Returns a map of id → full element HTML (including the tag itself).
///
/// This is a lightweight scanner, not a full HTML parser. It handles:
/// - `id="value"` and `id='value'` attributes
/// - Nested elements with proper depth tracking
/// - Self-closing tags (void elements)
/// - Quoted attributes containing `>` characters
fn extract_fragments(html: &str) -> HashMap<String, String> {
    let mut fragments = HashMap::new();
    let bytes = html.as_bytes();
    let len = bytes.len();
    let mut i = 0;

    while i < len {
        // Find next '<' that starts a tag (not a closing tag)
        if bytes[i] == b'<' && i + 1 < len && bytes[i + 1] != b'/' && bytes[i + 1] != b'!' {
            // Extract the tag name
            let tag_start = i;
            let mut j = i + 1;
            // skip whitespace
            while j < len && bytes[j] == b' ' {
                j += 1;
            }
            let tag_name_start = j;
            while j < len && bytes[j] != b' ' && bytes[j] != b'>' && bytes[j] != b'/' {
                j += 1;
            }
            let tag_name = &html[tag_name_start..j];

            if tag_name.is_empty() {
                i += 1;
                continue;
            }

            // Find the end of the opening tag, respecting quoted attributes
            let mut k = j;
            let mut self_closing = false;
            while k < len {
                match bytes[k] {
                    b'"' | b'\'' => {
                        let quote = bytes[k];
                        k += 1;
                        while k < len && bytes[k] != quote {
                            k += 1;
                        }
                    }
                    b'/' if k + 1 < len && bytes[k + 1] == b'>' => {
                        self_closing = true;
                        k += 2; // skip />
                        break;
                    }
                    b'>' => {
                        k += 1;
                        break;
                    }
                    _ => {}
                }
                k += 1;
            }

            let opening_tag = &html[tag_start..k];

            // Check if this tag has an id attribute
            let id = extract_id_from_tag(opening_tag);

            if let Some(id) = id {
                if self_closing || is_void_element(tag_name) {
                    fragments.insert(id, html[tag_start..k].to_string());
                } else {
                    // Find matching closing tag with depth tracking
                    let mut depth = 1u32;
                    let mut end = k;
                    while end < len && depth > 0 {
                        if bytes[end] == b'<' {
                            if end + 1 < len && bytes[end + 1] == b'/' {
                                // Closing tag — check if it matches
                                let close_name_start = end + 2;
                                let mut close_end = close_name_start;
                                while close_end < len
                                    && bytes[close_end] != b'>'
                                    && bytes[close_end] != b' '
                                {
                                    close_end += 1;
                                }
                                let close_name = &html[close_name_start..close_end];
                                if close_name.eq_ignore_ascii_case(tag_name) {
                                    depth -= 1;
                                }
                                // Skip to >
                                while close_end < len && bytes[close_end] != b'>' {
                                    close_end += 1;
                                }
                                end = close_end + 1;
                            } else if end + 1 < len && bytes[end + 1] != b'!' {
                                // Opening tag — check if same tag name (for depth)
                                let inner_name_start = end + 1;
                                let mut inner_end = inner_name_start;
                                while inner_end < len
                                    && bytes[inner_end] != b' '
                                    && bytes[inner_end] != b'>'
                                    && bytes[inner_end] != b'/'
                                {
                                    inner_end += 1;
                                }
                                let inner_name = &html[inner_name_start..inner_end];
                                // Check for self-closing
                                let mut sc = inner_end;
                                while sc < len && bytes[sc] != b'>' {
                                    sc += 1;
                                }
                                let is_sc = sc > 0 && bytes[sc - 1] == b'/';
                                if inner_name.eq_ignore_ascii_case(tag_name)
                                    && !is_sc
                                    && !is_void_element(inner_name)
                                {
                                    depth += 1;
                                }
                                end = sc + 1;
                            } else {
                                end += 1;
                            }
                        } else {
                            end += 1;
                        }
                    }
                    fragments.insert(id, html[tag_start..end].to_string());
                    i = end;
                    continue;
                }
            }

            i = k;
        } else {
            i += 1;
        }
    }

    fragments
}

/// Extract the `id` value from an opening HTML tag string.
fn extract_id_from_tag(tag: &str) -> Option<String> {
    // Look for id="..." or id='...'
    let tag_lower = tag.to_ascii_lowercase();
    let id_patterns = [" id=\"", " id='", " id =\"", " id ='"];
    for pattern in &id_patterns {
        if let Some(start) = tag_lower.find(pattern) {
            let value_start = start + pattern.len();
            let quote = if pattern.ends_with('"') { '"' } else { '\'' };
            if let Some(end) = tag[value_start..].find(quote) {
                return Some(tag[value_start..value_start + end].to_string());
            }
        }
    }
    None
}

const VOID_ELEMENTS: &[&str] = &[
    "area", "base", "br", "col", "embed", "hr", "img", "input", "link", "meta", "param",
    "source", "track", "wbr",
];

fn is_void_element(tag: &str) -> bool {
    VOID_ELEMENTS
        .iter()
        .any(|v| tag.eq_ignore_ascii_case(v))
}

/// Inject `hx-swap-oob="true"` into the first opening HTML tag.
fn inject_oob_attr(html: &str) -> String {
    if let Some(pos) = html.find('>') {
        let mut result = String::with_capacity(html.len() + 22);
        if pos > 0 && html.as_bytes()[pos - 1] == b'/' {
            result.push_str(&html[..pos - 1]);
            result.push_str(" hx-swap-oob=\"true\"/>");
            result.push_str(&html[pos + 1..]);
        } else {
            result.push_str(&html[..pos]);
            result.push_str(" hx-swap-oob=\"true\"");
            result.push_str(&html[pos..]);
        }
        result
    } else {
        html.to_string()
    }
}

/// Diff two full HTML pages and return only the changed id'd elements
/// with `hx-swap-oob="true"` injected, plus an updated state `<input>`.
///
/// Both `old_html` and `new_html` should be the full rendered page.
/// The function extracts all elements with `id` attributes, compares them,
/// and returns only the ones that changed.
pub fn diff_html<T: Serialize>(old_html: &str, new_html: &str, new_state: &T) -> String {
    let old_frags = extract_fragments(old_html);
    let new_frags = extract_fragments(new_html);

    let mut out = String::new();
    for (id, new_fragment) in &new_frags {
        let changed = match old_frags.get(id) {
            Some(old_fragment) => old_fragment != new_fragment,
            None => true,
        };
        if changed {
            out.push_str(&inject_oob_attr(new_fragment));
        }
    }
    out.push_str(&state_update_script(new_state));
    out
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn inject_oob_basic() {
        let html = "<span id=\"count\" class=\"x\">5</span>";
        let result = inject_oob_attr(html);
        assert_eq!(
            result,
            "<span id=\"count\" class=\"x\" hx-swap-oob=\"true\">5</span>"
        );
    }

    #[test]
    fn inject_oob_self_closing() {
        let html = "<input id=\"x\" type=\"hidden\"/>";
        let result = inject_oob_attr(html);
        assert_eq!(
            result,
            "<input id=\"x\" type=\"hidden\" hx-swap-oob=\"true\"/>"
        );
    }

    #[test]
    fn extract_fragments_basic() {
        let html = r#"<div><span id="count">5</span><ul id="list"><li>a</li></ul></div>"#;
        let frags = extract_fragments(html);
        assert_eq!(frags.get("count").unwrap(), "<span id=\"count\">5</span>");
        assert_eq!(
            frags.get("list").unwrap(),
            "<ul id=\"list\"><li>a</li></ul>"
        );
    }

    #[test]
    fn extract_nested_same_tag() {
        let html = r#"<div id="outer"><div><div>inner</div></div></div>"#;
        let frags = extract_fragments(html);
        assert_eq!(
            frags.get("outer").unwrap(),
            "<div id=\"outer\"><div><div>inner</div></div></div>"
        );
    }

    #[test]
    fn extract_void_element() {
        let html = r#"<input id="myinput" type="text"><p>after</p>"#;
        let frags = extract_fragments(html);
        assert_eq!(
            frags.get("myinput").unwrap(),
            "<input id=\"myinput\" type=\"text\">"
        );
    }

    #[test]
    fn diff_html_returns_only_changed() {
        use serde::Serialize;

        #[derive(Serialize)]
        struct S {
            v: u32,
        }

        let old = r#"<div><span id="a">old</span><span id="b">same</span></div>"#;
        let new = r#"<div><span id="a">new</span><span id="b">same</span></div>"#;
        let result = diff_html(old, new, &S { v: 1 });
        assert!(result.contains("new"));
        assert!(result.contains("hx-swap-oob"));
        assert!(!result.contains(">same<"));
    }
}