1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
use std::ops::{Deref, DerefMut};

use bytes::Buf;
use snafu::ensure;

use crate::codec;
use crate::descriptors::TypePos;
use crate::errors::{self, DecodeError};
use crate::queryable::DescriptorMismatch;
use crate::queryable::{Decoder, DescriptorContext, Queryable};
use crate::serialization::decode::queryable::scalars::check_scalar;

/// A structure that represents `ext::pgvector::vector`
#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(feature = "with-serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Vector(pub Vec<f32>);

impl Deref for Vector {
    type Target = Vec<f32>;
    fn deref(&self) -> &Vec<f32> {
        &self.0
    }
}

impl DerefMut for Vector {
    fn deref_mut(&mut self) -> &mut Vec<f32> {
        &mut self.0
    }
}

impl Queryable for Vector {
    fn decode(_decoder: &Decoder, mut buf: &[u8]) -> Result<Self, DecodeError> {
        ensure!(buf.remaining() >= 4, errors::Underflow);
        let length = buf.get_u16() as usize;
        let _reserved = buf.get_u16();
        ensure!(buf.remaining() >= length * 4, errors::Underflow);
        let vec = (0..length).map(|_| f32::from_bits(buf.get_u32())).collect();
        Ok(Vector(vec))
    }

    fn check_descriptor(
        ctx: &DescriptorContext,
        type_pos: TypePos,
    ) -> Result<(), DescriptorMismatch> {
        check_scalar(
            ctx,
            type_pos,
            codec::PGVECTOR_VECTOR,
            "ext::pgvector::vector",
        )
    }
}