use std::io::{BufReader, BufWriter, Seek, SeekFrom, Write};
use super::{StorageReadProvider, StorageWriteProvider};
use byteorder::{LittleEndian, ReadBytesExt};
use diskann::{
ANNError, ANNResult,
utils::{IntoUsize, VectorRepr},
};
use diskann_utils::io::Metadata;
use crate::utils::load_metadata_from_file;
pub(crate) trait SetData {
type Item;
fn set_data(&mut self, i: usize, element: &[Self::Item]) -> ANNResult<()>;
}
pub(crate) trait GetData {
type Element;
type Item<'a>: std::ops::Deref<Target = [Self::Element]>
where
Self: 'a;
fn get_data(&self, i: usize) -> ANNResult<Self::Item<'_>>;
fn total(&self) -> usize;
fn dim(&self) -> usize;
}
pub(crate) trait SetAdjacencyList {
type Item;
fn set_adjacency_list(&mut self, i: usize, element: &[Self::Item]) -> ANNResult<()>;
}
pub(crate) trait GetAdjacencyList {
type Element;
type Item<'a>: std::ops::Deref<Target = [Self::Element]>
where
Self: 'a;
fn get_adjacency_list(&self, i: usize) -> ANNResult<Self::Item<'_>>;
fn total(&self) -> usize;
fn additional_points(&self) -> u64;
fn max_degree(&self) -> Option<u32>;
}
pub(crate) fn load_from_bin<T, P, F, S>(provider: &P, path: &str, create: F) -> ANNResult<S>
where
P: StorageReadProvider,
F: FnOnce(usize, usize) -> ANNResult<S>,
S: SetData<Item = T>,
T: VectorRepr,
{
let metadata = load_metadata_from_file(provider, path).map_err(|err| {
ANNError::log_index_error(format_args!(
"failed to load data file \"{}\" due to the following error: {}",
path, err
))
})?;
tracing::info!(
"Loading {} vectors with dimension {} from storage system {} into dataset...",
metadata.npoints(),
metadata.ndims(),
path
);
let mut data = create(metadata.npoints(), metadata.ndims())?;
let itr = crate::utils::VectorDataIterator::<_, T>::new(path, None, provider)?;
for (i, (vector, _)) in itr.enumerate() {
data.set_data(i.into_usize(), &vector)?;
}
tracing::info!("Dataset loaded.");
Ok(data)
}
pub(crate) fn save_to_bin<S, T, P>(data: &S, provider: &P, path: &str) -> ANNResult<usize>
where
S: GetData<Element = T>,
T: bytemuck::Pod,
P: StorageWriteProvider,
{
let total = data.total();
let dim = data.dim();
let mut writer = provider.create_for_write(path)?;
let mut points_written: u32 = 0;
Metadata::new(points_written, dim)?.write(&mut writer)?;
for i in 0..total {
let binding = data.get_data(i)?;
let slice = &*binding;
let len = slice.len();
if len != dim {
return Err(ANNError::log_index_error(
"data provider returned a vector with a dimension other than advertised",
));
}
let reinterpret: &[u8] = bytemuck::must_cast_slice(slice);
writer.write_all(reinterpret)?;
points_written += 1;
}
writer.seek(std::io::SeekFrom::Start(0_u64))?;
writer.write_all(&points_written.to_le_bytes())?;
writer.flush()?;
let bytes_written = 2 * std::mem::size_of::<u32>()
+ points_written.into_usize() * dim * std::mem::size_of::<T>();
Ok(bytes_written)
}
pub(crate) fn load_graph<P, S, F>(provider: &P, path: &str, create: F) -> ANNResult<S>
where
P: StorageReadProvider,
S: SetAdjacencyList<Item = u32>,
F: FnOnce(usize, usize, usize) -> ANNResult<S>,
{
const METADATA_SIZE: usize = 24;
let mut file = BufReader::new(provider.open_reader(path)?);
let file_size = file.read_u64::<LittleEndian>()?.into_usize();
let max_degree = file.read_u32::<LittleEndian>()?.into_usize();
let start = file.read_u32::<LittleEndian>()?;
let num_start_points = file.read_u64::<LittleEndian>()?.into_usize();
let mut position = METADATA_SIZE;
let mut num_points: usize = 0;
while position < file_size {
num_points += 1;
let num_neighbors: i64 = file.read_u32::<LittleEndian>()?.into();
let seek_amount: i64 = num_neighbors * (std::mem::size_of::<u32>() as i64);
file.seek_relative(seek_amount)?;
position += std::mem::size_of::<u32>() + (seek_amount as usize);
}
tracing::info!("Num points: {}, max degree: {}", num_points, max_degree);
file.seek_relative(-((position - METADATA_SIZE) as i64))?;
let mut graph = create(num_points, max_degree, num_start_points)?;
position = METADATA_SIZE;
let mut buffer: Vec<u32> = vec![0; max_degree];
num_points = 0;
let mut num_edges = 0;
while position < file_size {
let num_neighbors = file.read_u32::<LittleEndian>()?;
if num_neighbors == 0 {
tracing::debug!("Point found with no out-neighbors, point# {}", num_points);
}
let buffer = &mut buffer[..num_neighbors.into_usize()];
file.read_u32_into::<LittleEndian>(buffer)?;
graph.set_adjacency_list(num_points, buffer)?;
position += std::mem::size_of::<u32>() * (1 + num_neighbors.into_usize());
num_edges += num_neighbors.into_usize();
num_points += 1;
}
tracing::info!(
"Done. Index has {} nodes and {} out-edges, _start is set to {}",
num_points,
num_edges,
start
);
Ok(graph)
}
pub(crate) fn save_graph<S, P>(
graph: &S,
provider: &P,
start_point: u32,
path: &str,
) -> ANNResult<usize>
where
S: GetAdjacencyList<Element = u32>,
P: StorageWriteProvider,
{
let file = provider.create_for_write(path)?;
let mut out = BufWriter::new(file);
let mut index_size: u64 = 24;
let mut observed_max_degree: u32 = 0;
out.write_all(&index_size.to_le_bytes())?;
out.write_all(&observed_max_degree.to_le_bytes())?; out.write_all(&start_point.to_le_bytes())?;
out.write_all(&graph.additional_points().to_le_bytes())?;
let total = graph.total();
for i in 0..total {
let binding = graph.get_adjacency_list(i)?;
let neighbors: &[u32] = &binding;
let num_neighbors: u32 = neighbors.len() as u32;
out.write_all(&num_neighbors.to_le_bytes())?;
neighbors
.iter()
.copied()
.try_for_each(|n| out.write_all(&n.to_le_bytes()))?;
observed_max_degree = observed_max_degree.max(num_neighbors);
index_size += (std::mem::size_of::<u32>() * (1 + neighbors.len())) as u64;
}
let max_degree = graph.max_degree().unwrap_or(observed_max_degree);
out.seek(SeekFrom::Start(0))?;
out.write_all(&index_size.to_le_bytes())?;
out.write_all(&max_degree.to_le_bytes())?;
out.flush()?;
Ok(index_size.into_usize())
}