use crate::indexer::graphrag::types::{CodeGraph, CodeNode, CodeRelationship};
use crate::indexer::graphrag::utils::cosine_similarity;
use crate::store::Store;
use anyhow::Result;
use arrow::array::Array;
use arrow::datatypes::{DataType, Field, Schema};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
pub struct DatabaseOperations<'a> {
store: &'a Store,
}
impl<'a> DatabaseOperations<'a> {
pub fn new(store: &'a Store) -> Self {
Self { store }
}
pub async fn load_graph(&self, _project_root: &Path, quiet: bool) -> Result<CodeGraph> {
let mut graph = CodeGraph::default();
if !self
.store
.tables_exist(&["graphrag_nodes", "graphrag_relationships"])
.await?
{
return Ok(graph); }
let vector_dim = self.store.get_code_vector_dim();
let node_batch = self
.store
.search_graph_nodes(&vec![0.0; vector_dim], 10000)
.await?;
if node_batch.num_rows() == 0 {
return Ok(graph); }
if !quiet {
println!(
"Loading {} GraphRAG nodes from database...",
node_batch.num_rows()
);
}
let id_array = node_batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let name_array = node_batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let kind_array = node_batch
.column_by_name("kind")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let path_array = node_batch
.column_by_name("path")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let description_array = node_batch
.column_by_name("description")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let symbols_array = node_batch
.column_by_name("symbols")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let imports_array = node_batch
.column_by_name("imports")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let exports_array = node_batch
.column_by_name("exports")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let hash_array = node_batch
.column_by_name("hash")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let embedding_array = node_batch
.column_by_name("embedding")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::FixedSizeListArray>()
.unwrap();
let embedding_values = embedding_array
.values()
.as_any()
.downcast_ref::<arrow::array::Float32Array>()
.unwrap();
for i in 0..node_batch.num_rows() {
let id = id_array.value(i).to_string();
let name = name_array.value(i).to_string();
let kind = kind_array.value(i).to_string();
let path = path_array.value(i).to_string();
let description = description_array.value(i).to_string();
let symbols: Vec<String> = if symbols_array.is_null(i) {
Vec::new()
} else {
serde_json::from_str(symbols_array.value(i)).unwrap_or_default()
};
let imports: Vec<String> = if imports_array.is_null(i) {
Vec::new()
} else {
serde_json::from_str(imports_array.value(i)).unwrap_or_default()
};
let exports: Vec<String> = if exports_array.is_null(i) {
Vec::new()
} else {
serde_json::from_str(exports_array.value(i)).unwrap_or_default()
};
let hash = hash_array.value(i).to_string();
let embedding_offset = i * embedding_array.value_length() as usize;
let embedding_len = embedding_array.value_length() as usize;
let mut embedding = Vec::with_capacity(embedding_len);
for j in 0..embedding_len {
embedding.push(embedding_values.value(embedding_offset + j));
}
let node = CodeNode {
id,
name,
kind,
path,
description,
symbols,
imports,
exports,
functions: Vec::new(), size_lines: 0, language: "unknown".to_string(), hash,
embedding,
};
graph.nodes.insert(node.id.clone(), node);
}
let rel_batch = self.store.get_graph_relationships().await?;
if rel_batch.num_rows() > 0 {
if !quiet {
println!(
"Loading {} GraphRAG relationships from database...",
rel_batch.num_rows()
);
}
let source_array = rel_batch
.column_by_name("source")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let target_array = rel_batch
.column_by_name("target")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let type_array = rel_batch
.column_by_name("relation_type")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let desc_array = rel_batch
.column_by_name("description")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let conf_array = rel_batch
.column_by_name("confidence")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::Float32Array>()
.unwrap();
for i in 0..rel_batch.num_rows() {
let relationship = CodeRelationship {
source: source_array.value(i).to_string(),
target: target_array.value(i).to_string(),
relation_type: type_array
.value(i)
.parse()
.unwrap_or(crate::indexer::graphrag::types::RelationType::Imports),
description: desc_array.value(i).to_string(),
confidence: conf_array.value(i),
weight: 1.0, };
graph.relationships.push(relationship);
}
}
if !graph.nodes.is_empty() && !quiet {
println!(
"Loaded GraphRAG knowledge graph with {} nodes and {} relationships",
graph.nodes.len(),
graph.relationships.len()
);
}
Ok(graph)
}
pub async fn save_graph_incremental(
&self,
new_nodes: &[CodeNode],
new_relationships: &[CodeRelationship],
) -> Result<()> {
if new_nodes.is_empty() && new_relationships.is_empty() {
return Ok(());
}
if !new_nodes.is_empty() {
let nodes_map: HashMap<String, CodeNode> = new_nodes
.iter()
.map(|node| (node.id.clone(), node.clone()))
.collect();
let nodes_batch = self.nodes_to_batch(&nodes_map).await?;
self.store.store_graph_nodes(nodes_batch).await?;
}
if !new_relationships.is_empty() {
let rel_batch = self.relationships_to_batch(new_relationships).await?;
self.store.store_graph_relationships(rel_batch).await?;
}
Ok(())
}
pub async fn search_nodes_in_database(
&self,
query_embedding: &[f32],
query: &str,
) -> Result<Vec<CodeNode>> {
if !self.store.tables_exist(&["graphrag_nodes"]).await? {
return Ok(Vec::new()); }
let node_batch = self.store.search_graph_nodes(query_embedding, 50).await?;
if node_batch.num_rows() == 0 {
return Ok(Vec::new());
}
let mut nodes = Vec::new();
let query_lower = query.to_lowercase();
let id_array = node_batch
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let name_array = node_batch
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let kind_array = node_batch
.column_by_name("kind")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let path_array = node_batch
.column_by_name("path")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let description_array = node_batch
.column_by_name("description")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let symbols_array = node_batch
.column_by_name("symbols")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let hash_array = node_batch
.column_by_name("hash")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap();
let embedding_array = node_batch
.column_by_name("embedding")
.unwrap()
.as_any()
.downcast_ref::<arrow::array::FixedSizeListArray>()
.unwrap();
let embedding_values = embedding_array
.values()
.as_any()
.downcast_ref::<arrow::array::Float32Array>()
.unwrap();
for i in 0..node_batch.num_rows() {
let id = id_array.value(i).to_string();
let name = name_array.value(i).to_string();
let kind = kind_array.value(i).to_string();
let path = path_array.value(i).to_string();
let description = description_array.value(i).to_string();
let symbols: Vec<String> = if symbols_array.is_null(i) {
Vec::new()
} else {
serde_json::from_str(symbols_array.value(i)).unwrap_or_default()
};
let hash = hash_array.value(i).to_string();
let embedding_offset = i * embedding_array.value_length() as usize;
let embedding_len = embedding_array.value_length() as usize;
let mut embedding = Vec::with_capacity(embedding_len);
for j in 0..embedding_len {
embedding.push(embedding_values.value(embedding_offset + j));
}
let similarity = cosine_similarity(query_embedding, &embedding);
let name_contains = name.to_lowercase().contains(&query_lower);
let kind_contains = kind.to_lowercase().contains(&query_lower);
let desc_contains = description.to_lowercase().contains(&query_lower);
let symbols_contain = symbols
.iter()
.any(|s| s.to_lowercase().contains(&query_lower));
if similarity > 0.5
|| name_contains
|| kind_contains
|| desc_contains
|| symbols_contain
{
let node = CodeNode {
id,
name,
kind,
path,
description,
symbols,
imports: Vec::new(), exports: Vec::new(), functions: Vec::new(), size_lines: 0, language: "unknown".to_string(), hash,
embedding,
};
nodes.push(node);
}
}
nodes.sort_by(|a, b| {
let a_contains = a.name.to_lowercase().contains(&query_lower)
|| a.kind.to_lowercase().contains(&query_lower)
|| a.symbols
.iter()
.any(|s| s.to_lowercase().contains(&query_lower));
let b_contains = b.name.to_lowercase().contains(&query_lower)
|| b.kind.to_lowercase().contains(&query_lower)
|| b.symbols
.iter()
.any(|s| s.to_lowercase().contains(&query_lower));
if a_contains && !b_contains {
std::cmp::Ordering::Less
} else if !a_contains && b_contains {
std::cmp::Ordering::Greater
} else {
let a_sim = cosine_similarity(query_embedding, &a.embedding);
let b_sim = cosine_similarity(query_embedding, &b.embedding);
b_sim
.partial_cmp(&a_sim)
.unwrap_or(std::cmp::Ordering::Equal)
}
});
Ok(nodes)
}
async fn nodes_to_batch(
&self,
nodes: &HashMap<String, CodeNode>,
) -> Result<arrow::record_batch::RecordBatch> {
let vector_dim = self.store.get_code_vector_dim();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("name", DataType::Utf8, false),
Field::new("kind", DataType::Utf8, false),
Field::new("path", DataType::Utf8, false),
Field::new("description", DataType::Utf8, false),
Field::new("symbols", DataType::Utf8, true), Field::new("imports", DataType::Utf8, true), Field::new("exports", DataType::Utf8, true), Field::new("functions", DataType::Utf8, true), Field::new("size_lines", DataType::UInt32, false),
Field::new("language", DataType::Utf8, false),
Field::new("hash", DataType::Utf8, false),
Field::new(
"embedding",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
vector_dim as i32,
),
true,
),
]));
let nodes_vec: Vec<&CodeNode> = nodes.values().collect();
if nodes_vec.is_empty() {
return Err(anyhow::anyhow!("No nodes to convert to batch"));
}
let ids: Vec<&str> = nodes_vec.iter().map(|n| n.id.as_str()).collect();
let names: Vec<&str> = nodes_vec.iter().map(|n| n.name.as_str()).collect();
let kinds: Vec<&str> = nodes_vec.iter().map(|n| n.kind.as_str()).collect();
let paths: Vec<&str> = nodes_vec.iter().map(|n| n.path.as_str()).collect();
let descriptions: Vec<&str> = nodes_vec.iter().map(|n| n.description.as_str()).collect();
let symbols: Vec<String> = nodes_vec
.iter()
.map(|n| serde_json::to_string(&n.symbols).unwrap_or_default())
.collect();
let imports: Vec<String> = nodes_vec
.iter()
.map(|n| serde_json::to_string(&n.imports).unwrap_or_default())
.collect();
let exports: Vec<String> = nodes_vec
.iter()
.map(|n| serde_json::to_string(&n.exports).unwrap_or_default())
.collect();
let functions: Vec<String> = nodes_vec
.iter()
.map(|n| serde_json::to_string(&n.functions).unwrap_or_default())
.collect();
let size_lines: Vec<u32> = nodes_vec.iter().map(|n| n.size_lines).collect();
let languages: Vec<&str> = nodes_vec.iter().map(|n| n.language.as_str()).collect();
let hashes: Vec<&str> = nodes_vec.iter().map(|n| n.hash.as_str()).collect();
let mut flattened_embeddings = Vec::with_capacity(nodes_vec.len() * vector_dim);
for node in &nodes_vec {
if node.embedding.len() != vector_dim {
return Err(anyhow::anyhow!(
"Node embedding has dimension {} but expected {}",
node.embedding.len(),
vector_dim
));
}
flattened_embeddings.extend_from_slice(&node.embedding);
}
let values = arrow::array::Float32Array::from(flattened_embeddings);
let embedding_array = arrow::array::FixedSizeListArray::new(
Arc::new(Field::new("item", DataType::Float32, true)),
vector_dim as i32,
Arc::new(values),
None, );
let batch = arrow::record_batch::RecordBatch::try_new(
schema,
vec![
Arc::new(arrow::array::StringArray::from(ids)),
Arc::new(arrow::array::StringArray::from(names)),
Arc::new(arrow::array::StringArray::from(kinds)),
Arc::new(arrow::array::StringArray::from(paths)),
Arc::new(arrow::array::StringArray::from(descriptions)),
Arc::new(arrow::array::StringArray::from(symbols)),
Arc::new(arrow::array::StringArray::from(imports)),
Arc::new(arrow::array::StringArray::from(exports)),
Arc::new(arrow::array::StringArray::from(functions)),
Arc::new(arrow::array::UInt32Array::from(size_lines)),
Arc::new(arrow::array::StringArray::from(languages)),
Arc::new(arrow::array::StringArray::from(hashes)),
Arc::new(embedding_array),
],
)?;
Ok(batch)
}
async fn relationships_to_batch(
&self,
relationships: &[CodeRelationship],
) -> Result<arrow::record_batch::RecordBatch> {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("source", DataType::Utf8, false),
Field::new("target", DataType::Utf8, false),
Field::new("relation_type", DataType::Utf8, false),
Field::new("description", DataType::Utf8, false),
Field::new("confidence", DataType::Float32, false),
Field::new("weight", DataType::Float32, false),
]));
let ids: Vec<String> = relationships
.iter()
.map(|_| uuid::Uuid::new_v4().to_string())
.collect();
let sources: Vec<&str> = relationships.iter().map(|r| r.source.as_str()).collect();
let targets: Vec<&str> = relationships.iter().map(|r| r.target.as_str()).collect();
let types: Vec<&str> = relationships
.iter()
.map(|r| r.relation_type.as_str())
.collect();
let descriptions: Vec<&str> = relationships
.iter()
.map(|r| r.description.as_str())
.collect();
let confidences: Vec<f32> = relationships.iter().map(|r| r.confidence).collect();
let weights: Vec<f32> = relationships.iter().map(|r| r.weight).collect();
let batch = arrow::record_batch::RecordBatch::try_new(
schema,
vec![
Arc::new(arrow::array::StringArray::from(ids)),
Arc::new(arrow::array::StringArray::from(sources)),
Arc::new(arrow::array::StringArray::from(targets)),
Arc::new(arrow::array::StringArray::from(types)),
Arc::new(arrow::array::StringArray::from(descriptions)),
Arc::new(arrow::array::Float32Array::from(confidences)),
Arc::new(arrow::array::Float32Array::from(weights)),
],
)?;
Ok(batch)
}
}