use std::sync::Arc;
use crate::circuit_breaker::CircuitBreaker;
use crate::error::{ClusterError, Result};
use super::super::merge::{ArrayAggPartial, merge_slice_rows, reduce_agg_partials};
use super::super::rpc::ShardRpcDispatch;
use super::super::scatter::{FanOutParams, FanOutPartitionedParams, fan_out, fan_out_partitioned};
use super::super::wire::{
ArrayShardAggReq, ArrayShardAggResp, ArrayShardDeleteReq, ArrayShardDeleteResp,
ArrayShardSliceReq, ArrayShardSliceResp, ArrayShardSurrogateBitmapReq,
ArrayShardSurrogateBitmapResp,
};
pub struct ArrayCoordParams {
pub source_node: u64,
pub shard_ids: Vec<u32>,
pub timeout_ms: u64,
pub prefix_bits: u8,
pub slice_hilbert_ranges: Vec<(u64, u64)>,
}
#[derive(Debug, Clone, Default)]
pub struct CoordSliceResult {
pub rows: Vec<Vec<u8>>,
pub truncated_before_horizon: bool,
}
pub(super) fn shard_hilbert_range_for_vshard(shard_id: u32, prefix_bits: u8) -> (u64, u64) {
use crate::routing::VSHARD_COUNT;
let stride = (VSHARD_COUNT >> (prefix_bits as u32)).max(1);
let bucket = shard_id / stride;
let shift = 64u8.saturating_sub(prefix_bits);
let lo = (bucket as u64) << shift;
let hi = if shift == 0 {
u64::MAX
} else {
lo.saturating_add((1u64 << shift).saturating_sub(1))
};
(lo, hi)
}
pub struct ArrayCoordinator {
pub(super) params: ArrayCoordParams,
pub(super) dispatch: Arc<dyn ShardRpcDispatch>,
pub(super) circuit_breaker: Arc<CircuitBreaker>,
}
impl ArrayCoordinator {
pub fn new(
params: ArrayCoordParams,
dispatch: Arc<dyn ShardRpcDispatch>,
circuit_breaker: Arc<CircuitBreaker>,
) -> Self {
Self {
params,
dispatch,
circuit_breaker,
}
}
pub fn for_slice(
source_node: u64,
timeout_ms: u64,
slice_hilbert_ranges: &[(u64, u64)],
prefix_bits: u8,
total_shards: u32,
dispatch: Arc<dyn ShardRpcDispatch>,
circuit_breaker: Arc<CircuitBreaker>,
) -> crate::error::Result<Self> {
let shard_ids = super::super::routing::array_vshards_for_slice(
slice_hilbert_ranges,
prefix_bits,
total_shards,
)?;
Ok(Self {
params: ArrayCoordParams {
source_node,
shard_ids,
timeout_ms,
prefix_bits,
slice_hilbert_ranges: slice_hilbert_ranges.to_vec(),
},
dispatch,
circuit_breaker,
})
}
pub async fn coord_slice(
&self,
req: ArrayShardSliceReq,
coordinator_limit: u32,
) -> Result<CoordSliceResult> {
let prefix_bits = self.params.prefix_bits;
let per_shard: Vec<(u32, Vec<u8>)> = self
.params
.shard_ids
.iter()
.map(|&shard_id| {
let shard_hilbert_range = if prefix_bits > 0 {
Some(shard_hilbert_range_for_vshard(shard_id, prefix_bits))
} else {
None
};
let per_shard_req = ArrayShardSliceReq {
prefix_bits,
slice_hilbert_ranges: self.params.slice_hilbert_ranges.clone(),
shard_hilbert_range,
..req.clone()
};
let bytes =
zerompk::to_msgpack_vec(&per_shard_req).map_err(|e| ClusterError::Codec {
detail: format!("ArrayShardSliceReq serialise: {e}"),
})?;
Ok((shard_id, bytes))
})
.collect::<Result<Vec<_>>>()?;
let fo_params = FanOutPartitionedParams {
source_node: self.params.source_node,
timeout_ms: self.params.timeout_ms,
};
let raw = fan_out_partitioned(
&fo_params,
super::super::opcodes::ARRAY_SHARD_SLICE_REQ,
&per_shard,
&self.dispatch,
&self.circuit_breaker,
)
.await?;
let resps = decode_resps::<ArrayShardSliceResp>(&raw)?;
let truncated_before_horizon =
super::super::merge::any_truncated_before_horizon_slice(&resps);
let rows = merge_slice_rows(&resps, coordinator_limit);
Ok(CoordSliceResult {
rows,
truncated_before_horizon,
})
}
pub async fn coord_agg(&self, req: ArrayShardAggReq) -> Result<Vec<ArrayAggPartial>> {
let prefix_bits = self.params.prefix_bits;
let per_shard: Vec<(u32, Vec<u8>)> = self
.params
.shard_ids
.iter()
.map(|&shard_id| {
let hilbert_range = if prefix_bits > 0 {
Some(shard_hilbert_range_for_vshard(shard_id, prefix_bits))
} else {
None
};
let per_shard_req = ArrayShardAggReq {
shard_hilbert_range: hilbert_range,
..req.clone()
};
let bytes =
zerompk::to_msgpack_vec(&per_shard_req).map_err(|e| ClusterError::Codec {
detail: format!("ArrayShardAggReq serialise: {e}"),
})?;
Ok((shard_id, bytes))
})
.collect::<Result<Vec<_>>>()?;
let fo_params = FanOutPartitionedParams {
source_node: self.params.source_node,
timeout_ms: self.params.timeout_ms,
};
let raw = fan_out_partitioned(
&fo_params,
super::super::opcodes::ARRAY_SHARD_AGG_REQ,
&per_shard,
&self.dispatch,
&self.circuit_breaker,
)
.await?;
let resps = decode_resps::<ArrayShardAggResp>(&raw)?;
Ok(reduce_agg_partials(&resps))
}
pub async fn coord_delete(
&self,
req: ArrayShardDeleteReq,
) -> Result<Vec<ArrayShardDeleteResp>> {
let req_bytes = zerompk::to_msgpack_vec(&req).map_err(|e| ClusterError::Codec {
detail: format!("ArrayShardDeleteReq serialise: {e}"),
})?;
let raw = fan_out(
&self.fan_out_params(),
super::super::opcodes::ARRAY_SHARD_DELETE_REQ,
&req_bytes,
&self.dispatch,
&self.circuit_breaker,
)
.await?;
decode_resps::<ArrayShardDeleteResp>(&raw)
}
pub async fn coord_surrogate_bitmap_scan(
&self,
req: ArrayShardSurrogateBitmapReq,
) -> Result<Vec<ArrayShardSurrogateBitmapResp>> {
let req_bytes = zerompk::to_msgpack_vec(&req).map_err(|e| ClusterError::Codec {
detail: format!("ArrayShardSurrogateBitmapReq serialise: {e}"),
})?;
let raw = fan_out(
&self.fan_out_params(),
super::super::opcodes::ARRAY_SHARD_SURROGATE_BITMAP_REQ,
&req_bytes,
&self.dispatch,
&self.circuit_breaker,
)
.await?;
decode_resps::<ArrayShardSurrogateBitmapResp>(&raw)
}
pub(super) fn fan_out_params(&self) -> FanOutParams {
FanOutParams {
shard_ids: self.params.shard_ids.clone(),
timeout_ms: self.params.timeout_ms,
source_node: self.params.source_node,
}
}
}
pub(super) fn decode_resps<T>(raw: &[(u32, Vec<u8>)]) -> Result<Vec<T>>
where
T: for<'a> zerompk::FromMessagePack<'a>,
{
raw.iter()
.map(|(_, bytes)| {
zerompk::from_msgpack(bytes).map_err(|e| ClusterError::Codec {
detail: format!("array response deserialise: {e}"),
})
})
.collect()
}