qdrant_edge/shard/operations/
payload_ops.rs1use std::fmt;
2
3#[cfg(feature = "api")]
4use schemars::JsonSchema;
5use crate::segment::json_path::JsonPath;
6use crate::segment::types::{Filter, Payload, PayloadKeyType, PointIdType};
7use serde::{self, Deserialize, Serialize};
8use strum::{EnumDiscriminants, EnumIter};
9#[cfg(feature = "api")]
10use validator::Validate;
11
12#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, EnumDiscriminants, Hash)]
14#[strum_discriminants(derive(EnumIter))]
15#[serde(rename_all = "snake_case")]
16pub enum PayloadOps {
17 SetPayload(SetPayloadOp),
19 DeletePayload(DeletePayloadOp),
21 ClearPayload { points: Vec<PointIdType> },
23 ClearPayloadByFilter(Filter),
25 OverwritePayload(SetPayloadOp),
27}
28
29impl PayloadOps {
30 pub fn point_ids(&self) -> Option<Vec<PointIdType>> {
31 match self {
32 Self::SetPayload(op) => op.points.clone(),
33 Self::DeletePayload(op) => op.points.clone(),
34 Self::ClearPayload { points } => Some(points.clone()),
35 Self::ClearPayloadByFilter(_) => None,
36 Self::OverwritePayload(op) => op.points.clone(),
37 }
38 }
39
40 pub fn retain_point_ids<F>(&mut self, filter: F)
41 where
42 F: Fn(&PointIdType) -> bool,
43 {
44 match self {
45 Self::SetPayload(op) => retain_opt(op.points.as_mut(), filter),
46 Self::DeletePayload(op) => retain_opt(op.points.as_mut(), filter),
47 Self::ClearPayload { points } => points.retain(filter),
48 Self::ClearPayloadByFilter(_) => (),
49 Self::OverwritePayload(op) => retain_opt(op.points.as_mut(), filter),
50 }
51 }
52}
53
54fn retain_opt<T, F>(vec: Option<&mut Vec<T>>, filter: F)
55where
56 F: Fn(&T) -> bool,
57{
58 if let Some(vec) = vec {
59 vec.retain(filter);
60 }
61}
62
63#[cfg(feature = "api")]
65#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone)]
66#[serde(try_from = "SetPayloadShadow")]
67pub struct SetPayload {
68 pub payload: Payload,
69 pub points: Option<Vec<PointIdType>>,
71 #[validate(nested)]
73 pub filter: Option<Filter>,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
75 pub shard_key: Option<api::rest::ShardKeySelector>,
76 pub key: Option<JsonPath>,
78}
79
80#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Hash)]
86pub struct SetPayloadOp {
87 pub payload: Payload,
88 pub points: Option<Vec<PointIdType>>,
90 pub filter: Option<Filter>,
92 pub key: Option<JsonPath>,
94}
95
96#[cfg(feature = "api")]
98#[derive(Debug, Deserialize, Serialize, JsonSchema, Validate, Clone)]
99#[serde(try_from = "DeletePayloadShadow")]
100pub struct DeletePayload {
101 pub keys: Vec<PayloadKeyType>,
103 pub points: Option<Vec<PointIdType>>,
105 #[validate(nested)]
107 pub filter: Option<Filter>,
108 #[serde(default, skip_serializing_if = "Option::is_none")]
109 pub shard_key: Option<api::rest::ShardKeySelector>,
110}
111
112#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Hash)]
118pub struct DeletePayloadOp {
119 pub keys: Vec<PayloadKeyType>,
121 pub points: Option<Vec<PointIdType>>,
123 pub filter: Option<Filter>,
125}
126
127#[cfg(feature = "api")]
128#[derive(Deserialize)]
129struct SetPayloadShadow {
130 pub payload: Payload,
131 pub points: Option<Vec<PointIdType>>,
132 pub filter: Option<Filter>,
133 pub shard_key: Option<api::rest::ShardKeySelector>,
134 pub key: Option<JsonPath>,
135}
136
137#[cfg(feature = "api")]
138impl TryFrom<SetPayloadShadow> for SetPayload {
139 type Error = PointsSelectorValidationError;
140
141 fn try_from(value: SetPayloadShadow) -> Result<Self, Self::Error> {
142 let SetPayloadShadow {
143 payload,
144 points,
145 filter,
146 shard_key,
147 key,
148 } = value;
149
150 if points.is_some() || filter.is_some() {
151 Ok(SetPayload {
152 payload,
153 points,
154 filter,
155 shard_key,
156 key,
157 })
158 } else {
159 Err(PointsSelectorValidationError)
160 }
161 }
162}
163
164#[cfg(feature = "api")]
165#[derive(Deserialize)]
166struct DeletePayloadShadow {
167 pub keys: Vec<PayloadKeyType>,
168 pub points: Option<Vec<PointIdType>>,
169 pub filter: Option<Filter>,
170 pub shard_key: Option<api::rest::ShardKeySelector>,
171}
172
173#[cfg(feature = "api")]
174impl TryFrom<DeletePayloadShadow> for DeletePayload {
175 type Error = PointsSelectorValidationError;
176
177 fn try_from(value: DeletePayloadShadow) -> Result<Self, Self::Error> {
178 let DeletePayloadShadow {
179 keys,
180 points,
181 filter,
182 shard_key,
183 } = value;
184 if points.is_some() || filter.is_some() {
185 Ok(DeletePayload {
186 keys,
187 points,
188 filter,
189 shard_key,
190 })
191 } else {
192 Err(PointsSelectorValidationError)
193 }
194 }
195}
196
197#[derive(Debug)]
198pub struct PointsSelectorValidationError;
199
200impl fmt::Display for PointsSelectorValidationError {
201 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202 write!(f, "Either list of point ids or filter must be provided")
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use crate::segment::types::{Payload, PayloadContainer};
209 use serde_json::Value;
210
211 use super::*;
212
213 #[derive(Debug, Deserialize, Serialize)]
214 pub struct TextSelector {
215 pub points: Vec<PointIdType>,
216 }
217
218 #[derive(Debug, Deserialize, Serialize)]
219 pub struct TextSelectorOpt {
220 pub points: Option<Vec<PointIdType>>,
221 pub filter: Option<Filter>,
222 }
223
224 #[test]
225 fn test_replace_with_opt_in_cbor() {
226 let obj1 = TextSelector {
227 points: vec![1.into(), 2.into(), 3.into()],
228 };
229 let raw_cbor = serde_cbor::to_vec(&obj1).unwrap();
230 let obj2 = serde_cbor::from_slice::<TextSelectorOpt>(&raw_cbor).unwrap();
231 eprintln!("obj2 = {obj2:#?}");
232 assert_eq!(obj1.points, obj2.points.unwrap());
233 }
234
235 #[test]
236 fn test_serialization() {
237 let query1 = r#"
238 {
239 "set_payload": {
240 "points": [1, 2, 3],
241 "payload": {
242 "key1": "hello" ,
243 "key2": [1,2,3,4],
244 "key3": {"json": {"key1":"value1"} }
245 }
246 }
247 }
248 "#;
249
250 let operation: PayloadOps = serde_json::from_str(query1).unwrap();
251
252 match operation {
253 PayloadOps::SetPayload(set_payload) => {
254 let payload: Payload = set_payload.payload;
255 assert_eq!(payload.len(), 3);
256
257 assert!(payload.contains_key("key1"));
258
259 let payload_type = payload
260 .get_value(&"key1".parse().unwrap())
261 .into_iter()
262 .next()
263 .cloned()
264 .expect("No key key1");
265
266 match payload_type {
267 Value::String(x) => assert_eq!(x, "hello"),
268 _ => panic!("Wrong payload type"),
269 }
270
271 let payload_type_json = payload
272 .get_value(&"key3".parse().unwrap())
273 .into_iter()
274 .next()
275 .cloned();
276
277 assert!(matches!(payload_type_json, Some(Value::Object(_))))
278 }
279 _ => panic!("Wrong operation"),
280 }
281 }
282}