use base64::{prelude::BASE64_STANDARD, Engine};
use dom_query::Document;
use http::{
header::{HeaderValue, CONTENT_SECURITY_POLICY, CONTENT_TYPE},
Response as HttpResponse,
};
use sha2::{Digest, Sha256};
use std::borrow::Cow;
use crate::InitializationScript;
pub fn inject_scripts_into_html(
mut response: HttpResponse<Cow<'static, [u8]>>,
scripts: &[InitializationScript],
) -> HttpResponse<Cow<'static, [u8]>> {
if scripts.is_empty() {
return response;
}
let should_inject_scripts = response
.headers()
.get(CONTENT_TYPE)
.and_then(|content_type| content_type.to_str().ok())
.map(|content_type_str| content_type_str.to_lowercase().starts_with("text/html"))
.unwrap_or_default();
if !should_inject_scripts {
return response;
}
let document = Document::from(String::from_utf8_lossy(response.body()).as_ref());
let csp = response.headers_mut().get_mut(CONTENT_SECURITY_POLICY);
let head = document.head().unwrap_or_else(|| {
let html = document.html_root();
let head = document.tree.new_element("head");
html.prepend_child(&head);
head
});
let mut hashes = Vec::new();
for script in scripts.iter().rev().map(|s| &s.script) {
let script_tag = document.tree.new_element("script");
script_tag.set_text(script.as_str());
head.prepend_child(&script_tag);
if csp.is_some() {
hashes.push(hash_script(script));
}
}
if let Some(csp) = csp {
let csp_string = csp.to_str().unwrap().to_string();
let csp_string = if csp_string.contains("script-src") {
csp_string.replace("script-src", &format!("script-src {}", hashes.join(" ")))
} else {
format!("{csp_string} script-src {}", hashes.join(" "))
};
*csp = HeaderValue::from_str(&csp_string).unwrap();
}
*response.body_mut() = Cow::Owned(document.html().as_bytes().to_vec());
response
}
fn hash_script(script: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(script);
let hash = hasher.finalize();
format!("'sha256-{}'", BASE64_STANDARD.encode(hash))
}
#[cfg(test)]
mod tests {
use super::*;
use http::StatusCode;
#[test]
fn test_no_scripts_returns_original_response() {
let body = "<html><head></head><body>Test</body></html>";
let result = run(body, "text/html", vec![]);
assert_eq!(result, body);
}
#[test]
fn test_non_html_response_not_modified() {
let body = r#"{"key": "value"}"#;
let scripts = vec!["console.log('test');".to_string()];
let result = run(body, "application/json", scripts);
assert_eq!(result, body);
}
#[test]
fn test_inject_single_script() {
let body = "<html><head></head><body>Content</body></html>";
let scripts = vec!["console.log('injected');".to_string()];
let result = run(body, "text/html", scripts);
assert_eq!(
result,
"<html><head><script>console.log('injected');</script></head><body>Content</body></html>"
);
}
#[test]
fn test_inject_multiple_scripts() {
let body = "<html><head></head><body>Content</body></html>";
let scripts = vec![
"var first = 1;".to_owned(),
"let second = 2;".to_owned(),
"const third = 3;".to_owned(),
"window.test = () => console.log('test');".to_owned(),
];
let result = run(body, "text/html", scripts);
assert_eq!(
result,
"<html><head><script>var first = 1;</script><script>let second = 2;</script><script>const third = 3;</script><script>window.test = () => console.log('test');</script></head><body>Content</body></html>"
);
}
#[test]
fn test_inject_script_creates_head_if_missing() {
let body = "<html><body>Content</body></html>";
let scripts = vec!["console.log('test');".to_string()];
let result = run(body, "text/html", scripts);
assert_eq!(
result,
"<html><head><script>console.log('test');</script></head><body>Content</body></html>"
);
}
#[test]
fn test_inject_script_creates_html_structure_if_missing() {
let body = "Just some text";
let scripts = vec!["console.log('test');".to_string()];
let result = run(body, "text/html", scripts);
assert_eq!(
result,
"<html><head><script>console.log('test');</script></head><body>Just some text</body></html>"
);
}
#[test]
fn test_csp_header_updated_with_script_hashes() {
let body = "<html><head></head><body>Content</body></html>";
let mut response = create_response(body, "text/html");
response.headers_mut().insert(
CONTENT_SECURITY_POLICY,
HeaderValue::from_static("default-src 'self'"),
);
let script_code = "console.log('test');";
let scripts = vec![script_code.to_string()];
let scripts: Vec<InitializationScript> = scripts
.into_iter()
.map(|script| InitializationScript {
script,
for_main_frame_only: true,
})
.collect();
let result = inject_scripts_into_html(response, &scripts);
let result_body = String::from_utf8_lossy(result.body()).to_string();
let csp = result.headers().get(CONTENT_SECURITY_POLICY).unwrap();
let csp_str = csp.to_str().unwrap();
assert_eq!(
result_body,
"<html><head><script>console.log('test');</script></head><body>Content</body></html>"
);
assert_eq!(
csp_str,
"default-src 'self' script-src 'sha256-3x8DE279hr8o/Aq0dEdH4WApIwn5rbRKhugPzn6Bofw='"
);
}
fn create_response(body: &str, content_type: &'static str) -> HttpResponse<Cow<'static, [u8]>> {
let mut response = HttpResponse::builder()
.status(StatusCode::OK)
.body(Cow::Owned(body.as_bytes().to_vec()))
.unwrap();
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static(content_type));
response
}
fn run(body: &str, content_type: &'static str, scripts: Vec<String>) -> String {
let response = create_response(body, content_type);
let scripts: Vec<InitializationScript> = scripts
.into_iter()
.map(|script| InitializationScript {
script,
for_main_frame_only: true,
})
.collect();
let result = inject_scripts_into_html(response, &scripts);
String::from_utf8_lossy(result.body()).to_string()
}
}