use crate::{Error, Result};
use bytemuck::{Pod, Zeroable};
use memmap2::Mmap;
use std::fs::File;
use std::sync::Arc;
const MAGIC: u32 = 0x3143_564D;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum VectorDataType {
F32 = 0,
F16 = 1,
I8 = 2,
}
impl VectorDataType {
pub const fn element_size(&self) -> usize {
match self {
Self::F32 => 4,
Self::F16 => 2,
Self::I8 => 1,
}
}
fn from_u32(value: u32) -> Option<Self> {
match value {
0 => Some(Self::F32),
1 => Some(Self::F16),
2 => Some(Self::I8),
_ => None,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
struct VectorHeader {
magic: u32,
vocab_size: u32,
dim: u32,
data_type: u32,
reserved: [u32; 4],
}
unsafe impl Pod for VectorHeader {}
unsafe impl Zeroable for VectorHeader {}
pub struct VectorStore {
_mmap: Arc<Mmap>,
data_ptr: *const u8,
dim: usize,
vocab_size: usize,
data_type: VectorDataType,
}
unsafe impl Send for VectorStore {}
unsafe impl Sync for VectorStore {}
impl std::fmt::Debug for VectorStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VectorStore")
.field("vocab_size", &self.vocab_size)
.field("dim", &self.dim)
.field("data_type", &self.data_type)
.finish()
}
}
impl VectorStore {
const HEADER_SIZE: usize = 32;
pub fn from_mmap(mmap: Arc<Mmap>) -> Result<Self> {
let data = &mmap[..];
if data.len() < Self::HEADER_SIZE {
return Err(Error::VectorError(
"Vector file too small for header".to_string(),
));
}
let header: VectorHeader = bytemuck::pod_read_unaligned(&data[0..Self::HEADER_SIZE]);
if header.magic != MAGIC {
return Err(Error::VectorError(format!(
"Invalid magic number: expected 0x{:08X}, got 0x{:08X}",
MAGIC, header.magic
)));
}
let data_type = VectorDataType::from_u32(header.data_type).ok_or_else(|| {
Error::VectorError(format!("Invalid data type: {}", header.data_type))
})?;
let vocab_size = header.vocab_size as usize;
let dim = header.dim as usize;
let element_size = data_type.element_size();
let expected_size = Self::HEADER_SIZE + (vocab_size * dim * element_size);
if data.len() != expected_size {
return Err(Error::VectorError(format!(
"Vector file size mismatch: expected {} bytes, got {}",
expected_size,
data.len()
)));
}
let data_ptr = data[Self::HEADER_SIZE..].as_ptr();
Ok(Self {
_mmap: mmap,
data_ptr,
dim,
vocab_size,
data_type,
})
}
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
let file = File::open(path).map_err(|e| Error::VectorError(e.to_string()))?;
let mmap = unsafe { Mmap::map(&file).map_err(|e| Error::VectorError(e.to_string()))? };
Self::from_mmap(Arc::new(mmap))
}
#[inline]
pub fn get(&self, word_id: u32) -> Option<&[f32]> {
if self.data_type != VectorDataType::F32 {
return None;
}
let word_id = word_id as usize;
if word_id >= self.vocab_size {
return None;
}
let element_size = self.data_type.element_size();
let start = word_id * self.dim * element_size;
unsafe {
let slice =
std::slice::from_raw_parts(self.data_ptr.add(start), self.dim * element_size);
Some(bytemuck::cast_slice(slice))
}
}
#[inline]
pub const fn dim(&self) -> usize {
self.dim
}
#[inline]
pub const fn vocab_size(&self) -> usize {
self.vocab_size
}
#[inline]
pub const fn data_type(&self) -> VectorDataType {
self.data_type
}
pub fn mean_pooling(&self, word_ids: &[u32]) -> Option<Vec<f32>> {
let mut sum = vec![0.0_f32; self.dim];
let mut count = 0;
for &word_id in word_ids {
if let Some(vec) = self.get(word_id) {
for (i, &val) in vec.iter().enumerate() {
sum[i] += val;
}
count += 1;
}
}
if count == 0 {
return None;
}
let count_f32 = count as f32;
for val in sum.iter_mut() {
*val /= count_f32;
}
Some(sum)
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Option<f32> {
if a.len() != b.len() {
return None;
}
let mut dot = 0.0_f32;
let mut norm_a = 0.0_f32;
let mut norm_b = 0.0_f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom < 1e-10 {
return None;
}
Some(dot / denom)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_data_type() {
assert_eq!(VectorDataType::F32.element_size(), 4);
assert_eq!(VectorDataType::F16.element_size(), 2);
assert_eq!(VectorDataType::I8.element_size(), 1);
assert_eq!(VectorDataType::from_u32(0), Some(VectorDataType::F32));
assert_eq!(VectorDataType::from_u32(1), Some(VectorDataType::F16));
assert_eq!(VectorDataType::from_u32(2), Some(VectorDataType::I8));
assert_eq!(VectorDataType::from_u32(99), None);
}
#[test]
fn test_header_size() {
assert_eq!(std::mem::size_of::<VectorHeader>(), 32);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((VectorStore::cosine_similarity(&a, &b).unwrap() - 1.0).abs() < 1e-6);
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
assert!((VectorStore::cosine_similarity(&a, &b).unwrap() - 0.0).abs() < 1e-6);
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
assert!((VectorStore::cosine_similarity(&a, &b).unwrap() + 1.0).abs() < 1e-6);
}
}