use pgrx::prelude::*;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::str::FromStr;
#[derive(Debug, Clone, thiserror::Error)]
pub enum SparseError {
#[error("Length mismatch: indices and values must have the same length")]
LengthMismatch,
#[error("Index out of bounds: index {0} >= dimension {1}")]
IndexOutOfBounds(u32, u32),
#[error("Parse error: {0}")]
ParseError(String),
#[error("Invalid format: expected '{{idx:val, ...}}'")]
InvalidFormat,
#[error("Empty sparse vector")]
EmptyVector,
}
#[derive(Debug, Clone, Serialize, Deserialize, PostgresType)]
#[inoutfuncs]
pub struct SparseVec {
indices: Vec<u32>,
values: Vec<f32>,
dim: u32,
}
impl SparseVec {
pub fn new(indices: Vec<u32>, values: Vec<f32>, dim: u32) -> Result<Self, SparseError> {
if indices.len() != values.len() {
return Err(SparseError::LengthMismatch);
}
if indices.is_empty() {
return Ok(Self {
indices: Vec::new(),
values: Vec::new(),
dim,
});
}
let mut pairs: Vec<_> = indices.into_iter().zip(values.into_iter()).collect();
pairs.sort_by_key(|(i, _)| *i);
pairs.dedup_by_key(|(i, _)| *i);
let (indices, values): (Vec<_>, Vec<_>) = pairs.into_iter().unzip();
if let Some(&max_idx) = indices.last() {
if max_idx >= dim {
return Err(SparseError::IndexOutOfBounds(max_idx, dim));
}
}
Ok(Self {
indices,
values,
dim,
})
}
#[inline]
pub fn nnz(&self) -> usize {
self.indices.len()
}
#[inline]
pub fn dim(&self) -> u32 {
self.dim
}
#[inline]
pub fn get(&self, index: u32) -> f32 {
match self.indices.binary_search(&index) {
Ok(pos) => self.values[pos],
Err(_) => 0.0,
}
}
pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
self.indices
.iter()
.copied()
.zip(self.values.iter().copied())
}
#[inline]
pub fn indices(&self) -> &[u32] {
&self.indices
}
#[inline]
pub fn values(&self) -> &[f32] {
&self.values
}
pub fn norm(&self) -> f32 {
self.values.iter().map(|&v| v * v).sum::<f32>().sqrt()
}
pub fn l1_norm(&self) -> f32 {
self.values.iter().map(|v| v.abs()).sum()
}
pub fn prune(&mut self, threshold: f32) {
let pairs: Vec<_> = self
.indices
.iter()
.copied()
.zip(self.values.iter().copied())
.filter(|(_, v)| v.abs() >= threshold)
.collect();
self.indices = pairs.iter().map(|(i, _)| *i).collect();
self.values = pairs.iter().map(|(_, v)| *v).collect();
}
pub fn top_k(&self, k: usize) -> Self {
if k >= self.nnz() {
return self.clone();
}
let mut indexed: Vec<_> = self
.indices
.iter()
.copied()
.zip(self.values.iter().copied())
.collect();
indexed.sort_by(|(_, a), (_, b)| b.abs().partial_cmp(&a.abs()).unwrap());
indexed.truncate(k);
indexed.sort_by_key(|(i, _)| *i);
let (indices, values): (Vec<_>, Vec<_>) = indexed.into_iter().unzip();
Self {
indices,
values,
dim: self.dim,
}
}
pub fn to_dense(&self) -> Vec<f32> {
let mut dense = vec![0.0; self.dim as usize];
for (idx, val) in self.iter() {
dense[idx as usize] = val;
}
dense
}
}
impl FromStr for SparseVec {
type Err = SparseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.trim();
if !s.starts_with('{') || !s.ends_with('}') {
return Err(SparseError::InvalidFormat);
}
let s = &s[1..s.len() - 1];
if s.trim().is_empty() {
return Ok(Self {
indices: Vec::new(),
values: Vec::new(),
dim: 0,
});
}
let mut indices = Vec::new();
let mut values = Vec::new();
let mut max_index = 0u32;
for pair in s.split(',') {
let parts: Vec<_> = pair.trim().split(':').collect();
if parts.len() != 2 {
return Err(SparseError::ParseError(format!(
"Invalid pair format: '{}'",
pair
)));
}
let idx: u32 = parts[0]
.trim()
.parse()
.map_err(|_| SparseError::ParseError(format!("Invalid index: '{}'", parts[0])))?;
let val: f32 = parts[1]
.trim()
.parse()
.map_err(|_| SparseError::ParseError(format!("Invalid value: '{}'", parts[1])))?;
indices.push(idx);
values.push(val);
max_index = max_index.max(idx);
}
Self::new(indices, values, max_index + 1)
}
}
impl fmt::Display for SparseVec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{{")?;
for (i, (idx, val)) in self.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}:{}", idx, val)?;
}
write!(f, "}}")
}
}
impl pgrx::InOutFuncs for SparseVec {
fn input(input: &core::ffi::CStr) -> Self {
let s = input.to_str().unwrap_or("");
s.parse().unwrap_or_else(|_| Self {
indices: Vec::new(),
values: Vec::new(),
dim: 0,
})
}
fn output(&self, buffer: &mut pgrx::StringInfo) {
buffer.push_str(&format!("{}", self));
}
}
#[cfg(feature = "pg_test")]
#[pg_schema]
mod tests {
use super::*;
#[test]
fn test_sparse_vec_creation() {
let sparse = SparseVec::new(vec![0, 2, 5], vec![1.0, 2.0, 3.0], 10).unwrap();
assert_eq!(sparse.nnz(), 3);
assert_eq!(sparse.dim(), 10);
assert_eq!(sparse.get(0), 1.0);
assert_eq!(sparse.get(2), 2.0);
assert_eq!(sparse.get(5), 3.0);
assert_eq!(sparse.get(1), 0.0);
}
#[test]
fn test_sparse_vec_sorted() {
let sparse = SparseVec::new(vec![5, 0, 2], vec![3.0, 1.0, 2.0], 10).unwrap();
assert_eq!(sparse.indices(), &[0, 2, 5]);
assert_eq!(sparse.values(), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_sparse_vec_dedup() {
let sparse = SparseVec::new(vec![0, 2, 2, 5], vec![1.0, 2.0, 3.0, 4.0], 10).unwrap();
assert_eq!(sparse.nnz(), 3);
assert_eq!(sparse.get(2), 3.0); }
#[test]
fn test_sparse_vec_norm() {
let sparse = SparseVec::new(vec![0, 1, 2], vec![3.0, 4.0, 0.0], 10).unwrap();
assert_eq!(sparse.norm(), 5.0); }
#[test]
fn test_sparse_vec_parse() {
let sparse: SparseVec = "{1:0.5, 2:0.3, 5:0.8}".parse().unwrap();
assert_eq!(sparse.nnz(), 3);
assert_eq!(sparse.get(1), 0.5);
assert_eq!(sparse.get(2), 0.3);
assert_eq!(sparse.get(5), 0.8);
}
#[test]
fn test_sparse_vec_display() {
let sparse = SparseVec::new(vec![1, 2, 5], vec![0.5, 0.3, 0.8], 10).unwrap();
let s = format!("{}", sparse);
assert_eq!(s, "{1:0.5, 2:0.3, 5:0.8}");
}
#[test]
fn test_sparse_vec_prune() {
let mut sparse = SparseVec::new(vec![0, 1, 2, 3], vec![0.1, 0.5, 0.05, 0.8], 10).unwrap();
sparse.prune(0.2);
assert_eq!(sparse.nnz(), 2);
assert_eq!(sparse.get(1), 0.5);
assert_eq!(sparse.get(3), 0.8);
}
#[test]
fn test_sparse_vec_top_k() {
let sparse = SparseVec::new(vec![0, 1, 2, 3], vec![0.1, 0.5, 0.05, 0.8], 10).unwrap();
let top2 = sparse.top_k(2);
assert_eq!(top2.nnz(), 2);
assert!(top2.indices().contains(&1));
assert!(top2.indices().contains(&3));
}
#[pg_test]
fn pg_test_sparse_vec_type() {
let sparse = SparseVec::new(vec![0, 2, 5], vec![1.0, 2.0, 3.0], 10).unwrap();
assert_eq!(sparse.nnz(), 3);
}
}