use crate::distributed_array::rpc::ShardRpcDispatch;
use crate::error::{ClusterError, Result};
use crate::wire::VShardEnvelope;
#[derive(Debug, Clone)]
pub struct ShardMessage {
pub kind: ShardMessageKind,
pub payload: Vec<u8>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShardMessageKind {
CompassCoarseDescriptor,
SpireCentroidTable,
}
#[derive(Debug, Clone)]
pub struct ShardMessageReply {
pub payload: Vec<u8>,
}
#[derive(Debug, thiserror::Error)]
pub enum VectorSeamError {
#[error("direct shard-to-shard messaging is not supported by this implementation")]
DirectMessagingUnsupported,
#[error("build-time exchange failed for shard {peer_shard}: {detail}")]
BuildExchangeFailed { peer_shard: u32, detail: String },
#[error("cluster transport error during seam call: {0}")]
Transport(#[from] ClusterError),
}
#[derive(Debug, Clone, Copy)]
pub struct MemoryRegion {
pub remote_addr: u64,
pub rkey: u32,
pub len: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ShardRef {
pub node_id: u64,
pub vshard_id: u32,
}
#[derive(Debug, Clone)]
pub enum ShardSubset {
All,
Subset(Vec<u32>),
}
impl ShardSubset {
pub fn resolve<'a>(&'a self, all_shards: &'a [u32]) -> &'a [u32] {
match self {
ShardSubset::All => all_shards,
ShardSubset::Subset(ids) => ids.as_slice(),
}
}
}
pub trait VectorShardSeam: Send + Sync + 'static {
fn select_shards(&self, _query_vector: &[f32], _all_shards: &[u32]) -> ShardSubset {
ShardSubset::All
}
fn build_time_exchange(&self, _peer: ShardRef, _dispatch: &dyn ShardRpcDispatch) -> Result<()> {
Ok(())
}
fn direct_message(
&self,
peer: ShardRef,
msg: ShardMessage,
dispatch: &dyn ShardRpcDispatch,
source_node: u64,
) -> impl std::future::Future<Output = std::result::Result<ShardMessageReply, VectorSeamError>> + Send
{
let envelope = build_message_envelope(source_node, peer, &msg);
async move {
let reply = dispatch
.call(envelope, 5_000)
.await
.map_err(VectorSeamError::Transport)?;
Ok(ShardMessageReply {
payload: reply.payload,
})
}
}
fn exposed_region(&self) -> Option<MemoryRegion> {
None
}
}
fn build_message_envelope(source_node: u64, peer: ShardRef, msg: &ShardMessage) -> VShardEnvelope {
use crate::wire::{VShardEnvelope, VShardMessageType, WIRE_VERSION};
let msg_type = match msg.kind {
ShardMessageKind::CompassCoarseDescriptor => VShardMessageType::VectorCoarseRouteRequest,
ShardMessageKind::SpireCentroidTable => VShardMessageType::VectorBuildExchangeRequest,
};
VShardEnvelope {
version: WIRE_VERSION,
msg_type,
source_node,
target_node: peer.node_id,
vshard_id: peer.vshard_id,
payload: msg.payload.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
struct DefaultSeam;
impl VectorShardSeam for DefaultSeam {}
#[test]
fn default_select_shards_returns_all() {
let seam = DefaultSeam;
let all = [0u32, 1, 2, 3];
let result = seam.select_shards(&[0.1, 0.2], &all);
assert!(matches!(result, ShardSubset::All));
assert_eq!(result.resolve(&all), &all);
}
#[test]
fn default_exposed_region_is_none() {
let seam = DefaultSeam;
assert!(seam.exposed_region().is_none());
}
#[test]
fn default_build_time_exchange_is_noop() {
struct MockDispatch;
#[async_trait::async_trait]
impl crate::distributed_array::rpc::ShardRpcDispatch for MockDispatch {
async fn call(
&self,
_req: VShardEnvelope,
_timeout_ms: u64,
) -> crate::error::Result<VShardEnvelope> {
Err(crate::error::ClusterError::Transport {
detail: "mock".into(),
})
}
}
let seam = DefaultSeam;
let peer = ShardRef {
node_id: 2,
vshard_id: 5,
};
let dispatch = MockDispatch;
assert!(seam.build_time_exchange(peer, &dispatch).is_ok());
}
#[test]
fn shard_subset_resolve_subset() {
let all = [0u32, 1, 2, 3, 4];
let subset = ShardSubset::Subset(vec![1, 3]);
assert_eq!(subset.resolve(&all), &[1u32, 3]);
}
#[test]
fn shard_subset_resolve_all() {
let all = [0u32, 1, 2];
let subset = ShardSubset::All;
assert_eq!(subset.resolve(&all), &[0u32, 1, 2]);
}
#[test]
fn memory_region_fields() {
let region = MemoryRegion {
remote_addr: 0xDEAD_BEEF_0000_0000,
rkey: 42,
len: 1024 * 1024,
};
assert_eq!(region.rkey, 42);
assert_eq!(region.len, 1024 * 1024);
}
#[test]
fn build_message_envelope_compass() {
use crate::wire::VShardMessageType;
let peer = ShardRef {
node_id: 7,
vshard_id: 3,
};
let msg = ShardMessage {
kind: ShardMessageKind::CompassCoarseDescriptor,
payload: vec![1, 2, 3],
};
let env = build_message_envelope(1, peer, &msg);
assert_eq!(env.msg_type, VShardMessageType::VectorCoarseRouteRequest);
assert_eq!(env.target_node, 7);
assert_eq!(env.vshard_id, 3);
assert_eq!(env.payload, vec![1, 2, 3]);
}
#[test]
fn build_message_envelope_spire() {
use crate::wire::VShardMessageType;
let peer = ShardRef {
node_id: 9,
vshard_id: 1,
};
let msg = ShardMessage {
kind: ShardMessageKind::SpireCentroidTable,
payload: vec![0xFF, 0xAA],
};
let env = build_message_envelope(2, peer, &msg);
assert_eq!(env.msg_type, VShardMessageType::VectorBuildExchangeRequest);
assert_eq!(env.source_node, 2);
}
}