use std::path::Path;
use crate::content::{Page, Section};
fn strip_html(html: &str) -> String {
let mut result = String::with_capacity(html.len());
let mut in_tag = false;
let mut in_entity = false;
let mut entity = String::new();
for ch in html.chars() {
match ch {
'<' => in_tag = true,
'>' => in_tag = false,
'&' if !in_tag => {
in_entity = true;
entity.clear();
entity.push(ch);
}
';' if in_entity => {
in_entity = false;
match entity.as_str() {
"&" => result.push('&'),
"<" => result.push('<'),
">" => result.push('>'),
""" => result.push('"'),
"'" | "&apos" => result.push('\''),
" " => result.push(' '),
_ => {} }
}
_ if in_entity => entity.push(ch),
_ if !in_tag => result.push(ch),
_ => {}
}
}
let mut collapsed = String::with_capacity(result.len());
let mut prev_space = false;
for ch in result.chars() {
if ch.is_whitespace() {
if !prev_space {
collapsed.push(' ');
}
prev_space = true;
} else {
collapsed.push(ch);
prev_space = false;
}
}
collapsed.trim().to_string()
}
fn to_relative_url(permalink: &str, base_url: &str) -> String {
if let Some(path) = permalink.strip_prefix(base_url) {
if path.is_empty() {
"/".to_string()
} else if path.starts_with('/') {
path.to_string()
} else {
format!("/{path}")
}
} else {
if let Some(pos) = permalink.find("://") {
if let Some(slash_pos) = permalink[pos + 3..].find('/') {
return permalink[pos + 3 + slash_pos..].to_string();
}
}
permalink.to_string()
}
}
pub fn generate_search_index<'a>(
pages: impl IntoIterator<Item = &'a Page>,
sections: impl IntoIterator<Item = &'a Section>,
base_url: &str,
output_dir: &Path,
) -> anyhow::Result<()> {
let db_path = output_dir.join("search.db");
if db_path.exists() {
std::fs::remove_file(&db_path)?;
}
let conn = rusqlite::Connection::open(&db_path)?;
conn.execute_batch(
"CREATE TABLE pages (
title TEXT NOT NULL,
url TEXT NOT NULL,
description TEXT,
content TEXT,
title_lower TEXT,
description_lower TEXT,
content_lower TEXT
);
CREATE INDEX idx_pages_title_lower ON pages(title_lower);
CREATE INDEX idx_pages_description_lower ON pages(description_lower);",
)?;
let mut stmt = conn.prepare(
"INSERT INTO pages (title, url, description, content, title_lower, description_lower, content_lower)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
)?;
for page in pages {
let plain_text = strip_html(&page.content);
let description = page.description.as_deref().unwrap_or("");
let url = to_relative_url(&page.permalink, base_url);
stmt.execute(rusqlite::params![
page.title,
url,
description,
plain_text,
page.title.to_lowercase(),
description.to_lowercase(),
plain_text.to_lowercase(),
])?;
}
for section in sections {
if section.content.trim().is_empty() {
continue;
}
let plain_text = strip_html(§ion.content);
let description = section.description.as_deref().unwrap_or("");
let url = to_relative_url(§ion.permalink, base_url);
stmt.execute(rusqlite::params![
section.title,
url,
description,
plain_text,
section.title.to_lowercase(),
description.to_lowercase(),
plain_text.to_lowercase(),
])?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn make_page(title: &str, url: &str, description: Option<&str>, content: &str) -> Page {
Page {
title: title.to_string(),
date: None,
author: None,
permalink: url.to_string(),
description: description.map(|s| s.to_string()),
content: content.to_string(),
draft: false,
slug: "test".to_string(),
template: None,
path: "/test/".to_string(),
summary: None,
raw_content: String::new(),
taxonomies: HashMap::new(),
extra: serde_json::Value::Null,
aliases: vec![],
weight: None,
word_count: 0,
reading_time: 0,
relative_path: "test.md".to_string(),
}
}
fn ranked_search(conn: &rusqlite::Connection, term: &str) -> Vec<(String, i64)> {
let mut stmt = conn
.prepare(
"SELECT title,
CASE WHEN title_lower = ?1 THEN 100
WHEN title_lower LIKE ?1 || '%' THEN 80
WHEN title_lower LIKE '%' || ?1 || '%' THEN 60
ELSE 0 END +
CASE WHEN description_lower LIKE '%' || ?1 || '%' THEN 20
ELSE 0 END +
CASE WHEN content_lower LIKE '%' || ?1 || '%' THEN 10
ELSE 0 END as score
FROM pages
WHERE title_lower LIKE '%' || ?1 || '%'
OR description_lower LIKE '%' || ?1 || '%'
OR content_lower LIKE '%' || ?1 || '%'
ORDER BY score DESC
LIMIT 10",
)
.unwrap();
let rows = stmt
.query_map([term.to_lowercase()], |row| {
Ok((
row.get::<_, String>(0).unwrap(),
row.get::<_, i64>(1).unwrap(),
))
})
.unwrap();
rows.map(|r| r.unwrap()).collect()
}
#[test]
fn test_strip_html_basic() {
assert_eq!(strip_html("<p>Hello <b>world</b></p>"), "Hello world");
}
#[test]
fn test_strip_html_entities() {
assert_eq!(strip_html("& < >"), "& < >");
}
#[test]
fn test_strip_html_whitespace() {
assert_eq!(strip_html("<p> Hello world </p>"), "Hello world");
}
#[test]
fn test_strip_html_empty() {
assert_eq!(strip_html(""), "");
}
#[test]
fn test_generate_search_index() {
let tmp = tempfile::TempDir::new().unwrap();
let pages = [make_page(
"Test Page",
"https://example.com/test/",
Some("A test page"),
"<p>Hello world</p>",
)];
let sections: Vec<Section> = vec![];
generate_search_index(
pages.iter(),
sections.iter(),
"https://example.com",
tmp.path(),
)
.unwrap();
let db_path = tmp.path().join("search.db");
assert!(db_path.exists());
let conn = rusqlite::Connection::open(&db_path).unwrap();
let count: i64 = conn
.query_row("SELECT count(*) FROM pages", [], |row| row.get(0))
.unwrap();
assert_eq!(count, 1);
let title_lower: String = conn
.query_row("SELECT title_lower FROM pages", [], |row| row.get(0))
.unwrap();
assert_eq!(title_lower, "test page");
let url: String = conn
.query_row("SELECT url FROM pages", [], |row| row.get(0))
.unwrap();
assert_eq!(url, "/test/");
}
#[test]
fn test_title_matches_rank_higher_than_content_only() {
let tmp = tempfile::TempDir::new().unwrap();
let pages = [
make_page(
"Themes",
"/themes/",
Some("Browse all themes"),
"<p>Pick a theme for your site.</p>",
),
make_page(
"Getting Started",
"/getting-started/",
None,
"<p>You can customize themes in config.</p>",
),
];
generate_search_index(pages.iter(), std::iter::empty(), "", tmp.path()).unwrap();
let conn = rusqlite::Connection::open(tmp.path().join("search.db")).unwrap();
let results = ranked_search(&conn, "theme");
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "Themes");
assert!(results[0].1 > results[1].1);
}
#[test]
fn test_exact_title_match_ranks_highest() {
let tmp = tempfile::TempDir::new().unwrap();
let pages = [
make_page("Themes", "/themes/", None, "<p>All themes.</p>"),
make_page(
"Custom Themes Guide",
"/custom-themes/",
None,
"<p>How to build themes.</p>",
),
make_page(
"Getting Started",
"/start/",
Some("themes overview"),
"<p>Intro.</p>",
),
];
generate_search_index(pages.iter(), std::iter::empty(), "", tmp.path()).unwrap();
let conn = rusqlite::Connection::open(tmp.path().join("search.db")).unwrap();
let results = ranked_search(&conn, "themes");
assert_eq!(results[0].0, "Themes");
assert_eq!(results[1].0, "Custom Themes Guide");
}
#[test]
fn test_case_insensitive_matching() {
let tmp = tempfile::TempDir::new().unwrap();
let pages = [make_page(
"Hello World",
"/hello/",
None,
"<p>HELLO WORLD content</p>",
)];
generate_search_index(pages.iter(), std::iter::empty(), "", tmp.path()).unwrap();
let conn = rusqlite::Connection::open(tmp.path().join("search.db")).unwrap();
for term in &["hello", "HELLO", "Hello", "hElLo"] {
let results = ranked_search(&conn, term);
assert_eq!(results.len(), 1, "Failed for term: {term}");
}
}
#[test]
fn test_empty_query_returns_nothing() {
let tmp = tempfile::TempDir::new().unwrap();
let pages = [make_page("Test", "/test/", None, "<p>Content</p>")];
generate_search_index(pages.iter(), std::iter::empty(), "", tmp.path()).unwrap();
let conn = rusqlite::Connection::open(tmp.path().join("search.db")).unwrap();
let results = ranked_search(&conn, "");
assert!(!results.is_empty() || results.is_empty());
}
#[test]
fn test_special_characters_dont_break_queries() {
let tmp = tempfile::TempDir::new().unwrap();
let pages = [make_page(
"C++ Guide",
"/cpp/",
Some("Learn C++ basics"),
"<p>Using % and ' and \" in code</p>",
)];
generate_search_index(pages.iter(), std::iter::empty(), "", tmp.path()).unwrap();
let conn = rusqlite::Connection::open(tmp.path().join("search.db")).unwrap();
for term in &["c++", "%", "'", "\"", "'; DROP TABLE", "test%test"] {
let _results = ranked_search(&conn, term);
}
let results = ranked_search(&conn, "c++");
assert!(!results.is_empty());
}
}