use serde::Serialize;
use std::collections::HashMap;
use crate::state::state_update_script;
fn extract_fragments(html: &str) -> HashMap<String, String> {
let mut fragments = HashMap::new();
let bytes = html.as_bytes();
let len = bytes.len();
let mut i = 0;
while i < len {
if bytes[i] == b'<' && i + 1 < len && bytes[i + 1] != b'/' && bytes[i + 1] != b'!' {
let tag_start = i;
let mut j = i + 1;
while j < len && bytes[j] == b' ' {
j += 1;
}
let tag_name_start = j;
while j < len && bytes[j] != b' ' && bytes[j] != b'>' && bytes[j] != b'/' {
j += 1;
}
let tag_name = &html[tag_name_start..j];
if tag_name.is_empty() {
i += 1;
continue;
}
let mut k = j;
let mut self_closing = false;
while k < len {
match bytes[k] {
b'"' | b'\'' => {
let quote = bytes[k];
k += 1;
while k < len && bytes[k] != quote {
k += 1;
}
}
b'/' if k + 1 < len && bytes[k + 1] == b'>' => {
self_closing = true;
k += 2; break;
}
b'>' => {
k += 1;
break;
}
_ => {}
}
k += 1;
}
let opening_tag = &html[tag_start..k];
let id = extract_id_from_tag(opening_tag);
if let Some(id) = id {
if self_closing || is_void_element(tag_name) {
fragments.insert(id, html[tag_start..k].to_string());
} else {
let mut depth = 1u32;
let mut end = k;
while end < len && depth > 0 {
if bytes[end] == b'<' {
if end + 1 < len && bytes[end + 1] == b'/' {
let close_name_start = end + 2;
let mut close_end = close_name_start;
while close_end < len
&& bytes[close_end] != b'>'
&& bytes[close_end] != b' '
{
close_end += 1;
}
let close_name = &html[close_name_start..close_end];
if close_name.eq_ignore_ascii_case(tag_name) {
depth -= 1;
}
while close_end < len && bytes[close_end] != b'>' {
close_end += 1;
}
end = close_end + 1;
} else if end + 1 < len && bytes[end + 1] != b'!' {
let inner_name_start = end + 1;
let mut inner_end = inner_name_start;
while inner_end < len
&& bytes[inner_end] != b' '
&& bytes[inner_end] != b'>'
&& bytes[inner_end] != b'/'
{
inner_end += 1;
}
let inner_name = &html[inner_name_start..inner_end];
let mut sc = inner_end;
while sc < len && bytes[sc] != b'>' {
sc += 1;
}
let is_sc = sc > 0 && bytes[sc - 1] == b'/';
if inner_name.eq_ignore_ascii_case(tag_name)
&& !is_sc
&& !is_void_element(inner_name)
{
depth += 1;
}
end = sc + 1;
} else {
end += 1;
}
} else {
end += 1;
}
}
fragments.insert(id, html[tag_start..end].to_string());
i = end;
continue;
}
}
i = k;
} else {
i += 1;
}
}
fragments
}
fn extract_id_from_tag(tag: &str) -> Option<String> {
let tag_lower = tag.to_ascii_lowercase();
let id_patterns = [" id=\"", " id='", " id =\"", " id ='"];
for pattern in &id_patterns {
if let Some(start) = tag_lower.find(pattern) {
let value_start = start + pattern.len();
let quote = if pattern.ends_with('"') { '"' } else { '\'' };
if let Some(end) = tag[value_start..].find(quote) {
return Some(tag[value_start..value_start + end].to_string());
}
}
}
None
}
const VOID_ELEMENTS: &[&str] = &[
"area", "base", "br", "col", "embed", "hr", "img", "input", "link", "meta", "param",
"source", "track", "wbr",
];
fn is_void_element(tag: &str) -> bool {
VOID_ELEMENTS
.iter()
.any(|v| tag.eq_ignore_ascii_case(v))
}
fn inject_oob_attr(html: &str) -> String {
if let Some(pos) = html.find('>') {
let mut result = String::with_capacity(html.len() + 22);
if pos > 0 && html.as_bytes()[pos - 1] == b'/' {
result.push_str(&html[..pos - 1]);
result.push_str(" hx-swap-oob=\"true\"/>");
result.push_str(&html[pos + 1..]);
} else {
result.push_str(&html[..pos]);
result.push_str(" hx-swap-oob=\"true\"");
result.push_str(&html[pos..]);
}
result
} else {
html.to_string()
}
}
pub fn diff_html<T: Serialize>(old_html: &str, new_html: &str, new_state: &T) -> String {
let old_frags = extract_fragments(old_html);
let new_frags = extract_fragments(new_html);
let mut out = String::new();
for (id, new_fragment) in &new_frags {
let changed = match old_frags.get(id) {
Some(old_fragment) => old_fragment != new_fragment,
None => true,
};
if changed {
out.push_str(&inject_oob_attr(new_fragment));
}
}
out.push_str(&state_update_script(new_state));
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn inject_oob_basic() {
let html = "<span id=\"count\" class=\"x\">5</span>";
let result = inject_oob_attr(html);
assert_eq!(
result,
"<span id=\"count\" class=\"x\" hx-swap-oob=\"true\">5</span>"
);
}
#[test]
fn inject_oob_self_closing() {
let html = "<input id=\"x\" type=\"hidden\"/>";
let result = inject_oob_attr(html);
assert_eq!(
result,
"<input id=\"x\" type=\"hidden\" hx-swap-oob=\"true\"/>"
);
}
#[test]
fn extract_fragments_basic() {
let html = r#"<div><span id="count">5</span><ul id="list"><li>a</li></ul></div>"#;
let frags = extract_fragments(html);
assert_eq!(frags.get("count").unwrap(), "<span id=\"count\">5</span>");
assert_eq!(
frags.get("list").unwrap(),
"<ul id=\"list\"><li>a</li></ul>"
);
}
#[test]
fn extract_nested_same_tag() {
let html = r#"<div id="outer"><div><div>inner</div></div></div>"#;
let frags = extract_fragments(html);
assert_eq!(
frags.get("outer").unwrap(),
"<div id=\"outer\"><div><div>inner</div></div></div>"
);
}
#[test]
fn extract_void_element() {
let html = r#"<input id="myinput" type="text"><p>after</p>"#;
let frags = extract_fragments(html);
assert_eq!(
frags.get("myinput").unwrap(),
"<input id=\"myinput\" type=\"text\">"
);
}
#[test]
fn diff_html_returns_only_changed() {
use serde::Serialize;
#[derive(Serialize)]
struct S {
v: u32,
}
let old = r#"<div><span id="a">old</span><span id="b">same</span></div>"#;
let new = r#"<div><span id="a">new</span><span id="b">same</span></div>"#;
let result = diff_html(old, new, &S { v: 1 });
assert!(result.contains("new"));
assert!(result.contains("hx-swap-oob"));
assert!(!result.contains(">same<"));
}
}