leankg 0.14.9

Lightweight Knowledge Graph for AI-Assisted Development
Documentation
use crate::db::schema::CozoDb;
use crate::graph::GraphEngine;
use std::collections::HashMap;

pub struct CommunityDetector {
    graph_engine: GraphEngine,
}

impl CommunityDetector {
    pub fn new(db: &CozoDb) -> Self {
        Self {
            graph_engine: GraphEngine::new(db.clone()),
        }
    }

    pub fn detect_communities(
        &self,
    ) -> Result<HashMap<String, Cluster>, Box<dyn std::error::Error>> {
        let elements = self.graph_engine.all_elements()?;
        let relationships = self.graph_engine.all_relationships()?;

        if elements.is_empty() {
            return Ok(HashMap::new());
        }

        let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
        for elem in &elements {
            adjacency.entry(elem.qualified_name.clone()).or_default();
        }

        for rel in &relationships {
            if rel.rel_type == "calls" || rel.rel_type == "imports" {
                adjacency
                    .entry(rel.source_qualified.clone())
                    .or_default()
                    .push(rel.target_qualified.clone());
                adjacency
                    .entry(rel.target_qualified.clone())
                    .or_default()
                    .push(rel.source_qualified.clone());
            }
        }

        let mut labels: HashMap<String, String> = HashMap::new();
        let element_names: Vec<String> =
            elements.iter().map(|e| e.qualified_name.clone()).collect();
        for (i, name) in element_names.iter().enumerate() {
            labels.insert(name.clone(), format!("cluster_{}", i % 10));
        }

        for _ in 0..5 {
            for name in &element_names {
                if let Some(neighbors) = adjacency.get(name) {
                    if neighbors.is_empty() {
                        continue;
                    }
                    let mut label_counts: HashMap<String, usize> = HashMap::new();
                    for neighbor in neighbors {
                        if let Some(label) = labels.get(neighbor) {
                            *label_counts.entry(label.clone()).or_insert(0) += 1;
                        }
                    }
                    if let Some(max_label) = label_counts
                        .into_iter()
                        .max_by_key(|(_, count)| *count)
                        .map(|(label, _)| label)
                    {
                        labels.insert(name.clone(), max_label);
                    }
                }
            }
        }

        let mut clusters: HashMap<String, Cluster> = HashMap::new();
        for elem in &elements {
            let label = labels
                .get(&elem.qualified_name)
                .cloned()
                .unwrap_or_else(|| "unknown".to_string());
            let cluster_label = self.generate_cluster_label(&label, &elem.file_path);

            clusters
                .entry(label.clone())
                .or_insert_with(|| Cluster {
                    id: label.clone(),
                    label: cluster_label,
                    members: Vec::new(),
                    representative_files: Vec::new(),
                })
                .members
                .push(elem.qualified_name.clone());
        }

        for cluster in clusters.values_mut() {
            let mut file_counts: HashMap<String, usize> = HashMap::new();
            for member in &cluster.members {
                if let Some(elem) = elements.iter().find(|e| &e.qualified_name == member) {
                    *file_counts.entry(elem.file_path.clone()).or_insert(0) += 1;
                }
            }
            let mut files: Vec<(String, usize)> = file_counts.into_iter().collect();
            files.sort_by(|a, b| b.1.cmp(&a.1));
            cluster.representative_files =
                files.into_iter().take(3).map(|(path, _)| path).collect();
        }

        Ok(clusters)
    }

    fn generate_cluster_label(&self, cluster_id: &str, file_path: &str) -> String {
        let path_parts: Vec<&str> = file_path.split('/').collect();
        if path_parts.len() >= 2 {
            let dir = path_parts[path_parts.len() - 2];
            let normalized = dir
                .chars()
                .map(|c| {
                    if c.is_alphanumeric() {
                        c.to_ascii_lowercase()
                    } else {
                        '_'
                    }
                })
                .collect::<String>();
            if !normalized.is_empty() && normalized != "_" {
                return normalized;
            }
        }
        cluster_id.replace("cluster_", "module_")
    }

    pub fn assign_clusters_to_elements(&self) -> Result<(), Box<dyn std::error::Error>> {
        let clusters = self.detect_communities()?;

        for cluster in clusters.values() {
            for member_qualified in &cluster.members {
                self.graph_engine.update_element_cluster(
                    member_qualified,
                    Some(cluster.id.clone()),
                    Some(cluster.label.clone()),
                )?;
            }
        }

        Ok(())
    }
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Cluster {
    pub id: String,
    pub label: String,
    pub members: Vec<String>,
    pub representative_files: Vec<String>,
}

pub fn get_cluster_stats(clusters: &HashMap<String, Cluster>) -> ClusterStats {
    let total_members: usize = clusters.values().map(|c| c.members.len()).sum();
    let avg_cluster_size = if clusters.is_empty() {
        0.0
    } else {
        total_members as f64 / clusters.len() as f64
    };

    ClusterStats {
        total_clusters: clusters.len(),
        total_members,
        avg_cluster_size,
    }
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ClusterStats {
    pub total_clusters: usize,
    pub total_members: usize,
    pub avg_cluster_size: f64,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_cluster_stats() {
        let mut clusters = HashMap::new();
        clusters.insert(
            "c1".to_string(),
            Cluster {
                id: "c1".to_string(),
                label: "auth".to_string(),
                members: vec!["a".to_string(), "b".to_string()],
                representative_files: vec!["auth.rs".to_string()],
            },
        );
        clusters.insert(
            "c2".to_string(),
            Cluster {
                id: "c2".to_string(),
                label: "api".to_string(),
                members: vec!["c".to_string(), "d".to_string(), "e".to_string()],
                representative_files: vec!["api.rs".to_string()],
            },
        );

        let stats = get_cluster_stats(&clusters);
        assert_eq!(stats.total_clusters, 2);
        assert_eq!(stats.total_members, 5);
        assert!((stats.avg_cluster_size - 2.5).abs() < 0.001);
    }
}