use log::info;
use std::{
borrow::Cow,
collections::{HashMap, HashSet, hash_map::Entry},
path::Path,
sync::Arc,
};
use crate::{
Result,
extract::{html::html5gum::extract_html_fragments, markdown::extract_markdown_fragments},
types::{ErrorKind, FileType},
};
use percent_encoding::percent_decode_str;
use tokio::{fs, sync::Mutex};
use url::Url;
pub(crate) struct FragmentInput<'a> {
pub content: Cow<'a, str>,
pub file_type: FileType,
}
impl FragmentInput<'_> {
pub(crate) async fn from_path(path: &Path) -> Result<Self> {
let content = fs::read_to_string(path)
.await
.map_err(|err| ErrorKind::ReadFileInput(err, path.to_path_buf()))?;
let file_type = FileType::from(path);
Ok(Self {
content: Cow::Owned(content),
file_type,
})
}
}
struct FragmentBuilder {
variants: Vec<String>,
decoded: Vec<String>,
}
impl FragmentBuilder {
fn new(fragment: &str, url: &Url, file_type: FileType) -> Result<Self> {
let mut variants = vec![fragment.into()];
if url
.host_str()
.is_some_and(|host| host.ends_with("github.com"))
{
variants.push(format!("user-content-{fragment}"));
}
let mut decoded = Vec::new();
for frag in &variants {
let mut require_alloc = false;
let mut fragment_decoded: Cow<'_, str> = match percent_decode_str(frag).decode_utf8()? {
Cow::Borrowed(s) => s.into(),
Cow::Owned(s) => {
require_alloc = true;
s.into()
}
};
if file_type == FileType::Markdown {
let lowercase = fragment_decoded.to_lowercase();
if lowercase != fragment_decoded {
fragment_decoded = lowercase.into();
require_alloc = true;
}
}
if require_alloc {
decoded.push(fragment_decoded.into());
}
}
Ok(Self { variants, decoded })
}
fn any_matches(&self, fragments: &HashSet<String>) -> bool {
self.variants
.iter()
.chain(self.decoded.iter())
.any(|frag| fragments.contains(frag))
}
}
#[derive(Default, Clone, Debug)]
pub(crate) struct FragmentChecker {
cache: Arc<Mutex<HashMap<String, HashSet<String>>>>,
}
impl FragmentChecker {
pub(crate) fn new() -> Self {
Self {
cache: Arc::default(),
}
}
pub(crate) async fn check(&self, input: FragmentInput<'_>, url: &Url) -> Result<bool> {
let Some(fragment) = url.fragment() else {
return Ok(true);
};
if fragment.is_empty() || fragment.eq_ignore_ascii_case("top") {
return Ok(true);
}
let url_without_frag = Self::remove_fragment(url.clone());
let FragmentInput { content, file_type } = input;
let extractor = match file_type {
FileType::Markdown => extract_markdown_fragments,
FileType::Html => extract_html_fragments,
FileType::Css | FileType::Plaintext => {
info!("Skipping fragment check for {url} within a {file_type} file");
return Ok(true);
}
};
let fragment_candidates = FragmentBuilder::new(fragment, url, file_type)?;
match self.cache.lock().await.entry(url_without_frag) {
Entry::Vacant(entry) => {
let file_frags = extractor(&content);
let contains_fragment = fragment_candidates.any_matches(&file_frags);
entry.insert(file_frags);
Ok(contains_fragment)
}
Entry::Occupied(entry) => {
let file_frags = entry.get();
Ok(fragment_candidates.any_matches(file_frags))
}
}
}
fn remove_fragment(mut url: Url) -> String {
url.set_fragment(None);
url.into()
}
}