1use half::f16;
2
3#[cfg(feature = "diesel")]
4use crate::diesel_ext::halfvec::HalfVectorType;
5
6#[cfg(feature = "diesel")]
7use diesel::{deserialize::FromSqlRow, expression::AsExpression};
8
9#[derive(Clone, Debug, PartialEq)]
11#[cfg_attr(feature = "diesel", derive(FromSqlRow, AsExpression))]
12#[cfg_attr(feature = "diesel", diesel(sql_type = HalfVectorType))]
13pub struct HalfVector(pub(crate) Vec<f16>);
14
15impl From<Vec<f16>> for HalfVector {
16 fn from(v: Vec<f16>) -> Self {
17 HalfVector(v)
18 }
19}
20
21impl From<HalfVector> for Vec<f16> {
22 fn from(val: HalfVector) -> Self {
23 val.0
24 }
25}
26
27impl HalfVector {
28 pub fn from_f32_slice(slice: &[f32]) -> HalfVector {
30 HalfVector(slice.iter().map(|v| f16::from_f32(*v)).collect())
31 }
32
33 pub fn to_vec(&self) -> Vec<f16> {
35 self.0.clone()
36 }
37
38 pub fn as_slice(&self) -> &[f16] {
40 self.0.as_slice()
41 }
42
43 #[cfg(any(feature = "postgres", feature = "sqlx", feature = "diesel"))]
44 pub(crate) fn from_sql(
45 buf: &[u8],
46 ) -> Result<HalfVector, Box<dyn std::error::Error + Sync + Send>> {
47 let dim = u16::from_be_bytes(buf[0..2].try_into()?).into();
48 let unused = u16::from_be_bytes(buf[2..4].try_into()?);
49 if unused != 0 {
50 return Err("expected unused to be 0".into());
51 }
52
53 let mut vec = Vec::with_capacity(dim);
54 for i in 0..dim {
55 let s = 4 + 2 * i;
56 vec.push(f16::from_be_bytes(buf[s..s + 2].try_into()?));
57 }
58
59 Ok(HalfVector(vec))
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use crate::HalfVector;
66 use half::f16;
67
68 #[test]
69 fn test_into() {
70 let vec = HalfVector::from(vec![
71 f16::from_f32(1.0),
72 f16::from_f32(2.0),
73 f16::from_f32(3.0),
74 ]);
75 let f16_vec: Vec<f16> = vec.into();
76 assert_eq!(
77 f16_vec,
78 vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]
79 );
80 }
81
82 #[test]
83 fn test_to_vec() {
84 let vec = HalfVector::from_f32_slice(&[1.0, 2.0, 3.0]);
85 assert_eq!(
86 vec.to_vec(),
87 vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]
88 );
89 }
90
91 #[test]
92 fn test_as_slice() {
93 let vec = HalfVector::from_f32_slice(&[1.0, 2.0, 3.0]);
94 assert_eq!(
95 vec.as_slice(),
96 &[f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]
97 );
98 }
99}