evoc 0.0.1

Embedding Vector Oriented Clustering — fast clustering of high-dimensional embedding vectors (Rust port of EVōC)
//! RLX-backed compute backends that delegate to the strict reference until GPU kernels land.

use super::{ComputeBackend, RlxBackend};
use crate::knn::KnnError;
use crate::numpy_rng::NumpyRandomState;
use crate::{EmbeddingData, KnnGraphOptions};
use ndarray::Array2;
use sprs::CsMat;

/// Runs the strict EVōC pipeline; tagged with an RLX [`ComputeBackend`] for routing/env.
pub struct DelegateBackend {
    pub kind: ComputeBackend,
}

impl RlxBackend for DelegateBackend {
    fn kind(&self) -> ComputeBackend {
        self.kind
    }

    fn knn_graph(
        &self,
        data: EmbeddingData,
        opts: KnnGraphOptions,
        rng: &mut NumpyRandomState,
        strict_precision: bool,
    ) -> Result<(Array2<i32>, Array2<f32>), KnnError> {
        super::strict::StrictBackend.knn_graph(data, opts, rng, strict_precision)
    }

    fn label_propagation_init(
        &self,
        graph: &CsMat<f32>,
        n_label_prop_iter: usize,
        n_embedding_epochs: usize,
        approx_n_parts: usize,
        n_components: usize,
        scaling: f32,
        random_scale: f32,
        noise_level: f32,
        rng: &mut NumpyRandomState,
        data: Option<&Array2<f32>>,
        strict_precision: bool,
    ) -> Array2<f32> {
        super::strict::StrictBackend.label_propagation_init(
            graph,
            n_label_prop_iter,
            n_embedding_epochs,
            approx_n_parts,
            n_components,
            scaling,
            random_scale,
            noise_level,
            rng,
            data,
            strict_precision,
        )
    }

    fn node_embedding(
        &self,
        graph: &CsMat<f32>,
        n_components: usize,
        n_epochs: usize,
        initial_embedding: Option<Array2<f32>>,
        initial_alpha: f32,
        negative_sample_rate: f32,
        noise_level: f32,
        rng: &mut NumpyRandomState,
        reproducible_flag: bool,
        strict_precision: bool,
    ) -> Array2<f32> {
        super::strict::StrictBackend.node_embedding(
            graph,
            n_components,
            n_epochs,
            initial_embedding,
            initial_alpha,
            negative_sample_rate,
            noise_level,
            rng,
            reproducible_flag,
            strict_precision,
        )
    }
}