Skip to main content

uni_algo/
projection_input.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Discriminated union over the three projection sources an algorithm
5//! invocation can reference, per proposal §4.10.1.
6//!
7//! Every algorithm call uses the 2-arg `(graphRef, config)` shape where
8//! `graphRef` is a `Map` that [`parse_graph_ref`] decodes into one of the
9//! variants below. `Native` materialises immediately from labels + edge
10//! types; `Cypher` runs inner queries through `QueryProcedureHost`;
11//! `Named` resolves through the per-`Database` `ProjectionStore`.
12
13use serde_json::Value;
14
15/// Source of a graph projection for an algorithm invocation.
16///
17/// The dispatcher [`parse_graph_ref`] picks a variant based on which keys
18/// the user supplied; conflicting key sets are rejected so the call site
19/// cannot mix `nodeLabels` with `nodeQuery`.
20#[derive(Debug, Clone, PartialEq)]
21pub enum ProjectionInput {
22    /// Build a CSR directly from native labels + edge types.
23    Native {
24        /// Vertex labels to include.
25        node_labels: Vec<String>,
26        /// Edge types to traverse.
27        edge_types: Vec<String>,
28        /// Optional edge property to read as scalar weight.
29        weight_property: Option<String>,
30        /// When `true`, build the reverse CSR alongside the forward one.
31        include_reverse: bool,
32    },
33    /// Build a CSR from two inner Cypher queries; the node query must
34    /// yield an `id` column, the edge query must yield `source`/`target`
35    /// (and optionally the column named by `weight_column`).
36    Cypher {
37        /// Cypher query producing the node rows.
38        node_query: String,
39        /// Cypher query producing the edge rows.
40        edge_query: String,
41        /// Optional name of the column in the edge query carrying the
42        /// scalar weight.
43        weight_column: Option<String>,
44        /// When `true`, build the reverse CSR alongside the forward one.
45        include_reverse: bool,
46    },
47    /// Look up a previously materialised projection from the per-
48    /// `Database` `ProjectionStore` (resolved by the host crate;
49    /// `uni-algo` itself only holds the named-lookup variant).
50    Named {
51        /// Name the projection was registered under.
52        name: String,
53    },
54}
55
56/// Error returned by [`parse_graph_ref`] when the input map cannot be
57/// decoded as exactly one [`ProjectionInput`] variant.
58#[derive(Debug, Clone, PartialEq)]
59pub struct GraphRefParseError {
60    /// Human-readable message — flagged through to the caller's
61    /// `FnError::new(0x820, ...)` site.
62    pub message: String,
63}
64
65impl std::fmt::Display for GraphRefParseError {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.write_str(&self.message)
68    }
69}
70
71impl std::error::Error for GraphRefParseError {}
72
73fn err(msg: impl Into<String>) -> GraphRefParseError {
74    GraphRefParseError {
75        message: msg.into(),
76    }
77}
78
79/// Decode a `graphRef` map into one of the [`ProjectionInput`] variants.
80///
81/// Selection rules:
82/// - Presence of `nodeLabels` or `edgeTypes` → `Native`. `nodeQuery` /
83///   `edgeQuery` / `name` must be absent.
84/// - Presence of `nodeQuery` or `edgeQuery` → `Cypher`. Both must be
85///   supplied. `nodeLabels` / `edgeTypes` / `name` must be absent.
86/// - Presence of `name` (with no labels or queries) → `Named`.
87/// - Anything else is a parse error.
88///
89/// Optional keys per variant:
90/// - `Native`: `weightProperty: String`, `includeReverse: Bool`.
91/// - `Cypher`: `weightColumn: String`, `includeReverse: Bool`.
92///
93/// # Errors
94///
95/// Returns [`GraphRefParseError`] when the input is not a map, when key
96/// sets conflict (`nodeLabels` + `nodeQuery` together), when a required
97/// `Cypher` query is missing, or when the value attached to a key has the
98/// wrong shape (e.g. `nodeLabels` not an array of strings).
99pub fn parse_graph_ref(v: &Value) -> Result<ProjectionInput, GraphRefParseError> {
100    let map = v.as_object().ok_or_else(|| err("graphRef must be a Map"))?;
101
102    let has_native = map.contains_key("nodeLabels") || map.contains_key("edgeTypes");
103    let has_cypher = map.contains_key("nodeQuery") || map.contains_key("edgeQuery");
104    let has_named = map.contains_key("name");
105
106    let variants = [has_native, has_cypher, has_named];
107    let selected = variants.iter().filter(|b| **b).count();
108    if selected == 0 {
109        return Err(err(
110            "graphRef must contain one of: nodeLabels/edgeTypes (Native), \
111             nodeQuery/edgeQuery (Cypher), or name (Named)",
112        ));
113    }
114    if selected > 1 {
115        return Err(err(
116            "graphRef keys conflict: pick exactly one of Native (nodeLabels/edgeTypes), \
117             Cypher (nodeQuery/edgeQuery), or Named (name)",
118        ));
119    }
120
121    if has_native {
122        let node_labels = map
123            .get("nodeLabels")
124            .map(parse_string_array)
125            .transpose()?
126            .unwrap_or_default();
127        let edge_types = map
128            .get("edgeTypes")
129            .map(parse_string_array)
130            .transpose()?
131            .unwrap_or_default();
132        let weight_property = map
133            .get("weightProperty")
134            .map(parse_optional_string)
135            .transpose()?
136            .flatten();
137        let include_reverse = map
138            .get("includeReverse")
139            .map(parse_bool)
140            .transpose()?
141            .unwrap_or(true);
142        Ok(ProjectionInput::Native {
143            node_labels,
144            edge_types,
145            weight_property,
146            include_reverse,
147        })
148    } else if has_cypher {
149        let node_query = map
150            .get("nodeQuery")
151            .ok_or_else(|| err("Cypher graphRef requires nodeQuery"))?
152            .as_str()
153            .ok_or_else(|| err("nodeQuery must be a String"))?
154            .to_owned();
155        let edge_query = map
156            .get("edgeQuery")
157            .ok_or_else(|| err("Cypher graphRef requires edgeQuery"))?
158            .as_str()
159            .ok_or_else(|| err("edgeQuery must be a String"))?
160            .to_owned();
161        let weight_column = map
162            .get("weightColumn")
163            .map(parse_optional_string)
164            .transpose()?
165            .flatten();
166        let include_reverse = map
167            .get("includeReverse")
168            .map(parse_bool)
169            .transpose()?
170            .unwrap_or(true);
171        Ok(ProjectionInput::Cypher {
172            node_query,
173            edge_query,
174            weight_column,
175            include_reverse,
176        })
177    } else {
178        let name = map
179            .get("name")
180            .and_then(Value::as_str)
181            .ok_or_else(|| err("Named graphRef requires a String `name`"))?
182            .to_owned();
183        Ok(ProjectionInput::Named { name })
184    }
185}
186
187fn parse_string_array(v: &Value) -> Result<Vec<String>, GraphRefParseError> {
188    let arr = v.as_array().ok_or_else(|| err("expected a String array"))?;
189    arr.iter()
190        .map(|x| {
191            x.as_str()
192                .map(str::to_owned)
193                .ok_or_else(|| err("array element must be a String"))
194        })
195        .collect()
196}
197
198fn parse_optional_string(v: &Value) -> Result<Option<String>, GraphRefParseError> {
199    if v.is_null() {
200        Ok(None)
201    } else {
202        v.as_str()
203            .map(|s| Some(s.to_owned()))
204            .ok_or_else(|| err("expected a String"))
205    }
206}
207
208fn parse_bool(v: &Value) -> Result<bool, GraphRefParseError> {
209    v.as_bool().ok_or_else(|| err("expected a Bool"))
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use serde_json::json;
216
217    #[test]
218    fn native_minimal() {
219        // `include_reverse` defaults to `true` so PageRank /
220        // Louvain / WCC don't silently lose in-neighbors when the
221        // caller omits the field.
222        let v = json!({ "nodeLabels": ["Person"], "edgeTypes": ["KNOWS"] });
223        let got = parse_graph_ref(&v).unwrap();
224        assert_eq!(
225            got,
226            ProjectionInput::Native {
227                node_labels: vec!["Person".to_owned()],
228                edge_types: vec!["KNOWS".to_owned()],
229                weight_property: None,
230                include_reverse: true,
231            }
232        );
233    }
234
235    #[test]
236    fn native_full() {
237        let v = json!({
238            "nodeLabels": ["Person"],
239            "edgeTypes": ["KNOWS"],
240            "weightProperty": "weight",
241            "includeReverse": true,
242        });
243        let got = parse_graph_ref(&v).unwrap();
244        match got {
245            ProjectionInput::Native {
246                weight_property,
247                include_reverse,
248                ..
249            } => {
250                assert_eq!(weight_property.as_deref(), Some("weight"));
251                assert!(include_reverse);
252            }
253            _ => panic!("expected Native"),
254        }
255    }
256
257    #[test]
258    fn cypher_minimal() {
259        let v = json!({
260            "nodeQuery": "MATCH (p:Person) RETURN id(p) AS id",
261            "edgeQuery": "MATCH (a)-[:KNOWS]->(b) RETURN id(a) AS source, id(b) AS target",
262        });
263        let got = parse_graph_ref(&v).unwrap();
264        match got {
265            ProjectionInput::Cypher {
266                node_query,
267                edge_query,
268                weight_column,
269                include_reverse,
270            } => {
271                assert!(node_query.starts_with("MATCH (p:Person)"));
272                assert!(edge_query.starts_with("MATCH (a)"));
273                assert_eq!(weight_column, None);
274                // `include_reverse` defaults to `true` (see
275                // `native_minimal` rationale).
276                assert!(include_reverse);
277            }
278            _ => panic!("expected Cypher"),
279        }
280    }
281
282    #[test]
283    fn named() {
284        let v = json!({ "name": "myGraph" });
285        assert_eq!(
286            parse_graph_ref(&v).unwrap(),
287            ProjectionInput::Named {
288                name: "myGraph".to_owned()
289            }
290        );
291    }
292
293    #[test]
294    fn conflicting_keys_rejected() {
295        let v = json!({ "nodeLabels": ["Person"], "name": "g" });
296        let err = parse_graph_ref(&v).unwrap_err();
297        assert!(err.message.contains("conflict"), "{}", err.message);
298    }
299
300    #[test]
301    fn missing_cypher_partner_rejected() {
302        let v = json!({ "nodeQuery": "RETURN 1 AS id" });
303        let err = parse_graph_ref(&v).unwrap_err();
304        assert!(err.message.contains("edgeQuery"), "{}", err.message);
305    }
306
307    #[test]
308    fn empty_map_rejected() {
309        let v = json!({});
310        let err = parse_graph_ref(&v).unwrap_err();
311        assert!(err.message.contains("must contain"), "{}", err.message);
312    }
313
314    #[test]
315    fn non_map_rejected() {
316        let v = json!("not a map");
317        let err = parse_graph_ref(&v).unwrap_err();
318        assert!(err.message.contains("must be a Map"), "{}", err.message);
319    }
320}