1use super::{
2 PyAttributes, PyGraphRecord, PyGroup, PyNodeIndex,
3 attribute::PyGraphRecordAttribute,
4 datatype::PyDataType,
5 errors::PyGraphRecordError,
6 traits::{DeepFrom, DeepInto},
7};
8use graphrecords_core::{
9 errors::GraphError,
10 graphrecord::{
11 EdgeIndex, Group,
12 schema::{AttributeDataType, AttributeType, GroupSchema, Schema, SchemaType},
13 },
14};
15use parking_lot::RwLock;
16use pyo3::prelude::*;
17use std::collections::HashMap;
18
19#[pyclass(frozen, eq, eq_int)]
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum PyAttributeType {
22 Categorical = 0,
23 Continuous = 1,
24 Temporal = 2,
25 Unstructured = 3,
26}
27
28impl From<AttributeType> for PyAttributeType {
29 fn from(value: AttributeType) -> Self {
30 match value {
31 AttributeType::Categorical => Self::Categorical,
32 AttributeType::Continuous => Self::Continuous,
33 AttributeType::Temporal => Self::Temporal,
34 AttributeType::Unstructured => Self::Unstructured,
35 }
36 }
37}
38
39impl From<PyAttributeType> for AttributeType {
40 fn from(value: PyAttributeType) -> Self {
41 match value {
42 PyAttributeType::Categorical => Self::Categorical,
43 PyAttributeType::Continuous => Self::Continuous,
44 PyAttributeType::Temporal => Self::Temporal,
45 PyAttributeType::Unstructured => Self::Unstructured,
46 }
47 }
48}
49
50#[pymethods]
51impl PyAttributeType {
52 #[staticmethod]
53 pub fn infer(data_type: PyDataType) -> Self {
54 AttributeType::infer(&data_type.into()).into()
55 }
56}
57
58#[pyclass(frozen)]
59#[derive(Debug, Clone)]
60pub struct PyAttributeDataType {
61 data_type: PyDataType,
62 attribute_type: PyAttributeType,
63}
64
65impl From<AttributeDataType> for PyAttributeDataType {
66 fn from(value: AttributeDataType) -> Self {
67 Self {
68 data_type: value.data_type().clone().into(),
69 attribute_type: (*value.attribute_type()).into(),
70 }
71 }
72}
73
74impl TryFrom<PyAttributeDataType> for AttributeDataType {
75 type Error = GraphError;
76
77 fn try_from(value: PyAttributeDataType) -> Result<Self, Self::Error> {
78 Self::new(value.data_type.into(), value.attribute_type.into())
79 }
80}
81
82impl DeepFrom<AttributeDataType> for PyAttributeDataType {
83 fn deep_from(value: AttributeDataType) -> Self {
84 value.into()
85 }
86}
87
88#[pymethods]
89impl PyAttributeDataType {
90 #[new]
91 #[pyo3(signature = (data_type, attribute_type))]
92 pub const fn new(data_type: PyDataType, attribute_type: PyAttributeType) -> Self {
93 Self {
94 data_type,
95 attribute_type,
96 }
97 }
98
99 #[getter]
100 pub fn data_type(&self) -> PyDataType {
101 self.data_type.clone()
102 }
103
104 #[getter]
105 pub fn attribute_type(&self) -> PyAttributeType {
106 self.attribute_type.clone()
107 }
108}
109
110#[pyclass(frozen)]
111#[repr(transparent)]
112#[derive(Debug, Clone)]
113pub struct PyGroupSchema(GroupSchema);
114
115impl From<GroupSchema> for PyGroupSchema {
116 fn from(value: GroupSchema) -> Self {
117 Self(value)
118 }
119}
120
121impl From<PyGroupSchema> for GroupSchema {
122 fn from(value: PyGroupSchema) -> Self {
123 value.0
124 }
125}
126
127impl DeepFrom<GroupSchema> for PyGroupSchema {
128 fn deep_from(value: GroupSchema) -> Self {
129 value.into()
130 }
131}
132
133impl DeepFrom<PyGroupSchema> for GroupSchema {
134 fn deep_from(value: PyGroupSchema) -> Self {
135 value.into()
136 }
137}
138
139#[pymethods]
140impl PyGroupSchema {
141 #[new]
142 pub fn new(
143 nodes: HashMap<PyGraphRecordAttribute, PyAttributeDataType>,
144 edges: HashMap<PyGraphRecordAttribute, PyAttributeDataType>,
145 ) -> PyResult<Self> {
146 let nodes = nodes
147 .into_iter()
148 .map(|(k, v)| Ok((k.into(), v.try_into()?)))
149 .collect::<Result<HashMap<_, _>, GraphError>>()
150 .map_err(PyGraphRecordError::from)?
151 .into();
152 let edges = edges
153 .into_iter()
154 .map(|(k, v)| Ok((k.into(), v.try_into()?)))
155 .collect::<Result<HashMap<_, _>, GraphError>>()
156 .map_err(PyGraphRecordError::from)?
157 .into();
158
159 Ok(Self(GroupSchema::new(nodes, edges)))
160 }
161
162 #[getter]
163 pub fn nodes(&self) -> HashMap<PyGraphRecordAttribute, PyAttributeDataType> {
164 self.0.nodes().clone().deep_into()
165 }
166
167 #[getter]
168 pub fn edges(&self) -> HashMap<PyGraphRecordAttribute, PyAttributeDataType> {
169 self.0.edges().clone().deep_into()
170 }
171
172 pub fn validate_node(&self, index: PyNodeIndex, attributes: PyAttributes) -> PyResult<()> {
173 Ok(self
174 .0
175 .validate_node(&index.into(), &attributes.deep_into())
176 .map_err(PyGraphRecordError::from)?)
177 }
178
179 pub fn validate_edge(&self, index: EdgeIndex, attributes: PyAttributes) -> PyResult<()> {
180 Ok(self
181 .0
182 .validate_edge(&index, &attributes.deep_into())
183 .map_err(PyGraphRecordError::from)?)
184 }
185}
186
187#[pyclass(frozen, eq, eq_int)]
188#[derive(Debug, Clone, PartialEq, Eq)]
189pub enum PySchemaType {
190 Provided = 0,
191 Inferred = 1,
192}
193
194impl From<SchemaType> for PySchemaType {
195 fn from(value: SchemaType) -> Self {
196 match value {
197 SchemaType::Provided => Self::Provided,
198 SchemaType::Inferred => Self::Inferred,
199 }
200 }
201}
202
203impl From<PySchemaType> for SchemaType {
204 fn from(value: PySchemaType) -> Self {
205 match value {
206 PySchemaType::Provided => Self::Provided,
207 PySchemaType::Inferred => Self::Inferred,
208 }
209 }
210}
211
212#[pyclass(frozen)]
213#[repr(transparent)]
214#[derive(Debug)]
215pub struct PySchema(RwLock<Schema>);
216
217impl From<Schema> for PySchema {
218 fn from(value: Schema) -> Self {
219 Self(RwLock::new(value))
220 }
221}
222
223impl From<PySchema> for Schema {
224 fn from(value: PySchema) -> Self {
225 value.0.into_inner()
226 }
227}
228
229impl Clone for PySchema {
230 fn clone(&self) -> Self {
231 Self(RwLock::new(self.0.read().clone()))
232 }
233}
234
235#[pymethods]
236impl PySchema {
237 #[new]
238 #[pyo3(signature = (groups, ungrouped, schema_type=PySchemaType::Provided))]
239 pub fn new(
240 groups: HashMap<PyGroup, PyGroupSchema>,
241 ungrouped: PyGroupSchema,
242 schema_type: PySchemaType,
243 ) -> Self {
244 match schema_type {
245 PySchemaType::Provided => {
246 Schema::new_provided(groups.deep_into(), ungrouped.deep_into()).into()
247 }
248 PySchemaType::Inferred => {
249 Schema::new_inferred(groups.deep_into(), ungrouped.deep_into()).into()
250 }
251 }
252 }
253
254 #[staticmethod]
255 pub fn infer(graphrecord: Bound<'_, PyGraphRecord>) -> PyResult<Self> {
256 let graphrecord = graphrecord.get();
257
258 Ok(Schema::infer(&*graphrecord.inner()?).into())
259 }
260
261 #[getter]
262 pub fn groups(&self) -> Vec<PyGroup> {
263 self.0
264 .read()
265 .groups()
266 .keys()
267 .cloned()
268 .collect::<Vec<Group>>()
269 .deep_into()
270 }
271
272 pub fn group(&self, group: PyGroup) -> PyResult<PyGroupSchema> {
273 Ok(self
274 .0
275 .read()
276 .group(&group.into())
277 .map(|g| g.clone().into())
278 .map_err(PyGraphRecordError::from)?)
279 }
280
281 #[getter]
282 pub fn ungrouped(&self) -> PyGroupSchema {
283 self.0.read().ungrouped().clone().into()
284 }
285
286 #[getter]
287 pub fn schema_type(&self) -> PySchemaType {
288 self.0.read().schema_type().clone().into()
289 }
290
291 #[pyo3(signature = (index, attributes, group=None))]
292 pub fn validate_node(
293 &self,
294 index: PyNodeIndex,
295 attributes: PyAttributes,
296 group: Option<PyGroup>,
297 ) -> PyResult<()> {
298 Ok(self
299 .0
300 .read()
301 .validate_node(
302 &index.into(),
303 &attributes.deep_into(),
304 group.map(std::convert::Into::into).as_ref(),
305 )
306 .map_err(PyGraphRecordError::from)?)
307 }
308
309 #[pyo3(signature = (index, attributes, group=None))]
310 pub fn validate_edge(
311 &self,
312 index: EdgeIndex,
313 attributes: PyAttributes,
314 group: Option<PyGroup>,
315 ) -> PyResult<()> {
316 Ok(self
317 .0
318 .read()
319 .validate_edge(
320 &index,
321 &attributes.deep_into(),
322 group.map(std::convert::Into::into).as_ref(),
323 )
324 .map_err(PyGraphRecordError::from)?)
325 }
326
327 #[pyo3(signature = (attribute, data_type, attribute_type, group=None))]
328 pub fn set_node_attribute(
329 &self,
330 attribute: PyGraphRecordAttribute,
331 data_type: PyDataType,
332 attribute_type: PyAttributeType,
333 group: Option<PyGroup>,
334 ) -> PyResult<()> {
335 Ok(self
336 .0
337 .write()
338 .set_node_attribute(
339 &attribute.into(),
340 data_type.into(),
341 attribute_type.into(),
342 group.map(std::convert::Into::into).as_ref(),
343 )
344 .map_err(PyGraphRecordError::from)?)
345 }
346
347 #[pyo3(signature = (attribute, data_type, attribute_type, group=None))]
348 pub fn set_edge_attribute(
349 &self,
350 attribute: PyGraphRecordAttribute,
351 data_type: PyDataType,
352 attribute_type: PyAttributeType,
353 group: Option<PyGroup>,
354 ) -> PyResult<()> {
355 Ok(self
356 .0
357 .write()
358 .set_edge_attribute(
359 &attribute.into(),
360 data_type.into(),
361 attribute_type.into(),
362 group.map(std::convert::Into::into).as_ref(),
363 )
364 .map_err(PyGraphRecordError::from)?)
365 }
366
367 #[pyo3(signature = (attribute, data_type, attribute_type, group=None))]
368 pub fn update_node_attribute(
369 &self,
370 attribute: PyGraphRecordAttribute,
371 data_type: PyDataType,
372 attribute_type: PyAttributeType,
373 group: Option<PyGroup>,
374 ) -> PyResult<()> {
375 Ok(self
376 .0
377 .write()
378 .update_node_attribute(
379 &attribute.into(),
380 data_type.into(),
381 attribute_type.into(),
382 group.map(std::convert::Into::into).as_ref(),
383 )
384 .map_err(PyGraphRecordError::from)?)
385 }
386
387 #[pyo3(signature = (attribute, data_type, attribute_type, group=None))]
388 pub fn update_edge_attribute(
389 &self,
390 attribute: PyGraphRecordAttribute,
391 data_type: PyDataType,
392 attribute_type: PyAttributeType,
393 group: Option<PyGroup>,
394 ) -> PyResult<()> {
395 Ok(self
396 .0
397 .write()
398 .update_edge_attribute(
399 &attribute.into(),
400 data_type.into(),
401 attribute_type.into(),
402 group.map(std::convert::Into::into).as_ref(),
403 )
404 .map_err(PyGraphRecordError::from)?)
405 }
406
407 #[pyo3(signature = (attribute, group=None))]
408 pub fn remove_node_attribute(&self, attribute: PyGraphRecordAttribute, group: Option<PyGroup>) {
409 self.0.write().remove_node_attribute(
410 &attribute.into(),
411 group.map(std::convert::Into::into).as_ref(),
412 );
413 }
414
415 #[pyo3(signature = (attribute, group=None))]
416 pub fn remove_edge_attribute(&self, attribute: PyGraphRecordAttribute, group: Option<PyGroup>) {
417 self.0.write().remove_edge_attribute(
418 &attribute.into(),
419 group.map(std::convert::Into::into).as_ref(),
420 );
421 }
422
423 pub fn add_group(&self, group: PyGroup, schema: PyGroupSchema) -> PyResult<()> {
424 Ok(self
425 .0
426 .write()
427 .add_group(group.into(), schema.into())
428 .map_err(PyGraphRecordError::from)?)
429 }
430
431 pub fn remove_group(&self, group: PyGroup) {
432 self.0.write().remove_group(&group.into());
433 }
434
435 pub fn freeze(&self) {
436 self.0.write().freeze();
437 }
438
439 pub fn unfreeze(&self) {
440 self.0.write().unfreeze();
441 }
442}