use std::sync::Arc;
use arrow_schema::Field;
use async_trait::async_trait;
use datafusion::execution::SendableRecordBatchStream;
use lance_core::{Result, cache::LanceCache};
use crate::progress::IndexBuildProgress;
use crate::registry::IndexPluginRegistry;
use crate::{
frag_reuse::FragReuseIndex,
scalar::{CreatedIndex, IndexStore, ScalarIndex, expression::ScalarQueryParser},
};
pub const VALUE_COLUMN_NAME: &str = "value";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrainingOrdering {
Values,
Addresses,
None,
}
#[derive(Debug, Clone)]
pub struct TrainingCriteria {
pub ordering: TrainingOrdering,
pub needs_row_ids: bool,
pub needs_row_addrs: bool,
}
impl TrainingCriteria {
pub fn new(ordering: TrainingOrdering) -> Self {
Self {
ordering,
needs_row_ids: false,
needs_row_addrs: false,
}
}
pub fn with_row_id(mut self) -> Self {
self.needs_row_ids = true;
self
}
pub fn with_row_addr(mut self) -> Self {
self.needs_row_addrs = true;
self
}
}
pub trait TrainingRequest: std::any::Any + Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
fn criteria(&self) -> &TrainingCriteria;
}
pub(crate) struct DefaultTrainingRequest {
criteria: TrainingCriteria,
}
impl DefaultTrainingRequest {
pub fn new(criteria: TrainingCriteria) -> Self {
Self { criteria }
}
}
impl TrainingRequest for DefaultTrainingRequest {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn criteria(&self) -> &TrainingCriteria {
&self.criteria
}
}
#[async_trait]
pub trait ScalarIndexPlugin: Send + Sync + std::fmt::Debug {
fn new_training_request(&self, params: &str, field: &Field)
-> Result<Box<dyn TrainingRequest>>;
async fn train_index(
&self,
data: SendableRecordBatchStream,
index_store: &dyn IndexStore,
request: Box<dyn TrainingRequest>,
fragment_ids: Option<Vec<u32>>,
progress: Arc<dyn IndexBuildProgress>,
) -> Result<CreatedIndex>;
fn name(&self) -> &str;
fn provides_exact_answer(&self) -> bool;
fn version(&self) -> u32;
fn new_query_parser(
&self,
index_name: String,
index_details: &prost_types::Any,
) -> Option<Box<dyn ScalarQueryParser>>;
async fn load_index(
&self,
index_store: Arc<dyn IndexStore>,
index_details: &prost_types::Any,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
cache: &LanceCache,
) -> Result<Arc<dyn ScalarIndex>>;
async fn load_statistics(
&self,
_index_store: Arc<dyn IndexStore>,
_index_details: &prost_types::Any,
) -> Result<Option<serde_json::Value>> {
Ok(None)
}
fn attach_registry(&self, _registry: Arc<IndexPluginRegistry>) {}
fn details_as_json(&self, _details: &prost_types::Any) -> Result<serde_json::Value> {
Ok(serde_json::json!({}))
}
}