use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use anyhow::{Result, anyhow, bail};
use futures::future::BoxFuture;
use kvbm_common::LogicalLayoutHandle;
use kvbm_engine::object::ObjectBlockOps;
use kvbm_engine::worker::{
ConnectRemoteResponse, ImportMetadataResponse, RemoteDescriptor, SerializedLayoutResponse,
Worker, WorkerTransfers,
};
use kvbm_engine::{BlockId, InstanceId, SequenceHash};
use kvbm_physical::manager::{LayoutHandle, SerializedLayout};
use kvbm_physical::transfer::{PhysicalLayout, TransferCompleteNotification, TransferOptions};
use tokio::sync::Notify;
use velo::{Event, EventManager};
use super::bandwidth_sharing_model::{BandwidthSharingModel, TransferId};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferDirection {
G1ToG2,
G2ToG1,
}
struct PipelineAwaiter {
event: Event,
direction: TransferDirection,
num_blocks: usize,
}
pub(crate) struct TransferState {
offload_bw: BandwidthSharingModel,
onboard_bw: BandwidthSharingModel,
awaiters: HashMap<TransferId, PipelineAwaiter>,
swap_in_flags: HashMap<TransferId, Arc<std::sync::atomic::AtomicBool>>,
}
impl TransferState {
fn new(offload_gbps: f64, onboard_gbps: f64) -> Self {
let id_counter = Arc::new(std::sync::atomic::AtomicU64::new(0));
Self {
offload_bw: BandwidthSharingModel::new(offload_gbps, id_counter.clone()),
onboard_bw: BandwidthSharingModel::new(onboard_gbps, id_counter),
awaiters: HashMap::new(),
swap_in_flags: HashMap::new(),
}
}
}
fn ms_to_us(ms: f64) -> u64 {
(ms.max(0.0) * 1000.0) as u64
}
fn us_to_ms(us: u64) -> f64 {
(us as f64) / 1000.0
}
pub struct MockWorker {
now_us: Arc<AtomicU64>,
pub(crate) state: Arc<Mutex<TransferState>>,
event_manager: EventManager,
reservation_count: AtomicU64,
reservation_notify: Arc<Notify>,
block_bytes: usize,
g1_handle: Option<LayoutHandle>,
g2_handle: Option<LayoutHandle>,
}
impl MockWorker {
pub fn new(
block_bytes: usize,
offload_gbps: f64,
onboard_gbps: f64,
g1_handle: Option<LayoutHandle>,
g2_handle: Option<LayoutHandle>,
) -> Self {
Self {
now_us: Arc::new(AtomicU64::new(0)),
state: Arc::new(Mutex::new(TransferState::new(offload_gbps, onboard_gbps))),
event_manager: EventManager::local(),
reservation_count: AtomicU64::new(0),
reservation_notify: Arc::new(Notify::new()),
block_bytes,
g1_handle,
g2_handle,
}
}
pub fn set_now_ms(&self, now_ms: f64) {
self.now_us.store(ms_to_us(now_ms), Ordering::Release);
}
pub fn now_ms(&self) -> f64 {
us_to_ms(self.now_us.load(Ordering::Acquire))
}
pub fn reservation_count(&self) -> u64 {
self.reservation_count.load(Ordering::Acquire)
}
pub(crate) fn reservation_notifier(&self) -> Arc<Notify> {
self.reservation_notify.clone()
}
pub fn drain_completions(&self, now_ms: f64) -> (usize, usize, usize, usize) {
let mut state = self.state.lock().expect("TransferState mutex poisoned");
Self::drain_locked(&mut state, now_ms)
}
fn drain_locked(state: &mut TransferState, now_ms: f64) -> (usize, usize, usize, usize) {
let offload_before = state.offload_bw.active_count();
let onboard_before = state.onboard_bw.active_count();
let offload_drained = state.offload_bw.advance_to(now_ms);
let onboard_drained = state.onboard_bw.advance_to(now_ms);
let offload_drained_count = offload_drained.len();
let onboard_drained_count = onboard_drained.len();
let drained: Vec<TransferId> = offload_drained.into_iter().chain(onboard_drained).collect();
tracing::debug!(
now_ms,
offload_active_before = offload_before,
onboard_active_before = onboard_before,
drained_count = drained.len(),
offload_drained_count,
onboard_drained_count,
awaiter_map_size = state.awaiters.len(),
"kvbm-offload: drain transfer completions"
);
let mut awaiter_fired = 0usize;
let mut offload_awaiter_blocks = 0usize;
let mut onboard_awaiter_blocks = 0usize;
let mut swap_in_flipped = 0usize;
for id in drained {
if let Some(awaiter) = state.awaiters.remove(&id) {
match awaiter.direction {
TransferDirection::G1ToG2 => offload_awaiter_blocks += awaiter.num_blocks,
TransferDirection::G2ToG1 => onboard_awaiter_blocks += awaiter.num_blocks,
}
let _ = awaiter.event.trigger();
awaiter_fired += 1;
}
if let Some(flag) = state.swap_in_flags.remove(&id) {
flag.store(true, Ordering::Release);
swap_in_flipped += 1;
}
}
tracing::debug!(
awaiter_fired,
offload_awaiter_blocks,
onboard_awaiter_blocks,
swap_in_flipped,
"kvbm-offload: fired completed transfer waiters"
);
(
offload_drained_count,
onboard_drained_count,
offload_awaiter_blocks,
onboard_awaiter_blocks,
)
}
pub fn reserve_swap_in(
&self,
now_ms: f64,
num_blocks: usize,
complete: Arc<std::sync::atomic::AtomicBool>,
) -> TransferId {
let bytes = num_blocks.saturating_mul(self.block_bytes);
let mut state = self.state.lock().expect("TransferState mutex poisoned");
Self::drain_locked(&mut state, now_ms);
let id = state.onboard_bw.start_transfer(now_ms, bytes);
state.swap_in_flags.insert(id, complete);
id
}
pub fn earliest_finish(&self) -> Option<f64> {
let state = self.state.lock().expect("TransferState mutex poisoned");
state
.offload_bw
.earliest_finish()
.into_iter()
.chain(state.onboard_bw.earliest_finish())
.reduce(f64::min)
}
fn reserve_transfer(
&self,
direction: TransferDirection,
now_ms: f64,
num_blocks: usize,
) -> Result<TransferCompleteNotification> {
let bytes = num_blocks.saturating_mul(self.block_bytes);
let mut state = self.state.lock().expect("TransferState mutex poisoned");
Self::drain_locked(&mut state, now_ms);
let id = match direction {
TransferDirection::G1ToG2 => state.offload_bw.start_transfer(now_ms, bytes),
TransferDirection::G2ToG1 => state.onboard_bw.start_transfer(now_ms, bytes),
};
self.reservation_count.fetch_add(1, Ordering::AcqRel);
self.reservation_notify.notify_waiters();
let event = self
.event_manager
.new_event()
.map_err(|e| anyhow!("MockWorker: failed to allocate velo event: {e}"))?;
let awaiter = event
.awaiter()
.map_err(|e| anyhow!("MockWorker: failed to build event awaiter: {e}"))?;
state.awaiters.insert(
id,
PipelineAwaiter {
event,
direction,
num_blocks,
},
);
drop(state);
Ok(TransferCompleteNotification::from_awaiter(awaiter))
}
}
fn infer_direction(
src: LogicalLayoutHandle,
dst: LogicalLayoutHandle,
) -> Result<TransferDirection> {
match (src, dst) {
(LogicalLayoutHandle::G1, LogicalLayoutHandle::G2) => Ok(TransferDirection::G1ToG2),
(LogicalLayoutHandle::G2, LogicalLayoutHandle::G1) => Ok(TransferDirection::G2ToG1),
(s, d) => bail!(
"MockWorker only simulates G1↔G2 transfers; got src={:?} dst={:?}",
s,
d
),
}
}
impl WorkerTransfers for MockWorker {
fn execute_local_transfer(
&self,
src: LogicalLayoutHandle,
dst: LogicalLayoutHandle,
src_block_ids: Arc<[BlockId]>,
_dst_block_ids: Arc<[BlockId]>,
_options: TransferOptions,
) -> Result<TransferCompleteNotification> {
let direction = infer_direction(src, dst)?;
let now_ms = self.now_ms();
self.reserve_transfer(direction, now_ms, src_block_ids.len())
}
fn execute_remote_onboard(
&self,
_src: RemoteDescriptor,
_dst: LogicalLayoutHandle,
_dst_block_ids: Arc<[BlockId]>,
_options: TransferOptions,
) -> Result<TransferCompleteNotification> {
bail!("MockWorker: execute_remote_onboard not supported (mocker simulates G1↔G2 only)")
}
fn execute_remote_offload(
&self,
_src: LogicalLayoutHandle,
_src_block_ids: Arc<[BlockId]>,
_dst: RemoteDescriptor,
_options: TransferOptions,
) -> Result<TransferCompleteNotification> {
bail!("MockWorker: execute_remote_offload not supported")
}
fn connect_remote(
&self,
_instance_id: InstanceId,
_metadata: Vec<SerializedLayout>,
) -> Result<ConnectRemoteResponse> {
bail!("MockWorker: connect_remote not supported")
}
fn has_remote_metadata(&self, _instance_id: InstanceId) -> bool {
false
}
fn execute_remote_onboard_for_instance(
&self,
_instance_id: InstanceId,
_remote_logical_type: LogicalLayoutHandle,
_src_block_ids: Vec<BlockId>,
_dst: LogicalLayoutHandle,
_dst_block_ids: Arc<[BlockId]>,
_options: TransferOptions,
) -> Result<TransferCompleteNotification> {
bail!("MockWorker: execute_remote_onboard_for_instance not supported")
}
}
impl Worker for MockWorker {
fn g1_handle(&self) -> Option<LayoutHandle> {
self.g1_handle
}
fn g2_handle(&self) -> Option<LayoutHandle> {
self.g2_handle
}
fn g3_handle(&self) -> Option<LayoutHandle> {
None
}
fn export_metadata(&self) -> Result<SerializedLayoutResponse> {
bail!("MockWorker: export_metadata not supported (mocker is single-instance)")
}
fn import_metadata(&self, _metadata: SerializedLayout) -> Result<ImportMetadataResponse> {
bail!("MockWorker: import_metadata not supported (mocker is single-instance)")
}
}
impl ObjectBlockOps for MockWorker {
fn has_blocks(
&self,
keys: Vec<SequenceHash>,
) -> BoxFuture<'static, Vec<(SequenceHash, Option<usize>)>> {
Box::pin(async move { keys.into_iter().map(|k| (k, None)).collect() })
}
fn put_blocks(
&self,
keys: Vec<SequenceHash>,
_src_layout: LogicalLayoutHandle,
_block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
Box::pin(async move { keys.into_iter().map(Err).collect() })
}
fn get_blocks(
&self,
keys: Vec<SequenceHash>,
_dst_layout: LogicalLayoutHandle,
_block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
Box::pin(async move { keys.into_iter().map(Err).collect() })
}
fn put_blocks_with_layout(
&self,
keys: Vec<SequenceHash>,
_layout: PhysicalLayout,
_block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
Box::pin(async move { keys.into_iter().map(Err).collect() })
}
fn get_blocks_with_layout(
&self,
keys: Vec<SequenceHash>,
_layout: PhysicalLayout,
_block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
Box::pin(async move { keys.into_iter().map(Err).collect() })
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-6;
fn make_worker() -> MockWorker {
MockWorker::new(1_000_000, 1.0, 1.0, None, None)
}
#[tokio::test]
async fn mock_worker_single_transfer_completes_on_tick() {
let worker = make_worker();
worker.set_now_ms(0.0);
let src_ids: Arc<[BlockId]> = Arc::from(vec![0usize]);
let dst_ids: Arc<[BlockId]> = Arc::from(vec![0usize]);
let notification = worker
.execute_local_transfer(
LogicalLayoutHandle::G1,
LogicalLayoutHandle::G2,
src_ids,
dst_ids,
TransferOptions::default(),
)
.expect("reservation should succeed");
assert!(notification.could_yield());
worker.drain_completions(1.0);
notification
.await
.expect("transfer notification should resolve Ok after drain");
}
#[tokio::test]
async fn mock_worker_two_concurrent_transfers_complete_at_2x() {
let worker = make_worker();
worker.set_now_ms(0.0);
let mk_ids = || -> Arc<[BlockId]> { Arc::from(vec![0usize]) };
let n1 = worker
.execute_local_transfer(
LogicalLayoutHandle::G1,
LogicalLayoutHandle::G2,
mk_ids(),
mk_ids(),
TransferOptions::default(),
)
.unwrap();
let n2 = worker
.execute_local_transfer(
LogicalLayoutHandle::G1,
LogicalLayoutHandle::G2,
mk_ids(),
mk_ids(),
TransferOptions::default(),
)
.unwrap();
worker.drain_completions(1.0);
assert!(n1.could_yield());
assert!(n2.could_yield());
worker.drain_completions(2.0);
n1.await.expect("n1 should resolve Ok");
n2.await.expect("n2 should resolve Ok");
}
#[tokio::test]
async fn mock_worker_rejects_unsupported_directions() {
let worker = make_worker();
worker.set_now_ms(0.0);
let ids: Arc<[BlockId]> = Arc::from(vec![0usize]);
let result = worker.execute_local_transfer(
LogicalLayoutHandle::G2,
LogicalLayoutHandle::G3,
ids.clone(),
ids,
TransferOptions::default(),
);
let err = match result {
Ok(_) => panic!("G2→G3 must be rejected"),
Err(e) => e,
};
let msg = err.to_string();
assert!(msg.contains("G1↔G2"), "unexpected error: {msg}");
}
#[tokio::test]
async fn mock_worker_offload_and_swap_in_share_id_keyspace() {
use std::sync::atomic::AtomicBool;
let worker = make_worker();
worker.set_now_ms(0.0);
let swap_id = worker.reserve_swap_in(0.0, 1, Arc::new(AtomicBool::new(false)));
let ids: Arc<[BlockId]> = Arc::from(vec![0usize]);
let _offload = worker
.execute_local_transfer(
LogicalLayoutHandle::G1,
LogicalLayoutHandle::G2,
ids.clone(),
ids,
TransferOptions::default(),
)
.unwrap();
let state = worker.state.lock().unwrap();
let awaiter_id = *state
.awaiters
.keys()
.next()
.expect("offload must register an awaiter");
assert_ne!(
awaiter_id, swap_id,
"offload and swap-in must draw distinct TransferIds"
);
}
#[tokio::test]
async fn mock_worker_swap_in_flag_flips_on_drain() {
use std::sync::atomic::{AtomicBool, Ordering};
let worker = make_worker();
worker.set_now_ms(0.0);
let complete = Arc::new(AtomicBool::new(false));
let _id = worker.reserve_swap_in(0.0, 1, complete.clone());
assert!(!complete.load(Ordering::Acquire));
worker.drain_completions(0.5);
assert!(
!complete.load(Ordering::Acquire),
"swap-in must not complete before its finish time"
);
worker.drain_completions(1.0);
assert!(
complete.load(Ordering::Acquire),
"swap-in flag must flip after drain past finish time"
);
}
#[tokio::test]
async fn mock_worker_earliest_finish_min_of_both_links() {
let worker = make_worker();
worker.set_now_ms(0.0);
let ids: Arc<[BlockId]> = Arc::from(vec![0usize]);
let _n1 = worker
.execute_local_transfer(
LogicalLayoutHandle::G1,
LogicalLayoutHandle::G2,
ids.clone(),
ids.clone(),
TransferOptions::default(),
)
.unwrap();
let _n2 = worker
.execute_local_transfer(
LogicalLayoutHandle::G2,
LogicalLayoutHandle::G1,
ids.clone(),
ids,
TransferOptions::default(),
)
.unwrap();
let earliest = worker.earliest_finish().unwrap();
assert!(
(earliest - 1.0).abs() < EPS,
"expected 1.0 ms, got {earliest}"
);
}
}