1use std::collections::BTreeMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7
8use crate::provider::GraphProvider;
9use crate::source::{ColumnDef, SchemaSource};
10use crate::types::{
11 Direction, EdgeExpansion, GraphNode, GraphPayload, GraphRelationship, GraphSchema, GraphStats,
12 NodeMetadata, Props, SearchHits,
13};
14
15const SCHEMA_TS: &str = "1970-01-01T00:00:00Z";
17
18fn is_graph_storage_table(name: &str, all: &std::collections::HashSet<String>) -> bool {
23 if let Some(base) = name.strip_suffix("_nodes") {
24 return all.contains(&format!("{base}_edges"));
25 }
26 if let Some(base) = name.strip_suffix("_edges") {
27 return all.contains(&format!("{base}_nodes"));
28 }
29 false
30}
31
32fn table_node_id(database: &str, table: &str) -> String {
34 format!("{database}::{table}")
35}
36
37pub(crate) fn infer_edges(
40 database: &str,
41 tables: &[(String, Vec<ColumnDef>)],
42) -> Vec<GraphRelationship> {
43 let names: Vec<String> = tables.iter().map(|(n, _)| n.to_lowercase()).collect();
44 let mut edges = Vec::new();
45 for (tname, cols) in tables {
46 for c in cols {
47 let lname = c.name.to_lowercase();
48 let Some(base) = lname.strip_suffix("_id") else { continue };
49 if base.is_empty() {
50 continue;
51 }
52 let target = names.iter().find(|n| *n == base || **n == format!("{base}s"));
54 if let Some(target_lc) = target {
55 let target_name = tables
56 .iter()
57 .find(|(n, _)| n.to_lowercase() == *target_lc)
58 .map(|(n, _)| n.clone())
59 .unwrap();
60 if target_name == *tname {
61 continue; }
63 let mut props: Props = BTreeMap::new();
64 props.insert("via".into(), serde_json::json!(c.name));
65 edges.push(GraphRelationship {
66 id: format!(
67 "{}->{}:{}",
68 table_node_id(database, tname),
69 table_node_id(database, &target_name),
70 c.name
71 ),
72 source_id: table_node_id(database, tname),
73 target_id: table_node_id(database, &target_name),
74 relationship_type: "REFERENCES".into(),
75 properties: props,
76 });
77 }
78 }
79 }
80 edges
81}
82
83pub struct SchemaGraphProvider {
86 source: Arc<dyn SchemaSource>,
87}
88
89impl SchemaGraphProvider {
90 pub fn new(source: Arc<dyn SchemaSource>) -> Self {
91 Self { source }
92 }
93
94 async fn build(
98 &self,
99 realm: Option<&str>,
100 ) -> anyhow::Result<(Vec<GraphNode>, Vec<GraphRelationship>)> {
101 let mut nodes = Vec::new();
102 let mut edges = Vec::new();
103 for db in self.source.databases().await? {
104 if let Some(r) = realm {
105 if r != db {
106 continue;
107 }
108 }
109 let all_names: std::collections::HashSet<String> =
110 self.source.tables(&db).await?.into_iter().collect();
111 let mut tables: Vec<(String, Vec<ColumnDef>)> = Vec::new();
112 for t in &all_names {
113 if is_graph_storage_table(t, &all_names) {
115 continue;
116 }
117 let cols = self.source.columns(&db, t).await?;
118 tables.push((t.clone(), cols));
119 }
120 tables.sort_by(|a, b| a.0.cmp(&b.0));
121 for (tname, cols) in &tables {
122 let mut props: Props = BTreeMap::new();
123 props.insert("database".into(), serde_json::json!(db));
124 props.insert("column_count".into(), serde_json::json!(cols.len()));
125 props.insert(
126 "columns".into(),
127 serde_json::json!(cols
128 .iter()
129 .map(|c| serde_json::json!({"name": c.name, "type": c.type_, "nullable": c.nullable}))
130 .collect::<Vec<_>>()),
131 );
132 nodes.push(GraphNode {
133 id: table_node_id(&db, tname),
134 labels: vec!["Table".into()],
135 properties: props,
136 metadata: NodeMetadata {
137 created_at: SCHEMA_TS.into(),
138 updated_at: SCHEMA_TS.into(),
139 source_type: Some("schema".into()),
140 source_id: None,
141 realm: db.clone(),
142 },
143 });
144 }
145 edges.extend(infer_edges(&db, &tables));
146 }
147 Ok((nodes, edges))
148 }
149}
150
151fn compute_stats(nodes: &[GraphNode], edges: &[GraphRelationship]) -> GraphStats {
152 let mut label_counts: BTreeMap<String, usize> = BTreeMap::new();
153 for n in nodes {
154 for l in &n.labels {
155 *label_counts.entry(l.clone()).or_default() += 1;
156 }
157 }
158 let mut relationship_type_counts: BTreeMap<String, usize> = BTreeMap::new();
159 for e in edges {
160 *relationship_type_counts.entry(e.relationship_type.clone()).or_default() += 1;
161 }
162 GraphStats {
163 total_nodes: nodes.len(),
164 total_relationships: edges.len(),
165 label_counts,
166 relationship_type_counts,
167 }
168}
169
170#[async_trait]
171impl GraphProvider for SchemaGraphProvider {
172 async fn overview(&self, realm: Option<&str>, limit: usize) -> anyhow::Result<GraphPayload> {
173 let (mut nodes, edges) = self.build(realm).await?;
174 let stats = compute_stats(&nodes, &edges);
176 if nodes.len() > limit {
177 nodes.truncate(limit);
178 }
179 let kept: std::collections::HashSet<&String> = nodes.iter().map(|n| &n.id).collect();
180 let edges = edges
181 .into_iter()
182 .filter(|e| kept.contains(&e.source_id) && kept.contains(&e.target_id))
183 .collect();
184 Ok(GraphPayload { stats, nodes, edges })
185 }
186
187 async fn node(&self, id: &str) -> anyhow::Result<Option<GraphNode>> {
188 let (nodes, _) = self.build(None).await?;
189 Ok(nodes.into_iter().find(|n| n.id == id))
190 }
191
192 async fn neighbors(
193 &self,
194 ids: &[String],
195 dir: Direction,
196 _only_internal: bool,
197 limit: usize,
198 ) -> anyhow::Result<EdgeExpansion> {
199 let (_, all_edges) = self.build(None).await?;
201 let idset: std::collections::HashSet<&String> = ids.iter().collect();
202 let mut edges = Vec::new();
203 let mut new_ids = Vec::new();
204 for e in all_edges {
205 let touches = match dir {
206 Direction::Forward => idset.contains(&e.source_id),
207 Direction::Backward => idset.contains(&e.target_id),
208 Direction::Both => idset.contains(&e.source_id) || idset.contains(&e.target_id),
209 };
210 if !touches {
211 continue;
212 }
213 for end in [&e.source_id, &e.target_id] {
214 if !idset.contains(end) && !new_ids.contains(end) {
215 new_ids.push(end.clone());
216 }
217 }
218 edges.push(e);
219 if edges.len() >= limit {
220 break;
221 }
222 }
223 Ok(EdgeExpansion { edges, new_node_ids: new_ids })
224 }
225
226 async fn subgraph(&self, id: &str, depth: usize) -> anyhow::Result<GraphPayload> {
227 let (all_nodes, all_edges) = self.build(None).await?;
228 let mut frontier = vec![id.to_string()];
229 let mut visited: std::collections::HashSet<String> = frontier.iter().cloned().collect();
230 let mut kept_edges: Vec<GraphRelationship> = Vec::new();
231 for _ in 0..depth {
232 let mut next = Vec::new();
233 for e in &all_edges {
234 let (a, b) = (&e.source_id, &e.target_id);
235 let hit = frontier.contains(a) || frontier.contains(b);
236 if hit && !kept_edges.iter().any(|k| k.id == e.id) {
237 kept_edges.push(e.clone());
238 for end in [a, b] {
239 if visited.insert(end.clone()) {
240 next.push(end.clone());
241 }
242 }
243 }
244 }
245 if next.is_empty() {
246 break;
247 }
248 frontier = next;
249 }
250 let nodes: Vec<GraphNode> =
251 all_nodes.into_iter().filter(|n| visited.contains(&n.id)).collect();
252 let stats = compute_stats(&nodes, &kept_edges);
253 Ok(GraphPayload { stats, nodes, edges: kept_edges })
254 }
255
256 async fn search(
257 &self,
258 text: &str,
259 labels: &[String],
260 realm: Option<&str>,
261 limit: usize,
262 offset: usize,
263 ) -> anyhow::Result<SearchHits> {
264 let (nodes, _) = self.build(realm).await?;
265 let needle = text.to_lowercase();
266 let mut matched: Vec<GraphNode> = nodes
267 .into_iter()
268 .filter(|n| {
269 let table_name = n.id.rsplit("::").next().unwrap_or(n.id.as_str());
270 let name_ok = table_name.to_lowercase().contains(&needle);
271 let label_ok = labels.is_empty() || labels.iter().any(|l| n.labels.contains(l));
272 name_ok && label_ok
273 })
274 .collect();
275 let total = matched.len();
276 let hits = matched.drain(..).skip(offset).take(limit).collect();
277 Ok(SearchHits { hits, total, limit, offset })
278 }
279
280 async fn stats(&self, realm: Option<&str>) -> anyhow::Result<GraphStats> {
281 let (nodes, edges) = self.build(realm).await?;
282 Ok(compute_stats(&nodes, &edges))
283 }
284
285 async fn schema(&self) -> anyhow::Result<GraphSchema> {
286 let (nodes, edges) = self.build(None).await?;
287 let mut edge_types: Vec<String> =
288 edges.iter().map(|e| e.relationship_type.clone()).collect();
289 edge_types.sort();
290 edge_types.dedup();
291 let mut property_keys: BTreeMap<String, Vec<String>> = BTreeMap::new();
292 if !nodes.is_empty() {
293 property_keys.insert(
294 "Table".into(),
295 vec!["database".into(), "column_count".into(), "columns".into()],
296 );
297 }
298 Ok(GraphSchema {
299 node_kinds: if nodes.is_empty() { vec![] } else { vec!["Table".into()] },
300 edge_types,
301 property_keys,
302 })
303 }
304}
305
306#[cfg(test)]
307mod edge_tests {
308 use super::*;
309 use crate::source::ColumnDef;
310
311 fn col(name: &str) -> ColumnDef {
312 ColumnDef { name: name.into(), type_: "string".into(), nullable: true }
313 }
314
315 #[test]
316 fn hides_graph_storage_table_pairs() {
317 let all: std::collections::HashSet<String> = [
318 "github_nodes", "github_edges", "kg_nodes", "kg_edges", "api_calls", "users", "lonely_nodes",
319 ]
320 .iter()
321 .map(|s| s.to_string())
322 .collect();
323 assert!(is_graph_storage_table("github_nodes", &all));
325 assert!(is_graph_storage_table("github_edges", &all));
326 assert!(is_graph_storage_table("kg_nodes", &all));
327 assert!(is_graph_storage_table("kg_edges", &all));
328 assert!(!is_graph_storage_table("api_calls", &all));
330 assert!(!is_graph_storage_table("users", &all));
331 assert!(!is_graph_storage_table("lonely_nodes", &all));
333 }
334
335 #[test]
336 fn infers_fk_edge_from_user_id_to_users() {
337 let tables = vec![
338 ("users".to_string(), vec![col("id"), col("email")]),
339 ("orders".to_string(), vec![col("id"), col("user_id"), col("total")]),
340 ];
341 let edges = infer_edges("default", &tables);
342 assert_eq!(edges.len(), 1);
343 let e = &edges[0];
344 assert_eq!(e.source_id, "default::orders");
345 assert_eq!(e.target_id, "default::users");
346 assert_eq!(e.relationship_type, "REFERENCES");
347 assert_eq!(e.properties["via"], "user_id");
348 }
349
350 #[test]
351 fn no_edge_when_no_matching_table() {
352 let tables = vec![
353 ("orders".to_string(), vec![col("id"), col("customer_id")]),
354 ];
355 assert!(infer_edges("default", &tables).is_empty());
356 }
357
358 #[test]
359 fn plain_id_column_is_not_an_edge() {
360 let tables = vec![("users".to_string(), vec![col("id")])];
361 assert!(infer_edges("default", &tables).is_empty());
362 }
363}
364
365#[cfg(test)]
366mod provider_tests {
367 use super::*;
368 use crate::source::{ColumnDef, SchemaSource};
369
370 struct FakeSource;
371
372 fn col(name: &str, t: &str) -> ColumnDef {
373 ColumnDef { name: name.into(), type_: t.into(), nullable: true }
374 }
375
376 #[async_trait]
377 impl SchemaSource for FakeSource {
378 async fn databases(&self) -> anyhow::Result<Vec<String>> {
379 Ok(vec!["default".into()])
380 }
381 async fn tables(&self, _db: &str) -> anyhow::Result<Vec<String>> {
382 Ok(vec!["users".into(), "orders".into()])
383 }
384 async fn columns(&self, _db: &str, table: &str) -> anyhow::Result<Vec<ColumnDef>> {
385 Ok(match table {
386 "users" => vec![col("id", "string"), col("email", "string")],
387 "orders" => vec![col("id", "string"), col("user_id", "string")],
388 _ => vec![],
389 })
390 }
391 }
392
393 fn provider() -> SchemaGraphProvider {
394 SchemaGraphProvider::new(std::sync::Arc::new(FakeSource))
395 }
396
397 #[tokio::test]
398 async fn overview_has_two_table_nodes_and_one_edge() {
399 let p = provider();
400 let payload = p.overview(None, 100).await.unwrap();
401 assert_eq!(payload.nodes.len(), 2);
402 assert!(payload.nodes.iter().all(|n| n.labels == vec!["Table".to_string()]));
403 assert!(payload.nodes.iter().any(|n| n.id == "default::users"));
404 assert_eq!(payload.edges.len(), 1);
405 assert_eq!(payload.stats.total_nodes, 2);
406 assert_eq!(payload.stats.total_relationships, 1);
407 assert_eq!(payload.stats.label_counts["Table"], 2);
408 }
409
410 #[tokio::test]
411 async fn node_lookup_returns_table_props() {
412 let p = provider();
413 let n = p.node("default::orders").await.unwrap().unwrap();
414 assert_eq!(n.metadata.realm, "default");
415 assert_eq!(n.properties["database"], "default");
416 assert_eq!(n.properties["column_count"], 2);
417 assert!(p.node("default::nope").await.unwrap().is_none());
418 }
419
420 #[tokio::test]
421 async fn search_filters_by_name_substring() {
422 let p = provider();
423 let hits = p.search("ord", &[], None, 10, 0).await.unwrap();
424 assert_eq!(hits.total, 1);
425 assert_eq!(hits.hits[0].id, "default::orders");
426 }
427
428 #[tokio::test]
429 async fn neighbors_of_orders_returns_the_reference_edge() {
430 let p = provider();
431 let exp = p
432 .neighbors(&["default::orders".into()], Direction::Both, true, 100)
433 .await
434 .unwrap();
435 assert_eq!(exp.edges.len(), 1);
436 assert_eq!(exp.new_node_ids, vec!["default::users".to_string()]);
437 }
438
439 #[tokio::test]
440 async fn search_does_not_match_database_prefix() {
441 let p = provider();
442 let hits = p.search("default", &[], None, 10, 0).await.unwrap();
444 assert_eq!(hits.total, 0, "search must match table names, not the db prefix");
445 let hits2 = p.search("ord", &[], None, 10, 0).await.unwrap();
447 assert_eq!(hits2.total, 1);
448 }
449
450 #[tokio::test]
451 async fn schema_reports_table_kind_and_references_edge() {
452 let p = provider();
453 let s = p.schema().await.unwrap();
454 assert_eq!(s.node_kinds, vec!["Table".to_string()]);
455 assert_eq!(s.edge_types, vec!["REFERENCES".to_string()]);
456 }
457
458 #[tokio::test]
459 async fn overview_caps_nodes_but_stats_reflect_full_graph() {
460 let p = provider();
461 let payload = p.overview(None, 1).await.unwrap();
462 assert_eq!(payload.nodes.len(), 1, "nodes capped to limit");
463 assert_eq!(payload.stats.total_nodes, 2, "stats reflect full graph");
464 assert_eq!(payload.edges.len(), 0);
466 assert_eq!(payload.stats.total_relationships, 1);
467 }
468
469 struct ChainSource;
470
471 #[async_trait]
472 impl SchemaSource for ChainSource {
473 async fn databases(&self) -> anyhow::Result<Vec<String>> {
474 Ok(vec!["default".into()])
475 }
476 async fn tables(&self, _db: &str) -> anyhow::Result<Vec<String>> {
477 Ok(vec!["as_".into(), "bs".into(), "cs".into()])
478 }
479 async fn columns(&self, _db: &str, table: &str) -> anyhow::Result<Vec<ColumnDef>> {
480 Ok(match table {
482 "as_" => vec![col("id", "string"), col("b_id", "string")],
483 "bs" => vec![col("id", "string"), col("c_id", "string")],
484 "cs" => vec![col("id", "string")],
485 _ => vec![],
486 })
487 }
488 }
489
490 #[tokio::test]
491 async fn subgraph_two_hops_collects_chain() {
492 let p = SchemaGraphProvider::new(std::sync::Arc::new(ChainSource));
493 let sg = p.subgraph("default::as_", 2).await.unwrap();
495 let ids: std::collections::HashSet<String> = sg.nodes.iter().map(|n| n.id.clone()).collect();
496 assert!(ids.contains("default::as_"));
497 assert!(ids.contains("default::bs"));
498 assert!(ids.contains("default::cs"));
499 assert_eq!(sg.edges.len(), 2);
500 }
501}