use serde::{Deserialize, Serialize};
#[derive(
Debug, Clone, Serialize, Deserialize, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize,
)]
pub struct ShapeDefinition {
pub shape_id: String,
pub tenant_id: u32,
pub shape_type: ShapeType,
pub description: String,
#[serde(default)]
pub field_filter: Vec<String>,
}
#[derive(
Debug, Clone, Serialize, Deserialize, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize,
)]
pub enum ShapeType {
Document {
collection: String,
predicate: Vec<u8>,
},
Graph {
root_nodes: Vec<String>,
max_depth: usize,
edge_label: Option<String>,
},
Vector {
collection: String,
field_name: Option<String>,
},
}
impl ShapeDefinition {
pub fn could_match(&self, collection: &str, _doc_id: &str) -> bool {
match &self.shape_type {
ShapeType::Document {
collection: shape_coll,
..
} => shape_coll == collection,
ShapeType::Graph { root_nodes, .. } => {
!root_nodes.is_empty()
}
ShapeType::Vector {
collection: shape_coll,
..
} => shape_coll == collection,
}
}
pub fn collection(&self) -> Option<&str> {
match &self.shape_type {
ShapeType::Document { collection, .. } => Some(collection),
ShapeType::Vector { collection, .. } => Some(collection),
ShapeType::Graph { .. } => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn document_shape_matches_collection() {
let shape = ShapeDefinition {
shape_id: "s1".into(),
tenant_id: 1,
shape_type: ShapeType::Document {
collection: "orders".into(),
predicate: Vec::new(),
},
description: "all orders".into(),
field_filter: vec![],
};
assert!(shape.could_match("orders", "o1"));
assert!(!shape.could_match("users", "u1"));
assert_eq!(shape.collection(), Some("orders"));
}
#[test]
fn graph_shape() {
let shape = ShapeDefinition {
shape_id: "g1".into(),
tenant_id: 1,
shape_type: ShapeType::Graph {
root_nodes: vec!["alice".into()],
max_depth: 2,
edge_label: Some("KNOWS".into()),
},
description: "alice's network".into(),
field_filter: vec![],
};
assert!(shape.could_match("any_collection", "any_doc"));
assert_eq!(shape.collection(), None);
}
#[test]
fn vector_shape() {
let shape = ShapeDefinition {
shape_id: "v1".into(),
tenant_id: 1,
shape_type: ShapeType::Vector {
collection: "embeddings".into(),
field_name: Some("title".into()),
},
description: "title embeddings".into(),
field_filter: vec![],
};
assert!(shape.could_match("embeddings", "e1"));
assert!(!shape.could_match("other", "e1"));
}
#[test]
fn msgpack_roundtrip() {
let shape = ShapeDefinition {
shape_id: "test".into(),
tenant_id: 5,
shape_type: ShapeType::Document {
collection: "users".into(),
predicate: vec![1, 2, 3],
},
description: "test shape".into(),
field_filter: vec![],
};
let bytes = rmp_serde::to_vec_named(&shape).unwrap();
let decoded: ShapeDefinition = rmp_serde::from_slice(&bytes).unwrap();
assert_eq!(decoded.shape_id, "test");
assert_eq!(decoded.tenant_id, 5);
assert!(matches!(decoded.shape_type, ShapeType::Document { .. }));
}
}