pgvector/
halfvec.rs

1use half::f16;
2
3#[cfg(feature = "diesel")]
4use crate::diesel_ext::halfvec::HalfVectorType;
5
6#[cfg(feature = "diesel")]
7use diesel::{deserialize::FromSqlRow, expression::AsExpression};
8
9/// A half vector.
10#[derive(Clone, Debug, PartialEq)]
11#[cfg_attr(feature = "diesel", derive(FromSqlRow, AsExpression))]
12#[cfg_attr(feature = "diesel", diesel(sql_type = HalfVectorType))]
13pub struct HalfVector(pub(crate) Vec<f16>);
14
15impl From<Vec<f16>> for HalfVector {
16    fn from(v: Vec<f16>) -> Self {
17        HalfVector(v)
18    }
19}
20
21impl From<HalfVector> for Vec<f16> {
22    fn from(val: HalfVector) -> Self {
23        val.0
24    }
25}
26
27impl HalfVector {
28    /// Creates a half vector from a `f32` slice.
29    pub fn from_f32_slice(slice: &[f32]) -> HalfVector {
30        HalfVector(slice.iter().map(|v| f16::from_f32(*v)).collect())
31    }
32
33    /// Returns a copy of the half vector as a `Vec<f16>`.
34    pub fn to_vec(&self) -> Vec<f16> {
35        self.0.clone()
36    }
37
38    /// Returns the half vector as a slice.
39    pub fn as_slice(&self) -> &[f16] {
40        self.0.as_slice()
41    }
42
43    #[cfg(any(feature = "postgres", feature = "sqlx", feature = "diesel"))]
44    pub(crate) fn from_sql(
45        buf: &[u8],
46    ) -> Result<HalfVector, Box<dyn std::error::Error + Sync + Send>> {
47        let dim = u16::from_be_bytes(buf[0..2].try_into()?).into();
48        let unused = u16::from_be_bytes(buf[2..4].try_into()?);
49        if unused != 0 {
50            return Err("expected unused to be 0".into());
51        }
52
53        let mut vec = Vec::with_capacity(dim);
54        for i in 0..dim {
55            let s = 4 + 2 * i;
56            vec.push(f16::from_be_bytes(buf[s..s + 2].try_into()?));
57        }
58
59        Ok(HalfVector(vec))
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use crate::HalfVector;
66    use half::f16;
67
68    #[test]
69    fn test_into() {
70        let vec = HalfVector::from(vec![
71            f16::from_f32(1.0),
72            f16::from_f32(2.0),
73            f16::from_f32(3.0),
74        ]);
75        let f16_vec: Vec<f16> = vec.into();
76        assert_eq!(
77            f16_vec,
78            vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]
79        );
80    }
81
82    #[test]
83    fn test_to_vec() {
84        let vec = HalfVector::from_f32_slice(&[1.0, 2.0, 3.0]);
85        assert_eq!(
86            vec.to_vec(),
87            vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]
88        );
89    }
90
91    #[test]
92    fn test_as_slice() {
93        let vec = HalfVector::from_f32_slice(&[1.0, 2.0, 3.0]);
94        assert_eq!(
95            vec.as_slice(),
96            &[f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]
97        );
98    }
99}