1use std::collections::VecDeque;
2use std::io::BufRead;
3use std::mem::size_of;
4
5use crate::error::{Error, Result};
6use ndarray::{Array1, ArrayViewMut1, ArrayViewMut2};
7
8pub trait FromIteratorWithCapacity<T> {
11 fn from_iter_with_capacity<I>(iter: I, capacity: usize) -> Self
13 where
14 I: IntoIterator<Item = T>;
15}
16
17impl<T> FromIteratorWithCapacity<T> for Vec<T> {
18 fn from_iter_with_capacity<I>(iter: I, capacity: usize) -> Self
19 where
20 I: IntoIterator<Item = T>,
21 {
22 let mut v = Vec::with_capacity(capacity);
23 v.extend(iter);
24 v
25 }
26}
27
28impl<T> FromIteratorWithCapacity<T> for VecDeque<T> {
29 fn from_iter_with_capacity<I>(iter: I, capacity: usize) -> Self
30 where
31 I: IntoIterator<Item = T>,
32 {
33 let mut v = VecDeque::with_capacity(capacity);
34 v.extend(iter);
35 v
36 }
37}
38
39pub trait CollectWithCapacity {
42 type Item;
43
44 fn collect_with_capacity<B>(self, capacity: usize) -> B
46 where
47 B: FromIteratorWithCapacity<Self::Item>;
48}
49
50impl<I, T> CollectWithCapacity for I
51where
52 I: Iterator<Item = T>,
53{
54 type Item = T;
55
56 fn collect_with_capacity<B>(self, capacity: usize) -> B
57 where
58 B: FromIteratorWithCapacity<Self::Item>,
59 {
60 B::from_iter_with_capacity(self, capacity)
61 }
62}
63
64pub fn padding<T>(pos: u64) -> u64 {
65 let size = size_of::<T>() as u64;
66 size - (pos % size)
67}
68
69pub fn l2_normalize(mut v: ArrayViewMut1<f32>) -> f32 {
70 let norm = v.dot(&v).sqrt();
71
72 if norm != 0. {
73 v /= norm;
74 }
75
76 norm
77}
78
79pub fn l2_normalize_array(mut v: ArrayViewMut2<f32>) -> Array1<f32> {
80 let mut norms = Vec::with_capacity(v.nrows());
81 for embedding in v.outer_iter_mut() {
82 norms.push(l2_normalize(embedding));
83 }
84
85 norms.into()
86}
87
88pub fn read_number(reader: &mut dyn BufRead, delim: u8) -> Result<usize> {
89 let field_str = read_string(reader, delim, false)?;
90 field_str
91 .parse()
92 .map_err(|e| {
93 Error::Format(format!(
94 "Cannot parse shape component '{}': {}",
95 field_str, e
96 ))
97 })
98 .map_err(Error::from)
99}
100
101pub fn read_string(reader: &mut dyn BufRead, delim: u8, lossy: bool) -> Result<String> {
102 let mut buf = Vec::new();
103 reader
104 .read_until(delim, &mut buf)
105 .map_err(|e| Error::read_error("Cannot read string", e))?;
106 buf.pop();
107
108 let s = if lossy {
109 String::from_utf8_lossy(&buf).into_owned()
110 } else {
111 String::from_utf8(buf)
112 .map_err(|e| Error::Format(format!("Token contains invalid UTF-8: {}", e)))?
113 };
114
115 Ok(s)
116}