pub mod config;
pub mod embedder;
pub mod error;
pub mod model;
pub mod pooling;
pub mod ruvector_integration;
pub mod tokenizer;
#[cfg(feature = "gpu")]
pub mod gpu;
#[cfg(not(feature = "gpu"))]
pub mod gpu {
#[derive(Debug, Clone, Default)]
pub struct GpuConfig;
impl GpuConfig {
pub fn auto() -> Self { Self }
pub fn cpu_only() -> Self { Self }
}
pub async fn is_gpu_available() -> bool { false }
}
pub use config::{EmbedderConfig, ModelSource, PoolingStrategy};
pub use embedder::{Embedder, EmbedderBuilder, EmbeddingOutput};
pub use error::{EmbeddingError, Result};
pub use model::{OnnxModel, ModelInfo};
pub use pooling::Pooler;
pub use ruvector_integration::{
Distance, IndexConfig, RagPipeline, RuVectorBuilder, RuVectorEmbeddings, SearchResult, VectorId,
};
pub use tokenizer::Tokenizer;
#[cfg(feature = "gpu")]
pub use gpu::{
GpuAccelerator, GpuConfig, GpuMode, GpuInfo, GpuBackend,
HybridAccelerator, is_gpu_available,
};
pub mod prelude {
pub use crate::{
Distance, Embedder, EmbedderBuilder, EmbedderConfig, EmbeddingError,
IndexConfig, ModelSource, PoolingStrategy, RagPipeline, Result,
RuVectorBuilder, RuVectorEmbeddings, SearchResult, VectorId,
};
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum PretrainedModel {
#[default]
AllMiniLmL6V2,
AllMiniLmL12V2,
AllMpnetBaseV2,
MultiQaMiniLmL6,
ParaphraseMiniLmL6V2,
BgeSmallEnV15,
E5SmallV2,
GteSmall,
}
impl PretrainedModel {
pub fn model_id(&self) -> &'static str {
match self {
Self::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
Self::AllMiniLmL12V2 => "sentence-transformers/all-MiniLM-L12-v2",
Self::AllMpnetBaseV2 => "sentence-transformers/all-mpnet-base-v2",
Self::MultiQaMiniLmL6 => "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
Self::ParaphraseMiniLmL6V2 => "sentence-transformers/paraphrase-MiniLM-L6-v2",
Self::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
Self::E5SmallV2 => "intfloat/e5-small-v2",
Self::GteSmall => "thenlper/gte-small",
}
}
pub fn dimension(&self) -> usize {
match self {
Self::AllMiniLmL6V2
| Self::AllMiniLmL12V2
| Self::MultiQaMiniLmL6
| Self::ParaphraseMiniLmL6V2
| Self::BgeSmallEnV15
| Self::E5SmallV2
| Self::GteSmall => 384,
Self::AllMpnetBaseV2 => 768,
}
}
pub fn max_seq_length(&self) -> usize {
match self {
Self::AllMiniLmL6V2
| Self::AllMiniLmL12V2
| Self::MultiQaMiniLmL6
| Self::ParaphraseMiniLmL6V2 => 256,
Self::AllMpnetBaseV2 => 384,
Self::BgeSmallEnV15 | Self::E5SmallV2 | Self::GteSmall => 512,
}
}
pub fn normalize_output(&self) -> bool {
true
}
}