use crate::rule_config_serde::RuleConfig;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
pub const GFM_DISALLOWED_TAGS: &[&str] = &[
"title",
"textarea",
"style",
"xmp",
"iframe",
"noembed",
"noframes",
"script",
"plaintext",
];
pub const SAFE_FIXABLE_TAGS: &[&str] = &[
"em", "i", "strong", "b", "code", "br", "hr", "a", "img", ];
pub const ATTRIBUTE_FIXABLE_TAGS: &[&str] = &["a", "img"];
pub const SAFE_URL_SCHEMES: &[&str] = &["http://", "https://", "mailto:", "tel:", "ftp://", "ftps://"];
pub const DANGEROUS_URL_SCHEMES: &[&str] = &["javascript:", "vbscript:", "data:", "about:", "blob:", "file:"];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum BrStyle {
#[default]
TrailingSpaces,
Backslash,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum MD033FixMode {
#[default]
Conservative,
Relaxed,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MD033Config {
#[serde(default, rename = "allowed-elements", alias = "allowed_elements", alias = "allowed")]
pub allowed: Vec<String>,
#[serde(
default,
rename = "disallowed-elements",
alias = "disallowed_elements",
alias = "disallowed"
)]
pub disallowed: Vec<String>,
#[serde(default)]
pub fix: bool,
#[serde(default, rename = "fix-mode", alias = "fix_mode")]
pub fix_mode: MD033FixMode,
#[serde(
default = "default_drop_attributes",
rename = "drop-attributes",
alias = "drop_attributes"
)]
pub drop_attributes: Vec<String>,
#[serde(
default = "default_strip_wrapper_elements",
rename = "strip-wrapper-elements",
alias = "strip_wrapper_elements"
)]
pub strip_wrapper_elements: Vec<String>,
#[serde(default, rename = "br-style", alias = "br_style")]
pub br_style: BrStyle,
}
impl Default for MD033Config {
fn default() -> Self {
Self {
allowed: Vec::new(),
disallowed: Vec::new(),
fix: false,
fix_mode: MD033FixMode::default(),
drop_attributes: default_drop_attributes(),
strip_wrapper_elements: default_strip_wrapper_elements(),
br_style: BrStyle::default(),
}
}
}
fn default_drop_attributes() -> Vec<String> {
vec!["target", "rel", "width", "height", "align", "class", "id", "style"]
.into_iter()
.map(ToString::to_string)
.collect()
}
fn default_strip_wrapper_elements() -> Vec<String> {
vec!["p".to_string()]
}
impl MD033Config {
pub fn allowed_set(&self) -> HashSet<String> {
self.allowed.iter().map(|s| s.to_lowercase()).collect()
}
pub fn disallowed_set(&self) -> HashSet<String> {
let mut set = HashSet::new();
for tag in &self.disallowed {
let lower = tag.to_lowercase();
if lower == "gfm" {
for gfm_tag in GFM_DISALLOWED_TAGS {
set.insert((*gfm_tag).to_string());
}
} else {
set.insert(lower);
}
}
set
}
pub fn is_disallowed_mode(&self) -> bool {
!self.disallowed.is_empty()
}
pub fn is_safe_fixable_tag(tag_name: &str) -> bool {
SAFE_FIXABLE_TAGS.contains(&tag_name.to_ascii_lowercase().as_str())
}
pub fn requires_attribute_extraction(tag_name: &str) -> bool {
ATTRIBUTE_FIXABLE_TAGS.contains(&tag_name.to_ascii_lowercase().as_str())
}
pub fn drop_attributes_set(&self) -> HashSet<String> {
self.drop_attributes.iter().map(|s| s.to_lowercase()).collect()
}
pub fn strip_wrapper_elements_set(&self) -> HashSet<String> {
self.strip_wrapper_elements.iter().map(|s| s.to_lowercase()).collect()
}
fn decode_percent_encoding(url: &str) -> String {
let mut result = String::with_capacity(url.len());
let mut chars = url.chars().peekable();
while let Some(c) = chars.next() {
if c == '%' {
let hex: String = chars.by_ref().take(2).collect();
if hex.len() == 2
&& let Ok(byte) = u8::from_str_radix(&hex, 16)
{
result.push(byte as char);
continue;
}
result.push('%');
result.push_str(&hex);
} else {
result.push(c);
}
}
result
}
fn decode_html_entities(url: &str) -> String {
url.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace(""", "\"")
.replace("'", "'")
.replace(":", ":")
.replace(":", ":")
.replace(":", ":")
.replace("/", "/")
.replace("/", "/")
.replace("/", "/")
}
pub fn is_safe_url(url: &str) -> bool {
let decoded = Self::decode_percent_encoding(url);
let decoded = Self::decode_html_entities(&decoded);
let url_lower = decoded.to_ascii_lowercase();
let trimmed = url_lower.trim();
if trimmed.is_empty() {
return true;
}
for scheme in DANGEROUS_URL_SCHEMES {
if trimmed.starts_with(scheme) {
return false;
}
}
let dangerous_prefixes: &[&str] = &["javascript", "vbscript", "data", "about", "blob", "file"];
for prefix in dangerous_prefixes {
if let Some(rest) = trimmed.strip_prefix(prefix) {
if rest.starts_with(':') || rest.starts_with("%3a") || rest.starts_with("&#") {
return false;
}
}
}
if trimmed.starts_with('/') || trimmed.starts_with('.') || trimmed.starts_with('#') || trimmed.starts_with('?')
{
return true;
}
for scheme in SAFE_URL_SCHEMES {
if trimmed.starts_with(scheme) {
return true;
}
}
if trimmed.starts_with("//") {
return true;
}
if let Some(colon_pos) = trimmed.find(':') {
if let Some(slash_pos) = trimmed.find('/') {
if colon_pos > slash_pos {
return true;
}
}
false
} else {
true
}
}
}
impl RuleConfig for MD033Config {
const RULE_NAME: &'static str = "MD033";
}