use std::collections::HashSet;
use std::fs;
use std::path::PathBuf;
use tempfile::TempDir;
use tldr_core::get_code_structure;
use tldr_core::search::bm25::Bm25Index;
use tldr_core::search::enriched::{
enriched_search,
enriched_search_with_index,
enriched_search_with_structure_cache,
read_structure_cache,
search_with_inner,
write_structure_cache,
EnrichedSearchOptions,
SearchMode,
StructureLookup,
};
use tldr_core::types::Language;
fn create_test_project() -> (TempDir, PathBuf) {
let dir = TempDir::new().unwrap();
let project = dir.path().join("project");
fs::create_dir(&project).unwrap();
fs::write(
project.join("auth.py"),
r#"
TOKEN_SECRET = "my-secret-key"
def verify_token(request):
"""Verify authentication token from request headers."""
token = request.headers.get("Authorization")
if not token:
raise AuthError("Missing token")
claims = decode_token(token)
check_expiry(claims)
return claims
def decode_token(token):
"""Decode a JWT token string into claims."""
import jwt
return jwt.decode(token, key=TOKEN_SECRET)
def refresh_token(old_token):
"""Refresh an expired token."""
claims = decode_token(old_token)
claims["exp"] = new_expiry()
return encode_token(claims)
def check_expiry(claims):
"""Check if token claims have expired."""
if claims["exp"] < time.time():
raise AuthError("Token expired")
class AuthMiddleware:
"""Middleware for authentication."""
def __init__(self, app):
self.app = app
def process_request(self, request):
"""Process incoming request for auth."""
verify_token(request)
return self.app(request)
"#,
)
.unwrap();
fs::write(
project.join("utils.py"),
r#"
def parse_json(text):
"""Parse JSON string into a dictionary."""
import json
return json.loads(text)
def parse_csv(text):
"""Parse CSV string into rows."""
import csv
return list(csv.reader(text.splitlines()))
def format_date(dt):
"""Format a datetime object as ISO string."""
return dt.strftime("%Y-%m-%d")
def validate_email(email):
"""Validate an email address format."""
import re
return re.match(r"^[\w.]+@[\w.]+$", email) is not None
"#,
)
.unwrap();
fs::write(
project.join("handlers.py"),
r#"
def handle_login(request):
"""Handle user login request."""
username = request.data["username"]
password = request.data["password"]
token = create_token(username)
return {"token": token}
def handle_logout(request):
"""Handle user logout request."""
invalidate_token(request.token)
return {"status": "logged_out"}
def handle_refresh(request):
"""Handle token refresh request."""
old_token = request.headers.get("Authorization")
new_token = refresh_token(old_token)
return {"token": new_token}
"#,
)
.unwrap();
fs::write(
project.join("models.py"),
r#"
class User:
"""User model."""
def __init__(self, name, email):
self.name = name
self.email = email
def to_dict(self):
return {"name": self.name, "email": self.email}
class Token:
"""Token model for authentication."""
def __init__(self, value, expires_at):
self.value = value
self.expires_at = expires_at
def is_expired(self):
import time
return self.expires_at < time.time()
"#,
)
.unwrap();
(dir, project)
}
fn bm25_opts(top_k: usize) -> EnrichedSearchOptions {
EnrichedSearchOptions {
top_k,
include_callgraph: false,
search_mode: SearchMode::Bm25,
}
}
fn regex_opts(pattern: &str, top_k: usize) -> EnrichedSearchOptions {
EnrichedSearchOptions {
top_k,
include_callgraph: false,
search_mode: SearchMode::Regex(pattern.to_string()),
}
}
fn build_structure_lookup(root: &std::path::Path) -> StructureLookup {
let dir = tempfile::TempDir::new().unwrap();
let cache_path = dir.path().join("structure_cache.json");
let structure = get_code_structure(root, Language::Python, 0, None).unwrap();
write_structure_cache(&structure, &cache_path).unwrap();
read_structure_cache(&cache_path).unwrap()
}
#[test]
fn test_enriched_search_returns_results() {
let (_dir, root) = create_test_project();
let report = enriched_search("token verify", &root, Language::Python, bm25_opts(10)).unwrap();
assert!(
!report.results.is_empty(),
"enriched_search should return results for 'token verify'"
);
assert_eq!(report.query, "token verify");
assert!(
report.total_files_searched > 0,
"Should have searched files"
);
assert!(
report.search_mode.starts_with("bm25"),
"Search mode should start with 'bm25', got '{}'",
report.search_mode
);
for result in &report.results {
assert!(!result.name.is_empty(), "Result name must not be empty");
assert!(!result.kind.is_empty(), "Result kind must not be empty");
assert!(
!result.file.as_os_str().is_empty(),
"Result file must not be empty"
);
assert!(
result.line_range.0 > 0,
"Line range start should be 1-indexed, got {}",
result.line_range.0
);
assert!(
result.line_range.1 >= result.line_range.0,
"Line range end ({}) should be >= start ({})",
result.line_range.1,
result.line_range.0
);
assert!(result.score > 0.0, "Score should be positive");
}
let has_function = report.results.iter().any(|r| r.kind == "function");
assert!(
has_function,
"Should find at least one function-level result for 'token verify'"
);
}
#[test]
fn test_all_variants_return_same_results_for_bm25() {
let (_dir, root) = create_test_project();
let query = "token decode";
let cold_report = enriched_search(query, &root, Language::Python, bm25_opts(10)).unwrap();
let index = Bm25Index::from_project(&root, Language::Python).unwrap();
let cached_report =
enriched_search_with_index(query, &root, Language::Python, bm25_opts(10), &index).unwrap();
assert_eq!(
cold_report.results.len(),
cached_report.results.len(),
"Cold and cached BM25 should return same result count"
);
let mut cold_names: Vec<String> = cold_report.results.iter().map(|r| r.name.clone()).collect();
let mut cached_names: Vec<String> = cached_report
.results
.iter()
.map(|r| r.name.clone())
.collect();
cold_names.sort();
cached_names.sort();
assert_eq!(
cold_names, cached_names,
"Cold and cached BM25 should return same result names"
);
for (cold, cached) in cold_report.results.iter().zip(cached_report.results.iter()) {
assert!(
(cold.score - cached.score).abs() < f64::EPSILON,
"Scores should match for '{}': cold={}, cached={}",
cold.name,
cold.score,
cached.score
);
}
assert_eq!(
cold_report.total_files_searched, cached_report.total_files_searched,
"total_files_searched should match between cold and cached"
);
}
#[test]
fn test_structure_cache_matches_live_parse() {
let (_dir, root) = create_test_project();
let query = "parse json";
let lookup = build_structure_lookup(&root);
let live_report = enriched_search(query, &root, Language::Python, bm25_opts(10)).unwrap();
let cached_report = enriched_search_with_structure_cache(
query,
&root,
Language::Python,
bm25_opts(10),
&lookup,
)
.unwrap();
let mut live_names: Vec<String> = live_report.results.iter().map(|r| r.name.clone()).collect();
let mut cached_names: Vec<String> = cached_report
.results
.iter()
.map(|r| r.name.clone())
.collect();
live_names.sort();
cached_names.sort();
assert_eq!(
live_names, cached_names,
"Live tree-sitter and cached structure should return same result names.\n Live: {:?}\n Cached: {:?}",
live_names, cached_names
);
let mut live_kinds: Vec<(String, String)> = live_report
.results
.iter()
.map(|r| (r.name.clone(), r.kind.clone()))
.collect();
let mut cached_kinds: Vec<(String, String)> = cached_report
.results
.iter()
.map(|r| (r.name.clone(), r.kind.clone()))
.collect();
live_kinds.sort();
cached_kinds.sort();
assert_eq!(
live_kinds, cached_kinds,
"Result kinds should match between live and cached"
);
assert!(
live_report.search_mode.contains("structure"),
"Live report search_mode should contain 'structure', got '{}'",
live_report.search_mode
);
assert!(
cached_report.search_mode.contains("cached-structure"),
"Cached report search_mode should contain 'cached-structure', got '{}'",
cached_report.search_mode
);
}
#[test]
fn test_regex_mode_works_through_all_variants() {
let (_dir, root) = create_test_project();
let pattern = r"handle_\w+";
let report_base = enriched_search(
"", &root,
Language::Python,
regex_opts(pattern, 20),
)
.unwrap();
assert!(
!report_base.results.is_empty(),
"Regex 'handle_\\w+' should find results via enriched_search"
);
assert!(
report_base.search_mode.contains("regex"),
"Search mode should indicate regex, got '{}'",
report_base.search_mode
);
let base_names: HashSet<String> = report_base.results.iter().map(|r| r.name.clone()).collect();
let index = Bm25Index::from_project(&root, Language::Python).unwrap();
let report_index =
enriched_search_with_index("", &root, Language::Python, regex_opts(pattern, 20), &index)
.unwrap();
let index_names: HashSet<String> = report_index
.results
.iter()
.map(|r| r.name.clone())
.collect();
assert_eq!(
base_names, index_names,
"Regex results should be identical between enriched_search and enriched_search_with_index.\n Base: {:?}\n Index: {:?}",
base_names, index_names
);
let lookup = build_structure_lookup(&root);
let report_cached = enriched_search_with_structure_cache(
"",
&root,
Language::Python,
regex_opts(pattern, 20),
&lookup,
)
.unwrap();
let cached_names: HashSet<String> = report_cached
.results
.iter()
.map(|r| r.name.clone())
.collect();
assert_eq!(
base_names, cached_names,
"Regex results should be identical between enriched_search and enriched_search_with_structure_cache.\n Base: {:?}\n Cached: {:?}",
base_names, cached_names
);
assert!(
report_index.search_mode.contains("regex"),
"Index variant regex search_mode should contain 'regex', got '{}'",
report_index.search_mode
);
assert!(
report_cached.search_mode.contains("regex"),
"Cached variant regex search_mode should contain 'regex', got '{}'",
report_cached.search_mode
);
}
#[test]
fn test_top_k_respected() {
let (_dir, root) = create_test_project();
let query = "def"; let top_k = 3;
let report_base = enriched_search(query, &root, Language::Python, bm25_opts(top_k)).unwrap();
assert!(
report_base.results.len() <= top_k,
"enriched_search should return at most {} results, got {}",
top_k,
report_base.results.len()
);
let index = Bm25Index::from_project(&root, Language::Python).unwrap();
let report_index =
enriched_search_with_index(query, &root, Language::Python, bm25_opts(top_k), &index)
.unwrap();
assert!(
report_index.results.len() <= top_k,
"enriched_search_with_index should return at most {} results, got {}",
top_k,
report_index.results.len()
);
let lookup = build_structure_lookup(&root);
let report_cached = enriched_search_with_structure_cache(
query,
&root,
Language::Python,
bm25_opts(top_k),
&lookup,
)
.unwrap();
assert!(
report_cached.results.len() <= top_k,
"enriched_search_with_structure_cache should return at most {} results, got {}",
top_k,
report_cached.results.len()
);
let report_regex =
enriched_search("", &root, Language::Python, regex_opts("def \\w+", top_k)).unwrap();
assert!(
report_regex.results.len() <= top_k,
"enriched_search (regex) should return at most {} results, got {}",
top_k,
report_regex.results.len()
);
}
#[test]
fn test_module_penalty_applied() {
let (_dir, root) = create_test_project();
let report = enriched_search("token secret", &root, Language::Python, bm25_opts(20)).unwrap();
let functions: Vec<_> = report
.results
.iter()
.filter(|r| r.kind != "module")
.collect();
let modules: Vec<_> = report
.results
.iter()
.filter(|r| r.kind == "module")
.collect();
if !functions.is_empty() && !modules.is_empty() {
let min_function_score = functions
.iter()
.map(|r| r.score)
.fold(f64::INFINITY, f64::min);
let max_module_score = modules
.iter()
.map(|r| r.score)
.fold(f64::NEG_INFINITY, f64::max);
assert!(
min_function_score > max_module_score,
"Function results should rank above module results after penalty.\n\
Min function score: {:.4}, Max module score: {:.4}\n\
Functions: {:?}\n\
Modules: {:?}",
min_function_score,
max_module_score,
functions
.iter()
.map(|r| (&r.name, r.score))
.collect::<Vec<_>>(),
modules
.iter()
.map(|r| (&r.name, r.score))
.collect::<Vec<_>>(),
);
}
}
#[test]
fn test_search_with_inner_matches_enriched_search() {
let (_dir, root) = create_test_project();
let query = "token verify";
let existing_report = enriched_search(query, &root, Language::Python, bm25_opts(10)).unwrap();
let inner_report = search_with_inner(
query,
&root,
Language::Python,
bm25_opts(10),
None, None, None, )
.unwrap();
let mut existing_names: Vec<String> = existing_report
.results
.iter()
.map(|r| r.name.clone())
.collect();
let mut inner_names: Vec<String> = inner_report
.results
.iter()
.map(|r| r.name.clone())
.collect();
existing_names.sort();
inner_names.sort();
assert_eq!(
existing_names, inner_names,
"search_with_inner (no caches) should produce same results as enriched_search"
);
for (existing, inner) in existing_report
.results
.iter()
.zip(inner_report.results.iter())
{
assert!(
(existing.score - inner.score).abs() < f64::EPSILON,
"Scores should be identical for '{}': existing={}, inner={}",
existing.name,
existing.score,
inner.score
);
}
assert_eq!(existing_report.query, inner_report.query);
assert_eq!(
existing_report.total_files_searched,
inner_report.total_files_searched
);
assert_eq!(existing_report.search_mode, inner_report.search_mode);
}
#[test]
fn test_search_with_inner_cached_bm25_matches_with_index() {
let (_dir, root) = create_test_project();
let query = "parse json";
let index = Bm25Index::from_project(&root, Language::Python).unwrap();
let existing_report =
enriched_search_with_index(query, &root, Language::Python, bm25_opts(10), &index).unwrap();
let inner_report = search_with_inner(
query,
&root,
Language::Python,
bm25_opts(10),
Some(&index), None, None, )
.unwrap();
let mut existing_names: Vec<String> = existing_report
.results
.iter()
.map(|r| r.name.clone())
.collect();
let mut inner_names: Vec<String> = inner_report
.results
.iter()
.map(|r| r.name.clone())
.collect();
existing_names.sort();
inner_names.sort();
assert_eq!(
existing_names, inner_names,
"search_with_inner(bm25_index=Some) should match enriched_search_with_index"
);
}
#[test]
fn test_search_with_inner_structure_cache_matches_with_structure_cache() {
let (_dir, root) = create_test_project();
let query = "handle login";
let lookup = build_structure_lookup(&root);
let existing_report = enriched_search_with_structure_cache(
query,
&root,
Language::Python,
bm25_opts(10),
&lookup,
)
.unwrap();
let inner_report = search_with_inner(
query,
&root,
Language::Python,
bm25_opts(10),
None, Some(&lookup), None, )
.unwrap();
let mut existing_names: Vec<String> = existing_report
.results
.iter()
.map(|r| r.name.clone())
.collect();
let mut inner_names: Vec<String> = inner_report
.results
.iter()
.map(|r| r.name.clone())
.collect();
existing_names.sort();
inner_names.sort();
assert_eq!(
existing_names, inner_names,
"search_with_inner(structure_cache=Some) should match enriched_search_with_structure_cache"
);
assert!(
inner_report.search_mode.contains("cached-structure"),
"Inner with structure cache should indicate 'cached-structure' in search_mode, got '{}'",
inner_report.search_mode
);
}