Skip to main content

infigraph_core/structured/
cozo.rs

1use std::path::Path;
2
3use anyhow::{bail, Context, Result};
4
5use super::schema::{escape, interpolate_template, IngestResult, SchemaMeta};
6
7fn cozo_col_type(col_type: &str) -> &str {
8    match col_type {
9        "STRING" => "String",
10        "INT64" => "Int",
11        "BOOL" => "Bool",
12        "DOUBLE" => "Float",
13        "STRING[]" => "String",
14        _ => "String",
15    }
16}
17
18fn cozo_col_default(col_type: &str) -> &str {
19    match col_type {
20        "STRING" | "STRING[]" => "\"\"",
21        "INT64" => "0",
22        "BOOL" => "false",
23        "DOUBLE" => "0.0",
24        _ => "\"\"",
25    }
26}
27
28impl SchemaMeta {
29    pub fn generate_cozo_ddl(&self) -> Vec<String> {
30        let mut stmts = Vec::new();
31
32        let cols: Vec<String> = self
33            .columns
34            .iter()
35            .map(|c| {
36                format!(
37                    "{}: {} default {}",
38                    c.name,
39                    cozo_col_type(&c.col_type),
40                    cozo_col_default(&c.col_type)
41                )
42            })
43            .collect();
44        let table_name = self.node_table.to_lowercase();
45        if cols.is_empty() {
46            stmts.push(format!(":create {table_name} {{id: String}}"));
47        } else {
48            stmts.push(format!(
49                ":create {table_name} {{id: String => {}}}",
50                cols.join(", ")
51            ));
52        }
53
54        for edge in &self.edges {
55            let edge_name = edge.name.to_lowercase();
56            let prop_cols: Vec<String> = edge
57                .properties
58                .iter()
59                .map(|c| {
60                    format!(
61                        ", {}: {} default {}",
62                        c.name,
63                        cozo_col_type(&c.col_type),
64                        cozo_col_default(&c.col_type)
65                    )
66                })
67                .collect();
68            stmts.push(format!(
69                ":create {edge_name} {{from_id: String, to_id: String{}}}",
70                prop_cols.join("")
71            ));
72        }
73
74        stmts
75    }
76}
77
78pub fn ingest_data_cozo(
79    db: &cozo::DbInstance,
80    schema: &SchemaMeta,
81    data: &[serde_json::Value],
82) -> Result<IngestResult> {
83    for ddl in schema.generate_cozo_ddl() {
84        match db.run_script(
85            &ddl,
86            std::collections::BTreeMap::new(),
87            cozo::ScriptMutability::Mutable,
88        ) {
89            Ok(_) => {}
90            Err(e) => {
91                let msg = format!("{e}");
92                if !msg.contains("already exists") && !msg.contains("conflicts") {
93                    bail!("DDL failed: {}", e);
94                }
95            }
96        }
97    }
98
99    let table_name = schema.node_table.to_lowercase();
100    let mut nodes_created = 0usize;
101    let mut edges_created = 0usize;
102
103    for (idx, record) in data.iter().enumerate() {
104        let obj = record
105            .as_object()
106            .with_context(|| format!("record {} is not an object", idx))?;
107
108        let id = if let Some(tmpl) = &schema.id_template {
109            interpolate_template(tmpl, obj)
110        } else if let Some(v) = obj.get("id") {
111            v.as_str()
112                .unwrap_or(&format!("{}_{}", schema.schema_id, idx))
113                .to_string()
114        } else {
115            format!("{}_{}", schema.schema_id, idx)
116        };
117
118        let mut col_names = vec!["id".to_string()];
119        let mut col_vals = vec![format!("\"{}\"", escape(&id))];
120        for col in &schema.columns {
121            let val = obj.get(&col.name);
122            if col.required && val.is_none() {
123                bail!("Record {}: missing required field '{}'", idx, col.name);
124            }
125            col_names.push(col.name.clone());
126            col_vals.push(format_cozo_value(&col.col_type, val));
127        }
128
129        let put_script = format!(
130            "?[{}] <- [[{}]]\n:put {table_name} {{{}}}",
131            col_names.join(", "),
132            col_vals.join(", "),
133            col_names.join(", "),
134        );
135        db.run_script(
136            &put_script,
137            std::collections::BTreeMap::new(),
138            cozo::ScriptMutability::Mutable,
139        )
140        .map_err(|e| anyhow::anyhow!("failed to create node {}: {}", id, e))?;
141        nodes_created += 1;
142
143        for edge in &schema.edges {
144            let targets = match obj.get(&edge.source_field) {
145                Some(serde_json::Value::Array(arr)) => arr
146                    .iter()
147                    .filter_map(|v| v.as_str().map(String::from))
148                    .collect::<Vec<_>>(),
149                Some(serde_json::Value::String(s)) => vec![s.clone()],
150                _ => continue,
151            };
152
153            let edge_name = edge.name.to_lowercase();
154            for target in &targets {
155                let target_id = if edge.to_table == "Symbol" {
156                    resolve_symbol_cozo(db, target).unwrap_or_else(|| {
157                        eprintln!("[warn] unresolved symbol reference: '{}'", target);
158                        target.clone()
159                    })
160                } else if let Some(lookup) = &edge.target_lookup {
161                    format!("{}_{}", lookup, target)
162                } else {
163                    target.clone()
164                };
165
166                let to_table = edge.to_table.to_lowercase();
167                let check_script = format!(
168                    "?[count(id)] := *{to_table}{{id}}, id = \"{}\"",
169                    escape(&target_id)
170                );
171                let target_exists = db
172                    .run_script(
173                        &check_script,
174                        std::collections::BTreeMap::new(),
175                        cozo::ScriptMutability::Immutable,
176                    )
177                    .ok()
178                    .and_then(|r| {
179                        r.rows.first().and_then(|row| row.first()).map(|v| match v {
180                            cozo::DataValue::Num(cozo::Num::Int(i)) => *i > 0,
181                            _ => false,
182                        })
183                    })
184                    .unwrap_or(false);
185
186                if target_exists {
187                    let mut edge_col_names = vec!["from_id".to_string(), "to_id".to_string()];
188                    let mut edge_col_vals = vec![
189                        format!("\"{}\"", escape(&id)),
190                        format!("\"{}\"", escape(&target_id)),
191                    ];
192                    for prop in &edge.properties {
193                        edge_col_names.push(prop.name.clone());
194                        edge_col_vals.push(format_cozo_value(&prop.col_type, obj.get(&prop.name)));
195                    }
196
197                    let put_edge = format!(
198                        "?[{}] <- [[{}]]\n:put {edge_name} {{{}}}",
199                        edge_col_names.join(", "),
200                        edge_col_vals.join(", "),
201                        edge_col_names.join(", "),
202                    );
203                    if db
204                        .run_script(
205                            &put_edge,
206                            std::collections::BTreeMap::new(),
207                            cozo::ScriptMutability::Mutable,
208                        )
209                        .is_ok()
210                    {
211                        edges_created += 1;
212                    }
213                }
214            }
215        }
216    }
217
218    Ok(IngestResult {
219        nodes_created,
220        edges_created,
221    })
222}
223
224pub(crate) fn format_cozo_value(col_type: &str, val: Option<&serde_json::Value>) -> String {
225    match val {
226        None => match col_type {
227            "STRING" | "STRING[]" => "\"\"".to_string(),
228            "INT64" => "0".to_string(),
229            "BOOL" => "false".to_string(),
230            "DOUBLE" => "0.0".to_string(),
231            _ => "\"\"".to_string(),
232        },
233        Some(v) => match col_type {
234            "STRING" => format!("\"{}\"", escape(v.as_str().unwrap_or_default())),
235            "INT64" => v.as_i64().unwrap_or(0).to_string(),
236            "BOOL" => v.as_bool().unwrap_or(false).to_string(),
237            "DOUBLE" => v.as_f64().unwrap_or(0.0).to_string(),
238            "STRING[]" => {
239                if let Some(arr) = v.as_array() {
240                    let items: Vec<String> = arr
241                        .iter()
242                        .filter_map(|s| s.as_str().map(|s| format!("\"{}\"", escape(s))))
243                        .collect();
244                    format!("[{}]", items.join(", "))
245                } else {
246                    "\"\"".to_string()
247                }
248            }
249            _ => format!("\"{}\"", escape(&v.to_string())),
250        },
251    }
252}
253
254fn resolve_symbol_cozo(db: &cozo::DbInstance, reference: &str) -> Option<String> {
255    let esc = reference.replace('"', "\\\"");
256    let script =
257        format!("?[id] := *symbol{{id, name}}, id = \"{esc}\" or name = \"{esc}\"\n:limit 1");
258    db.run_script(
259        &script,
260        std::collections::BTreeMap::new(),
261        cozo::ScriptMutability::Immutable,
262    )
263    .ok()
264    .and_then(|r| {
265        r.rows.first().and_then(|row| {
266            row.first().map(|v| match v {
267                cozo::DataValue::Str(s) => s.to_string(),
268                _ => reference.to_string(),
269            })
270        })
271    })
272}
273
274pub fn ingest_file_cozo(
275    db: &cozo::DbInstance,
276    schema: &SchemaMeta,
277    data_path: &Path,
278) -> Result<IngestResult> {
279    let content = std::fs::read_to_string(data_path)
280        .with_context(|| format!("failed to read data file: {}", data_path.display()))?;
281
282    let ext = data_path.extension().and_then(|e| e.to_str()).unwrap_or("");
283    let data: Vec<serde_json::Value> = match ext {
284        "json" => {
285            let parsed: serde_json::Value = serde_json::from_str(&content)
286                .with_context(|| format!("invalid JSON: {}", data_path.display()))?;
287            match parsed {
288                serde_json::Value::Array(arr) => arr,
289                obj @ serde_json::Value::Object(_) => vec![obj],
290                _ => bail!("JSON must be an array or object"),
291            }
292        }
293        "yaml" | "yml" => {
294            let parsed: serde_json::Value = serde_yaml::from_str(&content)
295                .with_context(|| format!("invalid YAML: {}", data_path.display()))?;
296            match parsed {
297                serde_json::Value::Array(arr) => arr,
298                obj @ serde_json::Value::Object(_) => vec![obj],
299                _ => bail!("YAML must be a sequence or mapping"),
300            }
301        }
302        _ => bail!(
303            "Unsupported data file format '{}' — use .json or .yaml/.yml",
304            ext
305        ),
306    };
307
308    ingest_data_cozo(db, schema, &data)
309}
310
311pub fn ingest_directory_cozo(
312    db: &cozo::DbInstance,
313    schema: &SchemaMeta,
314    dir_path: &Path,
315) -> Result<IngestResult> {
316    if !dir_path.is_dir() {
317        bail!("'{}' is not a directory", dir_path.display());
318    }
319
320    let mut total = IngestResult {
321        nodes_created: 0,
322        edges_created: 0,
323    };
324
325    for entry in std::fs::read_dir(dir_path)
326        .with_context(|| format!("failed to read directory: {}", dir_path.display()))?
327    {
328        let entry = entry?;
329        let path = entry.path();
330        let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
331        if !matches!(ext, "json" | "yaml" | "yml") {
332            continue;
333        }
334        let result = ingest_file_cozo(db, schema, &path)?;
335        total.nodes_created += result.nodes_created;
336        total.edges_created += result.edges_created;
337    }
338
339    Ok(total)
340}