Skip to main content

panproto_protocols/data_science/
dataframe.rs

1//! DataFrame schema protocol definition (pandera-style).
2//!
3//! DataFrame uses a constrained hypergraph schema theory
4//! (`colimit(ThHypergraph, ThConstraint)`) and a set-valued functor
5//! instance theory (`ThFunctor`).
6
7use std::collections::HashMap;
8use std::hash::BuildHasher;
9
10use panproto_gat::Theory;
11use panproto_schema::{EdgeRule, Protocol, Schema, SchemaBuilder};
12
13use crate::emit::{children_by_edge, find_roots, vertex_constraints};
14use crate::error::ProtocolError;
15use crate::theories;
16
17/// Returns the `DataFrame` protocol definition.
18#[must_use]
19pub fn protocol() -> Protocol {
20    Protocol {
21        name: "dataframe".into(),
22        schema_theory: "ThDataFrameSchema".into(),
23        instance_theory: "ThDataFrameInstance".into(),
24        edge_rules: edge_rules(),
25        obj_kinds: vec![
26            "dataframe".into(),
27            "column".into(),
28            "index".into(),
29            "string".into(),
30            "int64".into(),
31            "float64".into(),
32            "bool".into(),
33            "datetime".into(),
34            "timedelta".into(),
35            "category".into(),
36            "object".into(),
37        ],
38        constraint_sorts: vec![
39            "nullable".into(),
40            "unique".into(),
41            "coerce".into(),
42            "regex".into(),
43            "ge".into(),
44            "le".into(),
45            "gt".into(),
46            "lt".into(),
47            "isin".into(),
48        ],
49        has_order: true,
50        nominal_identity: true,
51        ..Protocol::default()
52    }
53}
54
55/// Register the component GATs for `DataFrame` with a theory registry.
56pub fn register_theories<S: BuildHasher>(registry: &mut HashMap<String, Theory, S>) {
57    theories::register_hypergraph_functor(registry, "ThDataFrameSchema", "ThDataFrameInstance");
58}
59
60/// Parse a pandera-style `DataFrame` schema JSON into a [`Schema`].
61///
62/// Expects a JSON object with `columns` (object mapping names to column defs)
63/// and optional `index` array.
64///
65/// # Errors
66///
67/// Returns [`ProtocolError`] if the JSON is invalid.
68pub fn parse_dataframe_schema(json: &serde_json::Value) -> Result<Schema, ProtocolError> {
69    let proto = protocol();
70    let mut builder = SchemaBuilder::new(&proto);
71    let mut he_counter: usize = 0;
72
73    let df_name = json
74        .get("name")
75        .and_then(serde_json::Value::as_str)
76        .unwrap_or("dataframe");
77
78    builder = builder.vertex(df_name, "dataframe", None)?;
79
80    let mut sig = HashMap::new();
81
82    // Parse columns.
83    if let Some(columns) = json.get("columns").and_then(serde_json::Value::as_object) {
84        for (col_name, col_def) in columns {
85            let col_id = format!("{df_name}.{col_name}");
86            let dtype = col_def
87                .get("dtype")
88                .and_then(serde_json::Value::as_str)
89                .unwrap_or("object");
90            let kind = df_type_to_kind(dtype);
91
92            builder = builder.vertex(&col_id, &kind, None)?;
93            builder = builder.edge(df_name, &col_id, "prop", Some(col_name))?;
94            sig.insert(col_name.clone(), col_id.clone());
95
96            // Parse column constraints.
97            if let Some(nullable) = col_def.get("nullable").and_then(serde_json::Value::as_bool) {
98                builder = builder.constraint(&col_id, "nullable", &nullable.to_string());
99            }
100            if let Some(unique) = col_def.get("unique").and_then(serde_json::Value::as_bool) {
101                if unique {
102                    builder = builder.constraint(&col_id, "unique", "true");
103                }
104            }
105            if let Some(coerce) = col_def.get("coerce").and_then(serde_json::Value::as_bool) {
106                if coerce {
107                    builder = builder.constraint(&col_id, "coerce", "true");
108                }
109            }
110            if let Some(regex) = col_def.get("regex").and_then(serde_json::Value::as_bool) {
111                if regex {
112                    builder = builder.constraint(&col_id, "regex", "true");
113                }
114            }
115
116            // Parse checks.
117            if let Some(checks) = col_def.get("checks").and_then(serde_json::Value::as_object) {
118                for (check_name, check_val) in checks {
119                    match check_name.as_str() {
120                        "ge" | "le" | "gt" | "lt" => {
121                            builder = builder.constraint(
122                                &col_id,
123                                check_name,
124                                &json_val_to_string(check_val),
125                            );
126                        }
127                        "isin" => {
128                            if let Some(arr) = check_val.as_array() {
129                                let vals: Vec<String> = arr
130                                    .iter()
131                                    .map(|v| v.as_str().map_or_else(|| v.to_string(), String::from))
132                                    .collect();
133                                builder = builder.constraint(&col_id, "isin", &vals.join(","));
134                            }
135                        }
136                        _ => {}
137                    }
138                }
139            }
140        }
141    } else {
142        return Err(ProtocolError::MissingField("columns".into()));
143    }
144
145    // Parse index columns.
146    if let Some(index) = json.get("index").and_then(serde_json::Value::as_array) {
147        for idx_def in index {
148            let idx_name = idx_def
149                .get("name")
150                .and_then(serde_json::Value::as_str)
151                .unwrap_or("index");
152            let idx_id = format!("{df_name}:idx:{idx_name}");
153            let dtype = idx_def
154                .get("dtype")
155                .and_then(serde_json::Value::as_str)
156                .unwrap_or("int64");
157            let kind = df_type_to_kind(dtype);
158
159            builder = builder.vertex(&idx_id, &kind, None)?;
160            builder = builder.edge(df_name, &idx_id, "prop", Some(idx_name))?;
161            sig.insert(idx_name.to_string(), idx_id);
162        }
163    }
164
165    if !sig.is_empty() {
166        let he_id = format!("he_{he_counter}");
167        he_counter += 1;
168        builder = builder.hyper_edge(&he_id, "dataframe", sig, df_name)?;
169    }
170    let _ = he_counter;
171
172    let schema = builder.build()?;
173    Ok(schema)
174}
175
176/// Emit a [`Schema`] as pandera-style `DataFrame` schema JSON.
177///
178/// # Errors
179///
180/// Returns [`ProtocolError::Emit`] if the schema cannot be serialized.
181pub fn emit_dataframe_schema(schema: &Schema) -> Result<serde_json::Value, ProtocolError> {
182    let roots: Vec<_> = find_roots(schema, &["prop"]);
183    let df = roots
184        .into_iter()
185        .find(|v| v.kind == "dataframe")
186        .ok_or_else(|| ProtocolError::Emit("no dataframe vertex found".into()))?;
187
188    let children = children_by_edge(schema, &df.id, "prop");
189    let mut columns = serde_json::Map::new();
190
191    for (edge, vertex) in &children {
192        let col_name = edge.name.as_deref().unwrap_or(&vertex.id);
193        let dtype = kind_to_df_type(&vertex.kind);
194
195        let mut col_obj = serde_json::json!({ "dtype": dtype });
196
197        let constraints = vertex_constraints(schema, &vertex.id);
198        let mut checks = serde_json::Map::new();
199
200        for c in &constraints {
201            match c.sort.as_str() {
202                "nullable" => {
203                    col_obj["nullable"] = serde_json::Value::Bool(c.value == "true");
204                }
205                "unique" if c.value == "true" => {
206                    col_obj["unique"] = serde_json::Value::Bool(true);
207                }
208                "coerce" if c.value == "true" => {
209                    col_obj["coerce"] = serde_json::Value::Bool(true);
210                }
211                "regex" if c.value == "true" => {
212                    col_obj["regex"] = serde_json::Value::Bool(true);
213                }
214                "ge" | "le" | "gt" | "lt" => {
215                    let val = c.value.parse::<f64>().map_or_else(
216                        |_| serde_json::Value::String(c.value.clone()),
217                        |n| serde_json::json!(n),
218                    );
219                    checks.insert(c.sort.to_string(), val);
220                }
221                "isin" => {
222                    let vals: Vec<serde_json::Value> = c
223                        .value
224                        .split(',')
225                        .map(|s| serde_json::Value::String(s.to_string()))
226                        .collect();
227                    checks.insert("isin".into(), serde_json::Value::Array(vals));
228                }
229                _ => {}
230            }
231        }
232
233        if !checks.is_empty() {
234            col_obj["checks"] = serde_json::Value::Object(checks);
235        }
236
237        columns.insert(col_name.to_string(), col_obj);
238    }
239
240    Ok(serde_json::json!({
241        "name": df.id,
242        "columns": columns
243    }))
244}
245
246fn df_type_to_kind(dtype: &str) -> String {
247    match dtype.to_lowercase().as_str() {
248        "str" | "string" | "object" => "string",
249        "int" | "int64" | "int32" | "int16" | "int8" => "int64",
250        "float" | "float64" | "float32" => "float64",
251        "bool" | "boolean" => "bool",
252        "datetime" | "datetime64" | "datetime64[ns]" => "datetime",
253        "timedelta" | "timedelta64" | "timedelta64[ns]" => "timedelta",
254        "category" => "category",
255        _ => "object",
256    }
257    .into()
258}
259
260fn kind_to_df_type(kind: &str) -> &'static str {
261    match kind {
262        "string" => "string",
263        "int64" => "int64",
264        "float64" => "float64",
265        "bool" => "bool",
266        "datetime" => "datetime64[ns]",
267        "timedelta" => "timedelta64[ns]",
268        "category" => "category",
269        _ => "object",
270    }
271}
272
273fn json_val_to_string(val: &serde_json::Value) -> String {
274    match val {
275        serde_json::Value::String(s) => s.clone(),
276        serde_json::Value::Number(n) => n.to_string(),
277        _ => val.to_string(),
278    }
279}
280
281fn edge_rules() -> Vec<EdgeRule> {
282    vec![EdgeRule {
283        edge_kind: "prop".into(),
284        src_kinds: vec!["dataframe".into()],
285        tgt_kinds: vec![],
286    }]
287}
288
289#[cfg(test)]
290#[allow(clippy::expect_used, clippy::unwrap_used)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn protocol_creates_valid_definition() {
296        let p = protocol();
297        assert_eq!(p.name, "dataframe");
298        assert_eq!(p.schema_theory, "ThDataFrameSchema");
299        assert!(p.find_edge_rule("prop").is_some());
300    }
301
302    #[test]
303    fn register_theories_adds_correct_theories() {
304        let mut registry = HashMap::new();
305        register_theories(&mut registry);
306        assert!(registry.contains_key("ThDataFrameSchema"));
307        assert!(registry.contains_key("ThDataFrameInstance"));
308    }
309
310    #[test]
311    fn parse_simple_schema() {
312        let json = serde_json::json!({
313            "name": "users",
314            "columns": {
315                "name": { "dtype": "string", "nullable": false },
316                "age": { "dtype": "int64", "nullable": false, "checks": { "ge": 0, "le": 150 } },
317                "score": { "dtype": "float64", "nullable": true }
318            }
319        });
320        let schema = parse_dataframe_schema(&json).expect("should parse");
321        assert!(schema.has_vertex("users"));
322        assert!(schema.has_vertex("users.name"));
323        assert_eq!(schema.vertices.get("users.name").unwrap().kind, "string");
324        assert_eq!(schema.vertices.get("users.age").unwrap().kind, "int64");
325    }
326
327    #[test]
328    fn parse_with_checks() {
329        let json = serde_json::json!({
330            "columns": {
331                "status": {
332                    "dtype": "string",
333                    "checks": { "isin": ["active", "inactive", "pending"] }
334                }
335            }
336        });
337        let schema = parse_dataframe_schema(&json).expect("should parse");
338        assert!(schema.has_vertex("dataframe.status"));
339    }
340
341    #[test]
342    fn emit_roundtrip() {
343        let json = serde_json::json!({
344            "columns": {
345                "x": { "dtype": "int64" },
346                "y": { "dtype": "float64" }
347            }
348        });
349        let schema = parse_dataframe_schema(&json).expect("parse");
350        let emitted = emit_dataframe_schema(&schema).expect("emit");
351        assert_eq!(emitted["columns"].as_object().unwrap().len(), 2);
352    }
353
354    #[test]
355    fn parse_missing_columns_fails() {
356        let json = serde_json::json!({ "name": "broken" });
357        let result = parse_dataframe_schema(&json);
358        assert!(result.is_err());
359    }
360}