1use serde::{Deserialize, Serialize};
13
14#[derive(
16 Debug,
17 Clone,
18 Serialize,
19 Deserialize,
20 rkyv::Archive,
21 rkyv::Serialize,
22 rkyv::Deserialize,
23 zerompk::ToMessagePack,
24 zerompk::FromMessagePack,
25)]
26pub struct ShapeDefinition {
27 pub shape_id: String,
29 pub tenant_id: u32,
31 pub shape_type: ShapeType,
33 pub description: String,
35 #[serde(default)]
38 pub field_filter: Vec<String>,
39}
40
41#[derive(
53 Debug,
54 Clone,
55 PartialEq,
56 Eq,
57 Serialize,
58 Deserialize,
59 rkyv::Archive,
60 rkyv::Serialize,
61 rkyv::Deserialize,
62 zerompk::ToMessagePack,
63 zerompk::FromMessagePack,
64)]
65pub struct ArrayCoordRange {
66 pub start: Vec<u64>,
68 pub end: Option<Vec<u64>>,
71}
72
73impl ArrayCoordRange {
74 pub fn contains(&self, coord: &[u64]) -> bool {
79 if coord.len() != self.start.len() {
80 return false;
81 }
82 if coord < self.start.as_slice() {
83 return false;
84 }
85 if let Some(end) = &self.end
86 && (coord.len() != end.len() || coord > end.as_slice())
87 {
88 return false;
89 }
90 true
91 }
92}
93
94#[derive(
96 Debug,
97 Clone,
98 Serialize,
99 Deserialize,
100 rkyv::Archive,
101 rkyv::Serialize,
102 rkyv::Deserialize,
103 zerompk::ToMessagePack,
104 zerompk::FromMessagePack,
105)]
106#[serde(rename_all = "snake_case")]
107#[non_exhaustive]
108pub enum ShapeType {
109 #[serde(rename = "document")]
113 Document {
114 collection: String,
115 predicate: Vec<u8>,
117 },
118
119 #[serde(rename = "graph")]
124 Graph {
125 root_nodes: Vec<String>,
126 max_depth: usize,
127 edge_label: Option<String>,
128 },
129
130 #[serde(rename = "vector")]
134 Vector {
135 collection: String,
136 field_name: Option<String>,
137 },
138
139 #[serde(rename = "array")]
145 Array {
146 array_name: String,
147 coord_range: Option<ArrayCoordRange>,
148 },
149}
150
151impl ShapeDefinition {
152 pub fn could_match(&self, collection: &str, _doc_id: &str) -> bool {
158 match &self.shape_type {
159 ShapeType::Document {
160 collection: shape_coll,
161 ..
162 } => shape_coll == collection,
163 ShapeType::Graph { root_nodes, .. } => {
164 !root_nodes.is_empty()
166 }
167 ShapeType::Vector {
168 collection: shape_coll,
169 ..
170 } => shape_coll == collection,
171 ShapeType::Array { .. } => false,
172 }
173 }
174
175 pub fn matches_array_op(&self, array: &str, coord: &[u64]) -> bool {
180 match &self.shape_type {
181 ShapeType::Array {
182 array_name,
183 coord_range,
184 } => {
185 if array_name != array {
186 return false;
187 }
188 match coord_range {
189 None => true,
190 Some(range) => range.contains(coord),
191 }
192 }
193 _ => false,
194 }
195 }
196
197 pub fn collection(&self) -> Option<&str> {
199 match &self.shape_type {
200 ShapeType::Document { collection, .. } => Some(collection),
201 ShapeType::Vector { collection, .. } => Some(collection),
202 ShapeType::Graph { .. } => None,
203 ShapeType::Array { .. } => None,
204 }
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn document_shape_matches_collection() {
214 let shape = ShapeDefinition {
215 shape_id: "s1".into(),
216 tenant_id: 1,
217 shape_type: ShapeType::Document {
218 collection: "orders".into(),
219 predicate: Vec::new(),
220 },
221 description: "all orders".into(),
222 field_filter: vec![],
223 };
224
225 assert!(shape.could_match("orders", "o1"));
226 assert!(!shape.could_match("users", "u1"));
227 assert_eq!(shape.collection(), Some("orders"));
228 }
229
230 #[test]
231 fn graph_shape() {
232 let shape = ShapeDefinition {
233 shape_id: "g1".into(),
234 tenant_id: 1,
235 shape_type: ShapeType::Graph {
236 root_nodes: vec!["alice".into()],
237 max_depth: 2,
238 edge_label: Some("KNOWS".into()),
239 },
240 description: "alice's network".into(),
241 field_filter: vec![],
242 };
243
244 assert!(shape.could_match("any_collection", "any_doc"));
245 assert_eq!(shape.collection(), None);
246 }
247
248 #[test]
249 fn vector_shape() {
250 let shape = ShapeDefinition {
251 shape_id: "v1".into(),
252 tenant_id: 1,
253 shape_type: ShapeType::Vector {
254 collection: "embeddings".into(),
255 field_name: Some("title".into()),
256 },
257 description: "title embeddings".into(),
258 field_filter: vec![],
259 };
260
261 assert!(shape.could_match("embeddings", "e1"));
262 assert!(!shape.could_match("other", "e1"));
263 }
264
265 #[test]
266 fn msgpack_roundtrip() {
267 let shape = ShapeDefinition {
268 shape_id: "test".into(),
269 tenant_id: 5,
270 shape_type: ShapeType::Document {
271 collection: "users".into(),
272 predicate: vec![1, 2, 3],
273 },
274 description: "test shape".into(),
275 field_filter: vec![],
276 };
277 let bytes = zerompk::to_msgpack_vec(&shape).unwrap();
278 let decoded: ShapeDefinition = zerompk::from_msgpack(&bytes).unwrap();
279 assert_eq!(decoded.shape_id, "test");
280 assert_eq!(decoded.tenant_id, 5);
281 assert!(matches!(decoded.shape_type, ShapeType::Document { .. }));
282 }
283
284 #[test]
285 fn array_shape_matches_array_op_no_range() {
286 let shape = ShapeDefinition {
287 shape_id: "a1".into(),
288 tenant_id: 1,
289 shape_type: ShapeType::Array {
290 array_name: "prices".into(),
291 coord_range: None,
292 },
293 description: "all prices".into(),
294 field_filter: vec![],
295 };
296 assert!(shape.matches_array_op("prices", &[0, 0]));
297 assert!(shape.matches_array_op("prices", &[999, 999]));
298 assert!(!shape.matches_array_op("other", &[0, 0]));
299 assert!(!shape.could_match("prices", "x"));
301 assert_eq!(shape.collection(), None);
302 }
303
304 #[test]
305 fn array_shape_matches_array_op_with_range() {
306 let range = ArrayCoordRange {
307 start: vec![10, 10],
308 end: Some(vec![20, 20]),
309 };
310 let shape = ShapeDefinition {
311 shape_id: "a2".into(),
312 tenant_id: 1,
313 shape_type: ShapeType::Array {
314 array_name: "temps".into(),
315 coord_range: Some(range),
316 },
317 description: "temps sub-range".into(),
318 field_filter: vec![],
319 };
320 assert!(shape.matches_array_op("temps", &[10, 10]));
321 assert!(shape.matches_array_op("temps", &[15, 15]));
322 assert!(shape.matches_array_op("temps", &[20, 20]));
323 assert!(!shape.matches_array_op("temps", &[9, 9]));
324 assert!(!shape.matches_array_op("temps", &[21, 21]));
325 assert!(!shape.matches_array_op("temps", &[10])); }
327
328 #[test]
329 fn array_coord_range_msgpack_roundtrip() {
330 let range = ArrayCoordRange {
331 start: vec![1, 2, 3],
332 end: Some(vec![4, 5, 6]),
333 };
334 let bytes = zerompk::to_msgpack_vec(&range).unwrap();
335 let decoded: ArrayCoordRange = zerompk::from_msgpack(&bytes).unwrap();
336 assert_eq!(decoded.start, vec![1, 2, 3]);
337 assert_eq!(decoded.end, Some(vec![4, 5, 6]));
338 }
339
340 #[test]
341 fn array_shape_msgpack_roundtrip() {
342 let shape = ShapeDefinition {
343 shape_id: "a3".into(),
344 tenant_id: 7,
345 shape_type: ShapeType::Array {
346 array_name: "sensor_matrix".into(),
347 coord_range: Some(ArrayCoordRange {
348 start: vec![0, 0],
349 end: None,
350 }),
351 },
352 description: "half-open range".into(),
353 field_filter: vec![],
354 };
355 let bytes = zerompk::to_msgpack_vec(&shape).unwrap();
356 let decoded: ShapeDefinition = zerompk::from_msgpack(&bytes).unwrap();
357 assert_eq!(decoded.shape_id, "a3");
358 assert!(matches!(
359 decoded.shape_type,
360 ShapeType::Array { ref array_name, .. } if array_name == "sensor_matrix"
361 ));
362 }
363}