use surrealdb::Surreal;
use crate::error::GraphError;
use crate::store::Db;
use crate::types::*;
pub async fn traverse(
db: &Surreal<Db>,
entity_name: &str,
max_depth: u32,
) -> Result<TraversalNode, GraphError> {
traverse_filtered(db, entity_name, max_depth, None).await
}
pub async fn traverse_filtered(
db: &Surreal<Db>,
entity_name: &str,
max_depth: u32,
type_filter: Option<&str>,
) -> Result<TraversalNode, GraphError> {
let full = crate::crud::get_entity_by_name(db, entity_name)
.await?
.ok_or_else(|| GraphError::NotFound(entity_name.to_string()))?;
crate::crud::increment_access_counts(db, &[full.id_string()]).await?;
let root = EntitySummary {
id: full.id.clone(),
name: full.name,
entity_type: full.entity_type,
abstract_text: full.abstract_text,
};
traverse_from(db, &root, max_depth, 0, &mut vec![], type_filter).await
}
type TraversalFuture<'a> =
std::pin::Pin<Box<dyn std::future::Future<Output = Result<TraversalNode, GraphError>> + 'a>>;
fn traverse_from<'a>(
db: &'a Surreal<Db>,
entity: &'a EntitySummary,
max_depth: u32,
current_depth: u32,
visited: &'a mut Vec<String>,
type_filter: Option<&'a str>,
) -> TraversalFuture<'a> {
Box::pin(async move {
visited.push(entity.id_string());
if current_depth >= max_depth {
return Ok(TraversalNode {
entity: entity.clone(),
edges: vec![],
});
}
let mut edges = Vec::new();
let mut response = db
.query(
r#"
SELECT
rel_type,
valid_from,
valid_until,
out AS target_id
FROM relates_to
WHERE in = type::record($id)
AND valid_until IS NONE
"#,
)
.bind(("id", entity.id_string()))
.await?;
let outgoing: Vec<EdgeRow> = crate::deserialize_take(&mut response, 0)?;
collect_edges(
db,
outgoing,
"->",
max_depth,
current_depth,
visited,
type_filter,
&mut edges,
)
.await?;
let mut response = db
.query(
r#"
SELECT
rel_type,
valid_from,
valid_until,
in AS target_id
FROM relates_to
WHERE out = type::record($id)
AND valid_until IS NONE
"#,
)
.bind(("id", entity.id_string()))
.await?;
let incoming: Vec<EdgeRow> = crate::deserialize_take(&mut response, 0)?;
collect_edges(
db,
incoming,
"<-",
max_depth,
current_depth,
visited,
type_filter,
&mut edges,
)
.await?;
Ok(TraversalNode {
entity: entity.clone(),
edges,
})
})
}
#[allow(clippy::too_many_arguments)]
async fn collect_edges<'a>(
db: &'a Surreal<Db>,
edge_rows: Vec<EdgeRow>,
direction: &str,
max_depth: u32,
current_depth: u32,
visited: &'a mut Vec<String>,
type_filter: Option<&'a str>,
edges: &'a mut Vec<TraversalEdge>,
) -> Result<(), GraphError> {
for edge in edge_rows {
let tid = edge.target_id_string();
let target: Option<EntitySummary> = crate::crud::get_entity_summary(db, &tid).await?;
if let Some(target) = target {
if visited.contains(&target.id_string()) {
continue;
}
if let Some(filter) = type_filter {
if target.entity_type.to_string() != filter {
continue;
}
}
let child = traverse_from(
db,
&target,
max_depth,
current_depth + 1,
visited,
type_filter,
)
.await?;
edges.push(TraversalEdge {
rel_type: edge.rel_type,
direction: direction.to_string(),
target: child,
valid_from: edge.valid_from,
valid_until: edge.valid_until,
});
}
}
Ok(())
}
pub fn format_traversal(node: &TraversalNode, indent: usize) -> String {
let mut out = String::new();
let prefix = " ".repeat(indent);
if indent == 0 {
out.push_str(&format!(
"{} ({})\n",
node.entity.name, node.entity.entity_type
));
}
for edge in &node.edges {
let superseded = if edge.valid_until.is_some() {
" [superseded]"
} else {
""
};
out.push_str(&format!(
"{}{} {} {} {}{}\n",
prefix, "├──", edge.direction, edge.rel_type, edge.target.entity.name, superseded,
));
if !edge.target.edges.is_empty() {
out.push_str(&format_traversal(&edge.target, indent + 1));
}
}
out
}