Skip to main content

kyma_graph/
schema_graph.rs

1//! The synthetic schema-graph: the catalog rendered as a property-graph.
2
3use std::collections::BTreeMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7
8use crate::provider::GraphProvider;
9use crate::source::{ColumnDef, SchemaSource};
10use crate::types::{
11    Direction, EdgeExpansion, GraphNode, GraphPayload, GraphRelationship, GraphSchema, GraphStats,
12    NodeMetadata, Props, SearchHits,
13};
14
15// The schema graph is timeless; use a fixed stable timestamp for deterministic JSON.
16const SCHEMA_TS: &str = "1970-01-01T00:00:00Z";
17
18/// True if `name` is a `<g>_nodes`/`<g>_edges` table that backs a registered
19/// property-graph (its sibling table exists). These are storage plumbing for
20/// the connector/stored graphs and are hidden from the catalog *schema* graph —
21/// surfacing them as standalone `Table` nodes is confusing.
22fn is_graph_storage_table(name: &str, all: &std::collections::HashSet<String>) -> bool {
23    if let Some(base) = name.strip_suffix("_nodes") {
24        return all.contains(&format!("{base}_edges"));
25    }
26    if let Some(base) = name.strip_suffix("_edges") {
27        return all.contains(&format!("{base}_nodes"));
28    }
29    false
30}
31
32/// Stable node id for a table.
33fn table_node_id(database: &str, table: &str) -> String {
34    format!("{database}::{table}")
35}
36
37/// Infer `REFERENCES` edges among tables of one database from `<base>_id`
38/// column names. Pure + deterministic so it is trivially testable.
39pub(crate) fn infer_edges(
40    database: &str,
41    tables: &[(String, Vec<ColumnDef>)],
42) -> Vec<GraphRelationship> {
43    let names: Vec<String> = tables.iter().map(|(n, _)| n.to_lowercase()).collect();
44    let mut edges = Vec::new();
45    for (tname, cols) in tables {
46        for c in cols {
47            let lname = c.name.to_lowercase();
48            let Some(base) = lname.strip_suffix("_id") else { continue };
49            if base.is_empty() {
50                continue;
51            }
52            // Match a table whose name is `base` or `base` + "s".
53            let target = names.iter().find(|n| *n == base || **n == format!("{base}s"));
54            if let Some(target_lc) = target {
55                let target_name = tables
56                    .iter()
57                    .find(|(n, _)| n.to_lowercase() == *target_lc)
58                    .map(|(n, _)| n.clone())
59                    .unwrap();
60                if target_name == *tname {
61                    continue; // no self-edge
62                }
63                let mut props: Props = BTreeMap::new();
64                props.insert("via".into(), serde_json::json!(c.name));
65                edges.push(GraphRelationship {
66                    id: format!(
67                        "{}->{}:{}",
68                        table_node_id(database, tname),
69                        table_node_id(database, &target_name),
70                        c.name
71                    ),
72                    source_id: table_node_id(database, tname),
73                    target_id: table_node_id(database, &target_name),
74                    relationship_type: "REFERENCES".into(),
75                    properties: props,
76                });
77            }
78        }
79    }
80    edges
81}
82
83/// Synthetic graph computed live from a [`SchemaSource`]. Cheap to construct;
84/// each call snapshots the catalog (the server already caches schema reads).
85pub struct SchemaGraphProvider {
86    source: Arc<dyn SchemaSource>,
87}
88
89impl SchemaGraphProvider {
90    pub fn new(source: Arc<dyn SchemaSource>) -> Self {
91        Self { source }
92    }
93
94    /// Build the full node + edge set, optionally restricted to one realm
95    /// (= database). Timestamps are a fixed stable value (the schema graph is
96    /// timeless), keeping JSON deterministic.
97    async fn build(
98        &self,
99        realm: Option<&str>,
100    ) -> anyhow::Result<(Vec<GraphNode>, Vec<GraphRelationship>)> {
101        let mut nodes = Vec::new();
102        let mut edges = Vec::new();
103        for db in self.source.databases().await? {
104            if let Some(r) = realm {
105                if r != db {
106                    continue;
107                }
108            }
109            let all_names: std::collections::HashSet<String> =
110                self.source.tables(&db).await?.into_iter().collect();
111            let mut tables: Vec<(String, Vec<ColumnDef>)> = Vec::new();
112            for t in &all_names {
113                // Hide graph-storage plumbing (`<g>_nodes`/`<g>_edges` pairs).
114                if is_graph_storage_table(t, &all_names) {
115                    continue;
116                }
117                let cols = self.source.columns(&db, t).await?;
118                tables.push((t.clone(), cols));
119            }
120            tables.sort_by(|a, b| a.0.cmp(&b.0));
121            for (tname, cols) in &tables {
122                let mut props: Props = BTreeMap::new();
123                props.insert("database".into(), serde_json::json!(db));
124                props.insert("column_count".into(), serde_json::json!(cols.len()));
125                props.insert(
126                    "columns".into(),
127                    serde_json::json!(cols
128                        .iter()
129                        .map(|c| serde_json::json!({"name": c.name, "type": c.type_, "nullable": c.nullable}))
130                        .collect::<Vec<_>>()),
131                );
132                nodes.push(GraphNode {
133                    id: table_node_id(&db, tname),
134                    labels: vec!["Table".into()],
135                    properties: props,
136                    metadata: NodeMetadata {
137                        created_at: SCHEMA_TS.into(),
138                        updated_at: SCHEMA_TS.into(),
139                        source_type: Some("schema".into()),
140                        source_id: None,
141                        realm: db.clone(),
142                    },
143                });
144            }
145            edges.extend(infer_edges(&db, &tables));
146        }
147        Ok((nodes, edges))
148    }
149}
150
151fn compute_stats(nodes: &[GraphNode], edges: &[GraphRelationship]) -> GraphStats {
152    let mut label_counts: BTreeMap<String, usize> = BTreeMap::new();
153    for n in nodes {
154        for l in &n.labels {
155            *label_counts.entry(l.clone()).or_default() += 1;
156        }
157    }
158    let mut relationship_type_counts: BTreeMap<String, usize> = BTreeMap::new();
159    for e in edges {
160        *relationship_type_counts.entry(e.relationship_type.clone()).or_default() += 1;
161    }
162    GraphStats {
163        total_nodes: nodes.len(),
164        total_relationships: edges.len(),
165        label_counts,
166        relationship_type_counts,
167    }
168}
169
170#[async_trait]
171impl GraphProvider for SchemaGraphProvider {
172    async fn overview(&self, realm: Option<&str>, limit: usize) -> anyhow::Result<GraphPayload> {
173        let (mut nodes, edges) = self.build(realm).await?;
174        // stats reflect the FULL catalog; nodes/edges below are the capped view, so stats.total_* may exceed the returned slice lengths.
175        let stats = compute_stats(&nodes, &edges);
176        if nodes.len() > limit {
177            nodes.truncate(limit);
178        }
179        let kept: std::collections::HashSet<&String> = nodes.iter().map(|n| &n.id).collect();
180        let edges = edges
181            .into_iter()
182            .filter(|e| kept.contains(&e.source_id) && kept.contains(&e.target_id))
183            .collect();
184        Ok(GraphPayload { stats, nodes, edges })
185    }
186
187    async fn node(&self, id: &str) -> anyhow::Result<Option<GraphNode>> {
188        let (nodes, _) = self.build(None).await?;
189        Ok(nodes.into_iter().find(|n| n.id == id))
190    }
191
192    async fn neighbors(
193        &self,
194        ids: &[String],
195        dir: Direction,
196        _only_internal: bool,
197        limit: usize,
198    ) -> anyhow::Result<EdgeExpansion> {
199        // `only_internal` is N/A for the schema graph: every node is internal by definition.
200        let (_, all_edges) = self.build(None).await?;
201        let idset: std::collections::HashSet<&String> = ids.iter().collect();
202        let mut edges = Vec::new();
203        let mut new_ids = Vec::new();
204        for e in all_edges {
205            let touches = match dir {
206                Direction::Forward => idset.contains(&e.source_id),
207                Direction::Backward => idset.contains(&e.target_id),
208                Direction::Both => idset.contains(&e.source_id) || idset.contains(&e.target_id),
209            };
210            if !touches {
211                continue;
212            }
213            for end in [&e.source_id, &e.target_id] {
214                if !idset.contains(end) && !new_ids.contains(end) {
215                    new_ids.push(end.clone());
216                }
217            }
218            edges.push(e);
219            if edges.len() >= limit {
220                break;
221            }
222        }
223        Ok(EdgeExpansion { edges, new_node_ids: new_ids })
224    }
225
226    async fn subgraph(&self, id: &str, depth: usize) -> anyhow::Result<GraphPayload> {
227        let (all_nodes, all_edges) = self.build(None).await?;
228        let mut frontier = vec![id.to_string()];
229        let mut visited: std::collections::HashSet<String> = frontier.iter().cloned().collect();
230        let mut kept_edges: Vec<GraphRelationship> = Vec::new();
231        for _ in 0..depth {
232            let mut next = Vec::new();
233            for e in &all_edges {
234                let (a, b) = (&e.source_id, &e.target_id);
235                let hit = frontier.contains(a) || frontier.contains(b);
236                if hit && !kept_edges.iter().any(|k| k.id == e.id) {
237                    kept_edges.push(e.clone());
238                    for end in [a, b] {
239                        if visited.insert(end.clone()) {
240                            next.push(end.clone());
241                        }
242                    }
243                }
244            }
245            if next.is_empty() {
246                break;
247            }
248            frontier = next;
249        }
250        let nodes: Vec<GraphNode> =
251            all_nodes.into_iter().filter(|n| visited.contains(&n.id)).collect();
252        let stats = compute_stats(&nodes, &kept_edges);
253        Ok(GraphPayload { stats, nodes, edges: kept_edges })
254    }
255
256    async fn search(
257        &self,
258        text: &str,
259        labels: &[String],
260        realm: Option<&str>,
261        limit: usize,
262        offset: usize,
263    ) -> anyhow::Result<SearchHits> {
264        let (nodes, _) = self.build(realm).await?;
265        let needle = text.to_lowercase();
266        let mut matched: Vec<GraphNode> = nodes
267            .into_iter()
268            .filter(|n| {
269                let table_name = n.id.rsplit("::").next().unwrap_or(n.id.as_str());
270                let name_ok = table_name.to_lowercase().contains(&needle);
271                let label_ok = labels.is_empty() || labels.iter().any(|l| n.labels.contains(l));
272                name_ok && label_ok
273            })
274            .collect();
275        let total = matched.len();
276        let hits = matched.drain(..).skip(offset).take(limit).collect();
277        Ok(SearchHits { hits, total, limit, offset })
278    }
279
280    async fn stats(&self, realm: Option<&str>) -> anyhow::Result<GraphStats> {
281        let (nodes, edges) = self.build(realm).await?;
282        Ok(compute_stats(&nodes, &edges))
283    }
284
285    async fn schema(&self) -> anyhow::Result<GraphSchema> {
286        let (nodes, edges) = self.build(None).await?;
287        let mut edge_types: Vec<String> =
288            edges.iter().map(|e| e.relationship_type.clone()).collect();
289        edge_types.sort();
290        edge_types.dedup();
291        let mut property_keys: BTreeMap<String, Vec<String>> = BTreeMap::new();
292        if !nodes.is_empty() {
293            property_keys.insert(
294                "Table".into(),
295                vec!["database".into(), "column_count".into(), "columns".into()],
296            );
297        }
298        Ok(GraphSchema {
299            node_kinds: if nodes.is_empty() { vec![] } else { vec!["Table".into()] },
300            edge_types,
301            property_keys,
302        })
303    }
304}
305
306#[cfg(test)]
307mod edge_tests {
308    use super::*;
309    use crate::source::ColumnDef;
310
311    fn col(name: &str) -> ColumnDef {
312        ColumnDef { name: name.into(), type_: "string".into(), nullable: true }
313    }
314
315    #[test]
316    fn hides_graph_storage_table_pairs() {
317        let all: std::collections::HashSet<String> = [
318            "github_nodes", "github_edges", "kg_nodes", "kg_edges", "api_calls", "users", "lonely_nodes",
319        ]
320        .iter()
321        .map(|s| s.to_string())
322        .collect();
323        // `<g>_nodes`/`<g>_edges` pairs that back a graph are hidden.
324        assert!(is_graph_storage_table("github_nodes", &all));
325        assert!(is_graph_storage_table("github_edges", &all));
326        assert!(is_graph_storage_table("kg_nodes", &all));
327        assert!(is_graph_storage_table("kg_edges", &all));
328        // Real schema tables stay visible.
329        assert!(!is_graph_storage_table("api_calls", &all));
330        assert!(!is_graph_storage_table("users", &all));
331        // A lone `_nodes` table with no `_edges` sibling is NOT hidden.
332        assert!(!is_graph_storage_table("lonely_nodes", &all));
333    }
334
335    #[test]
336    fn infers_fk_edge_from_user_id_to_users() {
337        let tables = vec![
338            ("users".to_string(), vec![col("id"), col("email")]),
339            ("orders".to_string(), vec![col("id"), col("user_id"), col("total")]),
340        ];
341        let edges = infer_edges("default", &tables);
342        assert_eq!(edges.len(), 1);
343        let e = &edges[0];
344        assert_eq!(e.source_id, "default::orders");
345        assert_eq!(e.target_id, "default::users");
346        assert_eq!(e.relationship_type, "REFERENCES");
347        assert_eq!(e.properties["via"], "user_id");
348    }
349
350    #[test]
351    fn no_edge_when_no_matching_table() {
352        let tables = vec![
353            ("orders".to_string(), vec![col("id"), col("customer_id")]),
354        ];
355        assert!(infer_edges("default", &tables).is_empty());
356    }
357
358    #[test]
359    fn plain_id_column_is_not_an_edge() {
360        let tables = vec![("users".to_string(), vec![col("id")])];
361        assert!(infer_edges("default", &tables).is_empty());
362    }
363}
364
365#[cfg(test)]
366mod provider_tests {
367    use super::*;
368    use crate::source::{ColumnDef, SchemaSource};
369
370    struct FakeSource;
371
372    fn col(name: &str, t: &str) -> ColumnDef {
373        ColumnDef { name: name.into(), type_: t.into(), nullable: true }
374    }
375
376    #[async_trait]
377    impl SchemaSource for FakeSource {
378        async fn databases(&self) -> anyhow::Result<Vec<String>> {
379            Ok(vec!["default".into()])
380        }
381        async fn tables(&self, _db: &str) -> anyhow::Result<Vec<String>> {
382            Ok(vec!["users".into(), "orders".into()])
383        }
384        async fn columns(&self, _db: &str, table: &str) -> anyhow::Result<Vec<ColumnDef>> {
385            Ok(match table {
386                "users" => vec![col("id", "string"), col("email", "string")],
387                "orders" => vec![col("id", "string"), col("user_id", "string")],
388                _ => vec![],
389            })
390        }
391    }
392
393    fn provider() -> SchemaGraphProvider {
394        SchemaGraphProvider::new(std::sync::Arc::new(FakeSource))
395    }
396
397    #[tokio::test]
398    async fn overview_has_two_table_nodes_and_one_edge() {
399        let p = provider();
400        let payload = p.overview(None, 100).await.unwrap();
401        assert_eq!(payload.nodes.len(), 2);
402        assert!(payload.nodes.iter().all(|n| n.labels == vec!["Table".to_string()]));
403        assert!(payload.nodes.iter().any(|n| n.id == "default::users"));
404        assert_eq!(payload.edges.len(), 1);
405        assert_eq!(payload.stats.total_nodes, 2);
406        assert_eq!(payload.stats.total_relationships, 1);
407        assert_eq!(payload.stats.label_counts["Table"], 2);
408    }
409
410    #[tokio::test]
411    async fn node_lookup_returns_table_props() {
412        let p = provider();
413        let n = p.node("default::orders").await.unwrap().unwrap();
414        assert_eq!(n.metadata.realm, "default");
415        assert_eq!(n.properties["database"], "default");
416        assert_eq!(n.properties["column_count"], 2);
417        assert!(p.node("default::nope").await.unwrap().is_none());
418    }
419
420    #[tokio::test]
421    async fn search_filters_by_name_substring() {
422        let p = provider();
423        let hits = p.search("ord", &[], None, 10, 0).await.unwrap();
424        assert_eq!(hits.total, 1);
425        assert_eq!(hits.hits[0].id, "default::orders");
426    }
427
428    #[tokio::test]
429    async fn neighbors_of_orders_returns_the_reference_edge() {
430        let p = provider();
431        let exp = p
432            .neighbors(&["default::orders".into()], Direction::Both, true, 100)
433            .await
434            .unwrap();
435        assert_eq!(exp.edges.len(), 1);
436        assert_eq!(exp.new_node_ids, vec!["default::users".to_string()]);
437    }
438
439    #[tokio::test]
440    async fn search_does_not_match_database_prefix() {
441        let p = provider();
442        // "default" is the database name (part of every id) but not a table name.
443        let hits = p.search("default", &[], None, 10, 0).await.unwrap();
444        assert_eq!(hits.total, 0, "search must match table names, not the db prefix");
445        // sanity: a real table-name substring still matches
446        let hits2 = p.search("ord", &[], None, 10, 0).await.unwrap();
447        assert_eq!(hits2.total, 1);
448    }
449
450    #[tokio::test]
451    async fn schema_reports_table_kind_and_references_edge() {
452        let p = provider();
453        let s = p.schema().await.unwrap();
454        assert_eq!(s.node_kinds, vec!["Table".to_string()]);
455        assert_eq!(s.edge_types, vec!["REFERENCES".to_string()]);
456    }
457
458    #[tokio::test]
459    async fn overview_caps_nodes_but_stats_reflect_full_graph() {
460        let p = provider();
461        let payload = p.overview(None, 1).await.unwrap();
462        assert_eq!(payload.nodes.len(), 1, "nodes capped to limit");
463        assert_eq!(payload.stats.total_nodes, 2, "stats reflect full graph");
464        // the single FK edge needs both endpoints; with only 1 node kept it's filtered out
465        assert_eq!(payload.edges.len(), 0);
466        assert_eq!(payload.stats.total_relationships, 1);
467    }
468
469    struct ChainSource;
470
471    #[async_trait]
472    impl SchemaSource for ChainSource {
473        async fn databases(&self) -> anyhow::Result<Vec<String>> {
474            Ok(vec!["default".into()])
475        }
476        async fn tables(&self, _db: &str) -> anyhow::Result<Vec<String>> {
477            Ok(vec!["as_".into(), "bs".into(), "cs".into()])
478        }
479        async fn columns(&self, _db: &str, table: &str) -> anyhow::Result<Vec<ColumnDef>> {
480            // as_ -> bs (via b_id), bs -> cs (via c_id)
481            Ok(match table {
482                "as_" => vec![col("id", "string"), col("b_id", "string")],
483                "bs" => vec![col("id", "string"), col("c_id", "string")],
484                "cs" => vec![col("id", "string")],
485                _ => vec![],
486            })
487        }
488    }
489
490    #[tokio::test]
491    async fn subgraph_two_hops_collects_chain() {
492        let p = SchemaGraphProvider::new(std::sync::Arc::new(ChainSource));
493        // depth 2 from as_ should reach bs (hop 1) and cs (hop 2)
494        let sg = p.subgraph("default::as_", 2).await.unwrap();
495        let ids: std::collections::HashSet<String> = sg.nodes.iter().map(|n| n.id.clone()).collect();
496        assert!(ids.contains("default::as_"));
497        assert!(ids.contains("default::bs"));
498        assert!(ids.contains("default::cs"));
499        assert_eq!(sg.edges.len(), 2);
500    }
501}