finalfusion/
util.rs

1use std::collections::VecDeque;
2use std::io::BufRead;
3use std::mem::size_of;
4
5use crate::error::{Error, Result};
6use ndarray::{Array1, ArrayViewMut1, ArrayViewMut2};
7
8/// Conversion from an `Iterator` into a collection with a given
9/// capacity.
10pub trait FromIteratorWithCapacity<T> {
11    /// Construct a collection with the given capacity from an iterator.
12    fn from_iter_with_capacity<I>(iter: I, capacity: usize) -> Self
13    where
14        I: IntoIterator<Item = T>;
15}
16
17impl<T> FromIteratorWithCapacity<T> for Vec<T> {
18    fn from_iter_with_capacity<I>(iter: I, capacity: usize) -> Self
19    where
20        I: IntoIterator<Item = T>,
21    {
22        let mut v = Vec::with_capacity(capacity);
23        v.extend(iter);
24        v
25    }
26}
27
28impl<T> FromIteratorWithCapacity<T> for VecDeque<T> {
29    fn from_iter_with_capacity<I>(iter: I, capacity: usize) -> Self
30    where
31        I: IntoIterator<Item = T>,
32    {
33        let mut v = VecDeque::with_capacity(capacity);
34        v.extend(iter);
35        v
36    }
37}
38
39/// Collect iterms from an `Iterator` into a collection with a
40/// capacity.
41pub trait CollectWithCapacity {
42    type Item;
43
44    /// Transform an iterator into a collection with the given capacity.
45    fn collect_with_capacity<B>(self, capacity: usize) -> B
46    where
47        B: FromIteratorWithCapacity<Self::Item>;
48}
49
50impl<I, T> CollectWithCapacity for I
51where
52    I: Iterator<Item = T>,
53{
54    type Item = T;
55
56    fn collect_with_capacity<B>(self, capacity: usize) -> B
57    where
58        B: FromIteratorWithCapacity<Self::Item>,
59    {
60        B::from_iter_with_capacity(self, capacity)
61    }
62}
63
64pub fn padding<T>(pos: u64) -> u64 {
65    let size = size_of::<T>() as u64;
66    size - (pos % size)
67}
68
69pub fn l2_normalize(mut v: ArrayViewMut1<f32>) -> f32 {
70    let norm = v.dot(&v).sqrt();
71
72    if norm != 0. {
73        v /= norm;
74    }
75
76    norm
77}
78
79pub fn l2_normalize_array(mut v: ArrayViewMut2<f32>) -> Array1<f32> {
80    let mut norms = Vec::with_capacity(v.nrows());
81    for embedding in v.outer_iter_mut() {
82        norms.push(l2_normalize(embedding));
83    }
84
85    norms.into()
86}
87
88pub fn read_number(reader: &mut dyn BufRead, delim: u8) -> Result<usize> {
89    let field_str = read_string(reader, delim, false)?;
90    field_str
91        .parse()
92        .map_err(|e| {
93            Error::Format(format!(
94                "Cannot parse shape component '{}': {}",
95                field_str, e
96            ))
97        })
98        .map_err(Error::from)
99}
100
101pub fn read_string(reader: &mut dyn BufRead, delim: u8, lossy: bool) -> Result<String> {
102    let mut buf = Vec::new();
103    reader
104        .read_until(delim, &mut buf)
105        .map_err(|e| Error::read_error("Cannot read string", e))?;
106    buf.pop();
107
108    let s = if lossy {
109        String::from_utf8_lossy(&buf).into_owned()
110    } else {
111        String::from_utf8(buf)
112            .map_err(|e| Error::Format(format!("Token contains invalid UTF-8: {}", e)))?
113    };
114
115    Ok(s)
116}