use anyhow::Result;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::{Arc, RwLock};
use arrow_array::{Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use crate::store::{table_ops::TableOperations, CodeBlock};
use futures::TryStreamExt;
use lancedb::{
query::{ExecutableQuery, QueryBase},
Connection, DistanceType,
};
use crate::indexer::graphrag::types::RelationType;
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
type AdjacencyCache = HashMap<String, HashMap<RelationType, Vec<String>>>;
pub struct GraphRagOperations<'a> {
pub db: &'a Connection,
pub table_ops: TableOperations<'a>,
pub code_vector_dim: usize,
adjacency_cache: Arc<RwLock<AdjacencyCache>>,
cache_stats: Arc<RwLock<CacheStats>>,
}
impl<'a> GraphRagOperations<'a> {
pub fn new(db: &'a Connection, code_vector_dim: usize) -> Self {
Self {
db,
table_ops: TableOperations::new(db),
code_vector_dim,
adjacency_cache: Arc::new(RwLock::new(HashMap::new())),
cache_stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
pub fn clear_cache(&self) {
if let Ok(mut cache) = self.adjacency_cache.write() {
cache.clear();
}
if let Ok(mut stats) = self.cache_stats.write() {
*stats = CacheStats::default();
}
}
pub fn invalidate_cache_for_node(&self, node_id: &str) {
if let Ok(mut cache) = self.adjacency_cache.write() {
cache.remove(node_id);
}
}
pub fn get_cache_stats(&self) -> CacheStats {
self.cache_stats
.read()
.map(|stats| stats.clone())
.unwrap_or_default()
}
fn get_adjacent_nodes_cached(
&self,
node_id: &str,
relation_type: &RelationType,
) -> Option<Vec<String>> {
let cache = self.adjacency_cache.read().ok()?;
let result = cache
.get(node_id)
.and_then(|relations| relations.get(relation_type))
.cloned();
if let Ok(mut stats) = self.cache_stats.write() {
if result.is_some() {
stats.hits += 1;
} else {
stats.misses += 1;
}
}
result
}
fn update_cache(&self, source_id: &str, relation_type: RelationType, target_id: String) {
if let Ok(mut cache) = self.adjacency_cache.write() {
cache
.entry(source_id.to_string())
.or_default()
.entry(relation_type)
.or_default()
.push(target_id);
}
}
pub async fn build_adjacency_cache(&self) -> Result<()> {
self.clear_cache();
if !self
.table_ops
.table_exists("graphrag_relationships")
.await?
{
return Ok(());
}
let table = self
.db
.open_table("graphrag_relationships")
.execute()
.await?;
let mut results = table.query().execute().await?;
let mut total_relationships = 0;
while let Some(batch) = results.try_next().await? {
if batch.num_rows() == 0 {
continue;
}
let source_ids = batch
.column_by_name("source")
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
.ok_or_else(|| anyhow::anyhow!("Missing or invalid source column"))?;
let target_ids = batch
.column_by_name("target")
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
.ok_or_else(|| anyhow::anyhow!("Missing or invalid target column"))?;
let relation_types = batch
.column_by_name("relation_type")
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
.ok_or_else(|| anyhow::anyhow!("Missing or invalid relation_type column"))?;
for i in 0..batch.num_rows() {
if let (Some(source), Some(target), Some(rel_type_str)) = (
source_ids.value(i).to_string().into(),
target_ids.value(i).to_string().into(),
relation_types.value(i).to_string().into(),
) {
if let Ok(rel_type) = RelationType::from_str(&rel_type_str) {
self.update_cache(&source, rel_type, target);
total_relationships += 1;
}
}
}
}
tracing::info!(
"Built adjacency cache with {} relationships",
total_relationships
);
Ok(())
}
pub async fn get_neighbors_cached(
&self,
node_id: &str,
relation_type: &RelationType,
) -> Result<Vec<String>> {
if let Some(neighbors) = self.get_adjacent_nodes_cached(node_id, relation_type) {
return Ok(neighbors);
}
let relationships = self
.get_node_relationships(
node_id,
crate::indexer::graphrag::types::RelationshipDirection::Outgoing,
)
.await?;
let neighbors: Vec<String> = relationships
.iter()
.filter(|rel| &rel.relation_type == relation_type)
.map(|rel| rel.target.clone())
.collect();
for target in &neighbors {
self.update_cache(node_id, relation_type.clone(), target.clone());
}
Ok(neighbors)
}
pub async fn traverse_path_cached(
&self,
start_node: &str,
relation_types: &[RelationType],
max_depth: usize,
) -> Result<Vec<String>> {
use std::collections::{HashSet, VecDeque};
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
let mut result = Vec::new();
queue.push_back((start_node.to_string(), 0));
visited.insert(start_node.to_string());
while let Some((current_node, depth)) = queue.pop_front() {
result.push(current_node.clone());
if depth >= max_depth {
continue;
}
for rel_type in relation_types {
let neighbors = self.get_neighbors_cached(¤t_node, rel_type).await?;
for neighbor in neighbors {
if !visited.contains(&neighbor) {
visited.insert(neighbor.clone());
queue.push_back((neighbor, depth + 1));
}
}
}
}
tracing::debug!(
"Traversed from {} with depth {}: found {} nodes (cache hit rate: {:.2}%)",
start_node,
max_depth,
result.len(),
self.get_cache_stats().hit_rate() * 100.0
);
Ok(result)
}
pub async fn find_connected_components_cached(
&self,
relation_types: &[RelationType],
) -> Result<Vec<Vec<String>>> {
use std::collections::{HashSet, VecDeque};
let all_nodes = self.get_all_node_ids().await?;
let mut visited = HashSet::new();
let mut components = Vec::new();
for start_node in &all_nodes {
if visited.contains(start_node) {
continue;
}
let mut component = Vec::new();
let mut queue = VecDeque::new();
queue.push_back(start_node.clone());
visited.insert(start_node.clone());
while let Some(current_node) = queue.pop_front() {
component.push(current_node.clone());
for rel_type in relation_types {
let neighbors = self.get_neighbors_cached(¤t_node, rel_type).await?;
for neighbor in neighbors {
if !visited.contains(&neighbor) {
visited.insert(neighbor.clone());
queue.push_back(neighbor);
}
}
}
}
if !component.is_empty() {
components.push(component);
}
}
tracing::info!(
"Found {} connected components (cache hit rate: {:.2}%)",
components.len(),
self.get_cache_stats().hit_rate() * 100.0
);
Ok(components)
}
async fn get_all_node_ids(&self) -> Result<Vec<String>> {
if !self.table_ops.table_exists("graphrag_nodes").await? {
return Ok(Vec::new());
}
let table = self.db.open_table("graphrag_nodes").execute().await?;
let mut results = table.query().execute().await?;
let mut node_ids = Vec::new();
while let Some(batch) = results.try_next().await? {
if batch.num_rows() == 0 {
continue;
}
let id_array = batch
.column_by_name("id")
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
.ok_or_else(|| anyhow::anyhow!("Missing or invalid id column"))?;
for i in 0..batch.num_rows() {
if let Some(id) = id_array.value(i).to_string().into() {
node_ids.push(id);
}
}
}
Ok(node_ids)
}
pub async fn graphrag_needs_indexing(&self) -> Result<bool> {
if !self
.table_ops
.tables_exist(&["graphrag_nodes", "graphrag_relationships"])
.await?
{
return Ok(true); }
let nodes_table = self.db.open_table("graphrag_nodes").execute().await?;
let relationships_table = self
.db
.open_table("graphrag_relationships")
.execute()
.await?;
let nodes_count = nodes_table.count_rows(None).await?;
let relationships_count = relationships_table.count_rows(None).await?;
if nodes_count == 0 && relationships_count == 0 {
return Ok(true); }
Ok(false) }
pub async fn get_all_code_blocks_for_graphrag(&self) -> Result<Vec<CodeBlock>> {
let mut all_blocks = Vec::new();
if !self.table_ops.table_exists("code_blocks").await? {
return Ok(all_blocks);
}
let table = self.db.open_table("code_blocks").execute().await?;
let mut results = table.query().execute().await?;
while let Some(batch) = results.try_next().await? {
if batch.num_rows() > 0 {
let converter =
crate::store::batch_converter::BatchConverter::new(self.code_vector_dim);
let mut code_blocks = converter.batch_to_code_blocks(&batch, None)?;
all_blocks.append(&mut code_blocks);
if cfg!(debug_assertions) && all_blocks.len() % 1000 == 0 {
tracing::debug!(
"Loaded {} code blocks for GraphRAG processing...",
all_blocks.len()
);
}
}
}
Ok(all_blocks)
}
pub async fn store_graph_nodes(&self, node_batch: RecordBatch) -> Result<()> {
self.table_ops
.store_batch("graphrag_nodes", node_batch)
.await?;
if let Ok(table) = self.db.open_table("graphrag_nodes").execute().await {
let row_count = table.count_rows(None).await?;
let indices = table.list_indices().await?;
let has_index = indices.iter().any(|idx| idx.columns == vec!["embedding"]);
if !has_index {
if let Err(e) = self
.table_ops
.create_vector_index_optimized(
"graphrag_nodes",
"embedding",
self.code_vector_dim,
)
.await
{
tracing::warn!(
"Failed to create optimized vector index on graph_nodes: {}",
e
);
}
} else {
if super::vector_optimizer::VectorOptimizer::should_optimize_for_growth(
row_count,
self.code_vector_dim,
true,
) {
tracing::info!("Dataset growth detected, optimizing graphrag_nodes index");
if let Err(e) = self
.table_ops
.recreate_vector_index_optimized(
"graphrag_nodes",
"embedding",
self.code_vector_dim,
)
.await
{
tracing::warn!(
"Failed to recreate optimized vector index on graphrag_nodes: {}",
e
);
}
}
}
}
Ok(())
}
pub async fn store_graph_relationships(&self, rel_batch: RecordBatch) -> Result<()> {
self.table_ops
.store_batch("graphrag_relationships", rel_batch.clone())
.await?;
self.update_cache_from_batch(&rel_batch)?;
Ok(())
}
fn update_cache_from_batch(&self, rel_batch: &RecordBatch) -> Result<()> {
if rel_batch.num_rows() == 0 {
return Ok(());
}
let source_ids = rel_batch
.column_by_name("source")
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
.ok_or_else(|| anyhow::anyhow!("Missing or invalid source column"))?;
let target_ids = rel_batch
.column_by_name("target")
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
.ok_or_else(|| anyhow::anyhow!("Missing or invalid target column"))?;
let relation_types = rel_batch
.column_by_name("relation_type")
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
.ok_or_else(|| anyhow::anyhow!("Missing or invalid relation_type column"))?;
for i in 0..rel_batch.num_rows() {
if let (Some(source), Some(target), Some(rel_type_str)) = (
source_ids.value(i).to_string().into(),
target_ids.value(i).to_string().into(),
relation_types.value(i).to_string().into(),
) {
if let Ok(rel_type) = RelationType::from_str(&rel_type_str) {
self.update_cache(&source, rel_type, target);
}
}
}
Ok(())
}
pub async fn clear_graph_nodes(&self) -> Result<()> {
self.table_ops.clear_table("graphrag_nodes").await?;
self.clear_cache();
Ok(())
}
pub async fn clear_graph_relationships(&self) -> Result<()> {
self.table_ops.clear_table("graphrag_relationships").await?;
self.clear_cache();
Ok(())
}
pub async fn remove_graph_nodes_by_path(&self, file_path: &str) -> Result<usize> {
let node_ids = self.get_node_ids_for_file_path(file_path).await?;
let relationships_removed = self.remove_graph_relationships_by_path(file_path).await?;
let nodes_removed = self
.table_ops
.remove_blocks_by_path(file_path, "graphrag_nodes")
.await?;
for node_id in node_ids {
self.invalidate_cache_for_node(&node_id);
}
if nodes_removed > 0 || relationships_removed > 0 {
tracing::info!(
file = %file_path,
nodes_removed = nodes_removed,
relationships_removed = relationships_removed,
"Cleaned up GraphRAG data for file"
);
}
Ok(nodes_removed)
}
pub async fn remove_graph_relationships_by_path(&self, file_path: &str) -> Result<usize> {
if !self
.table_ops
.table_exists("graphrag_relationships")
.await?
{
return Ok(0);
}
let node_ids = self.get_node_ids_for_file_path(file_path).await?;
if node_ids.is_empty() {
return Ok(0); }
let table = self
.db
.open_table("graphrag_relationships")
.execute()
.await?;
let before_count = table.count_rows(None).await?;
let node_filters: Vec<String> = node_ids
.iter()
.flat_map(|node_id| {
vec![
format!("source = '{}'", node_id),
format!("target = '{}'", node_id),
]
})
.collect();
if !node_filters.is_empty() {
let filter = node_filters.join(" OR ");
table.delete(&filter).await.map_err(|e| {
anyhow::anyhow!("Failed to delete from graphrag_relationships: {}", e)
})?;
}
for node_id in &node_ids {
self.invalidate_cache_for_node(node_id);
}
let after_count = table.count_rows(None).await?;
let deleted_count = before_count.saturating_sub(after_count);
Ok(deleted_count)
}
async fn get_node_ids_for_file_path(&self, file_path: &str) -> Result<Vec<String>> {
let mut node_ids = Vec::new();
if !self.table_ops.table_exists("graphrag_nodes").await? {
return Ok(node_ids);
}
let table = self.db.open_table("graphrag_nodes").execute().await?;
let mut results = table
.query()
.only_if(format!("path = '{}'", file_path))
.select(lancedb::query::Select::Columns(vec!["id".to_string()]))
.execute()
.await?;
while let Some(batch) = results.try_next().await? {
if batch.num_rows() > 0 {
if let Some(column) = batch.column_by_name("id") {
if let Some(id_array) =
column.as_any().downcast_ref::<arrow::array::StringArray>()
{
for i in 0..id_array.len() {
node_ids.push(id_array.value(i).to_string());
}
}
}
}
}
Ok(node_ids)
}
pub async fn search_graph_nodes(&self, embedding: &[f32], limit: usize) -> Result<RecordBatch> {
if embedding.len() != self.code_vector_dim {
return Err(anyhow::anyhow!(
"Embedding dimension {} doesn't match expected {}",
embedding.len(),
self.code_vector_dim
));
}
if !self.table_ops.table_exists("graphrag_nodes").await? {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("name", DataType::Utf8, false),
Field::new("kind", DataType::Utf8, false),
Field::new("path", DataType::Utf8, false),
Field::new("description", DataType::Utf8, false),
Field::new("symbols", DataType::Utf8, true),
Field::new("imports", DataType::Utf8, true),
Field::new("exports", DataType::Utf8, true),
Field::new("functions", DataType::Utf8, true),
Field::new("size_lines", DataType::UInt32, false),
Field::new("language", DataType::Utf8, false),
Field::new("hash", DataType::Utf8, false),
]));
return Ok(RecordBatch::new_empty(schema));
}
let table = self.db.open_table("graphrag_nodes").execute().await?;
let query = table
.vector_search(embedding)?
.distance_type(DistanceType::Cosine)
.limit(limit);
let optimized_query = crate::store::vector_optimizer::VectorOptimizer::optimize_query(
query,
&table,
"graphrag_nodes",
)
.await
.map_err(|e| anyhow::anyhow!("Failed to optimize query: {}", e))?;
let mut results = optimized_query.execute().await?;
let mut all_batches = Vec::new();
while let Some(batch) = results.try_next().await? {
if batch.num_rows() > 0 {
all_batches.push(batch);
}
}
if all_batches.is_empty() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("file_path", DataType::Utf8, false),
Field::new("node_type", DataType::Utf8, false),
Field::new("name", DataType::Utf8, false),
Field::new("content", DataType::Utf8, false),
Field::new("description", DataType::Utf8, true),
]));
Ok(RecordBatch::new_empty(schema))
} else if all_batches.len() == 1 {
Ok(all_batches.into_iter().next().unwrap())
} else {
let schema = all_batches[0].schema();
let mut columns = Vec::new();
for i in 0..schema.fields().len() {
let _field = schema.field(i);
let mut column_data = Vec::new();
for batch in &all_batches {
if let Some(column) = batch.column(i).as_any().downcast_ref::<StringArray>() {
for value in column.iter() {
column_data.push(value);
}
}
}
columns
.push(Arc::new(StringArray::from(column_data)) as Arc<dyn arrow::array::Array>);
}
Ok(RecordBatch::try_new(schema, columns)?)
}
}
pub async fn get_graph_relationships(&self) -> Result<RecordBatch> {
if !self
.table_ops
.table_exists("graphrag_relationships")
.await?
{
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("source", DataType::Utf8, false),
Field::new("target", DataType::Utf8, false),
Field::new("relation_type", DataType::Utf8, false),
Field::new("description", DataType::Utf8, false),
Field::new("confidence", DataType::Float32, false),
Field::new("weight", DataType::Float32, false),
]));
return Ok(RecordBatch::new_empty(schema));
}
let table = self
.db
.open_table("graphrag_relationships")
.execute()
.await?;
let mut results = table.query().execute().await?;
let mut all_batches = Vec::new();
while let Some(batch) = results.try_next().await? {
if batch.num_rows() > 0 {
all_batches.push(batch);
}
}
if all_batches.is_empty() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("source", DataType::Utf8, false),
Field::new("target", DataType::Utf8, false),
Field::new("relation_type", DataType::Utf8, false),
Field::new("description", DataType::Utf8, false),
Field::new("confidence", DataType::Float32, false),
Field::new("weight", DataType::Float32, false),
]));
Ok(RecordBatch::new_empty(schema))
} else if all_batches.len() == 1 {
Ok(all_batches.into_iter().next().unwrap())
} else {
Ok(all_batches.into_iter().next().unwrap())
}
}
pub async fn get_node_relationships(
&self,
node_id: &str,
direction: crate::indexer::graphrag::types::RelationshipDirection,
) -> Result<Vec<crate::indexer::graphrag::types::CodeRelationship>> {
use crate::indexer::graphrag::types::RelationshipDirection;
if !self
.table_ops
.table_exists("graphrag_relationships")
.await?
{
return Ok(Vec::new());
}
let table = self
.db
.open_table("graphrag_relationships")
.execute()
.await?;
let filter = match direction {
RelationshipDirection::Outgoing => format!("source = '{}'", node_id),
RelationshipDirection::Incoming => format!("target = '{}'", node_id),
RelationshipDirection::Both => {
format!("source = '{}' OR target = '{}'", node_id, node_id)
}
};
let mut results = table.query().only_if(&filter).execute().await?;
let mut relationships = Vec::new();
while let Some(batch) = results.try_next().await? {
if batch.num_rows() == 0 {
continue;
}
let source_array = batch
.column_by_name("source")
.ok_or_else(|| anyhow::anyhow!("Missing source column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid source column type"))?;
let target_array = batch
.column_by_name("target")
.ok_or_else(|| anyhow::anyhow!("Missing target column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid target column type"))?;
let type_array = batch
.column_by_name("relation_type")
.ok_or_else(|| anyhow::anyhow!("Missing relation_type column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid relation_type column type"))?;
let desc_array = batch
.column_by_name("description")
.ok_or_else(|| anyhow::anyhow!("Missing description column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid description column type"))?;
let conf_array = batch
.column_by_name("confidence")
.ok_or_else(|| anyhow::anyhow!("Missing confidence column"))?
.as_any()
.downcast_ref::<arrow::array::Float32Array>()
.ok_or_else(|| anyhow::anyhow!("Invalid confidence column type"))?;
let weight_array = batch
.column_by_name("weight")
.ok_or_else(|| anyhow::anyhow!("Missing weight column"))?
.as_any()
.downcast_ref::<arrow::array::Float32Array>()
.ok_or_else(|| anyhow::anyhow!("Invalid weight column type"))?;
for i in 0..batch.num_rows() {
let relationship = crate::indexer::graphrag::types::CodeRelationship {
source: source_array.value(i).to_string(),
target: target_array.value(i).to_string(),
relation_type: type_array
.value(i)
.parse()
.unwrap_or(crate::indexer::graphrag::types::RelationType::Imports),
description: desc_array.value(i).to_string(),
confidence: conf_array.value(i),
weight: weight_array.value(i),
};
relationships.push(relationship);
}
}
Ok(relationships)
}
pub async fn get_relationships_by_type(
&self,
relation_type: &crate::indexer::graphrag::types::RelationType,
) -> Result<Vec<crate::indexer::graphrag::types::CodeRelationship>> {
if !self
.table_ops
.table_exists("graphrag_relationships")
.await?
{
return Ok(Vec::new());
}
let table = self
.db
.open_table("graphrag_relationships")
.execute()
.await?;
let filter = format!("relation_type = '{}'", relation_type.as_str());
let mut results = table.query().only_if(&filter).execute().await?;
let mut relationships = Vec::new();
while let Some(batch) = results.try_next().await? {
if batch.num_rows() == 0 {
continue;
}
let source_array = batch
.column_by_name("source")
.ok_or_else(|| anyhow::anyhow!("Missing source column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid source column type"))?;
let target_array = batch
.column_by_name("target")
.ok_or_else(|| anyhow::anyhow!("Missing target column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid target column type"))?;
let type_array = batch
.column_by_name("relation_type")
.ok_or_else(|| anyhow::anyhow!("Missing relation_type column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid relation_type column type"))?;
let desc_array = batch
.column_by_name("description")
.ok_or_else(|| anyhow::anyhow!("Missing description column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid description column type"))?;
let conf_array = batch
.column_by_name("confidence")
.ok_or_else(|| anyhow::anyhow!("Missing confidence column"))?
.as_any()
.downcast_ref::<arrow::array::Float32Array>()
.ok_or_else(|| anyhow::anyhow!("Invalid confidence column type"))?;
let weight_array = batch
.column_by_name("weight")
.ok_or_else(|| anyhow::anyhow!("Missing weight column"))?
.as_any()
.downcast_ref::<arrow::array::Float32Array>()
.ok_or_else(|| anyhow::anyhow!("Invalid weight column type"))?;
for i in 0..batch.num_rows() {
let relationship = crate::indexer::graphrag::types::CodeRelationship {
source: source_array.value(i).to_string(),
target: target_array.value(i).to_string(),
relation_type: type_array
.value(i)
.parse()
.unwrap_or(crate::indexer::graphrag::types::RelationType::Imports),
description: desc_array.value(i).to_string(),
confidence: conf_array.value(i),
weight: weight_array.value(i),
};
relationships.push(relationship);
}
}
Ok(relationships)
}
pub async fn get_all_nodes_paginated(
&self,
offset: usize,
limit: usize,
) -> Result<Vec<crate::indexer::graphrag::types::CodeNode>> {
if !self.table_ops.table_exists("graphrag_nodes").await? {
return Ok(Vec::new());
}
let table = self.db.open_table("graphrag_nodes").execute().await?;
let mut results = table.query().limit(limit).offset(offset).execute().await?;
let mut nodes = Vec::new();
while let Some(batch) = results.try_next().await? {
if batch.num_rows() == 0 {
continue;
}
let id_array = batch
.column_by_name("id")
.ok_or_else(|| anyhow::anyhow!("Missing id column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid id column type"))?;
let name_array = batch
.column_by_name("name")
.ok_or_else(|| anyhow::anyhow!("Missing name column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid name column type"))?;
let kind_array = batch
.column_by_name("kind")
.ok_or_else(|| anyhow::anyhow!("Missing kind column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid kind column type"))?;
let path_array = batch
.column_by_name("path")
.ok_or_else(|| anyhow::anyhow!("Missing path column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid path column type"))?;
let description_array = batch
.column_by_name("description")
.ok_or_else(|| anyhow::anyhow!("Missing description column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid description column type"))?;
let symbols_array = batch
.column_by_name("symbols")
.ok_or_else(|| anyhow::anyhow!("Missing symbols column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid symbols column type"))?;
let hash_array = batch
.column_by_name("hash")
.ok_or_else(|| anyhow::anyhow!("Missing hash column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid hash column type"))?;
let imports_array = batch
.column_by_name("imports")
.ok_or_else(|| anyhow::anyhow!("Missing imports column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid imports column type"))?;
let exports_array = batch
.column_by_name("exports")
.ok_or_else(|| anyhow::anyhow!("Missing exports column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid exports column type"))?;
let size_lines_array = batch
.column_by_name("size_lines")
.ok_or_else(|| anyhow::anyhow!("Missing size_lines column"))?
.as_any()
.downcast_ref::<arrow::array::UInt32Array>()
.ok_or_else(|| anyhow::anyhow!("Invalid size_lines column type"))?;
let language_array = batch
.column_by_name("language")
.ok_or_else(|| anyhow::anyhow!("Missing language column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid language column type"))?;
let embedding_array = batch
.column_by_name("embedding")
.ok_or_else(|| anyhow::anyhow!("Missing embedding column"))?
.as_any()
.downcast_ref::<arrow::array::FixedSizeListArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid embedding column type"))?;
let embedding_values = embedding_array
.values()
.as_any()
.downcast_ref::<arrow::array::Float32Array>()
.ok_or_else(|| anyhow::anyhow!("Invalid embedding values type"))?;
for i in 0..batch.num_rows() {
let symbols: Vec<String> = if symbols_array.is_null(i) {
Vec::new()
} else {
serde_json::from_str(symbols_array.value(i)).unwrap_or_default()
};
let imports: Vec<String> = if imports_array.is_null(i) {
Vec::new()
} else {
serde_json::from_str(imports_array.value(i)).unwrap_or_default()
};
let exports: Vec<String> = if exports_array.is_null(i) {
Vec::new()
} else {
serde_json::from_str(exports_array.value(i)).unwrap_or_default()
};
let embedding_start = i * self.code_vector_dim;
let embedding_end = embedding_start + self.code_vector_dim;
let embedding: Vec<f32> = (embedding_start..embedding_end)
.map(|idx| embedding_values.value(idx))
.collect();
let node = crate::indexer::graphrag::types::CodeNode {
id: id_array.value(i).to_string(),
name: name_array.value(i).to_string(),
kind: kind_array.value(i).to_string(),
path: path_array.value(i).to_string(),
description: description_array.value(i).to_string(),
symbols,
hash: hash_array.value(i).to_string(),
embedding,
imports,
exports,
functions: Vec::new(), size_lines: size_lines_array.value(i),
language: language_array.value(i).to_string(),
};
nodes.push(node);
}
}
Ok(nodes)
}
pub async fn get_all_relationships_efficient(
&self,
) -> Result<Vec<crate::indexer::graphrag::types::CodeRelationship>> {
if !self
.table_ops
.table_exists("graphrag_relationships")
.await?
{
return Ok(Vec::new());
}
let table = self
.db
.open_table("graphrag_relationships")
.execute()
.await?;
let mut results = table.query().execute().await?;
let mut relationships = Vec::new();
while let Some(batch) = results.try_next().await? {
if batch.num_rows() == 0 {
continue;
}
let source_array = batch
.column_by_name("source")
.ok_or_else(|| anyhow::anyhow!("Missing source column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid source column type"))?;
let target_array = batch
.column_by_name("target")
.ok_or_else(|| anyhow::anyhow!("Missing target column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid target column type"))?;
let type_array = batch
.column_by_name("relation_type")
.ok_or_else(|| anyhow::anyhow!("Missing relation_type column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid relation_type column type"))?;
let desc_array = batch
.column_by_name("description")
.ok_or_else(|| anyhow::anyhow!("Missing description column"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Invalid description column type"))?;
let conf_array = batch
.column_by_name("confidence")
.ok_or_else(|| anyhow::anyhow!("Missing confidence column"))?
.as_any()
.downcast_ref::<arrow::array::Float32Array>()
.ok_or_else(|| anyhow::anyhow!("Invalid confidence column type"))?;
let weight_array = batch
.column_by_name("weight")
.ok_or_else(|| anyhow::anyhow!("Missing weight column"))?
.as_any()
.downcast_ref::<arrow::array::Float32Array>()
.ok_or_else(|| anyhow::anyhow!("Invalid weight column type"))?;
for i in 0..batch.num_rows() {
let relationship = crate::indexer::graphrag::types::CodeRelationship {
source: source_array.value(i).to_string(),
target: target_array.value(i).to_string(),
relation_type: type_array
.value(i)
.parse()
.unwrap_or(crate::indexer::graphrag::types::RelationType::Imports),
description: desc_array.value(i).to_string(),
confidence: conf_array.value(i),
weight: weight_array.value(i),
};
relationships.push(relationship);
}
}
Ok(relationships)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::indexer::graphrag::types::RelationType;
#[test]
fn test_relationship_direction_filter_generation() {
let node_id = "src/main.rs";
let outgoing_filter = format!("source = '{}'", node_id);
assert_eq!(outgoing_filter, "source = 'src/main.rs'");
let incoming_filter = format!("target = '{}'", node_id);
assert_eq!(incoming_filter, "target = 'src/main.rs'");
let both_filter = format!("source = '{}' OR target = '{}'", node_id, node_id);
assert_eq!(
both_filter,
"source = 'src/main.rs' OR target = 'src/main.rs'"
);
}
#[test]
fn test_relationship_type_filter_generation() {
let rel_type = RelationType::Implements;
let filter = format!("relation_type = '{}'", rel_type.as_str());
assert_eq!(filter, "relation_type = 'implements'");
let rel_type = RelationType::Imports;
let filter = format!("relation_type = '{}'", rel_type.as_str());
assert_eq!(filter, "relation_type = 'imports'");
}
#[test]
fn test_code_relationship_parsing() {
let rel_type_str = "implements";
let parsed: RelationType = rel_type_str.parse().unwrap();
assert_eq!(parsed, RelationType::Implements);
let unknown_str = "unknown_type";
let parsed: RelationType = unknown_str.parse().unwrap();
assert_eq!(parsed, RelationType::Imports);
}
#[test]
fn test_pagination_parameters() {
let _offset = 0;
let limit = 100;
assert!(limit > 0, "Limit should be positive");
let _offset = 10000;
let limit = 1000;
assert!(limit <= 10000, "Limit should be reasonable");
}
#[test]
fn test_cache_stats_default() {
let stats = CacheStats::default();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.hit_rate(), 0.0);
}
#[test]
fn test_cache_stats_hit_rate() {
let stats = CacheStats { hits: 7, misses: 3 };
assert_eq!(stats.hit_rate(), 0.7);
let stats = CacheStats {
hits: 0,
misses: 10,
};
assert_eq!(stats.hit_rate(), 0.0);
let stats = CacheStats {
hits: 10,
misses: 0,
};
assert_eq!(stats.hit_rate(), 1.0);
}
#[test]
fn test_cache_update_and_get() {
use std::sync::{Arc, RwLock};
let cache: Arc<RwLock<AdjacencyCache>> = Arc::new(RwLock::new(HashMap::new()));
let _stats: Arc<RwLock<CacheStats>> = Arc::new(RwLock::new(CacheStats::default()));
{
let mut cache_lock = cache.write().unwrap();
cache_lock
.entry("node1".to_string())
.or_default()
.entry(RelationType::Imports)
.or_default()
.push("node2".to_string());
}
{
let cache_lock = cache.read().unwrap();
let neighbors = cache_lock
.get("node1")
.and_then(|rels| rels.get(&RelationType::Imports))
.cloned();
assert_eq!(neighbors, Some(vec!["node2".to_string()]));
}
{
let cache_lock = cache.read().unwrap();
let neighbors = cache_lock
.get("nonexistent")
.and_then(|rels| rels.get(&RelationType::Imports))
.cloned();
assert_eq!(neighbors, None);
}
}
#[test]
fn test_cache_invalidation() {
use std::sync::{Arc, RwLock};
let cache: Arc<RwLock<AdjacencyCache>> = Arc::new(RwLock::new(HashMap::new()));
{
let mut cache_lock = cache.write().unwrap();
cache_lock
.entry("node1".to_string())
.or_default()
.entry(RelationType::Imports)
.or_default()
.push("node2".to_string());
cache_lock
.entry("node1".to_string())
.or_default()
.entry(RelationType::Extends)
.or_default()
.push("node3".to_string());
}
{
let cache_lock = cache.read().unwrap();
assert!(cache_lock.contains_key("node1"));
}
{
let mut cache_lock = cache.write().unwrap();
cache_lock.remove("node1");
}
{
let cache_lock = cache.read().unwrap();
assert!(!cache_lock.contains_key("node1"));
}
}
#[test]
fn test_cache_clear() {
use std::sync::{Arc, RwLock};
let cache: Arc<RwLock<AdjacencyCache>> = Arc::new(RwLock::new(HashMap::new()));
let stats: Arc<RwLock<CacheStats>> = Arc::new(RwLock::new(CacheStats::default()));
{
let mut cache_lock = cache.write().unwrap();
for i in 0..5 {
cache_lock
.entry(format!("node{}", i))
.or_default()
.entry(RelationType::Imports)
.or_default()
.push(format!("target{}", i));
}
}
{
let mut stats_lock = stats.write().unwrap();
stats_lock.hits = 10;
stats_lock.misses = 5;
}
{
let cache_lock = cache.read().unwrap();
assert_eq!(cache_lock.len(), 5);
}
{
let mut cache_lock = cache.write().unwrap();
cache_lock.clear();
}
{
let mut stats_lock = stats.write().unwrap();
*stats_lock = CacheStats::default();
}
{
let cache_lock = cache.read().unwrap();
assert_eq!(cache_lock.len(), 0);
}
{
let stats_lock = stats.read().unwrap();
assert_eq!(stats_lock.hits, 0);
assert_eq!(stats_lock.misses, 0);
}
}
#[test]
fn test_cache_multiple_relation_types() {
use std::sync::{Arc, RwLock};
let cache: Arc<RwLock<AdjacencyCache>> = Arc::new(RwLock::new(HashMap::new()));
{
let mut cache_lock = cache.write().unwrap();
let node_entry = cache_lock.entry("node1".to_string()).or_default();
node_entry
.entry(RelationType::Imports)
.or_default()
.push("dep1".to_string());
node_entry
.entry(RelationType::Extends)
.or_default()
.push("base1".to_string());
node_entry
.entry(RelationType::Implements)
.or_default()
.push("interface1".to_string());
}
{
let cache_lock = cache.read().unwrap();
let node_rels = cache_lock.get("node1").unwrap();
assert_eq!(node_rels.len(), 3);
assert!(node_rels.contains_key(&RelationType::Imports));
assert!(node_rels.contains_key(&RelationType::Extends));
assert!(node_rels.contains_key(&RelationType::Implements));
}
}
#[test]
fn test_cache_multiple_targets_per_relation() {
use std::sync::{Arc, RwLock};
let cache: Arc<RwLock<AdjacencyCache>> = Arc::new(RwLock::new(HashMap::new()));
{
let mut cache_lock = cache.write().unwrap();
let targets = cache_lock
.entry("node1".to_string())
.or_default()
.entry(RelationType::Imports)
.or_default();
targets.push("dep1".to_string());
targets.push("dep2".to_string());
targets.push("dep3".to_string());
}
{
let cache_lock = cache.read().unwrap();
let targets = cache_lock
.get("node1")
.and_then(|rels| rels.get(&RelationType::Imports))
.unwrap();
assert_eq!(targets.len(), 3);
assert!(targets.contains(&"dep1".to_string()));
assert!(targets.contains(&"dep2".to_string()));
assert!(targets.contains(&"dep3".to_string()));
}
}
#[test]
fn test_cache_concurrent_access() {
use std::sync::{Arc, RwLock};
use std::thread;
let cache: Arc<RwLock<AdjacencyCache>> = Arc::new(RwLock::new(HashMap::new()));
let mut handles = vec![];
for i in 0..10 {
let cache_clone = Arc::clone(&cache);
let handle = thread::spawn(move || {
let mut cache_lock = cache_clone.write().unwrap();
cache_lock
.entry(format!("node{}", i))
.or_default()
.entry(RelationType::Imports)
.or_default()
.push(format!("target{}", i));
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
{
let cache_lock = cache.read().unwrap();
assert_eq!(cache_lock.len(), 10);
}
}
#[test]
fn test_relation_type_parsing_for_cache() {
let test_cases = vec![
("implements", RelationType::Implements),
("extends", RelationType::Extends),
("imports", RelationType::Imports),
("calls", RelationType::Calls),
("uses", RelationType::Uses),
];
for (str_val, expected) in test_cases {
let parsed: RelationType = str_val.parse().unwrap();
assert_eq!(parsed, expected);
assert_eq!(parsed.as_str(), str_val);
}
}
}