mod filters;
mod languages;
mod patterns;
mod symbols;
use crate::types::ExtractedEntities;
#[must_use]
pub fn extract_entities(input: &str) -> ExtractedEntities {
let extractor = EntityExtractor::new();
let quoted_spans: Vec<String> = extract_quoted_from_input(input);
extractor.extract(input, "ed_spans)
}
fn extract_quoted_from_input(input: &str) -> Vec<String> {
let mut spans = Vec::new();
let chars = input.chars().peekable();
let mut in_quote = false;
let mut current_span = String::new();
for c in chars {
if !in_quote && c == '"' {
in_quote = true;
current_span.clear();
} else if in_quote && c == '"' {
in_quote = false;
if !current_span.is_empty() {
spans.push(current_span.clone());
}
} else if in_quote {
current_span.push(c);
}
}
spans
}
pub struct EntityExtractor {
default_limit: u32,
max_depth: u32,
}
impl Default for EntityExtractor {
fn default() -> Self {
Self::new()
}
}
impl EntityExtractor {
#[must_use]
pub fn new() -> Self {
Self {
default_limit: 100,
max_depth: 20,
}
}
#[must_use]
pub fn with_defaults(default_limit: u32, max_depth: u32) -> Self {
Self {
default_limit,
max_depth,
}
}
#[must_use]
pub fn extract(&self, input: &str, quoted_spans: &[String]) -> ExtractedEntities {
let mut entities = ExtractedEntities::new();
entities.symbols = symbols::extract_symbols(input, quoted_spans);
entities.languages = languages::extract_languages(input);
entities.paths = patterns::extract_paths(input);
entities.kind = filters::extract_kind(input);
entities.limit = filters::extract_limit(input);
entities.depth = filters::extract_depth(input).map(|d| d.min(self.max_depth));
entities.format = filters::extract_format(input);
let (from, to) = patterns::extract_trace_path(input, quoted_spans);
entities.from_symbol = from;
entities.to_symbol = to;
entities.relation = patterns::extract_relation(input);
entities.predicate_type = filters::extract_predicate_type(input);
entities.impl_trait = filters::extract_impl_trait(input);
entities.predicate_arg = filters::extract_predicate_arg(input);
entities.visibility = filters::extract_visibility(input);
entities.is_async = filters::extract_async(input);
entities.is_unsafe = filters::extract_unsafe(input);
entities
}
#[must_use]
pub const fn default_limit(&self) -> u32 {
self.default_limit
}
#[must_use]
pub const fn max_depth(&self) -> u32 {
self.max_depth
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::SymbolKind;
#[test]
fn test_extract_basic() {
let extractor = EntityExtractor::new();
let entities = extractor.extract("find authentication", &[]);
assert!(entities.symbols.contains(&"authentication".to_string()));
}
#[test]
fn test_extract_with_quoted() {
let extractor = EntityExtractor::new();
let entities =
extractor.extract("find \"UserAuth::login\"", &["UserAuth::login".to_string()]);
assert!(entities.symbols.contains(&"UserAuth::login".to_string()));
}
#[test]
fn test_extract_with_language() {
let extractor = EntityExtractor::new();
let entities = extractor.extract("find foo in rust", &[]);
assert!(entities.languages.contains(&"rust".to_string()));
}
#[test]
fn test_extract_with_kind() {
let extractor = EntityExtractor::new();
let entities = extractor.extract("find all functions named foo", &[]);
assert_eq!(entities.kind, Some(SymbolKind::Function));
}
#[test]
fn test_extract_with_limit() {
let extractor = EntityExtractor::new();
let entities = extractor.extract("find first 5 functions", &[]);
assert_eq!(entities.limit, Some(5));
}
#[test]
fn test_verb_for_patterns() {
let extractor = EntityExtractor::new();
let entities = extractor.extract("grep for TODO comments", &[]);
assert!(entities.symbols.contains(&"TODO".to_string()));
let entities = extractor.extract("look for FIXME markers", &[]);
assert!(entities.symbols.contains(&"FIXME".to_string()));
}
#[test]
fn test_of_patterns() {
let extractor = EntityExtractor::new();
let entities = extractor.extract("find all imports of serde", &[]);
assert!(entities.symbols.contains(&"serde".to_string()));
let entities = extractor.extract("find usages of parse_json", &[]);
assert!(entities.symbols.contains(&"parse_json".to_string()));
}
#[test]
fn test_verb_symbol_patterns() {
let extractor = EntityExtractor::new();
let entities = extractor.extract("who uses the validate method", &[]);
assert!(entities.symbols.contains(&"validate".to_string()));
let entities = extractor.extract("what invokes send_email", &[]);
assert!(entities.symbols.contains(&"send_email".to_string()));
let entities = extractor.extract("what does main call", &[]);
assert!(entities.symbols.contains(&"main".to_string()));
}
#[test]
fn test_grep_direct_pattern() {
let extractor = EntityExtractor::new();
let entities = extractor.extract("grep unsafe blocks", &[]);
assert_eq!(
entities.is_unsafe,
Some(true),
"is_unsafe should be true for 'grep unsafe blocks'"
);
}
#[test]
fn test_kind_only_extraction() {
let extractor = EntityExtractor::new();
let entities = extractor.extract("list all traits", &[]);
assert_eq!(entities.kind, Some(SymbolKind::Trait));
}
}
#[cfg(test)]
mod predicate_tests {
use super::*;
use crate::types::SymbolKind;
#[test]
fn test_async_functions_extraction() {
let entities = extract_entities("find async functions");
assert_eq!(entities.is_async, Some(true));
assert_eq!(entities.kind, Some(SymbolKind::Function));
assert!(entities.has_predicate());
}
#[test]
fn test_unsafe_functions_extraction() {
let entities = extract_entities("find unsafe functions");
assert_eq!(entities.is_unsafe, Some(true));
assert_eq!(entities.kind, Some(SymbolKind::Function));
assert!(entities.has_predicate());
}
#[test]
fn test_public_async_functions_extraction() {
let entities = extract_entities("find public async functions");
assert_eq!(entities.visibility, Some(crate::types::Visibility::Public));
assert_eq!(entities.is_async, Some(true));
assert_eq!(entities.kind, Some(SymbolKind::Function));
assert!(entities.has_predicate());
}
}