1use std::fmt;
16
17use serde::{Deserialize, Serialize};
18
19use crate::error::{VectorError, VectorResult};
20
21#[derive(Clone, PartialEq)]
46pub struct Embedding {
47 inner: pgvector::Vector,
48}
49
50impl Embedding {
51 pub fn new(dimensions: Vec<f32>) -> Self {
53 Self {
54 inner: pgvector::Vector::from(dimensions),
55 }
56 }
57
58 pub fn from_slice(slice: &[f32]) -> Self {
60 Self {
61 inner: pgvector::Vector::from(slice.to_vec()),
62 }
63 }
64
65 pub fn zeros(dimensions: usize) -> Self {
67 Self::new(vec![0.0; dimensions])
68 }
69
70 pub fn try_new(dimensions: Vec<f32>) -> VectorResult<Self> {
76 if dimensions.is_empty() {
77 return Err(VectorError::EmptyVector);
78 }
79 Ok(Self::new(dimensions))
80 }
81
82 pub fn validate_dimensions(&self, expected: usize) -> VectorResult<()> {
88 let actual = self.len();
89 if actual != expected {
90 return Err(VectorError::dimension_mismatch(expected, actual));
91 }
92 Ok(())
93 }
94
95 pub fn len(&self) -> usize {
97 self.as_slice().len()
98 }
99
100 pub fn is_empty(&self) -> bool {
102 self.as_slice().is_empty()
103 }
104
105 pub fn as_slice(&self) -> &[f32] {
107 self.inner.as_slice()
108 }
109
110 pub fn to_vec(&self) -> Vec<f32> {
112 self.as_slice().to_vec()
113 }
114
115 pub fn into_inner(self) -> pgvector::Vector {
117 self.inner
118 }
119
120 pub fn inner(&self) -> &pgvector::Vector {
122 &self.inner
123 }
124
125 pub fn l2_norm(&self) -> f32 {
127 self.as_slice().iter().map(|x| x * x).sum::<f32>().sqrt()
128 }
129
130 pub fn normalize(&self) -> Option<Self> {
134 let norm = self.l2_norm();
135 if norm == 0.0 {
136 return None;
137 }
138 let normalized: Vec<f32> = self.as_slice().iter().map(|x| x / norm).collect();
139 Some(Self::new(normalized))
140 }
141
142 pub fn dot_product(&self, other: &Self) -> VectorResult<f32> {
148 if self.len() != other.len() {
149 return Err(VectorError::dimension_mismatch(self.len(), other.len()));
150 }
151 Ok(self
152 .as_slice()
153 .iter()
154 .zip(other.as_slice().iter())
155 .map(|(a, b)| a * b)
156 .sum())
157 }
158
159 pub fn cosine_similarity(&self, other: &Self) -> VectorResult<f32> {
167 let dot = self.dot_product(other)?;
168 let norm_a = self.l2_norm();
169 let norm_b = other.l2_norm();
170
171 if norm_a == 0.0 || norm_b == 0.0 {
172 return Ok(0.0);
173 }
174
175 Ok(dot / (norm_a * norm_b))
176 }
177
178 pub fn l2_distance(&self, other: &Self) -> VectorResult<f32> {
184 if self.len() != other.len() {
185 return Err(VectorError::dimension_mismatch(self.len(), other.len()));
186 }
187 Ok(self
188 .as_slice()
189 .iter()
190 .zip(other.as_slice().iter())
191 .map(|(a, b)| (a - b) * (a - b))
192 .sum::<f32>()
193 .sqrt())
194 }
195
196 pub fn to_sql_literal(&self) -> String {
200 let nums: Vec<String> = self.as_slice().iter().map(|f| f.to_string()).collect();
201 format!("'[{}]'::vector", nums.join(","))
202 }
203
204 pub fn to_sql_literal_typed(&self) -> String {
208 let nums: Vec<String> = self.as_slice().iter().map(|f| f.to_string()).collect();
209 format!("'[{}]'::vector({})", nums.join(","), self.len())
210 }
211}
212
213impl fmt::Debug for Embedding {
214 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215 write!(f, "Embedding({:?})", self.as_slice())
216 }
217}
218
219impl fmt::Display for Embedding {
220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221 let nums: Vec<String> = self.as_slice().iter().map(|x| format!("{x:.4}")).collect();
222 write!(f, "[{}]", nums.join(", "))
223 }
224}
225
226impl From<Vec<f32>> for Embedding {
227 fn from(v: Vec<f32>) -> Self {
228 Self::new(v)
229 }
230}
231
232impl From<&[f32]> for Embedding {
233 fn from(s: &[f32]) -> Self {
234 Self::from_slice(s)
235 }
236}
237
238impl From<pgvector::Vector> for Embedding {
239 fn from(v: pgvector::Vector) -> Self {
240 Self { inner: v }
241 }
242}
243
244impl From<Embedding> for pgvector::Vector {
245 fn from(e: Embedding) -> Self {
246 e.inner
247 }
248}
249
250impl From<Embedding> for Vec<f32> {
251 fn from(e: Embedding) -> Self {
252 e.to_vec()
253 }
254}
255
256impl Serialize for Embedding {
257 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
258 where
259 S: serde::Serializer,
260 {
261 self.as_slice().serialize(serializer)
262 }
263}
264
265impl<'de> Deserialize<'de> for Embedding {
266 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267 where
268 D: serde::Deserializer<'de>,
269 {
270 let v = Vec::<f32>::deserialize(deserializer)?;
271 Ok(Self::new(v))
272 }
273}
274
275#[derive(Clone, PartialEq)]
296pub struct SparseEmbedding {
297 inner: pgvector::SparseVector,
298}
299
300impl SparseEmbedding {
301 pub fn from_dense(values: Vec<f32>) -> Self {
305 Self {
306 inner: pgvector::SparseVector::from_dense(&values),
307 }
308 }
309
310 pub fn from_parts(indices: &[i32], values: &[f32], dimensions: usize) -> VectorResult<Self> {
317 if indices.len() != values.len() {
318 return Err(VectorError::InvalidDimensions(format!(
319 "indices length ({}) must match values length ({})",
320 indices.len(),
321 values.len()
322 )));
323 }
324
325 for &idx in indices {
326 if idx < 0 || idx as usize >= dimensions {
327 return Err(VectorError::InvalidDimensions(format!(
328 "index {idx} out of bounds for {dimensions} dimensions"
329 )));
330 }
331 }
332
333 let mut dense = vec![0.0f32; dimensions];
335 for (&idx, &val) in indices.iter().zip(values.iter()) {
336 dense[idx as usize] = val;
337 }
338 Ok(Self::from_dense(dense))
339 }
340
341 pub fn dimensions(&self) -> i32 {
343 self.inner.dimensions()
344 }
345
346 pub fn indices(&self) -> &[i32] {
348 self.inner.indices()
349 }
350
351 pub fn values(&self) -> &[f32] {
353 self.inner.values()
354 }
355
356 pub fn nnz(&self) -> usize {
358 self.inner.indices().len()
359 }
360
361 pub fn to_dense(&self) -> Vec<f32> {
363 self.inner.to_vec()
364 }
365
366 pub fn into_inner(self) -> pgvector::SparseVector {
368 self.inner
369 }
370
371 pub fn inner(&self) -> &pgvector::SparseVector {
373 &self.inner
374 }
375}
376
377impl fmt::Debug for SparseEmbedding {
378 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379 write!(
380 f,
381 "SparseEmbedding(dims={}, nnz={})",
382 self.dimensions(),
383 self.nnz()
384 )
385 }
386}
387
388impl fmt::Display for SparseEmbedding {
389 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
390 write!(f, "sparse[dims={}, nnz={}]", self.dimensions(), self.nnz())
391 }
392}
393
394impl From<pgvector::SparseVector> for SparseEmbedding {
395 fn from(v: pgvector::SparseVector) -> Self {
396 Self { inner: v }
397 }
398}
399
400impl From<SparseEmbedding> for pgvector::SparseVector {
401 fn from(e: SparseEmbedding) -> Self {
402 e.inner
403 }
404}
405
406impl Serialize for SparseEmbedding {
407 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
408 where
409 S: serde::Serializer,
410 {
411 self.to_dense().serialize(serializer)
413 }
414}
415
416impl<'de> Deserialize<'de> for SparseEmbedding {
417 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
418 where
419 D: serde::Deserializer<'de>,
420 {
421 let v = Vec::<f32>::deserialize(deserializer)?;
422 Ok(Self::from_dense(v))
423 }
424}
425
426#[derive(Clone, PartialEq)]
444pub struct BinaryVector {
445 inner: pgvector::Bit,
446}
447
448impl BinaryVector {
449 pub fn from_bools(bits: &[bool]) -> Self {
451 Self {
452 inner: pgvector::Bit::new(bits),
453 }
454 }
455
456 pub fn from_bytes(bytes: &[u8]) -> Self {
460 Self {
461 inner: pgvector::Bit::from_bytes(bytes),
462 }
463 }
464
465 pub fn len(&self) -> usize {
467 self.inner.len()
468 }
469
470 pub fn is_empty(&self) -> bool {
472 self.inner.len() == 0
473 }
474
475 pub fn as_bytes(&self) -> &[u8] {
477 self.inner.as_bytes()
478 }
479
480 pub fn into_inner(self) -> pgvector::Bit {
482 self.inner
483 }
484
485 pub fn inner(&self) -> &pgvector::Bit {
487 &self.inner
488 }
489}
490
491impl fmt::Debug for BinaryVector {
492 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
493 write!(f, "BinaryVector(len={})", self.len())
494 }
495}
496
497impl fmt::Display for BinaryVector {
498 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
499 write!(f, "bit[{}]", self.len())
500 }
501}
502
503impl From<pgvector::Bit> for BinaryVector {
504 fn from(v: pgvector::Bit) -> Self {
505 Self { inner: v }
506 }
507}
508
509impl From<BinaryVector> for pgvector::Bit {
510 fn from(e: BinaryVector) -> Self {
511 e.inner
512 }
513}
514
515#[cfg(feature = "halfvec")]
534#[derive(Clone, PartialEq)]
535pub struct HalfEmbedding {
536 inner: pgvector::HalfVector,
537}
538
539#[cfg(feature = "halfvec")]
540impl HalfEmbedding {
541 pub fn from_f32_slice(values: &[f32]) -> Self {
545 Self {
546 inner: pgvector::HalfVector::from_f32_slice(values),
547 }
548 }
549
550 pub fn len(&self) -> usize {
552 self.as_slice().len()
553 }
554
555 pub fn is_empty(&self) -> bool {
557 self.as_slice().is_empty()
558 }
559
560 pub fn as_slice(&self) -> &[half::f16] {
562 self.inner.as_slice()
563 }
564
565 pub fn into_inner(self) -> pgvector::HalfVector {
567 self.inner
568 }
569
570 pub fn inner(&self) -> &pgvector::HalfVector {
572 &self.inner
573 }
574}
575
576#[cfg(feature = "halfvec")]
577impl fmt::Debug for HalfEmbedding {
578 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
579 write!(f, "HalfEmbedding(len={})", self.len())
580 }
581}
582
583#[cfg(feature = "halfvec")]
584impl fmt::Display for HalfEmbedding {
585 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
586 write!(f, "halfvec[{}]", self.len())
587 }
588}
589
590#[cfg(feature = "halfvec")]
591impl From<pgvector::HalfVector> for HalfEmbedding {
592 fn from(v: pgvector::HalfVector) -> Self {
593 Self { inner: v }
594 }
595}
596
597#[cfg(feature = "halfvec")]
598impl From<HalfEmbedding> for pgvector::HalfVector {
599 fn from(e: HalfEmbedding) -> Self {
600 e.inner
601 }
602}
603
604#[cfg(test)]
609mod tests {
610 use super::*;
611
612 #[test]
613 fn test_embedding_new() {
614 let embedding = Embedding::new(vec![0.1, 0.2, 0.3]);
615 assert_eq!(embedding.len(), 3);
616 assert!(!embedding.is_empty());
617 }
618
619 #[test]
620 fn test_embedding_from_slice() {
621 let embedding = Embedding::from_slice(&[1.0, 2.0, 3.0, 4.0]);
622 assert_eq!(embedding.len(), 4);
623 assert_eq!(embedding.as_slice()[0], 1.0);
624 }
625
626 #[test]
627 fn test_embedding_zeros() {
628 let embedding = Embedding::zeros(5);
629 assert_eq!(embedding.len(), 5);
630 assert!(embedding.as_slice().iter().all(|&x| x == 0.0));
631 }
632
633 #[test]
634 fn test_embedding_try_new_empty() {
635 let result = Embedding::try_new(vec![]);
636 assert!(result.is_err());
637 }
638
639 #[test]
640 fn test_embedding_try_new_valid() {
641 let result = Embedding::try_new(vec![1.0, 2.0]);
642 assert!(result.is_ok());
643 assert_eq!(result.unwrap().len(), 2);
644 }
645
646 #[test]
647 fn test_embedding_validate_dimensions() {
648 let embedding = Embedding::new(vec![1.0, 2.0, 3.0]);
649 assert!(embedding.validate_dimensions(3).is_ok());
650 assert!(embedding.validate_dimensions(5).is_err());
651 }
652
653 #[test]
654 fn test_embedding_l2_norm() {
655 let embedding = Embedding::new(vec![3.0, 4.0]);
656 let norm = embedding.l2_norm();
657 assert!((norm - 5.0).abs() < 1e-6);
658 }
659
660 #[test]
661 fn test_embedding_normalize() {
662 let embedding = Embedding::new(vec![3.0, 4.0]);
663 let normalized = embedding.normalize().unwrap();
664 let norm = normalized.l2_norm();
665 assert!((norm - 1.0).abs() < 1e-6);
666 }
667
668 #[test]
669 fn test_embedding_normalize_zero() {
670 let embedding = Embedding::zeros(3);
671 assert!(embedding.normalize().is_none());
672 }
673
674 #[test]
675 fn test_embedding_dot_product() {
676 let a = Embedding::new(vec![1.0, 2.0, 3.0]);
677 let b = Embedding::new(vec![4.0, 5.0, 6.0]);
678 let dot = a.dot_product(&b).unwrap();
679 assert!((dot - 32.0).abs() < 1e-6);
680 }
681
682 #[test]
683 fn test_embedding_dot_product_dimension_mismatch() {
684 let a = Embedding::new(vec![1.0, 2.0]);
685 let b = Embedding::new(vec![1.0, 2.0, 3.0]);
686 assert!(a.dot_product(&b).is_err());
687 }
688
689 #[test]
690 fn test_embedding_cosine_similarity() {
691 let a = Embedding::new(vec![1.0, 0.0]);
692 let b = Embedding::new(vec![1.0, 0.0]);
693 let sim = a.cosine_similarity(&b).unwrap();
694 assert!((sim - 1.0).abs() < 1e-6);
695 }
696
697 #[test]
698 fn test_embedding_cosine_similarity_orthogonal() {
699 let a = Embedding::new(vec![1.0, 0.0]);
700 let b = Embedding::new(vec![0.0, 1.0]);
701 let sim = a.cosine_similarity(&b).unwrap();
702 assert!(sim.abs() < 1e-6);
703 }
704
705 #[test]
706 fn test_embedding_l2_distance() {
707 let a = Embedding::new(vec![0.0, 0.0]);
708 let b = Embedding::new(vec![3.0, 4.0]);
709 let dist = a.l2_distance(&b).unwrap();
710 assert!((dist - 5.0).abs() < 1e-6);
711 }
712
713 #[test]
714 fn test_embedding_to_sql_literal() {
715 let embedding = Embedding::new(vec![0.1, 0.2, 0.3]);
716 let sql = embedding.to_sql_literal();
717 assert!(sql.contains("::vector"));
718 assert!(sql.contains("0.1"));
719 }
720
721 #[test]
722 fn test_embedding_to_sql_literal_typed() {
723 let embedding = Embedding::new(vec![0.1, 0.2, 0.3]);
724 let sql = embedding.to_sql_literal_typed();
725 assert!(sql.contains("::vector(3)"));
726 }
727
728 #[test]
729 fn test_embedding_display() {
730 let embedding = Embedding::new(vec![0.1, 0.2]);
731 let display = format!("{embedding}");
732 assert!(display.contains("0.1000"));
733 }
734
735 #[test]
736 fn test_embedding_from_vec() {
737 let embedding: Embedding = vec![1.0, 2.0, 3.0].into();
738 assert_eq!(embedding.len(), 3);
739 }
740
741 #[test]
742 fn test_embedding_to_vec() {
743 let embedding = Embedding::new(vec![1.0, 2.0, 3.0]);
744 let v: Vec<f32> = embedding.into();
745 assert_eq!(v, vec![1.0, 2.0, 3.0]);
746 }
747
748 #[test]
749 fn test_embedding_serde_roundtrip() {
750 let embedding = Embedding::new(vec![0.1, 0.2, 0.3]);
751 let json = serde_json::to_string(&embedding).unwrap();
752 let deserialized: Embedding = serde_json::from_str(&json).unwrap();
753 assert_eq!(embedding, deserialized);
754 }
755
756 #[test]
757 fn test_embedding_pgvector_roundtrip() {
758 let embedding = Embedding::new(vec![1.0, 2.0, 3.0]);
759 let pgvec: pgvector::Vector = embedding.clone().into();
760 let back: Embedding = pgvec.into();
761 assert_eq!(embedding, back);
762 }
763
764 #[test]
765 fn test_sparse_embedding_from_dense() {
766 let sparse = SparseEmbedding::from_dense(vec![1.0, 0.0, 2.0, 0.0, 3.0]);
767 assert_eq!(sparse.dimensions(), 5);
768 assert_eq!(sparse.nnz(), 3);
769 }
770
771 #[test]
772 fn test_sparse_embedding_from_parts() {
773 let sparse = SparseEmbedding::from_parts(&[0, 2, 4], &[1.0, 2.0, 3.0], 5).unwrap();
774 assert_eq!(sparse.dimensions(), 5);
775 assert_eq!(sparse.nnz(), 3);
776 }
777
778 #[test]
779 fn test_sparse_embedding_from_parts_mismatched() {
780 let result = SparseEmbedding::from_parts(&[0, 2], &[1.0], 5);
781 assert!(result.is_err());
782 }
783
784 #[test]
785 fn test_sparse_embedding_from_parts_out_of_bounds() {
786 let result = SparseEmbedding::from_parts(&[10], &[1.0], 5);
787 assert!(result.is_err());
788 }
789
790 #[test]
791 fn test_sparse_embedding_to_dense() {
792 let sparse = SparseEmbedding::from_dense(vec![1.0, 0.0, 2.0]);
793 let dense = sparse.to_dense();
794 assert_eq!(dense, vec![1.0, 0.0, 2.0]);
795 }
796
797 #[test]
798 fn test_sparse_embedding_serde_roundtrip() {
799 let sparse = SparseEmbedding::from_dense(vec![1.0, 0.0, 2.0]);
800 let json = serde_json::to_string(&sparse).unwrap();
801 let deserialized: SparseEmbedding = serde_json::from_str(&json).unwrap();
802 assert_eq!(sparse.to_dense(), deserialized.to_dense());
803 }
804
805 #[test]
806 fn test_binary_vector_from_bools() {
807 let bv = BinaryVector::from_bools(&[true, false, true, true]);
808 assert_eq!(bv.len(), 4);
809 assert!(!bv.is_empty());
810 }
811
812 #[test]
813 fn test_binary_vector_from_bytes() {
814 let bv = BinaryVector::from_bytes(&[0b10110000]);
815 assert_eq!(bv.len(), 8);
816 }
817
818 #[test]
819 fn test_binary_vector_display() {
820 let bv = BinaryVector::from_bools(&[true, false, true]);
821 assert!(format!("{bv}").contains("3"));
822 }
823
824 #[test]
825 fn test_binary_vector_pgvector_roundtrip() {
826 let bv = BinaryVector::from_bools(&[true, false, true, false]);
827 let inner: pgvector::Bit = bv.clone().into();
828 let back: BinaryVector = inner.into();
829 assert_eq!(bv, back);
830 }
831}