embed_anything 0.4.17

Embed anything at lightning speed
Documentation
use candle_core::Tensor;
use ndarray::prelude::*;
use ndarray::{Array2, Array3};

#[derive(Debug, Clone, Default)]
pub enum Pooling {
    #[default]
    Mean,
    Cls,
}

#[derive(Debug, Clone)]
pub enum PooledOutput {
    Tensor(Tensor),
    Array(Array2<f32>),
}

impl PooledOutput {
    pub fn to_tensor(self) -> Result<Tensor, anyhow::Error> {
        Ok(match self {
            PooledOutput::Tensor(tensor) => tensor,
            PooledOutput::Array(_) => panic!("Not implemented"),
        })
    }

    pub fn to_array(self) -> Result<Array2<f32>, anyhow::Error> {
        Ok(match self {
            PooledOutput::Tensor(_) => panic!("Not implemented"),
            PooledOutput::Array(array) => array,
        })
    }
}

pub enum ModelOutput {
    Tensor(Tensor),
    Array(Array3<f32>),
}

impl Pooling {
    pub fn pool(&self, output: &ModelOutput) -> Result<PooledOutput, anyhow::Error> {
        match self {
            Pooling::Cls => Self::cls(output),
            Pooling::Mean => Self::mean(output),
        }
    }

    fn cls(output: &ModelOutput) -> Result<PooledOutput, anyhow::Error> {
        match output {
            ModelOutput::Tensor(tensor) => tensor
                .get_on_dim(1, 0)
                .map(PooledOutput::Tensor)
                .map_err(|_| anyhow::anyhow!("Cls of empty tensor")),
            ModelOutput::Array(array) => {
                Ok(PooledOutput::Array(array.slice(s![.., 0, ..]).to_owned()))
            }
        }
    }

    fn mean(output: &ModelOutput) -> Result<PooledOutput, anyhow::Error> {
        match output {
            ModelOutput::Tensor(tensor) => tensor
                .mean(1)
                .map(PooledOutput::Tensor)
                .map_err(|_| anyhow::anyhow!("Mean of empty tensor")),
            ModelOutput::Array(array) => array
                .mean_axis(Axis(1))
                .map(PooledOutput::Array)
                .ok_or_else(|| anyhow::anyhow!("Mean of empty array")),
        }
    }
}