Skip to main content

infigraph_core/structured/
cozo.rs

1use std::path::Path;
2
3use anyhow::{bail, Context, Result};
4
5use super::schema::{escape, interpolate_template, IngestResult, SchemaMeta};
6
7fn cozo_col_type(col_type: &str) -> &str {
8    match col_type {
9        "STRING" => "String",
10        "INT64" => "Int",
11        "BOOL" => "Bool",
12        "DOUBLE" => "Float",
13        "STRING[]" => "String",
14        _ => "String",
15    }
16}
17
18fn cozo_col_default(col_type: &str) -> &str {
19    match col_type {
20        "STRING" | "STRING[]" => "\"\"",
21        "INT64" => "0",
22        "BOOL" => "false",
23        "DOUBLE" => "0.0",
24        _ => "\"\"",
25    }
26}
27
28impl SchemaMeta {
29    pub fn generate_cozo_ddl(&self) -> Vec<String> {
30        let mut stmts = Vec::new();
31
32        let cols: Vec<String> = self
33            .columns
34            .iter()
35            .map(|c| {
36                format!(
37                    "{}: {} default {}",
38                    c.name,
39                    cozo_col_type(&c.col_type),
40                    cozo_col_default(&c.col_type)
41                )
42            })
43            .collect();
44        let table_name = self.node_table.to_lowercase();
45        if cols.is_empty() {
46            stmts.push(format!(":create {table_name} {{id: String}}"));
47        } else {
48            stmts.push(format!(
49                ":create {table_name} {{id: String => {}}}",
50                cols.join(", ")
51            ));
52        }
53
54        for edge in &self.edges {
55            let edge_name = edge.name.to_lowercase();
56            let prop_cols: Vec<String> = edge
57                .properties
58                .iter()
59                .map(|c| {
60                    format!(
61                        ", {}: {} default {}",
62                        c.name,
63                        cozo_col_type(&c.col_type),
64                        cozo_col_default(&c.col_type)
65                    )
66                })
67                .collect();
68            stmts.push(format!(
69                ":create {edge_name} {{from_id: String, to_id: String{}}}",
70                prop_cols.join("")
71            ));
72        }
73
74        stmts
75    }
76}
77
78pub fn ingest_data_cozo(
79    db: &cozo::DbInstance,
80    schema: &SchemaMeta,
81    data: &[serde_json::Value],
82) -> Result<IngestResult> {
83    for ddl in schema.generate_cozo_ddl() {
84        match db.run_script(
85            &ddl,
86            std::collections::BTreeMap::new(),
87            cozo::ScriptMutability::Mutable,
88        ) {
89            Ok(_) => {}
90            Err(e) => {
91                let msg = format!("{e}");
92                if !msg.contains("already exists") && !msg.contains("conflicts") {
93                    bail!("DDL failed: {}", e);
94                }
95            }
96        }
97    }
98
99    let table_name = schema.node_table.to_lowercase();
100    let mut nodes_created = 0usize;
101    let mut edges_created = 0usize;
102
103    for (idx, record) in data.iter().enumerate() {
104        let obj = record
105            .as_object()
106            .with_context(|| format!("record {} is not an object", idx))?;
107
108        let id = if let Some(tmpl) = &schema.id_template {
109            interpolate_template(tmpl, obj)
110        } else if let Some(v) = obj.get("id") {
111            v.as_str()
112                .unwrap_or(&format!("{}_{}", schema.schema_id, idx))
113                .to_string()
114        } else {
115            format!("{}_{}", schema.schema_id, idx)
116        };
117
118        let mut col_names = vec!["id".to_string()];
119        let mut col_vals = vec![format!("\"{}\"", escape(&id))];
120        for col in &schema.columns {
121            let val = obj.get(&col.name);
122            if col.required && val.is_none() {
123                bail!("Record {}: missing required field '{}'", idx, col.name);
124            }
125            col_names.push(col.name.clone());
126            col_vals.push(format_cozo_value(&col.col_type, val));
127        }
128
129        let put_script = format!(
130            "?[{}] <- [[{}]]\n:put {table_name} {{{}}}",
131            col_names.join(", "),
132            col_vals.join(", "),
133            col_names.join(", "),
134        );
135        db.run_script(
136            &put_script,
137            std::collections::BTreeMap::new(),
138            cozo::ScriptMutability::Mutable,
139        )
140        .map_err(|e| anyhow::anyhow!("failed to create node {}: {}", id, e))?;
141        nodes_created += 1;
142
143        for edge in &schema.edges {
144            let targets = match obj.get(&edge.source_field) {
145                Some(serde_json::Value::Array(arr)) => arr
146                    .iter()
147                    .filter_map(|v| v.as_str().map(String::from))
148                    .collect::<Vec<_>>(),
149                Some(serde_json::Value::String(s)) => vec![s.clone()],
150                _ => continue,
151            };
152
153            let edge_name = edge.name.to_lowercase();
154            for target in &targets {
155                let target_id = if edge.to_table == "Symbol" {
156                    resolve_symbol_cozo(db, target).unwrap_or_else(|| {
157                        eprintln!("[warn] unresolved symbol reference: '{}'", target);
158                        target.clone()
159                    })
160                } else if let Some(lookup) = &edge.target_lookup {
161                    format!("{}_{}", lookup, target)
162                } else {
163                    target.clone()
164                };
165
166                let to_table = edge.to_table.to_lowercase();
167                let check_script = format!(
168                    "?[count(id)] := *{to_table}{{id}}, id = \"{}\"",
169                    escape(&target_id)
170                );
171                let target_exists = db
172                    .run_script(
173                        &check_script,
174                        std::collections::BTreeMap::new(),
175                        cozo::ScriptMutability::Immutable,
176                    )
177                    .ok()
178                    .and_then(|r| {
179                        r.rows.first().and_then(|row| row.first()).map(|v| match v {
180                            cozo::DataValue::Num(cozo::Num::Int(i)) => *i > 0,
181                            _ => false,
182                        })
183                    })
184                    .unwrap_or(false);
185
186                if target_exists {
187                    let mut edge_col_names = vec!["from_id".to_string(), "to_id".to_string()];
188                    let mut edge_col_vals = vec![
189                        format!("\"{}\"", escape(&id)),
190                        format!("\"{}\"", escape(&target_id)),
191                    ];
192                    for prop in &edge.properties {
193                        edge_col_names.push(prop.name.clone());
194                        edge_col_vals.push(format_cozo_value(&prop.col_type, obj.get(&prop.name)));
195                    }
196
197                    let put_edge = format!(
198                        "?[{}] <- [[{}]]\n:put {edge_name} {{{}}}",
199                        edge_col_names.join(", "),
200                        edge_col_vals.join(", "),
201                        edge_col_names.join(", "),
202                    );
203                    match db.run_script(
204                        &put_edge,
205                        std::collections::BTreeMap::new(),
206                        cozo::ScriptMutability::Mutable,
207                    ) {
208                        Ok(_) => edges_created += 1,
209                        Err(_) => {}
210                    }
211                }
212            }
213        }
214    }
215
216    Ok(IngestResult {
217        nodes_created,
218        edges_created,
219    })
220}
221
222pub(crate) fn format_cozo_value(col_type: &str, val: Option<&serde_json::Value>) -> String {
223    match val {
224        None => match col_type {
225            "STRING" | "STRING[]" => "\"\"".to_string(),
226            "INT64" => "0".to_string(),
227            "BOOL" => "false".to_string(),
228            "DOUBLE" => "0.0".to_string(),
229            _ => "\"\"".to_string(),
230        },
231        Some(v) => match col_type {
232            "STRING" => format!(
233                "\"{}\"",
234                escape(&v.as_str().unwrap_or_default().to_string())
235            ),
236            "INT64" => v.as_i64().unwrap_or(0).to_string(),
237            "BOOL" => v.as_bool().unwrap_or(false).to_string(),
238            "DOUBLE" => v.as_f64().unwrap_or(0.0).to_string(),
239            "STRING[]" => {
240                if let Some(arr) = v.as_array() {
241                    let items: Vec<String> = arr
242                        .iter()
243                        .filter_map(|s| s.as_str().map(|s| format!("\"{}\"", escape(s))))
244                        .collect();
245                    format!("[{}]", items.join(", "))
246                } else {
247                    "\"\"".to_string()
248                }
249            }
250            _ => format!("\"{}\"", escape(&v.to_string())),
251        },
252    }
253}
254
255fn resolve_symbol_cozo(db: &cozo::DbInstance, reference: &str) -> Option<String> {
256    let esc = reference.replace('"', "\\\"");
257    let script =
258        format!("?[id] := *symbol{{id, name}}, id = \"{esc}\" or name = \"{esc}\"\n:limit 1");
259    db.run_script(
260        &script,
261        std::collections::BTreeMap::new(),
262        cozo::ScriptMutability::Immutable,
263    )
264    .ok()
265    .and_then(|r| {
266        r.rows.first().and_then(|row| {
267            row.first().map(|v| match v {
268                cozo::DataValue::Str(s) => s.to_string(),
269                _ => reference.to_string(),
270            })
271        })
272    })
273}
274
275pub fn ingest_file_cozo(
276    db: &cozo::DbInstance,
277    schema: &SchemaMeta,
278    data_path: &Path,
279) -> Result<IngestResult> {
280    let content = std::fs::read_to_string(data_path)
281        .with_context(|| format!("failed to read data file: {}", data_path.display()))?;
282
283    let ext = data_path.extension().and_then(|e| e.to_str()).unwrap_or("");
284    let data: Vec<serde_json::Value> = match ext {
285        "json" => {
286            let parsed: serde_json::Value = serde_json::from_str(&content)
287                .with_context(|| format!("invalid JSON: {}", data_path.display()))?;
288            match parsed {
289                serde_json::Value::Array(arr) => arr,
290                obj @ serde_json::Value::Object(_) => vec![obj],
291                _ => bail!("JSON must be an array or object"),
292            }
293        }
294        "yaml" | "yml" => {
295            let parsed: serde_json::Value = serde_yaml::from_str(&content)
296                .with_context(|| format!("invalid YAML: {}", data_path.display()))?;
297            match parsed {
298                serde_json::Value::Array(arr) => arr,
299                obj @ serde_json::Value::Object(_) => vec![obj],
300                _ => bail!("YAML must be a sequence or mapping"),
301            }
302        }
303        _ => bail!(
304            "Unsupported data file format '{}' — use .json or .yaml/.yml",
305            ext
306        ),
307    };
308
309    ingest_data_cozo(db, schema, &data)
310}
311
312pub fn ingest_directory_cozo(
313    db: &cozo::DbInstance,
314    schema: &SchemaMeta,
315    dir_path: &Path,
316) -> Result<IngestResult> {
317    if !dir_path.is_dir() {
318        bail!("'{}' is not a directory", dir_path.display());
319    }
320
321    let mut total = IngestResult {
322        nodes_created: 0,
323        edges_created: 0,
324    };
325
326    for entry in std::fs::read_dir(dir_path)
327        .with_context(|| format!("failed to read directory: {}", dir_path.display()))?
328    {
329        let entry = entry?;
330        let path = entry.path();
331        let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
332        if !matches!(ext, "json" | "yaml" | "yml") {
333            continue;
334        }
335        let result = ingest_file_cozo(db, schema, &path)?;
336        total.nodes_created += result.nodes_created;
337        total.edges_created += result.edges_created;
338    }
339
340    Ok(total)
341}