Skip to main content

qdrant_edge/shard/operations/
point_ops.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt::{Debug, Formatter};
3use std::hash::{Hash, Hasher};
4use std::mem;
5
6use crate::common::validation::validate_multi_vector;
7use itertools::Itertools as _;
8use ordered_float::OrderedFloat;
9use schemars::JsonSchema;
10use crate::segment::common::operation_error::OperationError;
11use crate::segment::common::utils::unordered_hash_unique;
12use crate::segment::data_types::named_vectors::NamedVectors;
13use crate::segment::data_types::segment_record::SegmentRecord;
14use crate::segment::data_types::vectors::{
15    BatchVectorStructInternal, DEFAULT_VECTOR_NAME, DenseVector, MultiDenseVector,
16    MultiDenseVectorInternal, VectorInternal, VectorStructInternal,
17};
18use crate::segment::types::{Filter, Payload, PointIdType, VectorNameBuf};
19use serde::{Deserialize, Serialize};
20use crate::sparse::common::types::{DimId, DimWeight};
21use strum::{EnumDiscriminants, EnumIter};
22use validator::{Validate, ValidationErrors};
23
24/// Defines the mode of the upsert operation
25///
26/// * `Upsert` - default mode, insert new points, update existing points
27/// * `InsertOnly` - only insert new points, do not update existing points
28/// * `UpdateOnly` - only update existing points, do not insert new points
29#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize, Hash)]
30#[serde(rename_all = "snake_case")]
31pub enum UpdateMode {
32    // Default mode - insert new points, update existing points
33    #[default]
34    Upsert,
35    // Only insert new points, do not update existing points
36    InsertOnly,
37    // Only update existing points, do not insert new points
38    UpdateOnly,
39}
40
41#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, JsonSchema, Validate, Hash)]
42#[serde(rename_all = "snake_case")]
43pub struct PointIdsList {
44    pub points: Vec<PointIdType>,
45    #[cfg(feature = "api")]
46    #[serde(default, skip_serializing_if = "Option::is_none")]
47    pub shard_key: Option<api::rest::ShardKeySelector>,
48}
49
50impl From<Vec<PointIdType>> for PointIdsList {
51    fn from(points: Vec<PointIdType>) -> Self {
52        Self {
53            points,
54            #[cfg(feature = "api")]
55            shard_key: None,
56        }
57    }
58}
59
60// General idea of having an extra layer of data structures after REST and gRPC
61// is to ensure that all vectors are inferenced and validated before they are persisted.
62//
63// This separation allows to have a single point, enforced by the type system,
64// where all Documents and other inference-able objects are resolved into raw vectors.
65//
66// Separation between VectorStructPersisted and VectorStructInternal is only needed
67// for legacy reasons, as the previous implementations wrote VectorStruct to WAL,
68// so we need an ability to read it back. VectorStructPersisted reproduces the same
69// structure as VectorStruct had in the previous versions.
70//
71//
72//        gRPC              REST API           ┌───┐              WAL
73//          │                  │               │ I │               ▲
74//          │                  │               │ n │               │
75//          │                  │               │ f │               │
76//  ┌───────▼───────┐    ┌─────▼──────┐        │ e │     ┌─────────┴───────────┐
77//  │ grpc::Vectors ├───►│VectorStruct├───────►│ r ├────►│VectorStructPersisted├─────┐
78//  └───────────────┘    └────────────┘        │ e │     └─────────────────────┘     │
79//                        Vectors              │ n │      Only Vectors               │
80//                        + Documents          │ c │                                 │
81//                        + Images             │ e │                                 │
82//                        + Other inference    └───┘                                 │
83//                        Implement JsonSchema                                       │
84//                                                       ┌─────────────────────┐     │
85//                                                       │                     ◄─────┘
86//                                                       │   Storage           │
87//                                                       │                     │
88//                        REST API Response              └────────┬────────────┘
89//                             ▲                                  │
90//                             │                                  │
91//                      ┌──────┴──────────────┐         ┌─────────▼───────────┐
92//                      │ VectorStructOutput  ◄───┬─────┤VectorStructInternal │
93//                      └─────────────────────┘   │     └─────────────────────┘
94//                       Only Vectors             │      Only Vectors
95//                       Implement JsonSchema     │      Optimized for search
96//                                                │
97//                                                │
98//                      ┌─────────────────────┐   │
99//                      │ grpc::VectorsOutput ◄───┘
100//                      └───────────┬─────────┘
101//                                  │
102//                                  ▼
103//                              gPRC Response
104
105#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, EnumDiscriminants, Hash)]
106#[strum_discriminants(derive(EnumIter))]
107#[serde(rename_all = "snake_case")]
108pub enum PointOperations {
109    /// Insert or update points
110    UpsertPoints(PointInsertOperationsInternal),
111    /// Insert points, or update existing points if condition matches
112    UpsertPointsConditional(ConditionalInsertOperationInternal),
113    /// Delete point if exists
114    DeletePoints { ids: Vec<PointIdType> },
115    /// Delete points by given filter criteria
116    DeletePointsByFilter(Filter),
117    /// Points Sync
118    SyncPoints(PointSyncOperation),
119}
120
121impl PointOperations {
122    pub fn point_ids(&self) -> Option<Vec<PointIdType>> {
123        match self {
124            Self::UpsertPoints(op) => Some(op.point_ids()),
125            Self::UpsertPointsConditional(op) => Some(op.points_op.point_ids()),
126            Self::DeletePoints { ids } => Some(ids.clone()),
127            Self::DeletePointsByFilter(_) => None,
128            Self::SyncPoints(op) => Some(op.points.iter().map(|point| point.id).collect()),
129        }
130    }
131
132    pub fn retain_point_ids<F>(&mut self, filter: F)
133    where
134        F: Fn(&PointIdType) -> bool,
135    {
136        match self {
137            Self::UpsertPoints(op) => op.retain_point_ids(filter),
138            Self::UpsertPointsConditional(op) => {
139                op.points_op.retain_point_ids(filter);
140            }
141            Self::DeletePoints { ids } => ids.retain(filter),
142            Self::DeletePointsByFilter(_) => (),
143            Self::SyncPoints(op) => op.points.retain(|point| filter(&point.id)),
144        }
145    }
146}
147
148#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, EnumDiscriminants, Hash)]
149#[strum_discriminants(derive(EnumIter))]
150#[serde(rename_all = "snake_case")]
151pub enum PointInsertOperationsInternal {
152    /// Inset points from a batch.
153    #[serde(rename = "batch")]
154    PointsBatch(BatchPersisted),
155    /// Insert points from a list
156    #[serde(rename = "points")]
157    PointsList(Vec<PointStructPersisted>),
158}
159
160impl PointInsertOperationsInternal {
161    pub fn point_ids(&self) -> Vec<PointIdType> {
162        match self {
163            Self::PointsBatch(batch) => batch.ids.clone(),
164            Self::PointsList(points) => points.iter().map(|point| point.id).collect(),
165        }
166    }
167
168    pub fn into_point_vec(self) -> Vec<PointStructPersisted> {
169        match self {
170            PointInsertOperationsInternal::PointsBatch(batch) => {
171                let batch_vectors = BatchVectorStructInternal::from(batch.vectors);
172                let all_vectors = batch_vectors.into_all_vectors(batch.ids.len());
173                let vectors_iter = batch.ids.into_iter().zip(all_vectors);
174                match batch.payloads {
175                    None => vectors_iter
176                        .map(|(id, vectors)| PointStructPersisted {
177                            id,
178                            vector: VectorStructInternal::from(vectors).into(),
179                            payload: None,
180                        })
181                        .collect(),
182                    Some(payloads) => vectors_iter
183                        .zip(payloads)
184                        .map(|((id, vectors), payload)| PointStructPersisted {
185                            id,
186                            vector: VectorStructInternal::from(vectors).into(),
187                            payload,
188                        })
189                        .collect(),
190                }
191            }
192            PointInsertOperationsInternal::PointsList(points) => points,
193        }
194    }
195
196    pub fn retain_point_ids<F>(&mut self, filter: F)
197    where
198        F: Fn(&PointIdType) -> bool,
199    {
200        match self {
201            Self::PointsBatch(batch) => {
202                let mut retain_indices = HashSet::new();
203
204                retain_with_index(&mut batch.ids, |index, id| {
205                    if filter(id) {
206                        retain_indices.insert(index);
207                        true
208                    } else {
209                        false
210                    }
211                });
212
213                match &mut batch.vectors {
214                    BatchVectorStructPersisted::Single(vectors) => {
215                        retain_with_index(vectors, |index, _| retain_indices.contains(&index));
216                    }
217
218                    BatchVectorStructPersisted::MultiDense(vectors) => {
219                        retain_with_index(vectors, |index, _| retain_indices.contains(&index));
220                    }
221
222                    BatchVectorStructPersisted::Named(vectors) => {
223                        for (_, vectors) in vectors.iter_mut() {
224                            retain_with_index(vectors, |index, _| retain_indices.contains(&index));
225                        }
226                    }
227                }
228
229                if let Some(payload) = &mut batch.payloads {
230                    retain_with_index(payload, |index, _| retain_indices.contains(&index));
231                }
232            }
233
234            Self::PointsList(points) => points.retain(|point| filter(&point.id)),
235        }
236    }
237}
238
239impl From<BatchPersisted> for PointInsertOperationsInternal {
240    fn from(batch: BatchPersisted) -> Self {
241        PointInsertOperationsInternal::PointsBatch(batch)
242    }
243}
244
245impl From<Vec<PointStructPersisted>> for PointInsertOperationsInternal {
246    fn from(points: Vec<PointStructPersisted>) -> Self {
247        PointInsertOperationsInternal::PointsList(points)
248    }
249}
250
251#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Hash)]
252pub struct ConditionalInsertOperationInternal {
253    pub points_op: PointInsertOperationsInternal,
254    /// Condition to check, if the point already exists
255    pub condition: Filter,
256    /// Mode of the upsert operation. If None, defaults to Upsert behavior.
257    #[serde(default, skip_serializing_if = "Option::is_none")]
258    pub update_mode: Option<UpdateMode>,
259}
260
261#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Hash)]
262pub struct PointSyncOperation {
263    /// Minimal id of the sync range
264    pub from_id: Option<PointIdType>,
265    /// Maximal id og
266    pub to_id: Option<PointIdType>,
267    pub points: Vec<PointStructPersisted>,
268}
269
270#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Hash)]
271#[serde(rename_all = "snake_case")]
272pub struct BatchPersisted {
273    pub ids: Vec<PointIdType>,
274    pub vectors: BatchVectorStructPersisted,
275    pub payloads: Option<Vec<Option<Payload>>>,
276}
277
278#[cfg(feature = "api")]
279impl TryFrom<BatchPersisted> for Vec<api::grpc::qdrant::PointStruct> {
280    type Error = tonic::Status;
281
282    fn try_from(batch: BatchPersisted) -> Result<Self, Self::Error> {
283        let BatchPersisted {
284            ids,
285            vectors,
286            payloads,
287        } = batch;
288        let mut points = Vec::with_capacity(ids.len());
289        let batch_vectors = BatchVectorStructInternal::from(vectors);
290        let all_vectors = batch_vectors.into_all_vectors(ids.len());
291        for (i, p_id) in ids.into_iter().enumerate() {
292            let id = Some(p_id.into());
293            let vector = all_vectors.get(i).cloned();
294            let payload = payloads.as_ref().and_then(|payloads| {
295                payloads.get(i).map(|payload| match payload {
296                    None => HashMap::new(),
297                    Some(payload) => api::conversions::json::payload_to_proto(payload.clone()),
298                })
299            });
300            let vectors: Option<VectorStructInternal> = vector.map(|v| v.into());
301
302            let point = api::grpc::qdrant::PointStruct {
303                id,
304                vectors: vectors.map(api::grpc::qdrant::Vectors::from),
305                payload: payload.unwrap_or_default(),
306            };
307            points.push(point);
308        }
309
310        Ok(points)
311    }
312}
313
314#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
315#[serde(untagged, rename_all = "snake_case")]
316pub enum BatchVectorStructPersisted {
317    Single(Vec<DenseVector>),
318    MultiDense(Vec<MultiDenseVector>),
319    Named(HashMap<VectorNameBuf, Vec<VectorPersisted>>),
320}
321
322impl Hash for BatchVectorStructPersisted {
323    fn hash<H: Hasher>(&self, state: &mut H) {
324        mem::discriminant(self).hash(state);
325        match self {
326            BatchVectorStructPersisted::Single(dense) => {
327                for vector in dense {
328                    for v in vector {
329                        OrderedFloat(*v).hash(state);
330                    }
331                }
332            }
333            BatchVectorStructPersisted::MultiDense(multidense) => {
334                for vector in multidense {
335                    for v in vector {
336                        for element in v {
337                            OrderedFloat(*element).hash(state);
338                        }
339                    }
340                }
341            }
342            BatchVectorStructPersisted::Named(named) => unordered_hash_unique(state, named.iter()),
343        }
344    }
345}
346
347impl From<BatchVectorStructPersisted> for BatchVectorStructInternal {
348    fn from(value: BatchVectorStructPersisted) -> Self {
349        match value {
350            BatchVectorStructPersisted::Single(vector) => BatchVectorStructInternal::Single(vector),
351            BatchVectorStructPersisted::MultiDense(vectors) => {
352                BatchVectorStructInternal::MultiDense(
353                    vectors
354                        .into_iter()
355                        .map(MultiDenseVectorInternal::new_unchecked)
356                        .collect(),
357                )
358            }
359            BatchVectorStructPersisted::Named(vectors) => BatchVectorStructInternal::Named(
360                vectors
361                    .into_iter()
362                    .map(|(k, v)| (k, v.into_iter().map(VectorInternal::from).collect()))
363                    .collect(),
364            ),
365        }
366    }
367}
368
369#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Validate, Hash)]
370#[serde(rename_all = "snake_case")]
371pub struct PointStructPersisted {
372    /// Point id
373    pub id: PointIdType,
374    /// Vectors
375    pub vector: VectorStructPersisted,
376    /// Payload values (optional)
377    pub payload: Option<Payload>,
378}
379
380impl PointStructPersisted {
381    pub fn get_vectors(&self) -> NamedVectors<'_> {
382        let mut named_vectors = NamedVectors::default();
383        match &self.vector {
384            VectorStructPersisted::Single(vector) => named_vectors.insert(
385                DEFAULT_VECTOR_NAME.to_owned(),
386                VectorInternal::from(vector.clone()),
387            ),
388            VectorStructPersisted::MultiDense(vector) => named_vectors.insert(
389                DEFAULT_VECTOR_NAME.to_owned(),
390                VectorInternal::from(MultiDenseVectorInternal::new_unchecked(vector.clone())),
391            ),
392            VectorStructPersisted::Named(vectors) => {
393                for (name, vector) in vectors {
394                    named_vectors.insert(name.clone(), VectorInternal::from(vector.clone()));
395                }
396            }
397        }
398        named_vectors
399    }
400
401    pub fn is_equal_to(&self, segment_record: &SegmentRecord) -> bool {
402        let SegmentRecord {
403            id,
404            vectors,
405            payload,
406        } = segment_record;
407
408        if &self.id != id {
409            return false;
410        }
411
412        let self_vectors = self.get_vectors().into_owned_map();
413
414        if let Some(segment_vectors) = vectors {
415            if self_vectors.len() != segment_vectors.len() {
416                return false;
417            }
418            for (name, vec) in segment_vectors {
419                if self_vectors.get(name) != Some(vec) {
420                    return false;
421                }
422            }
423        } else if !self_vectors.is_empty() {
424            return false;
425        }
426
427        // Check if payloads are equal, empty and non-existent payloads are considered equal
428        let self_payload = self.payload.as_ref().filter(|p| !p.is_empty());
429        let segment_payload = payload.as_ref().filter(|p| !p.is_empty());
430        self_payload == segment_payload
431    }
432}
433
434#[cfg(feature = "api")]
435impl TryFrom<api::rest::schema::Record> for PointStructPersisted {
436    type Error = String;
437
438    fn try_from(record: api::rest::schema::Record) -> Result<Self, Self::Error> {
439        let api::rest::schema::Record {
440            id,
441            payload,
442            vector,
443            shard_key: _,
444            order_value: _,
445        } = record;
446
447        if vector.is_none() {
448            return Err("Vector is empty".to_string());
449        }
450
451        Ok(Self {
452            id,
453            payload,
454            vector: VectorStructPersisted::from(vector.unwrap()),
455        })
456    }
457}
458
459#[cfg(feature = "api")]
460impl TryFrom<PointStructPersisted> for api::grpc::qdrant::PointStruct {
461    type Error = tonic::Status;
462
463    fn try_from(value: PointStructPersisted) -> Result<Self, Self::Error> {
464        let PointStructPersisted {
465            id,
466            vector,
467            payload,
468        } = value;
469
470        let vectors_internal = VectorStructInternal::try_from(vector).map_err(|e| {
471            tonic::Status::invalid_argument(format!("Failed to convert vectors: {e}"))
472        })?;
473
474        let vectors = api::grpc::qdrant::Vectors::from(vectors_internal);
475        let converted_payload = match payload {
476            None => HashMap::new(),
477            Some(payload) => api::conversions::json::payload_to_proto(payload),
478        };
479
480        Ok(Self {
481            id: Some(id.into()),
482            vectors: Some(vectors),
483            payload: converted_payload,
484        })
485    }
486}
487
488/// Data structure for point vectors, as it is persisted in WAL
489#[derive(Clone, PartialEq, Deserialize, Serialize)]
490#[serde(untagged, rename_all = "snake_case")]
491pub enum VectorStructPersisted {
492    Single(DenseVector),
493    MultiDense(MultiDenseVector),
494    Named(HashMap<VectorNameBuf, VectorPersisted>),
495}
496
497impl std::hash::Hash for VectorStructPersisted {
498    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
499        mem::discriminant(self).hash(state);
500        match self {
501            VectorStructPersisted::Single(vec) => {
502                for v in vec {
503                    OrderedFloat(*v).hash(state);
504                }
505            }
506            VectorStructPersisted::MultiDense(multi_vec) => {
507                for vec in multi_vec {
508                    for v in vec {
509                        OrderedFloat(*v).hash(state);
510                    }
511                }
512            }
513            VectorStructPersisted::Named(map) => {
514                unordered_hash_unique(state, map.iter());
515            }
516        }
517    }
518}
519
520impl Debug for VectorStructPersisted {
521    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
522        match self {
523            VectorStructPersisted::Single(vector) => {
524                let first_elements = vector.iter().take(4).join(", ");
525                write!(f, "Single([{}, ... x {}])", first_elements, vector.len())
526            }
527            VectorStructPersisted::MultiDense(vector) => {
528                let first_vectors = vector
529                    .iter()
530                    .take(4)
531                    .map(|v| {
532                        let first_elements = v.iter().take(4).join(", ");
533                        format!("[{}, ... x {}]", first_elements, v.len())
534                    })
535                    .join(", ");
536                write!(f, "MultiDense([{}, ... x {})", first_vectors, vector.len())
537            }
538            VectorStructPersisted::Named(vectors) => write!(f, "Named(( ")
539                .and_then(|_| {
540                    for (name, vector) in vectors {
541                        write!(f, "{name}: {vector:?}, ")?;
542                    }
543                    Ok(())
544                })
545                .and_then(|_| write!(f, "))")),
546        }
547    }
548}
549
550impl VectorStructPersisted {
551    /// Check if this vector struct is empty.
552    pub fn is_empty(&self) -> bool {
553        match self {
554            VectorStructPersisted::Single(vector) => vector.is_empty(),
555            VectorStructPersisted::MultiDense(vector) => vector.is_empty(),
556            VectorStructPersisted::Named(vectors) => vectors.values().all(|v| match v {
557                VectorPersisted::Dense(vector) => vector.is_empty(),
558                VectorPersisted::Sparse(vector) => vector.indices.is_empty(),
559                VectorPersisted::MultiDense(vector) => vector.is_empty(),
560            }),
561        }
562    }
563}
564
565impl Validate for VectorStructPersisted {
566    fn validate(&self) -> Result<(), ValidationErrors> {
567        match self {
568            VectorStructPersisted::Single(_) => Ok(()),
569            VectorStructPersisted::MultiDense(v) => validate_multi_vector(v),
570            VectorStructPersisted::Named(v) => crate::common::validation::validate_iter(v.values()),
571        }
572    }
573}
574
575impl From<DenseVector> for VectorStructPersisted {
576    fn from(value: DenseVector) -> Self {
577        VectorStructPersisted::Single(value)
578    }
579}
580
581impl From<VectorStructInternal> for VectorStructPersisted {
582    fn from(value: VectorStructInternal) -> Self {
583        match value {
584            VectorStructInternal::Single(vector) => VectorStructPersisted::Single(vector),
585            VectorStructInternal::MultiDense(vector) => {
586                VectorStructPersisted::MultiDense(vector.into_multi_vectors())
587            }
588            VectorStructInternal::Named(vectors) => VectorStructPersisted::Named(
589                vectors
590                    .into_iter()
591                    .map(|(k, v)| (k, VectorPersisted::from(v)))
592                    .collect(),
593            ),
594        }
595    }
596}
597
598#[cfg(feature = "api")]
599impl From<api::rest::VectorStructOutput> for VectorStructPersisted {
600    fn from(value: api::rest::VectorStructOutput) -> Self {
601        match value {
602            api::rest::VectorStructOutput::Single(vector) => VectorStructPersisted::Single(vector),
603            api::rest::VectorStructOutput::MultiDense(vector) => {
604                VectorStructPersisted::MultiDense(vector)
605            }
606            api::rest::VectorStructOutput::Named(vectors) => VectorStructPersisted::Named(
607                vectors
608                    .into_iter()
609                    .map(|(k, v)| (k, VectorPersisted::from(v)))
610                    .collect(),
611            ),
612        }
613    }
614}
615
616impl TryFrom<VectorStructPersisted> for VectorStructInternal {
617    type Error = OperationError;
618    fn try_from(value: VectorStructPersisted) -> Result<Self, Self::Error> {
619        let vector_struct = match value {
620            VectorStructPersisted::Single(vector) => VectorStructInternal::Single(vector),
621            VectorStructPersisted::MultiDense(vector) => {
622                VectorStructInternal::MultiDense(MultiDenseVectorInternal::try_from(vector)?)
623            }
624            VectorStructPersisted::Named(vectors) => VectorStructInternal::Named(
625                vectors
626                    .into_iter()
627                    .map(|(k, v)| (k, VectorInternal::from(v)))
628                    .collect(),
629            ),
630        };
631        Ok(vector_struct)
632    }
633}
634
635impl From<VectorStructPersisted> for NamedVectors<'_> {
636    fn from(value: VectorStructPersisted) -> Self {
637        match value {
638            VectorStructPersisted::Single(vector) => {
639                NamedVectors::from_pairs([(DEFAULT_VECTOR_NAME.to_owned(), vector)])
640            }
641            VectorStructPersisted::MultiDense(vector) => {
642                let mut named_vector = NamedVectors::default();
643                let multivec = MultiDenseVectorInternal::new_unchecked(vector);
644
645                named_vector.insert(
646                    DEFAULT_VECTOR_NAME.to_owned(),
647                    crate::segment::data_types::vectors::VectorInternal::from(multivec),
648                );
649                named_vector
650            }
651            VectorStructPersisted::Named(vectors) => {
652                let mut named_vector = NamedVectors::default();
653                for (name, vector) in vectors {
654                    named_vector.insert(
655                        name,
656                        crate::segment::data_types::vectors::VectorInternal::from(vector),
657                    );
658                }
659                named_vector
660            }
661        }
662    }
663}
664
665/// Single vector data, as it is persisted in WAL
666/// Unlike [`api::rest::Vector`], this struct only stores raw vectors, inferenced or resolved.
667/// Unlike [`VectorInternal`], is not optimized for search
668#[derive(Clone, PartialEq, Deserialize, Serialize)]
669#[serde(untagged, rename_all = "snake_case")]
670pub enum VectorPersisted {
671    Dense(DenseVector),
672    Sparse(crate::sparse::common::sparse_vector::SparseVector),
673    MultiDense(MultiDenseVector),
674}
675
676impl Hash for VectorPersisted {
677    fn hash<H: Hasher>(&self, state: &mut H) {
678        mem::discriminant(self).hash(state);
679        match self {
680            VectorPersisted::Dense(vec) => {
681                for v in vec {
682                    OrderedFloat(*v).hash(state);
683                }
684            }
685            VectorPersisted::Sparse(sparse) => {
686                sparse.hash(state);
687            }
688            VectorPersisted::MultiDense(multi_vec) => {
689                for vec in multi_vec {
690                    for v in vec {
691                        OrderedFloat(*v).hash(state);
692                    }
693                }
694            }
695        }
696    }
697}
698
699impl VectorPersisted {
700    pub fn new_sparse(indices: Vec<DimId>, values: Vec<DimWeight>) -> Self {
701        Self::Sparse(crate::sparse::common::sparse_vector::SparseVector { indices, values })
702    }
703
704    pub fn empty_sparse() -> Self {
705        Self::new_sparse(vec![], vec![])
706    }
707}
708
709impl Debug for VectorPersisted {
710    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
711        match self {
712            VectorPersisted::Dense(vector) => {
713                let first_elements = vector.iter().take(4).join(", ");
714                write!(f, "Dense([{}, ... x {}])", first_elements, vector.len())
715            }
716            VectorPersisted::Sparse(vector) => {
717                let first_elements = vector
718                    .indices
719                    .iter()
720                    .zip(vector.values.iter())
721                    .take(4)
722                    .map(|(k, v)| format!("{k}->{v}"))
723                    .join(", ");
724                write!(
725                    f,
726                    "Sparse([{}, ... x {})",
727                    first_elements,
728                    vector.indices.len()
729                )
730            }
731            VectorPersisted::MultiDense(vector) => {
732                let first_vectors = vector
733                    .iter()
734                    .take(4)
735                    .map(|v| {
736                        let first_elements = v.iter().take(4).join(", ");
737                        format!("[{}, ... x {}]", first_elements, v.len())
738                    })
739                    .join(", ");
740                write!(f, "MultiDense([{}, ... x {})", first_vectors, vector.len())
741            }
742        }
743    }
744}
745
746impl Validate for VectorPersisted {
747    fn validate(&self) -> Result<(), ValidationErrors> {
748        match self {
749            VectorPersisted::Dense(_) => Ok(()),
750            VectorPersisted::Sparse(v) => v.validate(),
751            VectorPersisted::MultiDense(m) => validate_multi_vector(m),
752        }
753    }
754}
755
756impl From<VectorInternal> for VectorPersisted {
757    fn from(value: VectorInternal) -> Self {
758        match value {
759            VectorInternal::Dense(vector) => VectorPersisted::Dense(vector),
760            VectorInternal::Sparse(vector) => VectorPersisted::Sparse(vector),
761            VectorInternal::MultiDense(vector) => {
762                VectorPersisted::MultiDense(vector.into_multi_vectors())
763            }
764        }
765    }
766}
767
768#[cfg(feature = "api")]
769impl From<api::rest::VectorOutput> for VectorPersisted {
770    fn from(value: api::rest::VectorOutput) -> Self {
771        match value {
772            api::rest::VectorOutput::Dense(vector) => VectorPersisted::Dense(vector),
773            api::rest::VectorOutput::Sparse(vector) => VectorPersisted::Sparse(vector),
774            api::rest::VectorOutput::MultiDense(vector) => VectorPersisted::MultiDense(vector),
775        }
776    }
777}
778
779impl From<VectorPersisted> for VectorInternal {
780    fn from(value: VectorPersisted) -> Self {
781        match value {
782            VectorPersisted::Dense(vector) => VectorInternal::Dense(vector),
783            VectorPersisted::Sparse(vector) => VectorInternal::Sparse(vector),
784            VectorPersisted::MultiDense(vector) => {
785                // the REST vectors have been validated already
786                // we can use an internal constructor
787                VectorInternal::MultiDense(MultiDenseVectorInternal::new_unchecked(vector))
788            }
789        }
790    }
791}
792
793fn retain_with_index<T, F>(vec: &mut Vec<T>, mut filter: F)
794where
795    F: FnMut(usize, &T) -> bool,
796{
797    let mut index = 0;
798
799    vec.retain(|item| {
800        let retain = filter(index, item);
801        index += 1;
802        retain
803    });
804}