use std::sync::Arc;
use tantivy::collector::TopDocs;
use tantivy::query::{BooleanQuery, Occur, Query, RegexQuery, TermQuery};
use tantivy::schema::{
Field, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, Value, STORED, STRING,
};
use tantivy::{Index, ReloadPolicy, TantivyDocument, Term};
use slotmap::Key;
use crate::context::ImHashMap;
use crate::symbol::{FileId, FileRegistry, SymbolId, SymbolRegistry};
use ryo_source::pure::PureFile;
use super::{LiteralCollector, LiteralInfo, LiteralKind};
#[derive(Debug)]
pub enum LiteralSearchError {
Index(tantivy::TantivyError),
QueryParse(tantivy::query::QueryParserError),
InvalidPattern(String),
}
impl std::fmt::Display for LiteralSearchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Index(e) => write!(f, "Index error: {}", e),
Self::QueryParse(e) => write!(f, "Query parse error: {}", e),
Self::InvalidPattern(s) => write!(f, "Invalid pattern: {}", s),
}
}
}
impl std::error::Error for LiteralSearchError {}
impl From<tantivy::TantivyError> for LiteralSearchError {
fn from(e: tantivy::TantivyError) -> Self {
Self::Index(e)
}
}
impl From<tantivy::query::QueryParserError> for LiteralSearchError {
fn from(e: tantivy::query::QueryParserError) -> Self {
Self::QueryParse(e)
}
}
#[derive(Debug, Clone, Default)]
pub struct LiteralQuery {
pub pattern: String,
pub kind: Option<LiteralKind>,
pub limit: usize,
}
impl LiteralQuery {
pub fn new(pattern: impl Into<String>) -> Self {
Self {
pattern: pattern.into(),
kind: None,
limit: 100,
}
}
pub fn with_kind(mut self, kind: LiteralKind) -> Self {
self.kind = Some(kind);
self
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
}
#[derive(Debug, Clone)]
pub struct LiteralMatch {
pub value: String,
pub kind: LiteralKind,
pub symbol_id: SymbolId,
pub file_id: FileId,
pub file_path: String,
pub score: f32,
}
struct LiteralSchema {
#[allow(dead_code)]
schema: Schema,
value: Field,
kind: Field,
symbol_id: Field,
file_id: Field,
file_path: Field,
}
impl LiteralSchema {
fn new() -> Self {
let mut schema_builder = Schema::builder();
let text_options = TextOptions::default()
.set_indexing_options(
TextFieldIndexing::default()
.set_tokenizer("raw")
.set_index_option(IndexRecordOption::WithFreqsAndPositions),
)
.set_stored();
let value = schema_builder.add_text_field("value", text_options);
let kind = schema_builder.add_text_field("kind", STRING | STORED);
let symbol_id = schema_builder.add_text_field("symbol_id", STRING | STORED);
let file_id = schema_builder.add_text_field("file_id", STRING | STORED);
let file_path = schema_builder.add_text_field("file_path", STRING | STORED);
let schema = schema_builder.build();
Self {
schema,
value,
kind,
symbol_id,
file_id,
file_path,
}
}
}
pub struct LiteralIndex {
index: Index,
schema: LiteralSchema,
}
impl LiteralIndex {
pub fn new() -> Result<Self, LiteralSearchError> {
let schema = LiteralSchema::new();
let index = Index::create_in_ram(schema.schema.clone());
index
.tokenizers()
.register("raw", tantivy::tokenizer::RawTokenizer::default());
Ok(Self { index, schema })
}
pub fn build_from_files(
files: &ImHashMap<FileId, Arc<PureFile>>,
registry: &SymbolRegistry,
file_registry: &FileRegistry,
) -> Result<Self, LiteralSearchError> {
let index = Self::new()?;
let mut writer = index.index.writer(50_000_000)?;
let default_symbol = SymbolId::from(slotmap::KeyData::from_ffi(0));
for (&file_id, file) in files.iter() {
let file_path_str = file_registry
.path(file_id)
.map(|p| p.as_relative().display().to_string())
.unwrap_or_default();
LiteralCollector::collect_file(
file.as_ref(),
file_id,
default_symbol,
|name| registry.lookup_by_name(name),
|info| {
let doc = index.create_document(&info, &file_path_str);
let _ = writer.add_document(doc);
},
);
}
writer.commit()?;
Ok(index)
}
pub fn build_from_workspace_files(
files: &ImHashMap<ryo_symbol::WorkspaceFilePath, Arc<PureFile>>,
registry: &SymbolRegistry,
) -> Result<Self, LiteralSearchError> {
let index = Self::new()?;
let mut writer = index.index.writer(50_000_000)?;
let default_symbol = SymbolId::from(slotmap::KeyData::from_ffi(0));
let default_file_id = FileId::from(slotmap::KeyData::from_ffi(0));
for (path, file) in files.iter() {
let file_path_str = path.as_relative().display().to_string();
LiteralCollector::collect_file(
file.as_ref(),
default_file_id,
default_symbol,
|name| registry.lookup_by_name(name),
|info| {
let doc = index.create_document(&info, &file_path_str);
let _ = writer.add_document(doc);
},
);
}
writer.commit()?;
Ok(index)
}
fn create_document(&self, info: &LiteralInfo, file_path: &str) -> TantivyDocument {
let mut doc = TantivyDocument::new();
doc.add_text(self.schema.value, &info.value);
doc.add_text(self.schema.kind, info.kind.as_str());
doc.add_text(
self.schema.symbol_id,
info.symbol_id.data().as_ffi().to_string(),
);
doc.add_text(
self.schema.file_id,
info.file_id.data().as_ffi().to_string(),
);
doc.add_text(self.schema.file_path, file_path);
doc
}
pub fn search(&self, query: &LiteralQuery) -> Result<Vec<LiteralMatch>, LiteralSearchError> {
let reader = self
.index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()?;
let searcher = reader.searcher();
let tantivy_query = self.build_query(query)?;
let top_docs = searcher.search(
&tantivy_query,
&TopDocs::with_limit(query.limit).order_by_score(),
)?;
let mut results = Vec::new();
for (score, doc_address) in top_docs {
let doc: tantivy::TantivyDocument = searcher.doc(doc_address)?;
if let Some(m) = self.doc_to_match(&doc, score) {
results.push(m);
}
}
Ok(results)
}
fn build_query(&self, query: &LiteralQuery) -> Result<Box<dyn Query>, LiteralSearchError> {
let mut subqueries: Vec<(Occur, Box<dyn Query>)> = Vec::new();
if !query.pattern.is_empty() {
let pattern_query = self.build_pattern_query(&query.pattern)?;
subqueries.push((Occur::Must, pattern_query));
}
if let Some(kind) = &query.kind {
let kind_term = Term::from_field_text(self.schema.kind, kind.as_str());
let kind_query = TermQuery::new(kind_term, IndexRecordOption::Basic);
subqueries.push((Occur::Must, Box::new(kind_query)));
}
if subqueries.is_empty() {
Ok(Box::new(tantivy::query::AllQuery))
} else if subqueries.len() == 1 {
Ok(subqueries.pop().unwrap().1)
} else {
Ok(Box::new(BooleanQuery::new(subqueries)))
}
}
fn build_pattern_query(&self, pattern: &str) -> Result<Box<dyn Query>, LiteralSearchError> {
if pattern == "*" {
return Ok(Box::new(tantivy::query::AllQuery));
}
let regex_pattern = glob_to_regex_tantivy(pattern);
let regex_query = RegexQuery::from_pattern(®ex_pattern, self.schema.value)
.map_err(|e| LiteralSearchError::InvalidPattern(format!("{}: {}", pattern, e)))?;
Ok(Box::new(regex_query))
}
fn doc_to_match(&self, doc: &tantivy::TantivyDocument, score: f32) -> Option<LiteralMatch> {
let value = doc.get_first(self.schema.value)?.as_str()?.to_string();
let kind_str = doc.get_first(self.schema.kind)?.as_str()?;
let kind = LiteralKind::parse_kind(kind_str)?;
let symbol_id_str = doc.get_first(self.schema.symbol_id)?.as_str()?;
let file_id_str = doc.get_first(self.schema.file_id)?.as_str()?;
let file_path = doc
.get_first(self.schema.file_path)
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let symbol_id = parse_symbol_id(symbol_id_str)?;
let file_id = parse_file_id(file_id_str)?;
Some(LiteralMatch {
value,
kind,
symbol_id,
file_id,
file_path,
score,
})
}
pub fn stats(&self) -> LiteralIndexStats {
let reader = self
.index
.reader_builder()
.reload_policy(ReloadPolicy::Manual)
.try_into()
.ok();
let doc_count = reader
.as_ref()
.map(|r| r.searcher().num_docs() as usize)
.unwrap_or(0);
LiteralIndexStats { doc_count }
}
}
#[derive(Debug, Clone)]
pub struct LiteralIndexStats {
pub doc_count: usize,
}
fn glob_to_regex_tantivy(glob: &str) -> String {
let mut regex = String::with_capacity(glob.len() * 2);
let chars: Vec<char> = glob.chars().collect();
let len = chars.len();
for (i, &c) in chars.iter().enumerate() {
match c {
'*' => {
let at_start = i == 0;
let at_end = i == len - 1;
if at_start && at_end {
regex.push_str(".*");
} else if at_start {
regex.push_str(".*");
} else if at_end {
regex.push_str(".*");
} else {
regex.push_str(".*");
}
}
'?' => regex.push('.'),
'.' | '+' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '^' | '$' | '\\' => {
regex.push('\\');
regex.push(c);
}
_ => regex.push(c),
}
}
regex
}
fn parse_symbol_id(s: &str) -> Option<SymbolId> {
let ffi: u64 = s.parse().ok()?;
Some(SymbolId::from(slotmap::KeyData::from_ffi(ffi)))
}
fn parse_file_id(s: &str) -> Option<FileId> {
let ffi: u64 = s.parse().ok()?;
Some(FileId::from(slotmap::KeyData::from_ffi(ffi)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_glob_to_regex_tantivy() {
assert_eq!(glob_to_regex_tantivy("*error*"), ".*error.*");
assert_eq!(glob_to_regex_tantivy("hello"), "hello");
assert_eq!(glob_to_regex_tantivy("test?"), "test.");
assert_eq!(glob_to_regex_tantivy("a.b"), "a\\.b");
}
#[test]
fn test_create_empty_index() {
let index = LiteralIndex::new().expect("Failed to create index");
let stats = index.stats();
assert_eq!(stats.doc_count, 0);
}
#[test]
fn test_search_empty_index() {
let index = LiteralIndex::new().expect("Failed to create index");
let query = LiteralQuery::new("*");
let results = index.search(&query).expect("Search failed");
assert!(results.is_empty());
}
#[test]
fn test_index_and_search() {
let index = LiteralIndex::new().expect("Failed to create index");
let mut writer = index
.index
.writer(50_000_000)
.expect("Failed to create writer");
let literals = vec![
("\"hello world\"", LiteralKind::String),
("\"error: connection failed\"", LiteralKind::String),
("42", LiteralKind::Int),
("3.14", LiteralKind::Float),
("true", LiteralKind::Bool),
];
let dummy_symbol = SymbolId::from(slotmap::KeyData::from_ffi(1));
let dummy_file = FileId::from(slotmap::KeyData::from_ffi(1));
for (value, kind) in literals {
let info = LiteralInfo::with_kind(value.to_string(), kind, dummy_symbol, dummy_file);
let doc = index.create_document(&info, "test/file.rs");
writer.add_document(doc).expect("Failed to add document");
}
writer.commit().expect("Failed to commit");
let stats = index.stats();
assert_eq!(stats.doc_count, 5);
let query = LiteralQuery::new("*error*");
let results = index.search(&query).expect("Search failed");
assert_eq!(results.len(), 1);
assert!(results[0].value.contains("error"));
let query = LiteralQuery::new("*").with_kind(LiteralKind::Int);
let results = index.search(&query).expect("Search failed");
assert_eq!(results.len(), 1);
assert_eq!(results[0].value, "42");
}
#[test]
fn test_literal_query_builder() {
let query = LiteralQuery::new("*test*")
.with_kind(LiteralKind::String)
.with_limit(50);
assert_eq!(query.pattern, "*test*");
assert_eq!(query.kind, Some(LiteralKind::String));
assert_eq!(query.limit, 50);
}
}