use crate::docset::{Docset, RefreshStatus};
use crate::indexing::IndexManager;
use crate::retrieval::HybridRetriever;
use crate::storage::{AccessContext, AccessLevel, Storage};
use crate::{
Chunk, Document, DocumentContent, DocumentType, EmbeddingIds, Error, Result, Source, SourceType,
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::{BTreeMap, HashSet, VecDeque};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::task::JoinSet;
use tokio::time::Duration;
use url::Url;
pub struct DocsetIngestor {
http: reqwest::Client,
retriever: HybridRetriever,
}
#[derive(Debug, Clone)]
pub struct DocsetIngestOptions {
pub max_pages: usize,
pub concurrency: usize,
pub request_timeout: Duration,
pub user_agent: String,
pub min_main_text_len: usize,
pub chunk_target_chars: usize,
pub chunk_overlap_chars: usize,
pub manifest_dir: Option<PathBuf>,
pub gc_removed_pages: bool,
pub refresh_due_only: bool,
}
impl Default for DocsetIngestOptions {
fn default() -> Self {
Self {
max_pages: 10_000,
concurrency: 12,
request_timeout: Duration::from_secs(30),
user_agent: format!("reasonkit-mem-docset/{}", env!("CARGO_PKG_VERSION")),
min_main_text_len: 200,
chunk_target_chars: 2000,
chunk_overlap_chars: 200,
manifest_dir: None,
gc_removed_pages: true,
refresh_due_only: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocsetIngestReport {
pub docset_id: uuid::Uuid,
pub docset_name: String,
pub started_at: DateTime<Utc>,
pub finished_at: DateTime<Utc>,
pub discovery_method: String,
pub discovered_urls: usize,
pub fetched_pages: usize,
pub indexed_pages: usize,
pub skipped_unchanged: usize,
pub removed_pages: usize,
pub failures: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum UrlDiscoveryMethod {
Sitemap,
LinkCrawl,
}
impl UrlDiscoveryMethod {
fn as_str(&self) -> &'static str {
match self {
UrlDiscoveryMethod::Sitemap => "sitemap",
UrlDiscoveryMethod::LinkCrawl => "link_crawl",
}
}
}
#[derive(Debug)]
struct UrlDiscovery {
urls: Vec<Url>,
method: UrlDiscoveryMethod,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
struct DocsetHttpCacheEntry {
#[serde(default)]
etag: Option<String>,
#[serde(default)]
last_modified: Option<String>,
#[serde(default)]
last_checked_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct DocsetManifest {
docset_id: uuid::Uuid,
updated_at: DateTime<Utc>,
discovery_method: UrlDiscoveryMethod,
urls: Vec<String>,
#[serde(default)]
http_cache: BTreeMap<String, DocsetHttpCacheEntry>,
}
impl DocsetIngestor {
pub fn new(retriever: HybridRetriever) -> Result<Self> {
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| Error::network(format!("Failed to build HTTP client: {e}")))?;
Ok(Self { http, retriever })
}
pub fn with_http(mut self, client: reqwest::Client) -> Self {
self.http = client;
self
}
pub fn retriever(&self) -> &HybridRetriever {
&self.retriever
}
pub fn into_retriever(self) -> HybridRetriever {
self.retriever
}
fn admin_context(&self, operation: &str) -> AccessContext {
AccessContext::new(
"docset".to_string(),
AccessLevel::Admin,
operation.to_string(),
)
}
pub async fn ingest_docset(
&self,
docset: &mut Docset,
opts: &DocsetIngestOptions,
) -> Result<DocsetIngestReport> {
let started_at = Utc::now();
if opts.refresh_due_only && !docset.is_due(started_at) {
return Ok(DocsetIngestReport {
docset_id: docset.id,
docset_name: docset.name.clone(),
started_at,
finished_at: started_at,
discovery_method: "skipped".to_string(),
discovered_urls: 0,
fetched_pages: 0,
indexed_pages: 0,
skipped_unchanged: 0,
removed_pages: 0,
failures: 0,
});
}
let docset_start = Url::parse(&docset.start_url)
.map_err(|e| Error::validation(format!("Invalid docset start_url: {e}")))?;
let docset_start = canonicalize_url(docset_start);
let allowed_prefixes: Vec<Url> = docset
.allowed_prefixes
.iter()
.map(|s| Url::parse(s).map(canonicalize_url))
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::validation(format!("Invalid allowed_prefix in docset: {e}")))?;
let discovery = self
.discover_urls(
&docset_start,
&allowed_prefixes,
opts.max_pages,
&opts.user_agent,
opts.request_timeout,
)
.await?;
let discovered_url_strings: Vec<String> = discovery
.urls
.iter()
.map(|u| u.as_str().to_string())
.collect();
let use_manifest_cache = opts.manifest_dir.is_some();
let manifest_path = opts
.manifest_dir
.as_deref()
.map(|dir| docset_manifest_path(dir, docset.id));
let mut http_cache: BTreeMap<String, DocsetHttpCacheEntry> = BTreeMap::new();
if let Some(path) = manifest_path.as_deref() {
if let Some(existing) = load_docset_manifest(path).await? {
http_cache = existing.http_cache;
}
}
let removed_pages = if opts.gc_removed_pages {
self.maybe_gc_removed_pages(
docset.id,
opts.manifest_dir.as_deref(),
discovery.method,
&discovered_url_strings,
)
.await?
} else {
0
};
let mut report = DocsetIngestReport {
docset_id: docset.id,
docset_name: docset.name.clone(),
started_at,
finished_at: started_at,
discovery_method: discovery.method.as_str().to_string(),
discovered_urls: discovery.urls.len(),
fetched_pages: 0,
indexed_pages: 0,
skipped_unchanged: 0,
removed_pages,
failures: 0,
};
let mut set: JoinSet<Result<PageFetchOutcome>> = JoinSet::new();
let mut in_flight = 0usize;
let http = self.http.clone();
let user_agent = Arc::<str>::from(opts.user_agent.clone());
let request_timeout = opts.request_timeout;
let min_main_text_len = opts.min_main_text_len;
let mut iter = discovery.urls.into_iter();
while in_flight < opts.concurrency {
if let Some(url) = iter.next() {
let http = http.clone();
let user_agent = user_agent.clone();
let prior_cache = http_cache.get(url.as_str()).cloned();
set.spawn(async move {
fetch_and_prepare_page_with_http(
http,
url,
min_main_text_len,
user_agent,
request_timeout,
prior_cache,
)
.await
});
in_flight += 1;
} else {
break;
}
}
while let Some(res) = set.join_next().await {
in_flight = in_flight.saturating_sub(1);
match res {
Ok(Ok(outcome)) => {
if use_manifest_cache {
let requested_key = outcome.requested_url.as_str().to_string();
let final_key = outcome.url.as_str().to_string();
http_cache.insert(requested_key, outcome.cache.clone());
http_cache.insert(final_key, outcome.cache);
}
if outcome.not_modified {
report.skipped_unchanged += 1;
} else if let Some(page_doc) = outcome.page {
report.fetched_pages += 1;
match self.index_page(docset, page_doc, opts).await {
Ok(IndexOutcome::Indexed) => report.indexed_pages += 1,
Ok(IndexOutcome::SkippedUnchanged) => report.skipped_unchanged += 1,
Err(_) => report.failures += 1,
}
} else {
report.fetched_pages += 1;
}
}
Ok(Err(_)) | Err(_) => {
report.failures += 1;
}
}
if let Some(url) = iter.next() {
let http = http.clone();
let user_agent = user_agent.clone();
let prior_cache = http_cache.get(url.as_str()).cloned();
set.spawn(async move {
fetch_and_prepare_page_with_http(
http,
url,
min_main_text_len,
user_agent,
request_timeout,
prior_cache,
)
.await
});
in_flight += 1;
}
}
let finished_at = Utc::now();
report.finished_at = finished_at;
self.maybe_write_manifest(
docset.id,
opts.manifest_dir.as_deref(),
discovery.method,
&discovered_url_strings,
&http_cache,
)
.await?;
if report.failures == 0 {
docset.status = RefreshStatus::Ok { at: finished_at };
} else {
docset.status = RefreshStatus::Error {
at: finished_at,
message: format!("{} page(s) failed during ingestion", report.failures),
};
}
Ok(report)
}
async fn maybe_gc_removed_pages(
&self,
docset_id: uuid::Uuid,
manifest_dir: Option<&Path>,
method: UrlDiscoveryMethod,
discovered_urls: &[String],
) -> Result<usize> {
let Some(manifest_dir) = manifest_dir else {
return Ok(0);
};
if method != UrlDiscoveryMethod::Sitemap {
return Ok(0);
}
let manifest_path = docset_manifest_path(manifest_dir, docset_id);
let Some(existing) = load_docset_manifest(&manifest_path).await? else {
return Ok(0);
};
if existing.discovery_method != UrlDiscoveryMethod::Sitemap {
return Ok(0);
}
if !existing.urls.is_empty()
&& discovered_urls.len() * 2 < existing.urls.len()
&& existing.urls.len() >= 50
{
return Ok(0);
}
let new_set: HashSet<&str> = discovered_urls.iter().map(|s| s.as_str()).collect();
let mut removed = 0usize;
for old_url in &existing.urls {
if new_set.contains(old_url.as_str()) {
continue;
}
if let Ok(url) = Url::parse(old_url) {
let doc_id = doc_id_for_url(&url);
let _ = self.retriever.delete_document(&doc_id).await;
removed += 1;
}
}
Ok(removed)
}
async fn maybe_write_manifest(
&self,
docset_id: uuid::Uuid,
manifest_dir: Option<&Path>,
method: UrlDiscoveryMethod,
discovered_urls: &[String],
http_cache: &BTreeMap<String, DocsetHttpCacheEntry>,
) -> Result<()> {
let Some(manifest_dir) = manifest_dir else {
return Ok(());
};
tokio::fs::create_dir_all(manifest_dir)
.await
.map_err(|e| Error::io(format!("Failed to create docset manifest dir: {e}")))?;
let manifest_path = docset_manifest_path(manifest_dir, docset_id);
let mut urls = discovered_urls.to_vec();
if method == UrlDiscoveryMethod::Sitemap {
if let Some(existing) = load_docset_manifest(&manifest_path).await? {
if existing.discovery_method == UrlDiscoveryMethod::Sitemap
&& !existing.urls.is_empty()
&& discovered_urls.len() * 2 < existing.urls.len()
&& existing.urls.len() >= 50
{
urls = existing.urls;
}
}
}
urls.sort();
urls.dedup();
let mut filtered_cache: BTreeMap<String, DocsetHttpCacheEntry> = BTreeMap::new();
for url in &urls {
if let Some(entry) = http_cache.get(url) {
filtered_cache.insert(url.clone(), entry.clone());
}
}
let manifest = DocsetManifest {
docset_id,
updated_at: Utc::now(),
discovery_method: method,
urls,
http_cache: filtered_cache,
};
write_docset_manifest(&manifest_path, &manifest).await
}
async fn discover_urls(
&self,
start_url: &Url,
allowed_prefixes: &[Url],
max_pages: usize,
user_agent: &str,
request_timeout: Duration,
) -> Result<UrlDiscovery> {
let mut urls = self
.discover_urls_via_sitemaps(start_url, allowed_prefixes, user_agent, request_timeout)
.await?;
if !urls.is_empty() {
urls.sort_by(|a, b| a.as_str().cmp(b.as_str()));
urls.dedup();
urls.truncate(max_pages);
return Ok(UrlDiscovery {
urls,
method: UrlDiscoveryMethod::Sitemap,
});
}
let urls = self
.discover_urls_via_link_crawl(
start_url,
allowed_prefixes,
max_pages,
user_agent,
request_timeout,
)
.await?;
Ok(UrlDiscovery {
urls,
method: UrlDiscoveryMethod::LinkCrawl,
})
}
async fn discover_urls_via_sitemaps(
&self,
start_url: &Url,
allowed_prefixes: &[Url],
user_agent: &str,
request_timeout: Duration,
) -> Result<Vec<Url>> {
let mut sitemap_urls: Vec<Url> = Vec::new();
if let Ok(robots) = self
.fetch_robots_txt(start_url, user_agent, request_timeout)
.await
{
let root = start_url.join("/").ok();
for sitemap in parse_robots_sitemaps(&robots) {
if let Ok(url) = Url::parse(&sitemap) {
sitemap_urls.push(url);
} else if let Some(root) = &root {
if let Ok(url) = root.join(&sitemap) {
sitemap_urls.push(url);
}
}
}
}
if let Ok(root) = start_url.join("/") {
if let Ok(url) = root.join("sitemap.xml") {
sitemap_urls.push(url);
}
if let Ok(url) = root.join("sitemap_index.xml") {
sitemap_urls.push(url);
}
}
sitemap_urls.sort_by(|a, b| a.as_str().cmp(b.as_str()));
sitemap_urls.dedup();
if sitemap_urls.is_empty() {
return Ok(Vec::new());
}
let mut discovered: Vec<Url> = Vec::new();
let mut seen_sitemaps: HashSet<String> = HashSet::new();
let mut queue: VecDeque<Url> = sitemap_urls.into_iter().collect();
while let Some(sitemap_url) = queue.pop_front() {
if !seen_sitemaps.insert(sitemap_url.as_str().to_string()) {
continue;
}
let xml = match self
.fetch_text(&sitemap_url, user_agent, request_timeout)
.await
{
Ok(v) => v,
Err(_) => continue,
};
for loc in extract_sitemap_locs(&xml) {
if let Ok(url) = Url::parse(&loc) {
if url.path().ends_with(".xml") {
queue.push_back(url);
continue;
}
let url = canonicalize_url(url);
if is_allowed(&url, allowed_prefixes) {
discovered.push(url);
}
}
}
}
Ok(discovered)
}
async fn discover_urls_via_link_crawl(
&self,
start_url: &Url,
allowed_prefixes: &[Url],
max_pages: usize,
user_agent: &str,
request_timeout: Duration,
) -> Result<Vec<Url>> {
let mut queue: VecDeque<Url> = VecDeque::new();
let mut seen: HashSet<String> = HashSet::new();
let mut out: Vec<Url> = Vec::new();
let start = canonicalize_url(start_url.clone());
queue.push_back(start.clone());
seen.insert(start.as_str().to_string());
while let Some(url) = queue.pop_front() {
if out.len() >= max_pages {
break;
}
out.push(url.clone());
let html = match self.fetch_text(&url, user_agent, request_timeout).await {
Ok(v) => v,
Err(_) => continue,
};
for link in extract_links(&html, &url) {
let link = canonicalize_url(link);
if !is_allowed(&link, allowed_prefixes) {
continue;
}
let key = link.as_str().to_string();
if seen.insert(key) {
queue.push_back(link);
}
}
}
Ok(out)
}
async fn fetch_robots_txt(
&self,
start_url: &Url,
user_agent: &str,
request_timeout: Duration,
) -> Result<String> {
let root = start_url
.join("/")
.map_err(|e| Error::network(format!("Failed to compute site root: {e}")))?;
let robots_url = root
.join("robots.txt")
.map_err(|e| Error::network(format!("Failed to compute robots.txt URL: {e}")))?;
self.fetch_text(&robots_url, user_agent, request_timeout)
.await
}
async fn fetch_text(
&self,
url: &Url,
user_agent: &str,
request_timeout: Duration,
) -> Result<String> {
let req = self
.http
.get(url.clone())
.header(reqwest::header::USER_AGENT, user_agent);
let res = tokio::time::timeout(request_timeout, req.send())
.await
.map_err(|_| Error::network(format!("Timed out fetching {url}")))?
.map_err(|e| Error::network(format!("Failed to fetch {url}: {e}")))?;
let status = res.status();
if !status.is_success() {
return Err(Error::network(format!(
"Failed to fetch {url}: HTTP {status}"
)));
}
res.text()
.await
.map_err(|e| Error::network(format!("Failed to read body for {url}: {e}")))
}
async fn index_page(
&self,
docset: &Docset,
page: PageDoc,
opts: &DocsetIngestOptions,
) -> Result<IndexOutcome> {
let doc_id = doc_id_for_url(&page.url);
let context = self.admin_context("docset_index_page");
if let Some(existing) = self
.retriever
.storage()
.get_document(&doc_id, &context)
.await?
{
if existing.source.version.as_deref() == Some(page.content_hash.as_str()) {
return Ok(IndexOutcome::SkippedUnchanged);
}
self.retriever.delete_document(&doc_id).await?;
}
let source = Source {
source_type: SourceType::Website,
url: Some(page.url.as_str().to_string()),
path: None,
arxiv_id: None,
github_repo: None,
retrieved_at: Utc::now(),
version: Some(page.content_hash.clone()),
};
let mut doc = Document::new(DocumentType::Documentation, source);
doc.id = doc_id;
doc.metadata.tags.push(format!("docset:{}", docset.name));
doc.metadata.tags.push(format!("docset_id:{}", docset.id));
if let Some(title) = page.title {
doc.metadata.title = Some(title);
}
doc.content = DocumentContent {
raw: page.markdown.clone(),
format: crate::ContentFormat::Markdown,
language: "en".to_string(),
word_count: page.markdown.split_whitespace().count(),
char_count: page.markdown.len(),
};
doc.chunks = chunk_markdown(
&page.markdown,
opts.chunk_target_chars,
opts.chunk_overlap_chars,
);
self.retriever.add_document(&doc).await?;
Ok(IndexOutcome::Indexed)
}
}
#[derive(Debug)]
enum IndexOutcome {
Indexed,
SkippedUnchanged,
}
#[derive(Debug)]
struct PageDoc {
url: Url,
title: Option<String>,
markdown: String,
content_hash: String,
}
#[derive(Debug)]
struct PageFetchOutcome {
requested_url: Url,
url: Url,
page: Option<PageDoc>,
cache: DocsetHttpCacheEntry,
not_modified: bool,
}
async fn fetch_and_prepare_page_with_http(
http: reqwest::Client,
url: Url,
min_main_text_len: usize,
user_agent: Arc<str>,
request_timeout: Duration,
prior_cache: Option<DocsetHttpCacheEntry>,
) -> Result<PageFetchOutcome> {
let mut cache = prior_cache.unwrap_or_default();
let requested_url = url.clone();
let mut req = http
.get(url.clone())
.header(reqwest::header::USER_AGENT, user_agent.as_ref());
if let Some(etag) = cache
.etag
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())
{
req = req.header(reqwest::header::IF_NONE_MATCH, etag);
}
if let Some(last_modified) = cache
.last_modified
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())
{
req = req.header(reqwest::header::IF_MODIFIED_SINCE, last_modified);
}
let res = tokio::time::timeout(request_timeout, req.send())
.await
.map_err(|_| Error::network(format!("Timed out fetching {url}")))?
.map_err(|e| Error::network(format!("Failed to fetch {url}: {e}")))?;
let final_url = canonicalize_url(res.url().clone());
if let Some(v) = res.headers().get(reqwest::header::ETAG) {
if let Ok(s) = v.to_str() {
let s = s.trim();
if !s.is_empty() {
cache.etag = Some(s.to_string());
}
}
}
if let Some(v) = res.headers().get(reqwest::header::LAST_MODIFIED) {
if let Ok(s) = v.to_str() {
let s = s.trim();
if !s.is_empty() {
cache.last_modified = Some(s.to_string());
}
}
}
cache.last_checked_at = Some(Utc::now());
let status = res.status();
if status == reqwest::StatusCode::NOT_MODIFIED {
return Ok(PageFetchOutcome {
requested_url,
url: final_url,
page: None,
cache,
not_modified: true,
});
}
if !status.is_success() {
return Err(Error::network(format!(
"Failed to fetch {url}: HTTP {status}"
)));
}
let html = res
.text()
.await
.map_err(|e| Error::network(format!("Failed to read body for {url}: {e}")))?;
let (title, main_html, main_text_len) = extract_title_and_main_html(&html, min_main_text_len);
if main_text_len < min_main_text_len {
return Ok(PageFetchOutcome {
requested_url,
url: final_url,
page: None,
cache,
not_modified: false,
});
}
let markdown = html_to_markdown(&main_html);
let content_hash = sha256_hex(markdown.as_bytes());
Ok(PageFetchOutcome {
requested_url,
url: final_url.clone(),
page: Some(PageDoc {
url: final_url,
title,
markdown,
content_hash,
}),
cache,
not_modified: false,
})
}
fn is_allowed(url: &Url, allowed_prefixes: &[Url]) -> bool {
allowed_prefixes
.iter()
.any(|prefix| url_matches_prefix(url, prefix))
}
fn url_matches_prefix(url: &Url, prefix: &Url) -> bool {
if url.scheme() != prefix.scheme() {
return false;
}
if url.host_str() != prefix.host_str() {
return false;
}
if url.port_or_known_default() != prefix.port_or_known_default() {
return false;
}
let url_path = url.path();
let prefix_path = prefix.path();
if prefix_path == "/" {
return true;
}
if url_path == prefix_path {
return true;
}
if !url_path.starts_with(prefix_path) {
return false;
}
let boundary_idx = prefix_path.len();
url_path
.as_bytes()
.get(boundary_idx)
.is_some_and(|b| *b == b'/')
}
fn canonicalize_url(mut url: Url) -> Url {
url.set_fragment(None);
if let Some(port) = url.port() {
let default = match url.scheme() {
"http" => 80,
"https" => 443,
_ => port,
};
if port == default {
let _ = url.set_port(None);
}
}
let path = url.path().to_string();
if path.len() > 1 && path.ends_with('/') {
url.set_path(path.trim_end_matches('/'));
}
if url.query().is_some() {
let mut pairs: Vec<(String, String)> = url
.query_pairs()
.into_owned()
.filter(|(k, _)| !is_tracking_query_key(k))
.collect();
pairs.sort();
url.set_query(None);
if !pairs.is_empty() {
let mut qp = url.query_pairs_mut();
for (k, v) in pairs {
qp.append_pair(&k, &v);
}
}
}
url
}
fn is_tracking_query_key(key: &str) -> bool {
let key = key.trim();
if key.is_empty() {
return false;
}
let k = key.to_ascii_lowercase();
if k.starts_with("utm_") {
return true;
}
matches!(
k.as_str(),
"gclid" | "fbclid" | "mc_cid" | "mc_eid" | "_hsenc" | "_hsmi" | "ref" | "source" | "spm"
)
}
fn docset_manifest_path(manifest_dir: &Path, docset_id: uuid::Uuid) -> PathBuf {
manifest_dir.join(format!("{docset_id}.json"))
}
async fn load_docset_manifest(path: &Path) -> Result<Option<DocsetManifest>> {
if !path.exists() {
return Ok(None);
}
let bytes = tokio::fs::read(path)
.await
.map_err(|e| Error::io(format!("Failed to read docset manifest {:?}: {e}", path)))?;
let manifest: DocsetManifest = serde_json::from_slice(&bytes)
.map_err(|e| Error::parse(format!("Failed to parse docset manifest {:?}: {e}", path)))?;
Ok(Some(manifest))
}
async fn write_docset_manifest(path: &Path, manifest: &DocsetManifest) -> Result<()> {
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await.map_err(|e| {
Error::io(format!(
"Failed to create docset manifest directory {:?}: {e}",
parent
))
})?;
}
let bytes = serde_json::to_vec_pretty(manifest)
.map_err(|e| Error::parse(format!("Failed to serialize docset manifest: {e}")))?;
let tmp_path = path.with_extension("json.tmp");
tokio::fs::write(&tmp_path, bytes)
.await
.map_err(|e| Error::io(format!("Failed to write temp docset manifest: {e}")))?;
tokio::fs::rename(&tmp_path, path)
.await
.map_err(|e| Error::io(format!("Failed to replace docset manifest: {e}")))?;
Ok(())
}
fn parse_robots_sitemaps(robots: &str) -> Vec<String> {
robots
.lines()
.filter_map(|line| {
let line = line.trim();
let lower = line.to_ascii_lowercase();
if !lower.starts_with("sitemap:") {
return None;
}
Some(line["sitemap:".len()..].trim().to_string())
})
.collect()
}
fn extract_sitemap_locs(xml: &str) -> Vec<String> {
let mut out = Vec::new();
let mut rest = xml;
while let Some(start) = rest.find("<loc>") {
let after_start = &rest[start + "<loc>".len()..];
if let Some(end) = after_start.find("</loc>") {
let loc = &after_start[..end];
out.push(loc.trim().to_string());
rest = &after_start[end + "</loc>".len()..];
} else {
break;
}
}
out
}
fn extract_links(html: &str, base: &Url) -> Vec<Url> {
let mut out = Vec::new();
let document = scraper::Html::parse_document(html);
let selector = match scraper::Selector::parse("a[href]") {
Ok(s) => s,
Err(_) => return out,
};
for el in document.select(&selector) {
if let Some(href) = el.value().attr("href") {
let href_trim = href.trim();
if href_trim.starts_with("javascript:") || href_trim.starts_with("mailto:") {
continue;
}
if let Ok(u) = base.join(href_trim) {
out.push(u);
}
}
}
out
}
fn extract_title_and_main_html(
html: &str,
min_main_text_len: usize,
) -> (Option<String>, String, usize) {
let doc = scraper::Html::parse_document(html);
let title = scraper::Selector::parse("title")
.ok()
.and_then(|sel| doc.select(&sel).next())
.map(|t| t.text().collect::<String>().trim().to_string())
.filter(|t| !t.is_empty());
let selectors = [
"article",
"main",
"[role=\"main\"]",
"[role=\"article\"]",
".article",
".post",
".content",
".entry-content",
".post-content",
"#content",
"#main-content",
".main-content",
];
for sel in selectors {
if let Ok(selector) = scraper::Selector::parse(sel) {
for el in doc.select(&selector) {
let text_len = el
.text()
.map(|t| t.trim())
.filter(|t| !t.is_empty())
.map(|t| t.len())
.sum();
if text_len >= min_main_text_len {
return (title, el.inner_html(), text_len);
}
}
}
}
if let Ok(body_sel) = scraper::Selector::parse("body") {
if let Some(body) = doc.select(&body_sel).next() {
let text_len = body
.text()
.map(|t| t.trim())
.filter(|t| !t.is_empty())
.map(|t| t.len())
.sum();
return (title, body.inner_html(), text_len);
}
}
(title, html.to_string(), html.len())
}
fn html_to_markdown(html: &str) -> String {
static SCRIPT_RE: once_cell::sync::Lazy<regex::Regex> = once_cell::sync::Lazy::new(|| {
regex::Regex::new(r"<script[^>]*>[\s\S]*?</script>").unwrap()
});
static STYLE_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| regex::Regex::new(r"<style[^>]*>[\s\S]*?</style>").unwrap());
static BR_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| regex::Regex::new(r"<br\s*/?>").unwrap());
static P_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| regex::Regex::new(r"<p[^>]*>(.*?)</p>").unwrap());
static A_RE: once_cell::sync::Lazy<regex::Regex> = once_cell::sync::Lazy::new(|| {
regex::Regex::new(r#"<a[^>]*href=["']([^"']+)["'][^>]*>(.*?)</a>"#).unwrap()
});
static B_RE: once_cell::sync::Lazy<regex::Regex> = once_cell::sync::Lazy::new(|| {
regex::Regex::new(r"<(b|strong)[^>]*>(.*?)</(b|strong)>").unwrap()
});
static I_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| regex::Regex::new(r"<(i|em)[^>]*>(.*?)</(i|em)>").unwrap());
static CODE_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| regex::Regex::new(r"<code[^>]*>(.*?)</code>").unwrap());
static PRE_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| regex::Regex::new(r"<pre[^>]*>([\s\S]*?)</pre>").unwrap());
static LI_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| regex::Regex::new(r"<li[^>]*>(.*?)</li>").unwrap());
static TAG_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| regex::Regex::new(r"<[^>]+>").unwrap());
static WS_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| regex::Regex::new(r"\n{3,}").unwrap());
let mut md = html.to_string();
md = SCRIPT_RE.replace_all(&md, "").to_string();
md = STYLE_RE.replace_all(&md, "").to_string();
for i in (1..=6).rev() {
let re = regex::Regex::new(&format!(r"<h{}[^>]*>(.*?)</h{}>", i, i)).unwrap();
let prefix = "#".repeat(i);
md = re.replace_all(&md, format!("{prefix} $1\n\n")).to_string();
}
md = P_RE.replace_all(&md, "$1\n\n").to_string();
md = BR_RE.replace_all(&md, "\n").to_string();
md = B_RE.replace_all(&md, "**$2**").to_string();
md = I_RE.replace_all(&md, "*$2*").to_string();
md = A_RE.replace_all(&md, "[$2]($1)").to_string();
md = CODE_RE.replace_all(&md, "`$1`").to_string();
md = PRE_RE.replace_all(&md, "```\n$1\n```").to_string();
md = LI_RE.replace_all(&md, "- $1\n").to_string();
md = TAG_RE.replace_all(&md, "").to_string();
md = decode_html_entities(&md);
md = WS_RE.replace_all(&md, "\n\n").to_string();
md.trim().to_string()
}
fn decode_html_entities(text: &str) -> String {
text.replace(" ", " ")
.replace("<", "<")
.replace(">", ">")
.replace("&", "&")
.replace(""", "\"")
.replace("'", "'")
.replace("'", "'")
.replace("'", "'")
.replace("/", "/")
.replace("©", "(c)")
.replace("®", "(R)")
.replace("™", "(TM)")
.replace("–", "-")
.replace("—", "--")
.replace("…", "...")
.replace("‘", "'")
.replace("’", "'")
.replace("“", "\"")
.replace("”", "\"")
}
fn sha256_hex(bytes: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(bytes);
hex::encode(hasher.finalize())
}
fn doc_id_for_url(url: &Url) -> uuid::Uuid {
let mut hasher = Sha256::new();
hasher.update(b"reasonkit-mem:docset:url:");
hasher.update(url.as_str().as_bytes());
let digest = hasher.finalize();
let mut bytes = [0u8; 16];
bytes.copy_from_slice(&digest[..16]);
uuid::Uuid::from_bytes(bytes)
}
fn chunk_markdown(markdown: &str, target_chars: usize, overlap_chars: usize) -> Vec<Chunk> {
chunk_markdown_semantic(markdown, target_chars, overlap_chars)
}
fn chunk_markdown_semantic(
markdown: &str,
target_chars: usize,
overlap_chars: usize,
) -> Vec<Chunk> {
if markdown.trim().is_empty() {
return Vec::new();
}
let mut chunks = Vec::new();
let mut lines = markdown.lines().enumerate().peekable();
let mut current_section: Option<String> = None;
let mut buf = String::new();
let mut buf_start_char = 0usize;
let mut buf_section: Option<String> = None;
let mut global_char_offset = 0usize;
let mut index = 0usize;
let flush_buf = |buf: &mut String,
buf_start_char: &mut usize,
buf_section: &mut Option<String>,
global_char_offset: usize,
index: &mut usize,
chunks: &mut Vec<Chunk>,
current_section: &Option<String>| {
let text = buf.trim().to_string();
if text.is_empty() {
buf.clear();
return;
}
let end_char = global_char_offset.saturating_sub(1);
chunks.push(Chunk {
id: uuid::Uuid::new_v4(),
text,
index: *index,
start_char: *buf_start_char,
end_char,
token_count: None,
section: current_section.clone(),
page: None,
embedding_ids: EmbeddingIds::default(),
});
*index += 1;
buf.clear();
*buf_start_char = global_char_offset;
*buf_section = None;
};
while let Some((line_idx, line)) = lines.next() {
if line_idx > 0 {
global_char_offset += 1; }
let line_start_offset = global_char_offset;
global_char_offset += line.len();
if line.starts_with('#') {
let heading_level = line.chars().take_while(|&c| c == '#').count();
let heading_text = line[heading_level..].trim();
if !buf.is_empty() {
flush_buf(
&mut buf,
&mut buf_start_char,
&mut buf_section,
line_start_offset,
&mut index,
&mut chunks,
¤t_section,
);
}
current_section = Some(heading_text.to_string());
buf_section = Some(heading_text.to_string());
buf_start_char = line_start_offset;
}
if !buf.is_empty() {
buf.push('\n');
}
buf.push_str(line);
let is_code_fence = line.trim_start().starts_with("```");
let next_is_code_fence = lines
.peek()
.is_some_and(|(_, next_line)| next_line.trim_start().starts_with("```"));
if buf.len() >= target_chars && (!is_code_fence || next_is_code_fence) {
flush_buf(
&mut buf,
&mut buf_start_char,
&mut buf_section,
global_char_offset,
&mut index,
&mut chunks,
¤t_section,
);
}
}
if !buf.is_empty() {
flush_buf(
&mut buf,
&mut buf_start_char,
&mut buf_section,
global_char_offset,
&mut index,
&mut chunks,
¤t_section,
);
}
if overlap_chars > 0 && chunks.len() > 1 {
let mut overlapped = Vec::with_capacity(chunks.len());
for (i, chunk) in chunks.into_iter().enumerate() {
if i == 0 {
overlapped.push(chunk);
continue;
}
let prev = &overlapped[i - 1];
let should_overlap = prev.section == chunk.section;
if should_overlap {
let tail = if prev.text.len() > overlap_chars {
&prev.text[prev.text.len() - overlap_chars..]
} else {
prev.text.as_str()
};
let mut text = String::new();
text.push_str(tail);
text.push('\n');
text.push_str(&chunk.text);
let mut new_chunk = chunk;
new_chunk.text = text;
overlapped.push(new_chunk);
} else {
overlapped.push(chunk);
}
}
chunks = overlapped;
}
chunks
}
fn chunk_markdown_simple(markdown: &str, target_chars: usize, overlap_chars: usize) -> Vec<Chunk> {
if markdown.trim().is_empty() {
return Vec::new();
}
let mut chunks = Vec::new();
let mut buf = String::new();
let mut in_code_block = false;
let mut chunk_start_char = 0usize;
let mut cursor_char_exclusive = 0usize;
let mut index = 0usize;
let flush = |buf: &mut String,
chunk_start_char: &mut usize,
cursor_char_exclusive: usize,
index: &mut usize,
chunks: &mut Vec<Chunk>| {
let text = buf.trim().to_string();
if text.is_empty() {
buf.clear();
return;
}
let end_char = cursor_char_exclusive.saturating_sub(1);
chunks.push(Chunk {
id: uuid::Uuid::new_v4(),
text,
index: *index,
start_char: *chunk_start_char,
end_char,
token_count: None,
section: None,
page: None,
embedding_ids: EmbeddingIds::default(),
});
*index += 1;
buf.clear();
*chunk_start_char = cursor_char_exclusive;
};
for (line_idx, line) in markdown.lines().enumerate() {
if line.trim_start().starts_with("```") {
in_code_block = !in_code_block;
}
if line_idx > 0 {
cursor_char_exclusive += 1;
}
cursor_char_exclusive += line.len();
if !buf.is_empty() {
buf.push('\n');
}
buf.push_str(line);
if !in_code_block && buf.len() >= target_chars {
flush(
&mut buf,
&mut chunk_start_char,
cursor_char_exclusive,
&mut index,
&mut chunks,
);
}
}
flush(
&mut buf,
&mut chunk_start_char,
cursor_char_exclusive,
&mut index,
&mut chunks,
);
if overlap_chars == 0 || chunks.len() < 2 {
return chunks;
}
let mut overlapped = Vec::with_capacity(chunks.len());
for (i, chunk) in chunks.into_iter().enumerate() {
if i == 0 {
overlapped.push(chunk);
continue;
}
let prev = &overlapped[i - 1].text;
let tail = if prev.len() > overlap_chars {
&prev[prev.len() - overlap_chars..]
} else {
prev.as_str()
};
let mut text = String::new();
text.push_str(tail);
text.push('\n');
text.push_str(&chunk.text);
let mut new_chunk = chunk;
new_chunk.text = text;
overlapped.push(new_chunk);
}
overlapped
}
pub async fn open_default_docset_retriever(base_dir: PathBuf) -> Result<HybridRetriever> {
let storage_dir = base_dir.join("storage");
let index_dir = base_dir.join("index");
tokio::fs::create_dir_all(&storage_dir).await.map_err(|e| {
Error::io(format!(
"Failed to create storage dir {:?}: {e}",
storage_dir
))
})?;
tokio::fs::create_dir_all(&index_dir)
.await
.map_err(|e| Error::io(format!("Failed to create index dir {:?}: {e}", index_dir)))?;
let storage = Storage::file(storage_dir).await?;
let index = IndexManager::open(index_dir)?;
Ok(HybridRetriever::new(storage, index))
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::Mutex;
async fn start_etag_test_server() -> (Url, Arc<Mutex<Vec<String>>>, tokio::task::JoinHandle<()>)
{
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let url = Url::parse(&format!("http://{addr}/page")).unwrap();
let requests: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let requests_task = requests.clone();
let handle = tokio::spawn(async move {
let etag = "\"rk-test-etag\"";
let body = "<html><head><title>Test</title></head><body><main>Hello world</main></body></html>";
for _ in 0..2 {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = Vec::new();
let mut tmp = [0u8; 1024];
loop {
let n = socket.read(&mut tmp).await.unwrap();
if n == 0 {
break;
}
buf.extend_from_slice(&tmp[..n]);
if buf.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
if buf.len() > 32 * 1024 {
break;
}
}
let req_text = String::from_utf8_lossy(&buf).to_string();
requests_task.lock().await.push(req_text.clone());
let lower = req_text.to_ascii_lowercase();
let is_conditional =
lower.contains("if-none-match:") && lower.contains("rk-test-etag");
if is_conditional {
let res = format!(
"HTTP/1.1 304 Not Modified\r\nETag: {etag}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"
);
socket.write_all(res.as_bytes()).await.unwrap();
} else {
let res = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nETag: {etag}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
);
socket.write_all(res.as_bytes()).await.unwrap();
}
let _ = socket.shutdown().await;
}
});
(url, requests, handle)
}
#[tokio::test]
async fn conditional_get_uses_etag_and_handles_304() {
let (url, requests, server) = start_etag_test_server().await;
let http = reqwest::Client::builder().build().unwrap();
let ua = Arc::<str>::from("rk-test");
let first = fetch_and_prepare_page_with_http(
http.clone(),
url.clone(),
1,
ua.clone(),
Duration::from_secs(2),
None,
)
.await
.unwrap();
assert!(!first.not_modified);
assert!(first.page.is_some());
assert_eq!(first.cache.etag.as_deref(), Some("\"rk-test-etag\""));
let second = fetch_and_prepare_page_with_http(
http,
url,
1,
ua,
Duration::from_secs(2),
Some(first.cache),
)
.await
.unwrap();
assert!(second.not_modified);
assert!(second.page.is_none());
server.await.unwrap();
let seen = requests.lock().await;
assert_eq!(seen.len(), 2);
assert!(
!seen[0].to_ascii_lowercase().contains("if-none-match:"),
"first request unexpectedly contained If-None-Match"
);
assert!(
seen[1].to_ascii_lowercase().contains("if-none-match:"),
"second request did not contain If-None-Match"
);
}
#[test]
fn test_canonicalize_url_removes_fragment() {
let url = Url::parse("https://example.com/path#fragment").unwrap();
let canonical = canonicalize_url(url);
assert_eq!(canonical.as_str(), "https://example.com/path");
}
#[test]
fn test_canonicalize_url_removes_default_ports() {
let url = Url::parse("https://example.com:443/path").unwrap();
let canonical = canonicalize_url(url);
assert_eq!(canonical.as_str(), "https://example.com/path");
let url = Url::parse("http://example.com:80/path").unwrap();
let canonical = canonicalize_url(url);
assert_eq!(canonical.as_str(), "http://example.com/path");
}
#[test]
fn test_canonicalize_url_normalizes_trailing_slashes() {
let url = Url::parse("https://example.com/path/").unwrap();
let canonical = canonicalize_url(url);
assert_eq!(canonical.as_str(), "https://example.com/path");
let url = Url::parse("https://example.com/").unwrap();
let canonical = canonicalize_url(url);
assert_eq!(canonical.as_str(), "https://example.com/");
}
#[test]
fn test_canonicalize_url_removes_tracking_params() {
let url =
Url::parse("https://example.com/path?utm_source=test&utm_medium=email¶m=value")
.unwrap();
let canonical = canonicalize_url(url);
assert_eq!(canonical.as_str(), "https://example.com/path?param=value");
let url =
Url::parse("https://example.com/path?utm_source=test&gclid=123&fbclid=456").unwrap();
let canonical = canonicalize_url(url);
assert_eq!(canonical.as_str(), "https://example.com/path");
}
#[test]
fn test_canonicalize_url_sorts_query_params() {
let url = Url::parse("https://example.com/path?b=2&a=1&c=3").unwrap();
let canonical = canonicalize_url(url);
assert_eq!(canonical.as_str(), "https://example.com/path?a=1&b=2&c=3");
}
#[test]
fn test_is_allowed_basic_matching() {
let prefix = Url::parse("https://example.com/reference/").unwrap();
let url = Url::parse("https://example.com/reference/hooks/usestate").unwrap();
assert!(is_allowed(&url, &[prefix]));
}
#[test]
fn test_is_allowed_prevents_false_positives() {
let prefix = Url::parse("https://example.com/reference/").unwrap();
let url = Url::parse("https://example.com/reference-old/page").unwrap();
assert!(!is_allowed(&url, &[prefix]));
}
#[test]
fn test_is_allowed_scheme_mismatch() {
let prefix = Url::parse("https://example.com/path/").unwrap();
let url = Url::parse("http://example.com/path/subpage").unwrap();
assert!(!is_allowed(&url, &[prefix]));
}
#[test]
fn test_is_allowed_root_prefix() {
let prefix = Url::parse("https://example.com/").unwrap();
let url = Url::parse("https://example.com/anything/goes").unwrap();
assert!(is_allowed(&url, &[prefix]));
}
#[test]
fn test_chunk_markdown_semantic_respects_headings() {
let markdown = r#"# Introduction
This is the introduction.
## Main Content
This is the main content with more details.
### Subsection
Even more detailed content here.
# Another Section
New section with different content."#;
let chunks = chunk_markdown_semantic(markdown, 100, 20);
assert!(!chunks.is_empty());
let intro_chunks: Vec<_> = chunks
.iter()
.filter(|c| c.section.as_deref() == Some("Introduction"))
.collect();
let main_chunks: Vec<_> = chunks
.iter()
.filter(|c| c.section.as_deref() == Some("Main Content"))
.collect();
assert!(intro_chunks.len() >= 0); assert!(main_chunks.len() >= 0);
}
#[test]
fn test_chunk_markdown_semantic_handles_empty() {
let chunks = chunk_markdown_semantic("", 1000, 100);
assert!(chunks.is_empty());
}
#[test]
fn test_chunk_markdown_semantic_single_chunk() {
let markdown = "This is a short document that fits in one chunk.";
let chunks = chunk_markdown_semantic(markdown, 1000, 100);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].text, markdown);
}
}