use std::collections::{HashMap, HashSet, VecDeque};
use crate::sharding::reader::ShardReader;
use crate::sharding::shard::{CsrEdge, CsrShard};
#[derive(Debug, Clone)]
pub struct Subgraph {
pub nodes: HashSet<u32>,
pub edges: Vec<CsrEdge>,
pub max_depth: usize,
}
impl Subgraph {
pub fn new() -> Self {
Self {
nodes: HashSet::new(),
edges: Vec::new(),
max_depth: 0,
}
}
pub fn add_node(&mut self, node_id: u32) {
self.nodes.insert(node_id);
}
pub fn add_edge(&mut self, edge: CsrEdge) {
let src = edge.src;
let dst = edge.dst;
self.edges.push(edge);
self.nodes.insert(src);
self.nodes.insert(dst);
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
}
impl Default for Subgraph {
fn default() -> Self {
Self::new()
}
}
pub struct SubgraphBuilder {
shards: Vec<CsrShard>,
shard_index: HashMap<u32, usize>,
}
impl SubgraphBuilder {
pub fn new(reader: &ShardReader) -> Self {
let mut shard_index = HashMap::new();
for shard_id in 0..reader.shard_count() {
if let Some(shard) = reader.get_shard(shard_id) {
for source_id in shard.source_start..shard.source_end {
shard_index.insert(source_id, shard_id);
}
}
}
let shards: Vec<CsrShard> = (0..reader.shard_count())
.filter_map(|id| reader.get_shard(id).cloned())
.collect();
Self {
shards,
shard_index,
}
}
pub fn build_subgraph(&self, prompt_tokens: &[u32], max_depth: usize) -> Subgraph {
let mut subgraph = Subgraph::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
for &token_id in prompt_tokens {
if !visited.contains(&token_id) {
visited.insert(token_id);
queue.push_back((token_id, 0));
subgraph.add_node(token_id);
}
}
while let Some((node_id, depth)) = queue.pop_front() {
if depth >= max_depth {
continue;
}
if let Some(edges) = self.get_outgoing_edges(node_id) {
for edge in edges {
let dst = edge.dst;
subgraph.add_edge(edge.clone());
if !visited.contains(&dst) {
visited.insert(dst);
queue.push_back((dst, depth + 1));
}
}
}
}
subgraph.max_depth = max_depth;
subgraph
}
fn get_outgoing_edges(&self, source_id: u32) -> Option<Vec<CsrEdge>> {
let shard_id = self.shard_index.get(&source_id)?;
let shard = &self.shards.get(*shard_id)?;
let start_idx = shard.edges.partition_point(|edge| edge.src < source_id);
let end_idx =
shard.edges[start_idx..].partition_point(|edge| edge.src == source_id) + start_idx;
Some(shard.edges[start_idx..end_idx].to_vec())
}
pub fn shard_stats(&self) -> Vec<(usize, usize, usize)> {
self.shards
.iter()
.map(|shard| {
(
shard.shard_id,
shard.edge_count(),
(shard.source_end - shard.source_start) as usize,
)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sharding::shard::CsrShard;
use std::path::PathBuf;
use tempfile::TempDir;
fn create_test_shards(_dir: &PathBuf) -> Vec<CsrShard> {
let mut shards = Vec::new();
let mut shard0 = CsrShard::new(0, 0, 1000);
shard0.add_edge(CsrEdge {
src: 100,
dst: 2000,
weight: 0.5,
flags: 0,
});
shard0.add_edge(CsrEdge {
src: 100,
dst: 3000,
weight: 0.3,
flags: 0,
});
shard0.add_edge(CsrEdge {
src: 200,
dst: 4000,
weight: 0.7,
flags: 0,
});
shard0.sort_edges();
shards.push(shard0);
let mut shard1 = CsrShard::new(1, 1000, 2000);
shard1.add_edge(CsrEdge {
src: 1500,
dst: 5000,
weight: 0.4,
flags: 0,
});
shard1.add_edge(CsrEdge {
src: 1500,
dst: 6000,
weight: 0.6,
flags: 0,
});
shard1.sort_edges();
shards.push(shard1);
let mut shard2 = CsrShard::new(2, 2000, 3000);
shard2.add_edge(CsrEdge {
src: 2500,
dst: 7000,
weight: 0.8,
flags: 0,
});
shard2.add_edge(CsrEdge {
src: 2500,
dst: 8000,
weight: 0.2,
flags: 0,
});
shard2.sort_edges();
shards.push(shard2);
shards
}
#[test]
fn test_subgraph_builder_creation() {
let temp = TempDir::new().unwrap();
let shards = create_test_shards(&temp.path().to_path_buf());
let mut reader_data = vec![];
for shard in &shards {
reader_data.push((
shard.shard_id,
shard.source_start,
shard.source_end,
shard.edge_count(),
"test.csr",
));
}
assert_eq!(shards.len(), 3);
assert_eq!(shards[0].source_start, 0);
assert_eq!(shards[0].source_end, 1000);
}
#[test]
fn test_subgraph_empty_tokens() {
let temp = TempDir::new().unwrap();
let shards = create_test_shards(&temp.path().to_path_buf());
let builder = SubgraphBuilder {
shards: shards.clone(),
shard_index: {
let mut index = HashMap::new();
for shard in &shards {
for source_id in shard.source_start..shard.source_end {
index.insert(source_id, shard.shard_id);
}
}
index
},
};
let subgraph = builder.build_subgraph(&[], 2);
assert_eq!(subgraph.node_count(), 0);
assert_eq!(subgraph.edge_count(), 0);
}
#[test]
fn test_subgraph_single_token_depth_1() {
let temp = TempDir::new().unwrap();
let shards = create_test_shards(&temp.path().to_path_buf());
let builder = SubgraphBuilder {
shards: shards.clone(),
shard_index: {
let mut index = HashMap::new();
for shard in &shards {
for source_id in shard.source_start..shard.source_end {
index.insert(source_id, shard.shard_id);
}
}
index
},
};
let subgraph = builder.build_subgraph(&[100], 1);
assert_eq!(subgraph.node_count(), 3);
assert_eq!(subgraph.edge_count(), 2); }
#[test]
fn test_subgraph_single_token_depth_2() {
let temp = TempDir::new().unwrap();
let shards = create_test_shards(&temp.path().to_path_buf());
let builder = SubgraphBuilder {
shards: shards.clone(),
shard_index: {
let mut index = HashMap::new();
for shard in &shards {
for source_id in shard.source_start..shard.source_end {
index.insert(source_id, shard.shard_id);
}
}
index
},
};
let subgraph = builder.build_subgraph(&[100], 2);
assert_eq!(subgraph.node_count(), 3);
assert_eq!(subgraph.edge_count(), 2);
}
#[test]
fn test_subgraph_multi_token() {
let temp = TempDir::new().unwrap();
let shards = create_test_shards(&temp.path().to_path_buf());
let builder = SubgraphBuilder {
shards: shards.clone(),
shard_index: {
let mut index = HashMap::new();
for shard in &shards {
for source_id in shard.source_start..shard.source_end {
index.insert(source_id, shard.shard_id);
}
}
index
},
};
let subgraph = builder.build_subgraph(&[100, 1500], 1);
assert_eq!(subgraph.node_count(), 6);
assert_eq!(subgraph.edge_count(), 4);
}
}