use crate::{as_vector::AsVectorRef, error::Error, iter::VecSpaceIter, vector::Vector};
use ahash::AHashMap;
use order_struct::{float_ord::FloatOrd, OrderVal};
use std::slice::Iter;
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct VecSpace {
vec_data: Vec<f32>,
words: Vec<String>,
dimension: usize,
pub term_map: Option<AHashMap<String, u32>>,
}
impl VecSpace {
#[inline]
pub fn new(dimension: usize) -> Self {
Self {
vec_data: vec![],
words: vec![],
dimension,
term_map: None,
}
}
#[inline]
pub fn with_termmap(mut self) -> Self {
self.term_map = Some(AHashMap::new());
if !self.is_empty() {
self.index_terms();
}
self
}
#[inline]
pub fn len(&self) -> usize {
self.words.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn dim(&self) -> usize {
self.dimension
}
pub fn shrink_to_fit(&mut self) {
self.words.shrink_to_fit();
self.vec_data.shrink_to_fit();
if let Some(term_map) = self.term_map.as_mut() {
term_map.shrink_to_fit();
}
}
pub fn total_cap(&self) -> usize {
self.words.capacity()
+ self.vec_data.capacity()
+ self.term_map.as_ref().map(|i| i.capacity()).unwrap_or(0)
}
pub fn reserve(&mut self, additional: usize) {
self.words.reserve(additional);
self.vec_data.reserve(additional * self.dimension);
}
#[inline]
pub fn iter(&self) -> VecSpaceIter {
VecSpaceIter::new(self)
}
#[inline]
pub fn terms(&self) -> Iter<String> {
self.words.iter()
}
pub fn insert<'v, 't, R: AsVectorRef<'v, 't>>(&mut self, vec: R) -> Result<(), Error> {
let vec = vec.as_vec_ref();
if vec.dim() != self.dimension {
return Err(Error::DimMismatch(vec.dim(), self.dim()));
}
if let Some(term_map) = self.term_map.as_mut() {
term_map.insert(vec.term().to_string(), self.words.len() as u32);
}
self.vec_data.extend_from_slice(vec.data());
self.words.push(vec.term().to_string());
Ok(())
}
pub fn get(&self, pos: usize) -> Option<Vector> {
let vec_idx = pos * self.dimension;
let word = self.words.get(pos)?;
let vec_data = self.vec_data.get(vec_idx..vec_idx + self.dimension)?;
Some(Vector::new(vec_data, word))
}
pub fn top_k<S>(&self, k: usize, sim: S) -> Vec<(f32, Vector)>
where
S: Fn(&Vector) -> f32,
{
let mut cont = priority_container::PrioContainerMax::new(k);
for v in (0..self.len()).map(|i| self.get(i).unwrap()) {
let s = sim(&v);
cont.insert(OrderVal::new(v, FloatOrd(s)));
}
let mut res: Vec<_> = cont
.into_iter()
.map(|i| (i.0.ord().0, i.0.into_inner()))
.collect();
res.reverse();
res
}
#[inline]
pub fn find_term<S: AsRef<str>>(&self, term: S) -> Option<Vector> {
self.get(self.find_term_idx(term.as_ref())?)
}
pub fn clear(&mut self) {
self.vec_data.clear();
self.words.clear();
if let Some(term_map) = self.term_map.as_mut() {
term_map.clear();
}
}
#[inline]
fn find_term_idx(&self, term: &str) -> Option<usize> {
self.term_map.as_ref()?.get(term).map(|i| *i as usize)
}
fn index_terms(&mut self) {
let mut map = self.term_map.take().unwrap_or_default();
map.clear();
for (pos, term) in self.words.iter().cloned().enumerate() {
map.insert(term, pos as u32);
}
self.term_map = Some(map);
}
}
impl<'v, 't, V> Extend<V> for VecSpace
where
V: AsVectorRef<'v, 't>,
{
fn extend<T: IntoIterator<Item = V>>(&mut self, iter: T) {
for i in iter {
let i = i.as_vec_ref();
if self.insert(i).is_err() {
panic!(
"Tried to insert a {} dimensional vec into a space with {} dimensions",
i.dim(),
self.dim(),
);
}
}
}
}
#[cfg(test)]
mod test {
use super::VecSpace;
use crate::vector::Vector;
fn get_vectors() -> [Vector<'static, 'static>; 3] {
[
Vector::new(&[1.0, 0.07, 23.1], "a"),
Vector::new(&[0.13, 3.19, 3.12], "b"),
Vector::new(&[3.193, 3.1, 32.1], "c"),
]
}
fn get_space() -> VecSpace {
let mut space = VecSpace::new(3);
space.extend(get_vectors().iter());
space
}
#[test]
fn test_space_get() {
let space = get_space();
let vectors = get_vectors();
for (pos, exp_vec) in vectors.iter().enumerate() {
let vec = space.get(pos).unwrap();
assert_eq!(vec, *exp_vec);
}
}
#[test]
fn test_space_find() {
let space = get_space().with_termmap();
let vectors = get_vectors();
for exp_vec in vectors {
let vec = space.find_term(exp_vec.term()).unwrap();
assert_eq!(vec, exp_vec);
}
let mut space = VecSpace::new(3).with_termmap();
let vectors = get_vectors();
space.extend(vectors);
for exp_vec in vectors {
let vec = space.find_term(exp_vec.term()).unwrap();
assert_eq!(vec, exp_vec);
}
}
}