Skip to main content

graphrecords_python/graphrecord/
schema.rs

1use super::{
2    PyAttributes, PyGraphRecord, PyGroup, PyNodeIndex,
3    attribute::PyGraphRecordAttribute,
4    datatype::PyDataType,
5    errors::PyGraphRecordError,
6    traits::{DeepFrom, DeepInto},
7};
8use graphrecords_core::{
9    errors::GraphError,
10    graphrecord::{
11        EdgeIndex, Group,
12        schema::{AttributeDataType, AttributeType, GroupSchema, Schema, SchemaType},
13    },
14};
15use parking_lot::RwLock;
16use pyo3::prelude::*;
17use std::collections::HashMap;
18
19#[pyclass(frozen, eq, eq_int)]
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum PyAttributeType {
22    Categorical = 0,
23    Continuous = 1,
24    Temporal = 2,
25    Unstructured = 3,
26}
27
28impl From<AttributeType> for PyAttributeType {
29    fn from(value: AttributeType) -> Self {
30        match value {
31            AttributeType::Categorical => Self::Categorical,
32            AttributeType::Continuous => Self::Continuous,
33            AttributeType::Temporal => Self::Temporal,
34            AttributeType::Unstructured => Self::Unstructured,
35        }
36    }
37}
38
39impl From<PyAttributeType> for AttributeType {
40    fn from(value: PyAttributeType) -> Self {
41        match value {
42            PyAttributeType::Categorical => Self::Categorical,
43            PyAttributeType::Continuous => Self::Continuous,
44            PyAttributeType::Temporal => Self::Temporal,
45            PyAttributeType::Unstructured => Self::Unstructured,
46        }
47    }
48}
49
50#[pymethods]
51impl PyAttributeType {
52    #[staticmethod]
53    pub fn infer(data_type: PyDataType) -> Self {
54        AttributeType::infer(&data_type.into()).into()
55    }
56}
57
58#[pyclass(frozen)]
59#[derive(Debug, Clone)]
60pub struct PyAttributeDataType {
61    data_type: PyDataType,
62    attribute_type: PyAttributeType,
63}
64
65impl From<AttributeDataType> for PyAttributeDataType {
66    fn from(value: AttributeDataType) -> Self {
67        Self {
68            data_type: value.data_type().clone().into(),
69            attribute_type: (*value.attribute_type()).into(),
70        }
71    }
72}
73
74impl TryFrom<PyAttributeDataType> for AttributeDataType {
75    type Error = GraphError;
76
77    fn try_from(value: PyAttributeDataType) -> Result<Self, Self::Error> {
78        Self::new(value.data_type.into(), value.attribute_type.into())
79    }
80}
81
82impl DeepFrom<AttributeDataType> for PyAttributeDataType {
83    fn deep_from(value: AttributeDataType) -> Self {
84        value.into()
85    }
86}
87
88#[pymethods]
89impl PyAttributeDataType {
90    #[new]
91    #[pyo3(signature = (data_type, attribute_type))]
92    pub const fn new(data_type: PyDataType, attribute_type: PyAttributeType) -> Self {
93        Self {
94            data_type,
95            attribute_type,
96        }
97    }
98
99    #[getter]
100    pub fn data_type(&self) -> PyDataType {
101        self.data_type.clone()
102    }
103
104    #[getter]
105    pub fn attribute_type(&self) -> PyAttributeType {
106        self.attribute_type.clone()
107    }
108}
109
110#[pyclass(frozen)]
111#[repr(transparent)]
112#[derive(Debug, Clone)]
113pub struct PyGroupSchema(GroupSchema);
114
115impl From<GroupSchema> for PyGroupSchema {
116    fn from(value: GroupSchema) -> Self {
117        Self(value)
118    }
119}
120
121impl From<PyGroupSchema> for GroupSchema {
122    fn from(value: PyGroupSchema) -> Self {
123        value.0
124    }
125}
126
127impl DeepFrom<GroupSchema> for PyGroupSchema {
128    fn deep_from(value: GroupSchema) -> Self {
129        value.into()
130    }
131}
132
133impl DeepFrom<PyGroupSchema> for GroupSchema {
134    fn deep_from(value: PyGroupSchema) -> Self {
135        value.into()
136    }
137}
138
139#[pymethods]
140impl PyGroupSchema {
141    #[new]
142    pub fn new(
143        nodes: HashMap<PyGraphRecordAttribute, PyAttributeDataType>,
144        edges: HashMap<PyGraphRecordAttribute, PyAttributeDataType>,
145    ) -> PyResult<Self> {
146        let nodes = nodes
147            .into_iter()
148            .map(|(k, v)| Ok((k.into(), v.try_into()?)))
149            .collect::<Result<HashMap<_, _>, GraphError>>()
150            .map_err(PyGraphRecordError::from)?
151            .into();
152        let edges = edges
153            .into_iter()
154            .map(|(k, v)| Ok((k.into(), v.try_into()?)))
155            .collect::<Result<HashMap<_, _>, GraphError>>()
156            .map_err(PyGraphRecordError::from)?
157            .into();
158
159        Ok(Self(GroupSchema::new(nodes, edges)))
160    }
161
162    #[getter]
163    pub fn nodes(&self) -> HashMap<PyGraphRecordAttribute, PyAttributeDataType> {
164        self.0.nodes().clone().deep_into()
165    }
166
167    #[getter]
168    pub fn edges(&self) -> HashMap<PyGraphRecordAttribute, PyAttributeDataType> {
169        self.0.edges().clone().deep_into()
170    }
171
172    pub fn validate_node(&self, index: PyNodeIndex, attributes: PyAttributes) -> PyResult<()> {
173        Ok(self
174            .0
175            .validate_node(&index.into(), &attributes.deep_into())
176            .map_err(PyGraphRecordError::from)?)
177    }
178
179    pub fn validate_edge(&self, index: EdgeIndex, attributes: PyAttributes) -> PyResult<()> {
180        Ok(self
181            .0
182            .validate_edge(&index, &attributes.deep_into())
183            .map_err(PyGraphRecordError::from)?)
184    }
185}
186
187#[pyclass(frozen, eq, eq_int)]
188#[derive(Debug, Clone, PartialEq, Eq)]
189pub enum PySchemaType {
190    Provided = 0,
191    Inferred = 1,
192}
193
194impl From<SchemaType> for PySchemaType {
195    fn from(value: SchemaType) -> Self {
196        match value {
197            SchemaType::Provided => Self::Provided,
198            SchemaType::Inferred => Self::Inferred,
199        }
200    }
201}
202
203impl From<PySchemaType> for SchemaType {
204    fn from(value: PySchemaType) -> Self {
205        match value {
206            PySchemaType::Provided => Self::Provided,
207            PySchemaType::Inferred => Self::Inferred,
208        }
209    }
210}
211
212#[pyclass(frozen)]
213#[repr(transparent)]
214#[derive(Debug)]
215pub struct PySchema(RwLock<Schema>);
216
217impl From<Schema> for PySchema {
218    fn from(value: Schema) -> Self {
219        Self(RwLock::new(value))
220    }
221}
222
223impl From<PySchema> for Schema {
224    fn from(value: PySchema) -> Self {
225        value.0.into_inner()
226    }
227}
228
229impl Clone for PySchema {
230    fn clone(&self) -> Self {
231        Self(RwLock::new(self.0.read().clone()))
232    }
233}
234
235#[pymethods]
236impl PySchema {
237    #[new]
238    #[pyo3(signature = (groups, ungrouped, schema_type=PySchemaType::Provided))]
239    pub fn new(
240        groups: HashMap<PyGroup, PyGroupSchema>,
241        ungrouped: PyGroupSchema,
242        schema_type: PySchemaType,
243    ) -> Self {
244        match schema_type {
245            PySchemaType::Provided => {
246                Schema::new_provided(groups.deep_into(), ungrouped.deep_into()).into()
247            }
248            PySchemaType::Inferred => {
249                Schema::new_inferred(groups.deep_into(), ungrouped.deep_into()).into()
250            }
251        }
252    }
253
254    #[staticmethod]
255    pub fn infer(graphrecord: Bound<'_, PyGraphRecord>) -> PyResult<Self> {
256        let graphrecord = graphrecord.get();
257
258        Ok(Schema::infer(&*graphrecord.inner()?).into())
259    }
260
261    #[getter]
262    pub fn groups(&self) -> Vec<PyGroup> {
263        self.0
264            .read()
265            .groups()
266            .keys()
267            .cloned()
268            .collect::<Vec<Group>>()
269            .deep_into()
270    }
271
272    pub fn group(&self, group: PyGroup) -> PyResult<PyGroupSchema> {
273        Ok(self
274            .0
275            .read()
276            .group(&group.into())
277            .map(|g| g.clone().into())
278            .map_err(PyGraphRecordError::from)?)
279    }
280
281    #[getter]
282    pub fn ungrouped(&self) -> PyGroupSchema {
283        self.0.read().ungrouped().clone().into()
284    }
285
286    #[getter]
287    pub fn schema_type(&self) -> PySchemaType {
288        self.0.read().schema_type().clone().into()
289    }
290
291    #[pyo3(signature = (index, attributes, group=None))]
292    pub fn validate_node(
293        &self,
294        index: PyNodeIndex,
295        attributes: PyAttributes,
296        group: Option<PyGroup>,
297    ) -> PyResult<()> {
298        Ok(self
299            .0
300            .read()
301            .validate_node(
302                &index.into(),
303                &attributes.deep_into(),
304                group.map(std::convert::Into::into).as_ref(),
305            )
306            .map_err(PyGraphRecordError::from)?)
307    }
308
309    #[pyo3(signature = (index, attributes, group=None))]
310    pub fn validate_edge(
311        &self,
312        index: EdgeIndex,
313        attributes: PyAttributes,
314        group: Option<PyGroup>,
315    ) -> PyResult<()> {
316        Ok(self
317            .0
318            .read()
319            .validate_edge(
320                &index,
321                &attributes.deep_into(),
322                group.map(std::convert::Into::into).as_ref(),
323            )
324            .map_err(PyGraphRecordError::from)?)
325    }
326
327    #[pyo3(signature = (attribute, data_type, attribute_type, group=None))]
328    pub fn set_node_attribute(
329        &self,
330        attribute: PyGraphRecordAttribute,
331        data_type: PyDataType,
332        attribute_type: PyAttributeType,
333        group: Option<PyGroup>,
334    ) -> PyResult<()> {
335        Ok(self
336            .0
337            .write()
338            .set_node_attribute(
339                &attribute.into(),
340                data_type.into(),
341                attribute_type.into(),
342                group.map(std::convert::Into::into).as_ref(),
343            )
344            .map_err(PyGraphRecordError::from)?)
345    }
346
347    #[pyo3(signature = (attribute, data_type, attribute_type, group=None))]
348    pub fn set_edge_attribute(
349        &self,
350        attribute: PyGraphRecordAttribute,
351        data_type: PyDataType,
352        attribute_type: PyAttributeType,
353        group: Option<PyGroup>,
354    ) -> PyResult<()> {
355        Ok(self
356            .0
357            .write()
358            .set_edge_attribute(
359                &attribute.into(),
360                data_type.into(),
361                attribute_type.into(),
362                group.map(std::convert::Into::into).as_ref(),
363            )
364            .map_err(PyGraphRecordError::from)?)
365    }
366
367    #[pyo3(signature = (attribute, data_type, attribute_type, group=None))]
368    pub fn update_node_attribute(
369        &self,
370        attribute: PyGraphRecordAttribute,
371        data_type: PyDataType,
372        attribute_type: PyAttributeType,
373        group: Option<PyGroup>,
374    ) -> PyResult<()> {
375        Ok(self
376            .0
377            .write()
378            .update_node_attribute(
379                &attribute.into(),
380                data_type.into(),
381                attribute_type.into(),
382                group.map(std::convert::Into::into).as_ref(),
383            )
384            .map_err(PyGraphRecordError::from)?)
385    }
386
387    #[pyo3(signature = (attribute, data_type, attribute_type, group=None))]
388    pub fn update_edge_attribute(
389        &self,
390        attribute: PyGraphRecordAttribute,
391        data_type: PyDataType,
392        attribute_type: PyAttributeType,
393        group: Option<PyGroup>,
394    ) -> PyResult<()> {
395        Ok(self
396            .0
397            .write()
398            .update_edge_attribute(
399                &attribute.into(),
400                data_type.into(),
401                attribute_type.into(),
402                group.map(std::convert::Into::into).as_ref(),
403            )
404            .map_err(PyGraphRecordError::from)?)
405    }
406
407    #[pyo3(signature = (attribute, group=None))]
408    pub fn remove_node_attribute(&self, attribute: PyGraphRecordAttribute, group: Option<PyGroup>) {
409        self.0.write().remove_node_attribute(
410            &attribute.into(),
411            group.map(std::convert::Into::into).as_ref(),
412        );
413    }
414
415    #[pyo3(signature = (attribute, group=None))]
416    pub fn remove_edge_attribute(&self, attribute: PyGraphRecordAttribute, group: Option<PyGroup>) {
417        self.0.write().remove_edge_attribute(
418            &attribute.into(),
419            group.map(std::convert::Into::into).as_ref(),
420        );
421    }
422
423    pub fn add_group(&self, group: PyGroup, schema: PyGroupSchema) -> PyResult<()> {
424        Ok(self
425            .0
426            .write()
427            .add_group(group.into(), schema.into())
428            .map_err(PyGraphRecordError::from)?)
429    }
430
431    pub fn remove_group(&self, group: PyGroup) {
432        self.0.write().remove_group(&group.into());
433    }
434
435    pub fn freeze(&self) {
436        self.0.write().freeze();
437    }
438
439    pub fn unfreeze(&self) {
440        self.0.write().unfreeze();
441    }
442}