use crate::parser::Document;
use std::collections::HashMap;
use std::ops::RangeInclusive;
use std::sync::Arc;
use super::error::QueryError;
use super::eval::EvalContext;
use super::value::Value;
pub type FunctionFn =
Arc<dyn Fn(&[Value], &EvalContext) -> Result<Vec<Value>, QueryError> + Send + Sync>;
pub type ExtractorFn =
Arc<dyn Fn(&Document, &EvalContext) -> Result<Vec<Value>, QueryError> + Send + Sync>;
#[derive(Clone)]
pub struct Function {
pub func: FunctionFn,
pub arity: RangeInclusive<usize>,
pub description: String,
pub takes_input: bool,
}
impl Function {
pub fn new<F>(func: F, arity: RangeInclusive<usize>) -> Self
where
F: Fn(&[Value], &EvalContext) -> Result<Vec<Value>, QueryError> + Send + Sync + 'static,
{
Self {
func: Arc::new(func),
arity,
description: String::new(),
takes_input: true,
}
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = desc.into();
self
}
pub fn with_takes_input(mut self, takes: bool) -> Self {
self.takes_input = takes;
self
}
pub fn accepts_arity(&self, count: usize) -> bool {
self.arity.contains(&count)
}
pub fn call(&self, args: &[Value], ctx: &EvalContext) -> Result<Vec<Value>, QueryError> {
(self.func)(args, ctx)
}
}
impl std::fmt::Debug for Function {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Function")
.field("arity", &self.arity)
.field("description", &self.description)
.field("takes_input", &self.takes_input)
.finish()
}
}
#[derive(Default)]
pub struct Registry {
functions: HashMap<String, Function>,
extractors: HashMap<String, ExtractorFn>,
aliases: HashMap<String, String>,
}
impl Registry {
pub fn new() -> Self {
Self::default()
}
pub fn with_builtins() -> Self {
let mut registry = Self::new();
super::builtins::register_all(&mut registry);
registry
}
pub fn register_function(&mut self, name: impl Into<String>, func: Function) {
self.functions.insert(name.into(), func);
}
pub fn register_alias(&mut self, alias: impl Into<String>, target: impl Into<String>) {
self.aliases.insert(alias.into(), target.into());
}
pub fn get_function(&self, name: &str) -> Option<&Function> {
if let Some(func) = self.functions.get(name) {
return Some(func);
}
if let Some(target) = self.aliases.get(name) {
return self.functions.get(target);
}
None
}
pub fn has_function(&self, name: &str) -> bool {
self.functions.contains_key(name) || self.aliases.contains_key(name)
}
pub fn function_names(&self) -> Vec<&str> {
self.functions.keys().map(|s| s.as_str()).collect()
}
pub fn suggest_function(&self, name: &str) -> Vec<&str> {
let name_lower = name.to_lowercase();
let mut suggestions: Vec<_> = self
.functions
.keys()
.filter(|n| {
let n_lower = n.to_lowercase();
n_lower.starts_with(&name_lower)
|| name_lower.starts_with(&n_lower)
|| n_lower.contains(&name_lower)
|| name_lower.contains(&n_lower)
|| levenshtein(&n_lower, &name_lower) <= 2
})
.map(|s| s.as_str())
.collect();
suggestions.sort_by_key(|s| levenshtein(&s.to_lowercase(), &name_lower));
suggestions.truncate(3);
suggestions
}
pub fn register_extractor(&mut self, name: impl Into<String>, extractor: ExtractorFn) {
self.extractors.insert(name.into(), extractor);
}
pub fn get_extractor(&self, name: &str) -> Option<&ExtractorFn> {
self.extractors.get(name)
}
}
impl std::fmt::Debug for Registry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Registry")
.field("functions", &self.functions.keys().collect::<Vec<_>>())
.field("extractors", &self.extractors.keys().collect::<Vec<_>>())
.field("aliases", &self.aliases)
.finish()
}
}
pub trait FunctionRegistry {
fn register(registry: &mut Registry);
}
fn levenshtein(a: &str, b: &str) -> usize {
let a_chars: Vec<char> = a.chars().collect();
let b_chars: Vec<char> = b.chars().collect();
let a_len = a_chars.len();
let b_len = b_chars.len();
if a_len == 0 {
return b_len;
}
if b_len == 0 {
return a_len;
}
let mut matrix = vec![vec![0usize; b_len + 1]; a_len + 1];
for (i, row) in matrix.iter_mut().enumerate().take(a_len + 1) {
row[0] = i;
}
for (j, item) in matrix[0].iter_mut().enumerate().take(b_len + 1) {
*item = j;
}
for i in 1..=a_len {
for j in 1..=b_len {
let cost = if a_chars[i - 1] == b_chars[j - 1] {
0
} else {
1
};
matrix[i][j] = (matrix[i - 1][j] + 1)
.min(matrix[i][j - 1] + 1)
.min(matrix[i - 1][j - 1] + cost);
}
}
matrix[a_len][b_len]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_levenshtein() {
assert_eq!(levenshtein("", ""), 0);
assert_eq!(levenshtein("abc", "abc"), 0);
assert_eq!(levenshtein("abc", "ab"), 1);
assert_eq!(levenshtein("abc", "abd"), 1);
assert_eq!(levenshtein("abc", "xyz"), 3);
assert_eq!(levenshtein("count", "conut"), 2); }
#[test]
fn test_registry_functions() {
let mut registry = Registry::new();
let test_fn = Function::new(|_args, _ctx| Ok(vec![Value::String("test".into())]), 0..=0);
registry.register_function("test", test_fn);
assert!(registry.has_function("test"));
assert!(!registry.has_function("nonexistent"));
}
#[test]
fn test_registry_aliases() {
let mut registry = Registry::new();
let count_fn = Function::new(|_args, _ctx| Ok(vec![Value::Number(0.0)]), 0..=0);
registry.register_function("count", count_fn);
registry.register_alias("length", "count");
assert!(registry.has_function("count"));
assert!(registry.has_function("length"));
assert!(registry.get_function("length").is_some());
}
#[test]
fn test_suggest_function() {
let mut registry = Registry::new();
registry.register_function("contains", Function::new(|_, _| Ok(vec![]), 1..=1));
registry.register_function("count", Function::new(|_, _| Ok(vec![]), 0..=0));
registry.register_function("startswith", Function::new(|_, _| Ok(vec![]), 1..=1));
let suggestions = registry.suggest_function("cont");
assert!(suggestions.contains(&"contains"));
assert!(suggestions.contains(&"count"));
}
}