use std::collections::VecDeque;
use std::io::BufRead;
use std::mem::size_of;
use crate::error::{Error, Result};
use ndarray::{Array1, ArrayViewMut1, ArrayViewMut2};
pub trait FromIteratorWithCapacity<T> {
fn from_iter_with_capacity<I>(iter: I, capacity: usize) -> Self
where
I: IntoIterator<Item = T>;
}
impl<T> FromIteratorWithCapacity<T> for Vec<T> {
fn from_iter_with_capacity<I>(iter: I, capacity: usize) -> Self
where
I: IntoIterator<Item = T>,
{
let mut v = Vec::with_capacity(capacity);
v.extend(iter);
v
}
}
impl<T> FromIteratorWithCapacity<T> for VecDeque<T> {
fn from_iter_with_capacity<I>(iter: I, capacity: usize) -> Self
where
I: IntoIterator<Item = T>,
{
let mut v = VecDeque::with_capacity(capacity);
v.extend(iter);
v
}
}
pub trait CollectWithCapacity {
type Item;
fn collect_with_capacity<B>(self, capacity: usize) -> B
where
B: FromIteratorWithCapacity<Self::Item>;
}
impl<I, T> CollectWithCapacity for I
where
I: Iterator<Item = T>,
{
type Item = T;
fn collect_with_capacity<B>(self, capacity: usize) -> B
where
B: FromIteratorWithCapacity<Self::Item>,
{
B::from_iter_with_capacity(self, capacity)
}
}
pub fn padding<T>(pos: u64) -> u64 {
let size = size_of::<T>() as u64;
size - (pos % size)
}
pub fn l2_normalize(mut v: ArrayViewMut1<f32>) -> f32 {
let norm = v.dot(&v).sqrt();
if norm != 0. {
v /= norm;
}
norm
}
pub fn l2_normalize_array(mut v: ArrayViewMut2<f32>) -> Array1<f32> {
let mut norms = Vec::with_capacity(v.nrows());
for embedding in v.outer_iter_mut() {
norms.push(l2_normalize(embedding));
}
norms.into()
}
pub fn read_number(reader: &mut dyn BufRead, delim: u8) -> Result<usize> {
let field_str = read_string(reader, delim, false)?;
field_str
.parse()
.map_err(|e| {
Error::Format(format!(
"Cannot parse shape component '{}': {}",
field_str, e
))
})
.map_err(Error::from)
}
pub fn read_string(reader: &mut dyn BufRead, delim: u8, lossy: bool) -> Result<String> {
let mut buf = Vec::new();
reader
.read_until(delim, &mut buf)
.map_err(|e| Error::read_error("Cannot read string", e))?;
buf.pop();
let s = if lossy {
String::from_utf8_lossy(&buf).into_owned()
} else {
String::from_utf8(buf)
.map_err(|e| Error::Format(format!("Token contains invalid UTF-8: {}", e)))?
};
Ok(s)
}