use crate::persistence::error::{PersistenceError, PersistenceResult};
use crate::RetrieveError;
use std::fs::File;
use std::io::{BufWriter, Read, Seek, SeekFrom, Write};
use std::path::Path;
const GRAPH_MAGIC: &[u8; 8] = b"DANN\x00\x00\x00\x01";
pub struct DiskGraphWriter {
writer: BufWriter<File>,
num_nodes: usize,
max_degree: usize,
start_node: u32,
}
impl DiskGraphWriter {
pub fn new(
path: &Path,
num_nodes: usize,
max_degree: usize,
start_node: u32,
) -> PersistenceResult<Self> {
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
writer.write_all(GRAPH_MAGIC)?;
writer.write_all(&(num_nodes as u64).to_le_bytes())?;
writer.write_all(&(max_degree as u64).to_le_bytes())?;
writer.write_all(&(start_node as u64).to_le_bytes())?;
writer.write_all(&[0u8; 32])?;
Ok(Self {
writer,
num_nodes,
max_degree,
start_node,
})
}
pub fn write_adjacency(&mut self, neighbors: &[u32]) -> PersistenceResult<()> {
if neighbors.len() > self.max_degree {
return Err(PersistenceError::Serialization(format!(
"Node degree {} exceeds max_degree {}",
neighbors.len(),
self.max_degree
)));
}
self.writer
.write_all(&(neighbors.len() as u32).to_le_bytes())?;
for &neighbor in neighbors {
self.writer.write_all(&neighbor.to_le_bytes())?;
}
let padding_len = (self.max_degree - neighbors.len()) * 4;
for _ in 0..padding_len {
self.writer.write_all(&[0u8])?;
}
Ok(())
}
pub fn flush(&mut self) -> PersistenceResult<()> {
self.writer.flush()?;
Ok(())
}
}
pub struct DiskGraphReader {
file: File,
pub num_nodes: usize,
pub max_degree: usize,
pub start_node: u32,
header_size: u64,
record_size: u64,
}
impl DiskGraphReader {
pub fn open(path: &Path) -> PersistenceResult<Self> {
let mut file = File::open(path)?;
let mut magic = [0u8; 8];
file.read_exact(&mut magic)?;
if &magic != GRAPH_MAGIC {
return Err(PersistenceError::Format(
"Invalid DiskANN graph file".to_string(),
));
}
let mut buf_u64 = [0u8; 8];
file.read_exact(&mut buf_u64)?;
let num_nodes = u64::from_le_bytes(buf_u64) as usize;
file.read_exact(&mut buf_u64)?;
let max_degree = u64::from_le_bytes(buf_u64) as usize;
file.read_exact(&mut buf_u64)?;
let start_node = u64::from_le_bytes(buf_u64) as u32;
file.seek(SeekFrom::Current(32))?;
let header_size = 8 + 8 + 8 + 8 + 32;
let record_size = 4 + (max_degree as u64 * 4);
Ok(Self {
file,
num_nodes,
max_degree,
start_node,
header_size,
record_size,
})
}
pub fn get_neighbors(&mut self, node_id: u32) -> Result<Vec<u32>, RetrieveError> {
if node_id as usize >= self.num_nodes {
return Err(RetrieveError::OutOfBounds(node_id as usize));
}
let offset = self.header_size + (node_id as u64 * self.record_size);
self.file
.seek(SeekFrom::Start(offset))
.map_err(|e| RetrieveError::Io(e.to_string()))?;
let mut degree_buf = [0u8; 4];
self.file
.read_exact(&mut degree_buf)
.map_err(|e| RetrieveError::Io(e.to_string()))?;
let degree = u32::from_le_bytes(degree_buf) as usize;
if degree > self.max_degree {
return Err(RetrieveError::Other(
"Invalid node degree in graph file".to_string(),
));
}
let mut neighbors = Vec::with_capacity(degree);
let mut neighbor_buf = [0u8; 4];
for _ in 0..degree {
self.file
.read_exact(&mut neighbor_buf)
.map_err(|e| RetrieveError::Io(e.to_string()))?;
neighbors.push(u32::from_le_bytes(neighbor_buf));
}
Ok(neighbors)
}
}