dynamic_learned_index 0.1.0

Dynamic Learned Index for efficient search in unstructured data.
use std::path::PathBuf;

use crate::model::candle_model;
use crate::model::tch_model;
use crate::model::BaseModelBuilder;
use crate::model::CandleBackend;
use crate::model::ModelInterface as _;
use crate::model::TchBackend;
use crate::structs::FloatElement;
use crate::DliError;
use candle_core::Tensor;
#[cfg(feature = "measure_time")]
use log::debug;
use measure_time_macro::log_time;
use rand::Rng as _;

use crate::{model::MixBackend, ArraySlice, DliResult, ModelConfig};

impl<F: FloatElement> crate::model::BaseModelBuilder<MixBackend, F> {
    pub fn build(&self) -> DliResult<Model<F>> {
        let device = self.device.ok_or(DliError::MissingAttribute("device"))?;
        let input_nodes = self
            .input_nodes
            .ok_or(DliError::MissingAttribute("input_nodes"))?;
        let labels = self.labels.ok_or(DliError::MissingAttribute("labels"))?;
        let train_params = self.train_params.unwrap_or_default();
        let label_method = self
            .label_method
            .ok_or(DliError::MissingAttribute("label_method"))?;
        let mut tch_builder = BaseModelBuilder::<TchBackend, F>::default();
        tch_builder
            .device(device)
            .input_nodes(input_nodes)
            .labels(labels)
            .train_params(train_params)
            .label_method(label_method)
            .layers(self.layers.clone());
        let mut candle_builder = BaseModelBuilder::<CandleBackend, F>::default();
        candle_builder
            .device(device)
            .input_nodes(input_nodes)
            .labels(labels)
            .train_params(train_params)
            .label_method(label_method)
            .quantize(self.quantize)
            .layers(self.layers.clone());
        if let Some(weights_path) = self.weights_path.clone() {
            tch_builder.weights_path(weights_path.clone());
        }
        let mut model = Model {
            tch_model: tch_builder.build()?,
            candle_model: candle_builder.build()?,
            candle_builder,
            input_shape: input_nodes as usize,
        };
        model.sync_weights_from_tch_to_candle()?;
        Ok(model)
    }
}

pub struct Model<F: FloatElement> {
    tch_model: tch_model::Model<F>,
    candle_model: candle_model::Model<F>,
    candle_builder: BaseModelBuilder<CandleBackend, F>,
    pub input_shape: usize,
}

impl<F: FloatElement> crate::model::ModelInterface<F> for Model<F> {
    type TensorType = Tensor;

    fn predict(&self, xs: &Self::TensorType) -> DliResult<Vec<(usize, f32)>> {
        self.candle_model.predict(xs)
    }

    #[log_time]
    fn predict_many(&self, xs: &[F]) -> DliResult<Vec<Vec<(usize, f32)>>> {
        self.candle_model.predict_many(xs)
    }

    fn train(&mut self, xs: &ArraySlice) -> DliResult<()> {
        self.tch_model.train(xs)?;
        self.sync_weights_from_tch_to_candle()
    }

    fn retrain(&mut self, xs: &ArraySlice) -> DliResult<()> {
        self.tch_model.retrain(xs)?;
        self.sync_weights_from_tch_to_candle()
    }

    fn dump(&self, weights_filename: PathBuf) -> DliResult<ModelConfig> {
        self.tch_model.dump(weights_filename)
    }

    fn memory_usage(&self) -> usize {
        self.candle_model.memory_usage() + self.tch_model.memory_usage()
    }

    fn vec2tensor(&self, xs: &[f32]) -> DliResult<Self::TensorType> {
        self.candle_model.vec2tensor(xs)
    }
}

impl<F: FloatElement> Model<F> {
    pub fn sync_weights_from_tch_to_candle(&mut self) -> DliResult<()> {
        let random_id = rand::rng().random::<u64>();
        let weights_path = std::path::PathBuf::from("/tmp")
            .join(format!("mix_model_weights_{}.safetensors", random_id));
        self.tch_model.dump(weights_path.clone())?;
        self.candle_model = self.candle_builder.weights_path(weights_path).build()?;
        Ok(())
    }
}