use std::{collections::HashMap, sync::Arc};
use lance_core::{Error, Result};
#[cfg(feature = "geo")]
use crate::scalar::rtree::RTreeIndexPlugin;
use crate::{
pb, pbold,
scalar::{
bitmap::BitmapIndexPlugin, bloomfilter::BloomFilterIndexPlugin, btree::BTreeIndexPlugin,
inverted::InvertedIndexPlugin, json::JsonIndexPlugin, label_list::LabelListIndexPlugin,
ngram::NGramIndexPlugin, registry::ScalarIndexPlugin, zonemap::ZoneMapIndexPlugin,
},
};
pub struct IndexPluginRegistry {
plugins: HashMap<String, Box<dyn ScalarIndexPlugin>>,
}
impl IndexPluginRegistry {
fn get_plugin_name_from_details_name(&self, details_name: &str) -> String {
let details_name = details_name.to_lowercase();
if details_name.ends_with("indexdetails") {
details_name.replace("indexdetails", "")
} else {
details_name
}
}
pub fn add_plugin<
DetailsType: prost::Message + prost::Name,
PluginType: ScalarIndexPlugin + std::default::Default + 'static,
>(
&mut self,
) {
let plugin_name = self.get_plugin_name_from_details_name(DetailsType::NAME);
self.plugins
.insert(plugin_name, Box::new(PluginType::default()));
}
pub fn with_default_plugins() -> Arc<Self> {
let mut registry = Self {
plugins: HashMap::new(),
};
registry.add_plugin::<pbold::BTreeIndexDetails, BTreeIndexPlugin>();
registry.add_plugin::<pbold::BitmapIndexDetails, BitmapIndexPlugin>();
registry.add_plugin::<pbold::LabelListIndexDetails, LabelListIndexPlugin>();
registry.add_plugin::<pbold::NGramIndexDetails, NGramIndexPlugin>();
registry.add_plugin::<pbold::ZoneMapIndexDetails, ZoneMapIndexPlugin>();
registry.add_plugin::<pb::BloomFilterIndexDetails, BloomFilterIndexPlugin>();
registry.add_plugin::<pbold::InvertedIndexDetails, InvertedIndexPlugin>();
registry.add_plugin::<pb::JsonIndexDetails, JsonIndexPlugin>();
#[cfg(feature = "geo")]
registry.add_plugin::<pb::RTreeIndexDetails, RTreeIndexPlugin>();
let registry = Arc::new(registry);
for plugin in registry.plugins.values() {
plugin.attach_registry(registry.clone());
}
registry
}
pub fn get_plugin_by_name(&self, name: &str) -> Result<&dyn ScalarIndexPlugin> {
self.plugins
.get(name)
.map(|plugin| plugin.as_ref())
.ok_or_else(|| {
let hint = if name == "rtree" {
". The 'rtree' index requires the `geo` feature. \
Rebuild with `--features geo` to enable geospatial support"
} else {
""
};
Error::invalid_input_source(
format!("No scalar index plugin found for name '{name}'{hint}").into(),
)
})
}
pub fn get_plugin_by_details(
&self,
details: &prost_types::Any,
) -> Result<&dyn ScalarIndexPlugin> {
let details_name = details.type_url.split('.').next_back().unwrap();
let plugin_name = self.get_plugin_name_from_details_name(details_name);
self.get_plugin_by_name(&plugin_name)
}
}