use reinhardt_core::security::xss::escape_html_attr;
#[derive(Debug, Clone, Default)]
pub struct Media {
css: std::collections::HashMap<String, Vec<String>>,
js: Vec<String>,
}
impl Media {
pub fn new() -> Self {
Self::default()
}
pub fn add_css(&mut self, media_type: impl Into<String>, path: impl Into<String>) {
let media_type = media_type.into();
let path = path.into();
self.css.entry(media_type).or_default().push(path);
}
pub fn add_js(&mut self, path: impl Into<String>) {
self.js.push(path.into());
}
pub fn merge(&mut self, other: &Media) {
for (media_type, files) in &other.css {
let entry = self.css.entry(media_type.clone()).or_default();
for file in files {
if !entry.contains(file) {
entry.push(file.clone());
}
}
}
for file in &other.js {
if !self.js.contains(file) {
self.js.push(file.clone());
}
}
}
pub fn render_css(&self) -> String {
let mut output = String::new();
let mut media_types: Vec<_> = self.css.keys().collect();
media_types.sort();
for media_type in media_types {
if let Some(files) = self.css.get(media_type.as_str()) {
for file in files {
output.push_str(&format!(
"<link rel=\"stylesheet\" href=\"{}\" media=\"{}\">\n",
escape_html_attr(file),
escape_html_attr(media_type)
));
}
}
}
output
}
pub fn render_js(&self) -> String {
let mut output = String::new();
for file in &self.js {
output.push_str(&format!(
"<script src=\"{}\"></script>\n",
escape_html_attr(file)
));
}
output
}
pub fn get_css_files(&self) -> Vec<(String, String)> {
let mut files = Vec::new();
for (media_type, paths) in &self.css {
for path in paths {
files.push((media_type.clone(), path.clone()));
}
}
files
}
pub fn get_js_files(&self) -> Vec<String> {
self.js.clone()
}
}
pub trait HasMedia {
fn media(&self) -> Media;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_css() {
let mut media = Media::new();
media.add_css("all", "css/style.css");
media.add_css("print", "css/print.css");
assert_eq!(media.css.get("all").unwrap()[0], "css/style.css");
assert_eq!(media.css.get("print").unwrap()[0], "css/print.css");
}
#[test]
fn test_add_js() {
let mut media = Media::new();
media.add_js("js/script1.js");
media.add_js("js/script2.js");
assert_eq!(media.js[0], "js/script1.js");
assert_eq!(media.js[1], "js/script2.js");
}
#[test]
fn test_merge_media() {
let mut media1 = Media::new();
media1.add_css("all", "css/base.css");
media1.add_js("js/base.js");
let mut media2 = Media::new();
media2.add_css("all", "css/forms.css");
media2.add_js("js/forms.js");
media1.merge(&media2);
assert_eq!(media1.css.get("all").unwrap().len(), 2);
assert_eq!(media1.js.len(), 2);
}
#[test]
fn test_merge_avoids_duplicates() {
let mut media1 = Media::new();
media1.add_css("all", "css/common.css");
media1.add_js("js/common.js");
let mut media2 = Media::new();
media2.add_css("all", "css/common.css"); media2.add_js("js/common.js");
media1.merge(&media2);
assert_eq!(media1.css.get("all").unwrap().len(), 1);
assert_eq!(media1.js.len(), 1);
}
#[test]
fn test_render_css() {
let mut media = Media::new();
media.add_css("all", "/static/css/style.css");
media.add_css("print", "/static/css/print.css");
let html = media.render_css();
assert!(html.contains("<link rel=\"stylesheet\""));
assert!(html.contains("href=\"/static/css/style.css\""));
assert!(html.contains("media=\"all\""));
assert!(html.contains("href=\"/static/css/print.css\""));
assert!(html.contains("media=\"print\""));
}
#[test]
fn test_render_js() {
let mut media = Media::new();
media.add_js("/static/js/app.js");
media.add_js("/static/js/widgets.js");
let html = media.render_js();
assert!(html.contains("<script src=\"/static/js/app.js\"></script>"));
assert!(html.contains("<script src=\"/static/js/widgets.js\"></script>"));
}
#[test]
fn test_has_media_trait() {
struct TestWidget;
impl HasMedia for TestWidget {
fn media(&self) -> Media {
let mut media = Media::new();
media.add_css("all", "widget.css");
media.add_js("widget.js");
media
}
}
let widget = TestWidget;
let media = widget.media();
assert_eq!(media.css.get("all").unwrap()[0], "widget.css");
assert_eq!(media.js[0], "widget.js");
}
#[test]
fn test_get_css_files() {
let mut media = Media::new();
media.add_css("all", "css/a.css");
media.add_css("print", "css/b.css");
let files = media.get_css_files();
assert_eq!(files.len(), 2);
}
#[test]
fn test_get_js_files() {
let mut media = Media::new();
media.add_js("js/a.js");
media.add_js("js/b.js");
let files = media.get_js_files();
assert_eq!(files, vec!["js/a.js", "js/b.js"]);
}
#[test]
fn test_render_css_escapes_xss_in_paths() {
let mut media = Media::new();
media.add_css("all", "\"><script>alert(1)</script><link href=\"");
let html = media.render_css();
assert!(
!html.contains("<script>"),
"CSS rendering must not contain unescaped script tags. Got: {}",
html
);
assert!(
html.contains("""),
"CSS rendering should contain escaped quotes. Got: {}",
html
);
}
#[test]
fn test_render_js_escapes_xss_in_paths() {
let mut media = Media::new();
media.add_js("\"><script>alert(1)</script><script src=\"");
let html = media.render_js();
assert!(
!html.contains("<script>alert(1)</script>"),
"JS rendering must not contain unescaped script tags. Got: {}",
html
);
assert!(
html.contains("""),
"JS rendering should contain escaped quotes. Got: {}",
html
);
}
}