omni_search 0.2.6

A unified Rust SDK for multimodal embedding and similarity search.
Documentation
use std::fs;
use std::path::Path;

use image::{DynamicImage, imageops::FilterType};
use ndarray::{Array, ArrayD, IxDyn};

use crate::error::Error;

pub(crate) struct FgClipImageInputs {
    pub pixel_values: ArrayD<f32>,
    pub pixel_attention_mask: ArrayD<i32>,
    pub spatial_height: usize,
    pub spatial_width: usize,
}

pub(crate) fn read_f32_file(path: &Path) -> Result<Vec<f32>, Error> {
    let bytes = fs::read(path)?;
    if bytes.len() % 4 != 0 {
        return Err(Error::image_preprocess(format!(
            "{} has {} bytes, not divisible by 4",
            path.display(),
            bytes.len()
        )));
    }

    Ok(bytes
        .chunks_exact(4)
        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
        .collect())
}

pub(crate) fn determine_max_patches(
    width: u32,
    height: u32,
    patch_size: usize,
    default_max_patches: usize,
) -> usize {
    let raw = ((width as usize) / patch_size) * ((height as usize) / patch_size);
    let mut buckets = vec![128usize, 256, 576, 784];
    buckets.retain(|candidate| *candidate <= default_max_patches);
    buckets.push(default_max_patches);
    buckets.sort_unstable();
    buckets.dedup();
    buckets
        .into_iter()
        .find(|candidate| raw <= *candidate)
        .unwrap_or(default_max_patches)
}

pub(crate) fn preprocess_image(
    image: &DynamicImage,
    patch_size: usize,
    max_patches: usize,
) -> Result<FgClipImageInputs, Error> {
    let image = image.to_rgb8();
    let (original_width, original_height) = image.dimensions();
    let (target_height, target_width) = get_image_size_for_max_num_patches(
        original_height as usize,
        original_width as usize,
        patch_size,
        max_patches,
    );
    let resized = image::imageops::resize(
        &image,
        target_width as u32,
        target_height as u32,
        FilterType::Triangle,
    );

    let spatial_height = target_height / patch_size;
    let spatial_width = target_width / patch_size;
    let valid_patches = spatial_height * spatial_width;
    let channels = patch_size * patch_size * 3;
    if valid_patches > max_patches {
        return Err(Error::image_preprocess(format!(
            "internal error: {valid_patches} valid patches > max_patches {max_patches}"
        )));
    }

    let mut pixel_values = vec![0.0f32; max_patches * channels];
    for patch_y in 0..spatial_height {
        for patch_x in 0..spatial_width {
            let patch_index = patch_y * spatial_width + patch_x;
            let mut dst = patch_index * channels;
            for y in 0..patch_size {
                for x in 0..patch_size {
                    let pixel = resized.get_pixel(
                        (patch_x * patch_size + x) as u32,
                        (patch_y * patch_size + y) as u32,
                    );
                    for channel in 0..3 {
                        pixel_values[dst] = pixel[channel] as f32 / 127.5 - 1.0;
                        dst += 1;
                    }
                }
            }
        }
    }

    let mut mask = vec![0i32; max_patches];
    for item in mask.iter_mut().take(valid_patches) {
        *item = 1;
    }

    Ok(FgClipImageInputs {
        pixel_values: Array::from_shape_vec(IxDyn(&[1, max_patches, channels]), pixel_values)
            .map_err(|error| Error::image_preprocess(error.to_string()))?,
        pixel_attention_mask: Array::from_shape_vec(IxDyn(&[1, max_patches]), mask)
            .map_err(|error| Error::image_preprocess(error.to_string()))?,
        spatial_height,
        spatial_width,
    })
}

pub(crate) fn stack_pixel_values(images: &[FgClipImageInputs]) -> Result<ArrayD<f32>, Error> {
    let batch = images.len();
    let shape = images
        .first()
        .ok_or_else(|| Error::image_preprocess("cannot stack an empty image batch"))?
        .pixel_values
        .shape()
        .to_vec();
    let max_patches = shape[1];
    let channels = shape[2];
    let mut values = Vec::with_capacity(batch * max_patches * channels);
    for image in images {
        if image.pixel_values.shape() != [1, max_patches, channels] {
            return Err(Error::image_preprocess(format!(
                "all fgclip pixel arrays must have shape [1,{max_patches},{channels}]"
            )));
        }
        values.extend_from_slice(
            image
                .pixel_values
                .as_slice()
                .ok_or_else(|| Error::image_preprocess("pixel array is not contiguous"))?,
        );
    }
    Array::from_shape_vec(IxDyn(&[batch, max_patches, channels]), values)
        .map_err(|error| Error::image_preprocess(error.to_string()))
}

pub(crate) fn stack_attention_masks(images: &[FgClipImageInputs]) -> Result<ArrayD<i32>, Error> {
    let batch = images.len();
    let max_patches = images
        .first()
        .ok_or_else(|| Error::image_preprocess("cannot stack an empty image batch"))?
        .pixel_attention_mask
        .shape()[1];
    let mut values = Vec::with_capacity(batch * max_patches);
    for image in images {
        if image.pixel_attention_mask.shape() != [1, max_patches] {
            return Err(Error::image_preprocess(format!(
                "all fgclip masks must have shape [1,{max_patches}]"
            )));
        }
        values.extend_from_slice(
            image
                .pixel_attention_mask
                .as_slice()
                .ok_or_else(|| Error::image_preprocess("mask array is not contiguous"))?,
        );
    }
    Array::from_shape_vec(IxDyn(&[batch, max_patches]), values)
        .map_err(|error| Error::image_preprocess(error.to_string()))
}

pub(crate) fn stack_f32_batches(
    arrays: &[ArrayD<f32>],
    shape: [usize; 3],
) -> Result<ArrayD<f32>, Error> {
    let mut values = Vec::with_capacity(shape.iter().product());
    for array in arrays {
        if array.shape() != [1, shape[1], shape[2]] {
            return Err(Error::image_preprocess(format!(
                "all arrays must have shape [1,{},{}]",
                shape[1], shape[2]
            )));
        }
        values.extend_from_slice(
            array
                .as_slice()
                .ok_or_else(|| Error::image_preprocess("array is not contiguous"))?,
        );
    }
    Array::from_shape_vec(IxDyn(&shape), values)
        .map_err(|error| Error::image_preprocess(error.to_string()))
}

pub(crate) fn build_positional_embedding(
    base_pos: &[f32],
    source_height: usize,
    source_width: usize,
    target_height: usize,
    target_width: usize,
    max_patches: usize,
    channels: usize,
) -> Result<ArrayD<f32>, Error> {
    if base_pos.len() != source_height * source_width * channels {
        return Err(Error::image_preprocess(format!(
            "unexpected vision position embedding length {}, expected {}",
            base_pos.len(),
            source_height * source_width * channels
        )));
    }

    let mut output = vec![0.0f32; max_patches * channels];
    for out_y in 0..target_height {
        let in_y = linear_source_coordinate(out_y, target_height, source_height);
        let y0 = in_y.floor().clamp(0.0, (source_height - 1) as f32) as usize;
        let y1 = (y0 + 1).min(source_height - 1);
        let wy = in_y - y0 as f32;

        for out_x in 0..target_width {
            let in_x = linear_source_coordinate(out_x, target_width, source_width);
            let x0 = in_x.floor().clamp(0.0, (source_width - 1) as f32) as usize;
            let x1 = (x0 + 1).min(source_width - 1);
            let wx = in_x - x0 as f32;
            let token = out_y * target_width + out_x;
            for channel in 0..channels {
                let top = lerp(
                    base_pos[((y0 * source_width + x0) * channels) + channel],
                    base_pos[((y0 * source_width + x1) * channels) + channel],
                    wx,
                );
                let bottom = lerp(
                    base_pos[((y1 * source_width + x0) * channels) + channel],
                    base_pos[((y1 * source_width + x1) * channels) + channel],
                    wx,
                );
                output[token * channels + channel] = lerp(top, bottom, wy);
            }
        }
    }

    let valid = target_height * target_width;
    if valid > 0 && valid < max_patches {
        for token in valid..max_patches {
            let src = output[..channels].to_vec();
            output[token * channels..(token + 1) * channels].copy_from_slice(&src);
        }
    }

    Array::from_shape_vec(IxDyn(&[1, max_patches, channels]), output)
        .map_err(|error| Error::image_preprocess(error.to_string()))
}

fn get_image_size_for_max_num_patches(
    image_height: usize,
    image_width: usize,
    patch_size: usize,
    max_num_patches: usize,
) -> (usize, usize) {
    fn scaled_size(scale: f64, size: usize, patch_size: usize) -> usize {
        let scaled = size as f64 * scale;
        let patched = (scaled / patch_size as f64).ceil() as usize * patch_size;
        patched.max(patch_size)
    }

    let eps = 1e-5f64;
    let mut scale_min = eps / 10.0;
    let mut scale_max = 100.0;
    while scale_max - scale_min >= eps {
        let scale = (scale_min + scale_max) / 2.0;
        let target_height = scaled_size(scale, image_height, patch_size);
        let target_width = scaled_size(scale, image_width, patch_size);
        let num_patches = (target_height / patch_size) * (target_width / patch_size);
        if num_patches <= max_num_patches {
            scale_min = scale;
        } else {
            scale_max = scale;
        }
    }

    (
        scaled_size(scale_min, image_height, patch_size),
        scaled_size(scale_min, image_width, patch_size),
    )
}

fn linear_source_coordinate(output_index: usize, output_size: usize, input_size: usize) -> f32 {
    let source = (output_index as f32 + 0.5) * input_size as f32 / output_size as f32 - 0.5;
    source.clamp(0.0, (input_size - 1) as f32)
}

fn lerp(a: f32, b: f32, weight: f32) -> f32 {
    a + (b - a) * weight
}

#[cfg(test)]
mod tests {
    use image::{DynamicImage, Rgb, RgbImage};

    use super::{build_positional_embedding, determine_max_patches, preprocess_image};

    #[test]
    fn determines_patch_bucket() {
        assert_eq!(determine_max_patches(1920, 1080, 16, 1024), 1024);
        assert_eq!(determine_max_patches(320, 240, 16, 1024), 576);
        assert_eq!(determine_max_patches(640, 427, 16, 576), 576);
    }

    #[test]
    fn preprocesses_fgclip_image() {
        let image = DynamicImage::ImageRgb8(RgbImage::from_pixel(64, 64, Rgb([255, 0, 0])));
        let encoded = preprocess_image(&image, 16, 128).unwrap();
        assert_eq!(encoded.pixel_values.shape(), [1, 128, 768]);
        assert_eq!(encoded.pixel_attention_mask.shape(), [1, 128]);
    }

    #[test]
    fn resizes_positional_embedding() {
        let base = vec![0.0f32; 16 * 16 * 4];
        let pos = build_positional_embedding(&base, 16, 16, 2, 2, 8, 4).unwrap();
        assert_eq!(pos.shape(), [1, 8, 4]);
    }
}