use std::borrow::Cow;
use std::fmt;
use std::mem::size_of;
use bytemuck::{bytes_of, cast_slice, pod_read_unaligned};
use byteorder::{ByteOrder, NativeEndian};
use heed::{BoxedError, BytesDecode, BytesEncode};
use roaring::RoaringBitmap;
use crate::distance::Distance;
use crate::unaligned_vector::UnalignedVector;
use crate::{ItemId, NodeId};
#[derive(Clone, Debug)]
pub enum Node<'a, D: Distance> {
Leaf(Leaf<'a, D>),
Descendants(Descendants<'a>),
SplitPlaneNormal(SplitPlaneNormal<'a, D>),
}
const LEAF_TAG: u8 = 0;
const DESCENDANTS_TAG: u8 = 1;
const SPLIT_PLANE_NORMAL_TAG: u8 = 2;
impl<'a, D: Distance> Node<'a, D> {
pub fn leaf(self) -> Option<Leaf<'a, D>> {
if let Node::Leaf(leaf) = self {
Some(leaf)
} else {
None
}
}
}
pub struct Leaf<'a, D: Distance> {
pub header: D::Header,
pub vector: Cow<'a, UnalignedVector<D::VectorCodec>>,
}
impl<D: Distance> fmt::Debug for Leaf<'_, D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Leaf").field("header", &self.header).field("vector", &self.vector).finish()
}
}
impl<D: Distance> Clone for Leaf<'_, D> {
fn clone(&self) -> Self {
Self { header: self.header, vector: self.vector.clone() }
}
}
impl<D: Distance> Leaf<'_, D> {
pub fn into_owned(self) -> Leaf<'static, D> {
Leaf { header: self.header, vector: Cow::Owned(self.vector.into_owned()) }
}
}
#[derive(Clone)]
pub struct Descendants<'a> {
pub descendants: Cow<'a, RoaringBitmap>,
}
impl fmt::Debug for Descendants<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let descendants = self.descendants.iter().collect::<Vec<_>>();
f.debug_struct("Descendants").field("descendants", &descendants).finish()
}
}
#[derive(Clone)]
pub struct ItemIds<'a> {
bytes: &'a [u8],
}
impl<'a> ItemIds<'a> {
pub fn from_slice(slice: &[u32]) -> ItemIds<'_> {
ItemIds::from_bytes(cast_slice(slice))
}
pub fn from_bytes(bytes: &[u8]) -> ItemIds<'_> {
ItemIds { bytes }
}
pub fn raw_bytes(&self) -> &[u8] {
self.bytes
}
pub fn len(&self) -> usize {
self.bytes.len() / size_of::<ItemId>()
}
pub fn iter(&self) -> impl Iterator<Item = ItemId> + 'a {
self.bytes.chunks_exact(size_of::<ItemId>()).map(NativeEndian::read_u32)
}
}
impl fmt::Debug for ItemIds<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut list = f.debug_list();
self.iter().for_each(|integer| {
list.entry(&integer);
});
list.finish()
}
}
pub struct SplitPlaneNormal<'a, D: Distance> {
pub left: NodeId,
pub right: NodeId,
pub normal: Cow<'a, UnalignedVector<D::VectorCodec>>,
}
impl<D: Distance> fmt::Debug for SplitPlaneNormal<'_, D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = format!("SplitPlaneNormal<{}>", D::name());
f.debug_struct(&name)
.field("left", &self.left)
.field("right", &self.right)
.field("normal", &self.normal)
.finish()
}
}
impl<D: Distance> Clone for SplitPlaneNormal<'_, D> {
fn clone(&self) -> Self {
Self { left: self.left, right: self.right, normal: self.normal.clone() }
}
}
pub struct NodeCodec<D>(D);
impl<'a, D: Distance> BytesEncode<'a> for NodeCodec<D> {
type EItem = Node<'a, D>;
fn bytes_encode(item: &Self::EItem) -> Result<Cow<'a, [u8]>, BoxedError> {
let mut bytes = Vec::new();
match item {
Node::Leaf(Leaf { header, vector }) => {
bytes.push(LEAF_TAG);
bytes.extend_from_slice(bytes_of(header));
bytes.extend_from_slice(vector.as_bytes());
}
Node::SplitPlaneNormal(SplitPlaneNormal { normal, left, right }) => {
bytes.push(SPLIT_PLANE_NORMAL_TAG);
bytes.extend_from_slice(&left.to_bytes());
bytes.extend_from_slice(&right.to_bytes());
bytes.extend_from_slice(normal.as_bytes());
}
Node::Descendants(Descendants { descendants }) => {
bytes.push(DESCENDANTS_TAG);
descendants.serialize_into(&mut bytes)?;
}
}
Ok(Cow::Owned(bytes))
}
}
impl<'a, D: Distance> BytesDecode<'a> for NodeCodec<D> {
type DItem = Node<'a, D>;
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
match bytes {
[LEAF_TAG, bytes @ ..] => {
let (header_bytes, remaining) = bytes.split_at(size_of::<D::Header>());
let header = pod_read_unaligned(header_bytes);
let vector = UnalignedVector::<D::VectorCodec>::from_bytes(remaining)?;
Ok(Node::Leaf(Leaf { header, vector }))
}
[SPLIT_PLANE_NORMAL_TAG, bytes @ ..] => {
let (left, bytes) = NodeId::from_bytes(bytes);
let (right, bytes) = NodeId::from_bytes(bytes);
Ok(Node::SplitPlaneNormal(SplitPlaneNormal {
normal: UnalignedVector::<D::VectorCodec>::from_bytes(bytes)?,
left,
right,
}))
}
[DESCENDANTS_TAG, bytes @ ..] => Ok(Node::Descendants(Descendants {
descendants: Cow::Owned(RoaringBitmap::deserialize_from(bytes)?),
})),
[unknown_tag, ..] => {
Err(Box::new(InvalidNodeDecoding { unknown_tag: Some(*unknown_tag) }))
}
[] => Err(Box::new(InvalidNodeDecoding { unknown_tag: None })),
}
}
}
#[derive(Debug, thiserror::Error)]
pub struct InvalidNodeDecoding {
unknown_tag: Option<u8>,
}
impl fmt::Display for InvalidNodeDecoding {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.unknown_tag {
Some(unknown_tag) => write!(f, "Invalid node decoding: unknown tag {unknown_tag}"),
None => write!(f, "Invalid node decoding: empty array of bytes"),
}
}
}