use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::Result;
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
use nab::content::ContentRouter;
use nab::rate_limit::DomainRateLimiter;
use nab::site::SiteRouter;
use super::fetch::{build_client, non_empty};
const DEFAULT_CONCURRENCY: usize = 10;
const DOMAIN_DELAY_MS: u64 = 200;
struct FetchedSource {
url: String,
title: String,
body: String,
ok: bool,
}
pub async fn cmd_context(urls: &[String], cookies: &str, max_tokens: usize) -> Result<()> {
anyhow::ensure!(!urls.is_empty(), "at least one URL is required");
let total_start = Instant::now();
let n = urls.len();
eprintln!("Fetching {n} URL(s) …");
let semaphore = Arc::new(Semaphore::new(DEFAULT_CONCURRENCY));
let limiter = Arc::new(DomainRateLimiter::new(Duration::from_millis(
DOMAIN_DELAY_MS,
)));
let cookies_owned = cookies.to_owned();
let char_budget = budget_per_source(max_tokens, n);
let mut set: JoinSet<(usize, FetchedSource)> = JoinSet::new();
for (idx, url) in urls.iter().enumerate() {
let url = url.clone();
let sem = semaphore.clone();
let lim = limiter.clone();
let cookies = cookies_owned.clone();
set.spawn(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
let domain = super::extract_domain(&url);
lim.wait(&domain).await;
let source = fetch_one(&url, &cookies, char_budget).await;
(idx, source)
});
}
let mut ordered: Vec<Option<FetchedSource>> = (0..n).map(|_| None).collect();
let mut completed = 0usize;
while let Some(res) = set.join_next().await {
completed += 1;
match res {
Ok((idx, source)) => {
let status = if source.ok { "ok" } else { "err" };
eprintln!(" [{completed}/{n}] {status}: {}", source.url);
ordered[idx] = Some(source);
}
Err(e) => eprintln!(" [{completed}/{n}] panic: {e}"),
}
}
let elapsed = total_start.elapsed();
emit_combined_markdown(ordered, n, elapsed);
eprintln!("Done in {:.2}s", elapsed.as_secs_f64());
Ok(())
}
fn budget_per_source(max_tokens: usize, n_sources: usize) -> usize {
let total_chars = max_tokens.saturating_mul(4);
total_chars / n_sources.max(1)
}
async fn fetch_one(url: &str, cookies: &str, char_budget: usize) -> FetchedSource {
match fetch_one_inner(url, cookies).await {
Ok((title, body)) => {
let truncated = truncate_to_budget(body, char_budget);
FetchedSource {
url: url.to_owned(),
title,
body: truncated,
ok: true,
}
}
Err(e) => FetchedSource {
url: url.to_owned(),
title: url.to_owned(),
body: format!("*Error fetching this source: {e}*"),
ok: false,
},
}
}
async fn fetch_one_inner(url: &str, cookies: &str) -> Result<(String, String)> {
let client = build_client(false, None, false)?;
let domain = super::extract_domain(url);
let cookie_header = super::resolve_cookie_header(cookies, &domain);
let cookie_opt = non_empty(&cookie_header);
let site_router = SiteRouter::new();
if let Some(content) = site_router.try_extract(url, &client, cookie_opt).await {
let title = content
.metadata
.title
.clone()
.or_else(|| extract_title_from_markdown(&content.markdown))
.unwrap_or_else(|| url.to_owned());
return Ok((title, content.markdown));
}
let mut request = client
.inner()
.get(url)
.headers(client.profile().await.to_headers());
if let Some(cv) = cookie_opt {
request = request.header("Cookie", cv);
}
let response = request.send().await?;
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("text/html")
.to_owned();
let bytes = response.bytes().await?;
let content_router = ContentRouter::new();
let fetch_url = url.to_owned();
let ct = content_type.clone();
let result = tokio::task::spawn_blocking(move || {
content_router.convert_with_url(&bytes, &ct, Some(&fetch_url))
})
.await??;
let title = extract_title_from_markdown(&result.markdown).unwrap_or_else(|| url.to_owned());
Ok((title, result.markdown))
}
fn truncate_to_budget(body: String, budget: usize) -> String {
if budget == 0 || body.len() <= budget {
return body;
}
let safe_budget = body.floor_char_boundary(budget);
let cut = body[..safe_budget]
.rfind('\n')
.map_or(safe_budget, |pos| pos + 1);
let mut out = body[..cut].to_owned();
out.push_str("\n\n*[content truncated for context budget]*");
out
}
fn extract_title_from_markdown(md: &str) -> Option<String> {
md.lines()
.find(|l| l.starts_with("# "))
.map(|l| l.trim_start_matches('#').trim().to_owned())
}
fn emit_combined_markdown(
sources: Vec<Option<FetchedSource>>,
n: usize,
elapsed: std::time::Duration,
) {
println!("# Context: {n} source{}", if n == 1 { "" } else { "s" });
println!();
println!(
"*Assembled in {:.2}s from {n} URL{}*",
elapsed.as_secs_f64(),
if n == 1 { "" } else { "s" }
);
for (i, slot) in sources.into_iter().enumerate() {
println!("\n---\n");
let source_num = i + 1;
let Some(src) = slot else {
println!("## Source {source_num}: (unavailable)");
println!();
println!("*This source could not be retrieved.*");
continue;
};
let status = if src.ok { "" } else { " [ERROR]" };
let label = if src.title == src.url {
src.url.clone()
} else {
format!("{} ({})", src.title, src.url)
};
println!("## Source {source_num}{status}: {label}");
println!();
println!("{}", src.body);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn budget_per_source_divides_evenly_across_sources() {
assert_eq!(budget_per_source(8_000, 4), 8_000);
}
#[test]
fn budget_per_source_single_source_gets_full_budget() {
assert_eq!(budget_per_source(8_000, 1), 32_000);
}
#[test]
fn budget_per_source_zero_sources_avoids_division_by_zero() {
let result = budget_per_source(8_000, 0);
assert_eq!(
result, 32_000,
"zero sources treated as 1 to avoid div-by-zero"
);
}
#[test]
fn budget_per_source_zero_max_tokens_returns_zero() {
assert_eq!(budget_per_source(0, 4), 0);
}
#[test]
fn truncate_to_budget_short_body_unchanged() {
let body = "short".to_owned();
let result = truncate_to_budget(body.clone(), 1_000);
assert_eq!(result, body);
}
#[test]
fn truncate_to_budget_zero_budget_passes_through() {
let body = "x".repeat(100_000);
let result = truncate_to_budget(body.clone(), 0);
assert_eq!(result, body, "budget=0 means unlimited");
}
#[test]
fn truncate_to_budget_appends_notice_when_truncated() {
let body = "line one\nline two\nline three\n".to_owned();
let result = truncate_to_budget(body, 15);
assert!(
result.contains("[content truncated"),
"should contain truncation notice, got: {result}"
);
}
#[test]
fn truncate_to_budget_cuts_at_newline_boundary() {
let body = "aaaa\nbbbb\ncccc\n".to_owned();
let result = truncate_to_budget(body, 8);
assert!(result.starts_with("aaaa\n"), "should cut at line boundary");
assert!(!result.contains("bbbb"), "should not include next line");
}
#[test]
fn extract_title_finds_h1_heading() {
let md = "# My Page Title\n\nSome paragraph text.";
assert_eq!(
extract_title_from_markdown(md),
Some("My Page Title".to_owned())
);
}
#[test]
fn extract_title_ignores_h2_and_lower() {
let md = "## Subtitle\n\nNo h1 here.";
assert_eq!(extract_title_from_markdown(md), None);
}
#[test]
fn extract_title_returns_none_for_empty_markdown() {
assert_eq!(extract_title_from_markdown(""), None);
}
#[test]
fn extract_title_strips_leading_hashes_and_whitespace() {
let md = "# Padded Title \n\nBody.";
assert_eq!(
extract_title_from_markdown(md),
Some("Padded Title".to_owned())
);
}
#[tokio::test]
async fn cmd_context_empty_urls_returns_error() {
let result = cmd_context(&[], "none", 8_000).await;
assert!(result.is_err(), "empty URL list should be an error");
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("at least one URL"),
"error message should be descriptive, got: {msg}"
);
}
}