omni_search 0.1.1

A unified Rust SDK for multimodal embedding and similarity search.
Documentation
use std::fmt;
use std::path::PathBuf;

use serde::{Deserialize, Serialize};

use crate::error::Error;

#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelFamily {
    FgClip,
    ChineseClip,
}

impl fmt::Display for ModelFamily {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::FgClip => f.write_str("fgclip"),
            Self::ChineseClip => f.write_str("chinese_clip"),
        }
    }
}

#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelSourceKind {
    LocalBundleDir,
}

impl fmt::Display for ModelSourceKind {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::LocalBundleDir => f.write_str("local_bundle_dir"),
        }
    }
}

#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ModelSource {
    LocalBundleDir(PathBuf),
}

impl ModelSource {
    pub fn kind(&self) -> ModelSourceKind {
        match self {
            Self::LocalBundleDir(_) => ModelSourceKind::LocalBundleDir,
        }
    }
}

#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ModelConfig {
    pub family: ModelFamily,
    pub source: ModelSource,
}

impl ModelConfig {
    pub fn new(family: ModelFamily, source: ModelSource) -> Self {
        Self { family, source }
    }

    pub fn from_local_bundle(family: ModelFamily, path: impl Into<PathBuf>) -> Self {
        Self {
            family,
            source: ModelSource::LocalBundleDir(path.into()),
        }
    }
}

#[derive(Clone, Debug)]
pub struct RuntimeConfigBuilder {
    config: RuntimeConfig,
}

impl Default for RuntimeConfigBuilder {
    fn default() -> Self {
        Self::new()
    }
}

impl RuntimeConfigBuilder {
    pub fn new() -> Self {
        Self {
            config: RuntimeConfig::default(),
        }
    }

    pub fn from_config(config: RuntimeConfig) -> Self {
        Self { config }
    }

    pub fn intra_threads(&mut self, val: usize) -> &mut Self {
        self.config.intra_threads = val;
        self
    }

    pub fn inter_threads(&mut self, val: usize) -> &mut Self {
        self.config.inter_threads = Some(val);
        self
    }

    pub fn clear_inter_threads(&mut self) -> &mut Self {
        self.config.inter_threads = None;
        self
    }

    pub fn fgclip_max_patches(&mut self, val: usize) -> &mut Self {
        self.config.fgclip_max_patches = Some(val);
        self
    }

    pub fn clear_fgclip_max_patches(&mut self) -> &mut Self {
        self.config.fgclip_max_patches = None;
        self
    }

    pub fn session_policy(&mut self, val: SessionPolicy) -> &mut Self {
        self.config.session_policy = val;
        self
    }

    pub fn graph_optimization_level(&mut self, val: GraphOptimizationLevel) -> &mut Self {
        self.config.graph_optimization_level = val;
        self
    }

    pub fn build(&mut self) -> Result<RuntimeConfig, Error> {
        self.config.validate()?;
        Ok(self.config.clone())
    }
}

#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SessionPolicy {
    #[serde(alias = "SingleActive")]
    SingleActive,
    #[serde(alias = "KeepBothLoaded")]
    KeepBothLoaded,
}

#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum GraphOptimizationLevel {
    #[serde(alias = "Disabled")]
    Disabled,
    #[serde(alias = "Basic")]
    Basic,
    #[serde(alias = "Extended")]
    Extended,
    #[serde(alias = "All")]
    All,
}

#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RuntimeConfig {
    pub intra_threads: usize,
    pub inter_threads: Option<usize>,
    pub fgclip_max_patches: Option<usize>,
    pub session_policy: SessionPolicy,
    pub graph_optimization_level: GraphOptimizationLevel,
}

impl Default for RuntimeConfig {
    fn default() -> Self {
        Self {
            intra_threads: std::thread::available_parallelism()
                .map(|parallelism| parallelism.get())
                .unwrap_or(4),
            inter_threads: None,
            fgclip_max_patches: None,
            session_policy: SessionPolicy::SingleActive,
            graph_optimization_level: GraphOptimizationLevel::All,
        }
    }
}

impl RuntimeConfig {
    pub fn builder() -> RuntimeConfigBuilder {
        RuntimeConfigBuilder::new()
    }

    pub fn validate(&self) -> Result<(), Error> {
        if self.intra_threads == 0 {
            return Err(Error::invalid_config(
                "runtime.intra_threads must be greater than 0",
            ));
        }
        if matches!(self.inter_threads, Some(0)) {
            return Err(Error::invalid_config(
                "runtime.inter_threads must be greater than 0 when set",
            ));
        }
        Ok(())
    }
}

#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct OmniSearchConfig {
    pub model: ModelConfig,
    pub runtime: RuntimeConfig,
}

impl OmniSearchConfig {
    pub fn new(model: ModelConfig, runtime: RuntimeConfig) -> Self {
        Self { model, runtime }
    }

    pub fn from_local_bundle(
        family: ModelFamily,
        path: impl Into<PathBuf>,
        runtime: RuntimeConfig,
    ) -> Self {
        Self {
            model: ModelConfig::from_local_bundle(family, path),
            runtime,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::{GraphOptimizationLevel, RuntimeConfig, RuntimeConfigBuilder, SessionPolicy};

    #[test]
    fn runtime_builder_uses_defaults_when_fields_are_not_overridden() {
        let expected = RuntimeConfig::default();
        let actual = RuntimeConfigBuilder::new().build().unwrap();

        assert_eq!(actual, expected);
    }

    #[test]
    fn runtime_builder_overrides_selected_fields_only() {
        let actual = RuntimeConfig::builder()
            .intra_threads(2)
            .inter_threads(1)
            .fgclip_max_patches(256)
            .session_policy(SessionPolicy::KeepBothLoaded)
            .graph_optimization_level(GraphOptimizationLevel::Basic)
            .build()
            .unwrap();

        assert_eq!(actual.intra_threads, 2);
        assert_eq!(actual.inter_threads, Some(1));
        assert_eq!(actual.fgclip_max_patches, Some(256));
        assert_eq!(actual.session_policy, SessionPolicy::KeepBothLoaded);
        assert_eq!(
            actual.graph_optimization_level,
            GraphOptimizationLevel::Basic
        );
    }

    #[test]
    fn runtime_builder_can_clear_optional_overrides() {
        let actual = RuntimeConfig::builder()
            .inter_threads(2)
            .clear_inter_threads()
            .fgclip_max_patches(256)
            .clear_fgclip_max_patches()
            .build()
            .unwrap();

        assert_eq!(actual.inter_threads, None);
        assert_eq!(actual.fgclip_max_patches, None);
    }

    #[test]
    fn runtime_builder_rejects_invalid_values() {
        let error = RuntimeConfig::builder()
            .intra_threads(0)
            .build()
            .unwrap_err();
        assert!(
            error
                .to_string()
                .contains("runtime.intra_threads must be greater than 0")
        );
    }

    #[test]
    fn session_policy_deserializes_snake_case_and_legacy_pascal_case() {
        let snake_case: SessionPolicy = serde_json::from_str(r#""keep_both_loaded""#).unwrap();
        let legacy_pascal_case: SessionPolicy =
            serde_json::from_str(r#""KeepBothLoaded""#).unwrap();

        assert_eq!(snake_case, SessionPolicy::KeepBothLoaded);
        assert_eq!(legacy_pascal_case, SessionPolicy::KeepBothLoaded);
    }

    #[test]
    fn graph_optimization_level_deserializes_snake_case_and_legacy_pascal_case() {
        let snake_case: GraphOptimizationLevel = serde_json::from_str(r#""basic""#).unwrap();
        let legacy_pascal_case: GraphOptimizationLevel =
            serde_json::from_str(r#""Basic""#).unwrap();

        assert_eq!(snake_case, GraphOptimizationLevel::Basic);
        assert_eq!(legacy_pascal_case, GraphOptimizationLevel::Basic);
    }
}