fastembed 5.13.2

Library for generating vector embeddings, reranking locally.
Documentation
use anyhow::{anyhow, Result};
use image::{imageops::FilterType, DynamicImage, GenericImageView};
use ndarray::{Array, Array3};
use std::ops::{Div, Sub};
#[cfg(feature = "hf-hub")]
use std::{fs::read_to_string, path::Path};

pub enum TransformData {
    Image(DynamicImage),
    NdArray(Array3<f32>),
}

impl TransformData {
    pub fn image(self) -> anyhow::Result<DynamicImage> {
        match self {
            TransformData::Image(img) => Ok(img),
            _ => Err(anyhow!("TransformData convert error")),
        }
    }

    pub fn array(self) -> anyhow::Result<Array3<f32>> {
        match self {
            TransformData::NdArray(array) => Ok(array),
            _ => Err(anyhow!("TransformData convert error")),
        }
    }
}

pub trait Transform: Send + Sync {
    fn transform(&self, images: TransformData) -> anyhow::Result<TransformData>;
}

struct ConvertToRGB;

impl Transform for ConvertToRGB {
    fn transform(&self, data: TransformData) -> anyhow::Result<TransformData> {
        let image = data.image()?;
        let image = image.into_rgb8().into();
        Ok(TransformData::Image(image))
    }
}

pub struct Resize {
    pub size: (u32, u32),
    pub resample: FilterType,
}

impl Transform for Resize {
    fn transform(&self, data: TransformData) -> anyhow::Result<TransformData> {
        let image = data.image()?;
        let image = image.resize_exact(self.size.0, self.size.1, self.resample);
        Ok(TransformData::Image(image))
    }
}

pub struct CenterCrop {
    pub size: (u32, u32),
}

impl Transform for CenterCrop {
    fn transform(&self, data: TransformData) -> anyhow::Result<TransformData> {
        let mut image = data.image()?;
        let (mut origin_width, mut origin_height) = image.dimensions();
        let (crop_width, crop_height) = self.size;
        if origin_width >= crop_width && origin_height >= crop_height {
            // cropped area is within image boundaries
            let x = (origin_width - crop_width) / 2;
            let y = (origin_height - crop_height) / 2;
            let image = image.crop_imm(x, y, crop_width, crop_height);
            Ok(TransformData::Image(image))
        } else {
            if origin_width > crop_width || origin_height > crop_height {
                let (new_width, new_height) =
                    (origin_width.min(crop_width), origin_height.min(crop_height));
                let (x, y) = if origin_width > crop_width {
                    ((origin_width - crop_width) / 2, 0)
                } else {
                    (0, (origin_height - crop_height) / 2)
                };
                image = image.crop_imm(x, y, new_width, new_height);
                (origin_width, origin_height) = image.dimensions();
            }
            let mut pixels_array =
                Array3::zeros((3usize, crop_width as usize, crop_height as usize));
            let offset_x = (crop_width - origin_width) / 2;
            let offset_y = (crop_height - origin_height) / 2;
            // whc -> chw
            for (x, y, pixel) in image.to_rgb8().enumerate_pixels() {
                pixels_array[[0, (y + offset_y) as usize, (x + offset_x) as usize]] =
                    pixel[0] as f32;
                pixels_array[[1, (y + offset_y) as usize, (x + offset_x) as usize]] =
                    pixel[1] as f32;
                pixels_array[[2, (y + offset_y) as usize, (x + offset_x) as usize]] =
                    pixel[2] as f32;
            }
            Ok(TransformData::NdArray(pixels_array))
        }
    }
}

struct PILToNDarray;

impl Transform for PILToNDarray {
    fn transform(&self, data: TransformData) -> anyhow::Result<TransformData> {
        match data {
            TransformData::Image(image) => {
                let image = image.to_rgb8();
                let (width, height) = image.dimensions();
                // whc -> chw
                let mut pixels_array = Array3::zeros((3usize, height as usize, width as usize));
                for (x, y, pixel) in image.enumerate_pixels() {
                    pixels_array[[0, y as usize, x as usize]] = pixel[0] as f32;
                    pixels_array[[1, y as usize, x as usize]] = pixel[1] as f32;
                    pixels_array[[2, y as usize, x as usize]] = pixel[2] as f32;
                }
                Ok(TransformData::NdArray(pixels_array))
            }
            ndarray => Ok(ndarray),
        }
    }
}

pub struct Rescale {
    pub scale: f32,
}

impl Transform for Rescale {
    fn transform(&self, data: TransformData) -> anyhow::Result<TransformData> {
        let array = data.array()?;
        let array = array * self.scale;
        Ok(TransformData::NdArray(array))
    }
}

pub struct Normalize {
    pub mean: Vec<f32>,
    pub std: Vec<f32>,
}

impl Transform for Normalize {
    fn transform(&self, data: TransformData) -> anyhow::Result<TransformData> {
        let array = data.array()?;
        let mean = Array::from_vec(self.mean.clone())
            .into_shape_with_order((3, 1, 1))
            .map_err(|e| anyhow!("Failed to reshape mean array: {}", e))?;
        let std = Array::from_vec(self.std.clone())
            .into_shape_with_order((3, 1, 1))
            .map_err(|e| anyhow!("Failed to reshape std array: {}", e))?;

        let shape = array.shape().to_vec();
        match shape.as_slice() {
            [c, h, w] => {
                let mean_broadcast = mean.broadcast((*c, *h, *w)).ok_or_else(|| {
                    anyhow!("Failed to broadcast mean array to shape {:?}", (*c, *h, *w))
                })?;
                let std_broadcast = std.broadcast((*c, *h, *w)).ok_or_else(|| {
                    anyhow!("Failed to broadcast std array to shape {:?}", (*c, *h, *w))
                })?;
                let array_normalized = array.sub(mean_broadcast).div(std_broadcast);
                Ok(TransformData::NdArray(array_normalized))
            }
            _ => Err(anyhow!(
                "Transformer convert error. Normalize operator got error shape."
            )),
        }
    }
}

pub struct Compose {
    transforms: Vec<Box<dyn Transform>>,
}

impl Compose {
    fn new(transforms: Vec<Box<dyn Transform>>) -> Self {
        Self { transforms }
    }

    #[cfg(feature = "hf-hub")]
    pub fn from_file<P: AsRef<Path>>(file: P) -> anyhow::Result<Self> {
        let content = read_to_string(file)?;
        let config = serde_json::from_str(&content)?;
        load_preprocessor(config)
    }

    pub fn from_bytes<P: AsRef<[u8]>>(bytes: P) -> anyhow::Result<Compose> {
        let config = serde_json::from_slice(bytes.as_ref())?;
        load_preprocessor(config)
    }
}

impl Transform for Compose {
    fn transform(&self, mut image: TransformData) -> anyhow::Result<TransformData> {
        for transform in &self.transforms {
            image = transform.transform(image)?;
        }
        Ok(image)
    }
}

fn load_preprocessor(config: serde_json::Value) -> anyhow::Result<Compose> {
    let mut transformers: Vec<Box<dyn Transform>> = vec![];
    transformers.push(Box::new(ConvertToRGB));

    let mode = config["image_processor_type"]
        .as_str()
        .unwrap_or("CLIPImageProcessor");
    match mode {
        "CLIPImageProcessor" => {
            if config["do_resize"].as_bool().unwrap_or(false) {
                let size = config["size"].clone();
                let shortest_edge = size["shortest_edge"].as_u64();
                let (height, width) = (size["height"].as_u64(), size["width"].as_u64());

                if let Some(shortest_edge) = shortest_edge {
                    let size = (shortest_edge as u32, shortest_edge as u32);
                    transformers.push(Box::new(Resize {
                        size,
                        resample: FilterType::CatmullRom,
                    }));
                } else if let (Some(height), Some(width)) = (height, width) {
                    let size = (height as u32, width as u32);
                    transformers.push(Box::new(Resize {
                        size,
                        resample: FilterType::CatmullRom,
                    }));
                } else {
                    return Err(anyhow!(
                        "Size must contain either 'shortest_edge' or 'height' and 'width'."
                    ));
                }
            }

            if config["do_center_crop"].as_bool().unwrap_or(false) {
                let crop_size = config["crop_size"].clone();
                let (height, width) = if crop_size.is_u64() {
                    let size = crop_size
                        .as_u64()
                        .ok_or_else(|| anyhow!("crop_size must be a valid u64"))?
                        as u32;
                    (size, size)
                } else if crop_size.is_object() {
                    (
                        crop_size["height"]
                            .as_u64()
                            .map(|height| height as u32)
                            .ok_or_else(|| anyhow!("crop_size height must be contained"))?,
                        crop_size["width"]
                            .as_u64()
                            .map(|width| width as u32)
                            .ok_or_else(|| anyhow!("crop_size width must be contained"))?,
                    )
                } else {
                    return Err(anyhow!("Invalid crop size: {:?}", crop_size));
                };
                transformers.push(Box::new(CenterCrop {
                    size: (width, height),
                }));
            }
        }
        "ConvNextFeatureExtractor" => {
            let shortest_edge = config["size"]["shortest_edge"].as_u64();
            if shortest_edge.is_none() {
                return Err(anyhow!("Size dictionary must contain 'shortest_edge' key."));
            }
            let shortest_edge = shortest_edge.unwrap() as u32;
            let crop_pct = config["crop_pct"].as_f64().unwrap_or(0.875);
            if shortest_edge < 384 {
                let resize_shortet_edge = shortest_edge as f64 / crop_pct;
                transformers.push(Box::new(Resize {
                    size: (resize_shortet_edge as u32, resize_shortet_edge as u32),
                    resample: FilterType::CatmullRom,
                }));
                transformers.push(Box::new(CenterCrop {
                    size: (shortest_edge, shortest_edge),
                }))
            } else {
                transformers.push(Box::new(Resize {
                    size: (shortest_edge, shortest_edge),
                    resample: FilterType::CatmullRom,
                }));
            }
        }
        "BitImageProcessor" => {
            if config["do_convert_rgb"].as_bool().unwrap_or(false) {
                transformers.push(Box::new(ConvertToRGB));
            }
            if config["do_resize"].as_bool().unwrap_or(false) {
                let size = config["size"].clone();
                let shortest_edge = size["shortest_edge"].as_u64();
                let (height, width) = (size["height"].as_u64(), size["width"].as_u64());

                if let Some(shortest_edge) = shortest_edge {
                    let size = (shortest_edge as u32, shortest_edge as u32);
                    transformers.push(Box::new(Resize {
                        size,
                        resample: FilterType::CatmullRom,
                    }));
                } else if let (Some(height), Some(width)) = (height, width) {
                    let size = (height as u32, width as u32);
                    transformers.push(Box::new(Resize {
                        size,
                        resample: FilterType::CatmullRom,
                    }));
                } else {
                    return Err(anyhow!(
                        "Size must contain either 'shortest_edge' or 'height' and 'width'."
                    ));
                }
            }

            if config["do_center_crop"].as_bool().unwrap_or(false) {
                let crop_size = config["crop_size"].clone();
                let (height, width) = if crop_size.is_u64() {
                    let size = crop_size
                        .as_u64()
                        .ok_or_else(|| anyhow!("crop_size must be a valid u64"))?
                        as u32;
                    (size, size)
                } else if crop_size.is_object() {
                    (
                        crop_size["height"]
                            .as_u64()
                            .map(|height| height as u32)
                            .ok_or_else(|| anyhow!("crop_size height must be contained"))?,
                        crop_size["width"]
                            .as_u64()
                            .map(|width| width as u32)
                            .ok_or_else(|| anyhow!("crop_size width must be contained"))?,
                    )
                } else {
                    return Err(anyhow!("Invalid crop size: {:?}", crop_size));
                };
                transformers.push(Box::new(CenterCrop {
                    size: (width, height),
                }));
            }
        }
        mode => return Err(anyhow!("Preprocessor {} is not supported", mode)),
    }

    transformers.push(Box::new(PILToNDarray));

    if config["do_rescale"].as_bool().unwrap_or(true) {
        let rescale_factor = config["rescale_factor"].as_f64().unwrap_or(1.0f64 / 255.0);
        transformers.push(Box::new(Rescale {
            scale: rescale_factor as f32,
        }));
    }

    if config["do_normalize"].as_bool().unwrap_or(false) {
        let mean = config["image_mean"]
            .as_array()
            .ok_or(anyhow!("image_mean must be contained"))?
            .iter()
            .map(|value| {
                value
                    .as_f64()
                    .map(|num| num as f32)
                    .ok_or(anyhow!("image_mean must be float"))
            })
            .collect::<Result<Vec<f32>>>()?;
        let std = config["image_std"]
            .as_array()
            .ok_or(anyhow!("image_std must be contained"))?
            .iter()
            .map(|value| {
                value
                    .as_f64()
                    .map(|num| num as f32)
                    .ok_or(anyhow!("image_std must be float"))
            })
            .collect::<Result<Vec<f32>>>()?;
        transformers.push(Box::new(Normalize { mean, std }));
    }

    Ok(Compose::new(transformers))
}