coraline 0.8.0

Coraline: semantic code knowledge graph for faster AI-assisted development.
Documentation
#![forbid(unsafe_code)]

use std::collections::{HashMap, HashSet, VecDeque};

use crate::db;
use crate::types::{Edge, EdgeKind, Subgraph, TraversalDirection, TraversalOptions};

#[derive(Debug, Default)]
pub struct Graph;

pub fn build_subgraph(
    conn: &rusqlite::Connection,
    roots: &[String],
    options: &TraversalOptions,
) -> std::io::Result<Subgraph> {
    let mut nodes = HashMap::new();
    let mut edges = Vec::new();
    let mut visited = HashSet::new();

    let max_depth = options.max_depth.unwrap_or(1);
    let include_start = options.include_start.unwrap_or(true);
    let limit = options.limit.unwrap_or(200);
    let direction = options.direction.unwrap_or(TraversalDirection::Both);
    let edge_kinds = options.edge_kinds.as_ref();
    let node_kinds = options.node_kinds.as_ref();

    let mut queue = VecDeque::new();
    for root in roots {
        queue.push_back((root.clone(), 0));
    }

    while let Some((node_id, depth)) = queue.pop_front() {
        if depth > max_depth {
            continue;
        }

        if !visited.insert(node_id.clone()) {
            continue;
        }

        if (include_start || depth > 0)
            && let Some(node) = db::get_node_by_id(conn, &node_id)?
            && node_kinds.is_none_or(|kinds| kinds.contains(&node.kind))
        {
            nodes.insert(node_id.clone(), node);
        }

        if edges.len() >= limit {
            break;
        }

        let mut next_edges = Vec::new();
        if direction != TraversalDirection::Incoming {
            next_edges.extend(fetch_edges(conn, &node_id, true, edge_kinds, limit)?);
        }
        if direction != TraversalDirection::Outgoing {
            next_edges.extend(fetch_edges(conn, &node_id, false, edge_kinds, limit)?);
        }

        for edge in next_edges {
            if edges.len() >= limit {
                break;
            }
            let (next_id, next_depth) = if edge.source == node_id {
                (edge.target.clone(), depth + 1)
            } else {
                (edge.source.clone(), depth + 1)
            };
            edges.push(edge);
            if next_depth <= max_depth {
                queue.push_back((next_id, next_depth));
            }
        }
    }

    Ok(Subgraph {
        nodes,
        edges,
        roots: roots.to_vec(),
    })
}

fn fetch_edges(
    conn: &rusqlite::Connection,
    node_id: &str,
    outgoing: bool,
    edge_kinds: Option<&Vec<EdgeKind>>,
    limit: usize,
) -> std::io::Result<Vec<Edge>> {
    let mut results = Vec::new();
    if let Some(kinds) = edge_kinds {
        for kind in kinds {
            let edges = if outgoing {
                db::get_edges_by_source(conn, node_id, Some(*kind), limit)?
            } else {
                db::get_edges_by_target(conn, node_id, Some(*kind), limit)?
            };
            results.extend(edges);
        }
    } else if outgoing {
        results = db::get_edges_by_source(conn, node_id, None, limit)?;
    } else {
        results = db::get_edges_by_target(conn, node_id, None, limit)?;
    }

    Ok(results)
}