use std::ops::{Deref, DerefMut};
#[derive(Debug, Clone, PartialEq, Default)]
pub struct Vector(pub Vec<f32>);
impl Vector {
#[must_use]
pub fn new(items: Vec<f32>) -> Self {
Self(items)
}
#[must_use]
pub fn into_inner(self) -> Vec<f32> {
self.0
}
#[must_use]
pub fn to_pg_text(&self) -> String {
let mut s = String::with_capacity(self.0.len() * 4 + 2);
s.push('[');
for (i, v) in self.0.iter().enumerate() {
if i > 0 {
s.push(',');
}
s.push_str(&v.to_string());
}
s.push(']');
s
}
}
impl Deref for Vector {
type Target = Vec<f32>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Vector {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl From<Vec<f32>> for Vector {
fn from(v: Vec<f32>) -> Self {
Self(v)
}
}
impl From<Vector> for Vec<f32> {
fn from(v: Vector) -> Self {
v.0
}
}
impl FromIterator<f32> for Vector {
fn from_iter<I: IntoIterator<Item = f32>>(iter: I) -> Self {
Self(iter.into_iter().collect())
}
}
impl serde::Serialize for Vector {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.0.serialize(serializer)
}
}
impl<'de> serde::Deserialize<'de> for Vector {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
Vec::<f32>::deserialize(deserializer).map(Self)
}
}
impl From<Vector> for crate::core::SqlValue {
fn from(v: Vector) -> Self {
crate::core::SqlValue::Vector(v.0)
}
}
#[cfg(feature = "postgres")]
fn encode_pgvector(v: &[f32], buf: &mut Vec<u8>) {
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let dim = v.len() as i16;
buf.extend_from_slice(&dim.to_be_bytes());
buf.extend_from_slice(&0i16.to_be_bytes()); for &x in v {
buf.extend_from_slice(&x.to_be_bytes());
}
}
#[cfg(feature = "postgres")]
fn decode_pgvector_binary(bytes: &[u8]) -> Result<Vec<f32>, sqlx::error::BoxDynError> {
if bytes.len() < 4 {
return Err("pgvector binary value too short (missing header)".into());
}
let dim = i16::from_be_bytes([bytes[0], bytes[1]]) as usize;
let body = &bytes[4..];
if body.len() != dim * 4 {
return Err(format!(
"pgvector binary value: header dim {dim} != {} float4 payload bytes",
body.len()
)
.into());
}
let mut out = Vec::with_capacity(dim);
for chunk in body.chunks_exact(4) {
out.push(f32::from_be_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Ok(out)
}
#[cfg(feature = "postgres")]
fn decode_pgvector_text(s: &str) -> Result<Vec<f32>, sqlx::error::BoxDynError> {
let inner = s.trim().trim_start_matches('[').trim_end_matches(']');
if inner.trim().is_empty() {
return Ok(Vec::new());
}
inner
.split(',')
.map(|p| p.trim().parse::<f32>().map_err(Into::into))
.collect()
}
#[cfg(feature = "postgres")]
impl sqlx::Type<sqlx::Postgres> for Vector {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_name("vector")
}
fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
use sqlx::TypeInfo as _;
ty.name().eq_ignore_ascii_case("vector")
}
}
#[cfg(feature = "postgres")]
impl sqlx::Encode<'_, sqlx::Postgres> for Vector {
fn encode_by_ref(
&self,
buf: &mut sqlx::postgres::PgArgumentBuffer,
) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
let mut bytes = Vec::with_capacity(4 + self.0.len() * 4);
encode_pgvector(&self.0, &mut bytes);
buf.extend_from_slice(&bytes);
Ok(sqlx::encode::IsNull::No)
}
}
#[cfg(feature = "postgres")]
impl sqlx::Decode<'_, sqlx::Postgres> for Vector {
fn decode(value: sqlx::postgres::PgValueRef<'_>) -> Result<Self, sqlx::error::BoxDynError> {
let inner = match value.format() {
sqlx::postgres::PgValueFormat::Binary => decode_pgvector_binary(value.as_bytes()?)?,
sqlx::postgres::PgValueFormat::Text => decode_pgvector_text(value.as_str()?)?,
};
Ok(Self(inner))
}
}
#[cfg(feature = "mysql")]
impl sqlx::Type<sqlx::MySql> for Vector {
fn type_info() -> sqlx::mysql::MySqlTypeInfo {
<Vec<u8> as sqlx::Type<sqlx::MySql>>::type_info()
}
}
#[cfg(feature = "mysql")]
impl sqlx::Decode<'_, sqlx::MySql> for Vector {
fn decode(_value: sqlx::mysql::MySqlValueRef<'_>) -> Result<Self, sqlx::error::BoxDynError> {
Err(
"`Vector` columns are PostgreSQL/pgvector-only; cannot decode on MySQL (issue #824)"
.into(),
)
}
}
#[cfg(feature = "sqlite")]
impl sqlx::Type<sqlx::Sqlite> for Vector {
fn type_info() -> sqlx::sqlite::SqliteTypeInfo {
<Vec<u8> as sqlx::Type<sqlx::Sqlite>>::type_info()
}
}
#[cfg(feature = "sqlite")]
impl sqlx::Decode<'_, sqlx::Sqlite> for Vector {
fn decode(_value: sqlx::sqlite::SqliteValueRef<'_>) -> Result<Self, sqlx::error::BoxDynError> {
Err(
"`Vector` columns are PostgreSQL/pgvector-only; cannot decode on SQLite (issue #824)"
.into(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deref_from_vec_and_text() {
let v: Vector = vec![1.0, 2.0, 3.0].into();
assert_eq!(v.len(), 3);
assert_eq!(v.to_pg_text(), "[1,2,3]");
assert_eq!(v.into_inner(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn serde_round_trips_as_plain_array() {
let v = Vector::new(vec![0.5, 1.5]);
let json = serde_json::to_string(&v).unwrap();
assert_eq!(json, "[0.5,1.5]");
let back: Vector = serde_json::from_str(&json).unwrap();
assert_eq!(back, v);
}
#[test]
fn into_sqlvalue_vector() {
let sv: crate::core::SqlValue = Vector(vec![1.0, 2.0]).into();
match sv {
crate::core::SqlValue::Vector(items) => assert_eq!(items, vec![1.0, 2.0]),
_ => panic!("expected SqlValue::Vector"),
}
}
#[cfg(feature = "postgres")]
#[test]
fn pgvector_binary_round_trip() {
let mut buf = Vec::new();
encode_pgvector(&[1.0, -2.5, 3.0], &mut buf);
assert_eq!(decode_pgvector_binary(&buf).unwrap(), vec![1.0, -2.5, 3.0]);
}
#[cfg(feature = "postgres")]
#[test]
fn pgvector_text_parse() {
assert_eq!(
decode_pgvector_text("[1,2,3]").unwrap(),
vec![1.0, 2.0, 3.0]
);
assert_eq!(decode_pgvector_text("[]").unwrap(), Vec::<f32>::new());
}
}