use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, Ordering};
use super::protocol::*;
use super::*;
use tokio::sync::mpsc;
pub const DISCONNECTED_WARNING: &str =
"runtime error: connections between components were lost; likely tearing down";
#[derive(Debug, thiserror::Error)]
pub enum SchedulerError {
#[error("runtime error: connections between components were lost; likely tearing down")]
Disconnected,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SchedulingDecision {
Execute,
Cancel,
}
#[derive(Clone)]
pub struct TransferSchedulerClient {
scheduler_tx: mpsc::Sender<TransferToSchedulerMessage>,
}
impl TransferSchedulerClient {
pub fn new(scheduler_tx: mpsc::Sender<TransferToSchedulerMessage>) -> Self {
Self { scheduler_tx }
}
#[tracing::instrument(level = "debug", skip_all, fields(request_id = %request.request_id, operation_id = %request.uuid))]
pub async fn schedule_transfer(
self,
request: LeaderTransferRequest,
) -> anyhow::Result<Box<dyn TransferCompletionHandle>> {
let scheduler_tx = self.scheduler_tx.clone();
match request.request_type {
RequestType::Immediate => {
let handle = ImmediateTransferCompletionHandle::new(
request.request_id,
request.uuid,
scheduler_tx.clone(),
);
Ok(Box::new(handle))
}
RequestType::Scheduled => {
let (response_tx, response_rx) = oneshot::channel();
let request = TransferScheduleRequest {
leader_request: request,
response_tx,
};
tracing::debug!("sending schedule request to scheduler");
scheduler_tx
.send(TransferToSchedulerMessage::ScheduleRequest(request))
.await?;
tracing::debug!("awaiting response from scheduler");
let handle = response_rx.await?.wait_for_decision().await;
tracing::debug!(
"received scheduler decision: {:?}",
handle.scheduler_decision()
);
Ok(handle)
}
}
}
}
pub struct WorkerSchedulerClient {
slots: HashMap<String, WorkerSchedulerClientSlot>,
scheduler_tx: mpsc::UnboundedSender<SchedulerMessage>,
iteration: u64,
iteration_complete: bool,
layers_complete: u32,
}
impl WorkerSchedulerClient {
pub fn new(
scheduler_tx: mpsc::UnboundedSender<SchedulerMessage>,
_cancel_token: CancellationToken,
) -> Self {
Self {
slots: HashMap::new(),
scheduler_tx,
iteration: 0,
iteration_complete: true,
layers_complete: 0,
}
}
pub fn iteration(&self) -> u64 {
self.iteration
}
pub fn start_next_iteration(&mut self) -> Result<(), SchedulerError> {
self.iteration += 1;
self.iteration_complete = false;
self.layers_complete = 0;
self.scheduler_tx
.send(SchedulerMessage::StartIteration(self.iteration))
.map_err(|_| SchedulerError::Disconnected)
}
pub fn mark_layer_complete(&mut self, layer_name: String) -> Result<(), SchedulerError> {
debug_assert!(
!self.iteration_complete,
"iteration must be complete before marking a layer as complete"
);
self.layers_complete += 1;
self.scheduler_tx
.send(SchedulerMessage::UpdateLayersCompleted(
layer_name,
self.layers_complete,
))
.map_err(|_| SchedulerError::Disconnected)
}
pub fn mark_iteration_complete(&mut self) -> Result<(), SchedulerError> {
debug_assert!(
!self.iteration_complete,
"iteration must be complete before marking it as complete"
);
self.iteration_complete = true;
self.scheduler_tx
.send(SchedulerMessage::EndIteration(self.iteration))
.map_err(|_| SchedulerError::Disconnected)
}
}
#[derive(Debug, Default)]
pub struct WorkerSchedulerClientSlot {
operations: Vec<uuid::Uuid>,
completed: Arc<AtomicU64>,
}
impl WorkerSchedulerClientSlot {
fn new() -> Self {
Self {
operations: Vec::new(),
completed: Arc::new(AtomicU64::new(0)),
}
}
fn make_scheduler_slot_request(
&self,
request_id: String,
expected_immediate_ops: u64,
) -> SchedulerCreateSlotDetails {
SchedulerCreateSlotDetails {
request_id,
completed: self.completed.clone(),
expected_immediate_ops,
}
}
pub fn is_complete(&self) -> bool {
self.completed.load(Ordering::Relaxed) == self.operations.len() as u64
}
}
impl WorkerSchedulerClient {
pub fn create_slot_with_immediate_ops(
&mut self,
request_id: String,
expected_immediate_ops: u64,
) -> Result<(), SchedulerError> {
let slot = WorkerSchedulerClientSlot::new();
let request = slot.make_scheduler_slot_request(request_id.clone(), expected_immediate_ops);
self.slots.insert(request_id.clone(), slot);
self.scheduler_tx
.send(SchedulerMessage::CreateSlot(request))
.map_err(|_| SchedulerError::Disconnected)?;
Ok(())
}
pub fn create_slot(&mut self, request_id: String) -> Result<(), SchedulerError> {
self.create_slot_with_immediate_ops(request_id, 0)
}
pub fn remove_slot(&mut self, request_id: &String) {
let slot = self.slots.remove(request_id).expect("slot does not exist");
assert!(slot.is_complete());
self.scheduler_tx
.send(SchedulerMessage::RequestFinished(request_id.clone()))
.expect("failed to send request finished message; disconnected");
}
pub fn enqueue_request(&mut self, request: WorkerTransferRequest) {
debug_assert!(
self.slots.contains_key(&request.request_id),
"slot does not exist"
);
let slot = self
.slots
.get_mut(&request.request_id)
.expect("slot does not exist");
slot.operations.push(request.uuid);
match request.request_type {
RequestType::Immediate => {}
RequestType::Scheduled => {
self.scheduler_tx
.send(SchedulerMessage::EnqueueRequest(request))
.expect("failed to enqueue request; disconnected");
}
}
}
pub fn has_slot(&self, request_id: &str) -> bool {
self.slots.contains_key(request_id)
}
pub fn is_complete(&self, request_id: &str) -> bool {
match self.slots.get(request_id) {
Some(slot) => slot.is_complete(),
None => true,
}
}
pub fn get_scheduler_tx(&self) -> mpsc::UnboundedSender<SchedulerMessage> {
self.scheduler_tx.clone()
}
pub fn record_operation(&mut self, request_id: &str, uuid: uuid::Uuid) {
let slot = self.slots.get_mut(request_id).expect("slot does not exist");
slot.operations.push(uuid);
}
}
pub type Iteration = u64;
pub type LayerName = String;
pub type LayerIndex = u32;
pub enum SchedulerMessage {
CreateSlot(SchedulerCreateSlotDetails),
EnqueueRequest(WorkerTransferRequest),
StartIteration(Iteration),
EndIteration(Iteration),
UpdateLayersCompleted(LayerName, LayerIndex),
RequestFinished(String),
}
pub struct Scheduler {
slots: HashMap<String, SchedulerSlot>,
cancel_tokens: HashMap<String, CancellationToken>,
unprocessed_immediate_results: HashMap<String, HashSet<uuid::Uuid>>,
enqueued_requests: HashMap<String, HashMap<uuid::Uuid, TransferRequestSource>>,
worker_rx: mpsc::UnboundedReceiver<SchedulerMessage>,
transfer_rx: mpsc::Receiver<TransferToSchedulerMessage>,
iteration: u64,
layers_complete: u32,
iteration_complete: bool,
}
impl Scheduler {
pub fn new(
cancel_token: CancellationToken,
) -> (Self, WorkerSchedulerClient, TransferSchedulerClient) {
let (scheduler_tx, scheduler_rx) = mpsc::unbounded_channel();
let (transfer_tx, transfer_rx) = mpsc::channel(128);
let worker_client = WorkerSchedulerClient::new(scheduler_tx, cancel_token);
let transfer_client = TransferSchedulerClient::new(transfer_tx);
(
Scheduler {
slots: HashMap::new(),
cancel_tokens: HashMap::new(),
unprocessed_immediate_results: HashMap::new(),
enqueued_requests: HashMap::new(),
worker_rx: scheduler_rx,
transfer_rx,
iteration: 0,
layers_complete: 0,
iteration_complete: true,
},
worker_client,
transfer_client,
)
}
pub async fn run(&mut self) -> anyhow::Result<()> {
loop {
if !self.step().await {
break;
}
}
Ok(())
}
async fn step(&mut self) -> bool {
if self.worker_rx.is_closed() || self.transfer_rx.is_closed() {
return false;
}
tokio::select! {
maybe_worker_msg = self.worker_rx.recv(), if !self.worker_rx.is_closed() => {
match maybe_worker_msg {
Some(SchedulerMessage::StartIteration(new_iteration)) => {
self.start_iteration(new_iteration);
}
Some(SchedulerMessage::EndIteration(iteration)) => {
self.end_iteration(iteration);
}
Some(SchedulerMessage::UpdateLayersCompleted(last_layer_name, layers_completed)) => {
self.update_layers_completed(last_layer_name, layers_completed);
}
Some(SchedulerMessage::CreateSlot(request)) => {
self.add_slot(request);
}
Some(SchedulerMessage::RequestFinished(request_id)) => {
self.remove_slot(request_id);
}
Some(SchedulerMessage::EnqueueRequest(request)) => {
self.handle_worker_request(request);
}
None => {
return false;
}
}
}
maybe_transfer_msg = self.transfer_rx.recv(), if !self.transfer_rx.is_closed() => {
match maybe_transfer_msg {
Some(TransferToSchedulerMessage::ScheduleRequest(request)) => {
self.handle_scheduled_transfer_request(request);
}
Some(TransferToSchedulerMessage::ImmediateResult(result)) => {
self.handle_immediate_result(result);
}
None => {
return false;
}
}
}
}
true
}
#[tracing::instrument(level = "debug", skip_all, fields(request_id = %req.request_id))]
fn add_slot(&mut self, req: SchedulerCreateSlotDetails) {
let request_id = req.request_id.clone();
let slot = SchedulerSlot {
completed: req.completed,
};
if let Some(buffered_results) = self.unprocessed_immediate_results.get(&request_id) {
let num_buffered = buffered_results.len() as u64;
debug_assert!(
num_buffered <= req.expected_immediate_ops,
"buffered results ({}) exceed expected immediate ops ({})",
num_buffered,
req.expected_immediate_ops
);
slot.completed.fetch_add(num_buffered, Ordering::Relaxed);
}
self.slots.insert(request_id, slot);
}
fn remove_slot(&mut self, request_id: String) {
debug_assert!(self.slots.contains_key(&request_id), "slot not found");
self.cancel_tokens.remove(&request_id);
self.slots.remove(&request_id);
let maybe_controller = self.enqueued_requests.remove(&request_id);
debug_assert!(
maybe_controller.is_none() || maybe_controller.unwrap().is_empty(),
"any scheduled request should be removed and enqueued/scheduled before the slot is removed"
);
self.unprocessed_immediate_results.remove(&request_id);
tracing::debug!(
request_id,
iteration = self.iteration,
"engine state removing slot"
);
}
fn handle_worker_request(&mut self, request: WorkerTransferRequest) {
debug_assert!(
self.slots.contains_key(&request.request_id),
"slot does not exist"
);
let maybe_controller = self.try_prepare_controller(
request.request_id,
request.uuid,
TransferRequestSource::Worker,
);
if let Some(controller) = maybe_controller {
self.schedule_request(controller);
}
}
fn start_iteration(&mut self, iteration: u64) {
debug_assert_eq!(
self.iteration,
iteration - 1,
"iteration must be incremented by 1"
);
self.iteration = iteration;
self.layers_complete = 0;
self.iteration_complete = false;
}
fn end_iteration(&mut self, iteration: u64) {
tracing::debug!(iteration, "engine state updating iteration");
self.iteration_complete = true;
}
fn update_layers_completed(&mut self, last_layer_name: String, layers_completed: u32) {
self.layers_complete = layers_completed;
tracing::debug!(
iteration = self.iteration,
layers_completed,
"layer {last_layer_name} is complete"
);
}
#[tracing::instrument(level = "debug", skip_all, fields(request_id = %result.request_id, operation_id = %result.uuid))]
fn handle_immediate_result(&mut self, result: ImmediateTransferResult) {
match self.slots.get_mut(&result.request_id) {
Some(slot) => {
slot.completed.fetch_add(1, Ordering::Relaxed);
tracing::debug!(
"matched slot; incrementing completed counter to {}",
slot.completed.load(Ordering::Relaxed)
);
}
None => {
tracing::debug!("no slot found; adding to unprocessed immediate results");
self.unprocessed_immediate_results
.entry(result.request_id)
.or_default()
.insert(result.uuid);
}
}
}
fn try_prepare_controller(
&mut self,
request_id: String,
uuid: uuid::Uuid,
incoming: TransferRequestSource,
) -> Option<ScheduledTaskController> {
let entry = self.enqueued_requests.entry(request_id).or_default();
match (entry.remove(&uuid), incoming) {
(Some(TransferRequestSource::Worker), TransferRequestSource::Transfer(controller)) => {
tracing::debug!("worker arrived first, then transfer ==> scheduling transfer");
Some(controller)
}
(Some(TransferRequestSource::Transfer(controller)), TransferRequestSource::Worker) => {
tracing::debug!("transfer arrived first, then worker ==> scheduling transfer");
Some(controller)
}
(None, TransferRequestSource::Worker) => {
tracing::debug!("worker arrived first; must wait for transfer");
entry.insert(uuid, TransferRequestSource::Worker);
None
}
(None, TransferRequestSource::Transfer(controller)) => {
tracing::debug!("transfer arrived first; must wait for worker");
entry.insert(uuid, TransferRequestSource::Transfer(controller));
None
}
_ => {
panic!("invalid combination of request sources");
}
}
}
#[tracing::instrument(level = "debug", skip_all, fields(request_id = %request.leader_request.request_id))]
fn handle_scheduled_transfer_request(&mut self, request: TransferScheduleRequest) {
let controller = self.process_scheduled_transfer_request(request).unwrap();
let maybe_controller = self.try_prepare_controller(
controller.request.request_id.clone(),
controller.request.uuid,
TransferRequestSource::Transfer(controller),
);
if let Some(controller) = maybe_controller {
tracing::debug!("scheduling transfer");
self.schedule_request(controller);
}
}
fn schedule_request(&mut self, xfer_req: ScheduledTaskController) {
self.execute_scheduled_transfer(xfer_req);
}
fn execute_scheduled_transfer(&mut self, xfer_req: ScheduledTaskController) {
debug_assert!(
self.slots.contains_key(&xfer_req.request.request_id),
"slot not found"
);
let completed = self
.slots
.get(&xfer_req.request.request_id)
.unwrap()
.completed
.clone();
tokio::spawn(xfer_req.execute(SchedulingDecision::Execute, completed));
}
fn process_scheduled_transfer_request(
&mut self,
xfer_req: TransferScheduleRequest,
) -> anyhow::Result<ScheduledTaskController> {
let (decision_tx, decision_rx) = oneshot::channel();
let cancel_token = self
.cancel_tokens
.entry(xfer_req.leader_request.request_id.clone())
.or_default()
.child_token();
let task_handle = ScheduledTaskHandle {
decision_rx,
cancel_token,
};
xfer_req
.response_tx
.send(task_handle)
.map_err(|_| anyhow::anyhow!("Failed to send scheduled task handle to xfer client"))?;
let controller = ScheduledTaskController {
request: xfer_req.leader_request,
decision_tx,
};
Ok(controller)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ScheduledTaskError {}
pub struct ScheduledTaskController {
request: LeaderTransferRequest,
decision_tx: oneshot::Sender<(SchedulingDecision, oneshot::Sender<anyhow::Result<()>>)>,
}
impl ScheduledTaskController {
pub async fn execute(
self,
decision: SchedulingDecision,
completed: Arc<AtomicU64>,
) -> anyhow::Result<()> {
let (completion_tx, completion_rx) = oneshot::channel();
self.decision_tx
.send((decision, completion_tx))
.map_err(|_| anyhow::anyhow!(DISCONNECTED_WARNING))?;
let _ = completion_rx
.await
.map_err(|_| anyhow::anyhow!(DISCONNECTED_WARNING))?;
completed.fetch_add(1, Ordering::Relaxed);
Ok(())
}
}
enum TransferRequestSource {
Worker,
Transfer(ScheduledTaskController),
}
pub struct ScheduledTaskAsyncResult {
completion_rx: oneshot::Receiver<anyhow::Result<()>>,
}
impl ScheduledTaskAsyncResult {
pub async fn await_completion(self) -> anyhow::Result<()> {
self.completion_rx.await.unwrap()
}
}
pub struct SchedulerCreateSlotDetails {
pub request_id: String,
pub completed: Arc<AtomicU64>,
pub expected_immediate_ops: u64,
}
pub struct SchedulerSlot {
completed: Arc<AtomicU64>,
}
pub trait TaskScheduler {
fn start_iteration(&mut self, iteration: u64) -> Result<(), SchedulerError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_scheduler_lifecycle() {
let cancel_token = CancellationToken::new();
let (mut scheduler, mut worker_client, _transfer_client) = Scheduler::new(cancel_token);
worker_client.create_slot("test".to_string()).unwrap();
assert!(!scheduler.slots.contains_key("test"));
scheduler.step().await;
assert!(scheduler.slots.contains_key("test"));
worker_client.start_next_iteration().unwrap();
scheduler.step().await;
assert_eq!(scheduler.iteration, 1);
worker_client.mark_iteration_complete().unwrap();
scheduler.step().await;
assert_eq!(scheduler.iteration, 1);
assert!(scheduler.iteration_complete);
}
#[tokio::test]
async fn test_transfer_immediate_arrives_first() {
dynamo_runtime::logging::init();
let cancel_token = CancellationToken::new();
let (mut scheduler, mut worker_client, transfer_client) = Scheduler::new(cancel_token);
let operation_id = uuid::Uuid::new_v4();
let request = LeaderTransferRequest {
request_id: "test".to_string(),
uuid: operation_id,
requirement: None,
request_type: RequestType::Immediate,
};
let handle = transfer_client
.clone()
.schedule_transfer(request)
.await
.unwrap();
assert_eq!(handle.scheduler_decision(), SchedulingDecision::Execute);
handle.mark_complete(Ok(())).await;
assert_eq!(scheduler.unprocessed_immediate_results.len(), 0);
scheduler.step().await;
assert_eq!(scheduler.unprocessed_immediate_results.len(), 1);
worker_client
.create_slot_with_immediate_ops("test".to_string(), 1)
.unwrap();
assert!(!scheduler.slots.contains_key("test"));
scheduler.step().await;
assert!(scheduler.slots.contains_key("test"));
assert_eq!(scheduler.unprocessed_immediate_results.len(), 1);
assert_eq!(
scheduler
.slots
.get("test")
.unwrap()
.completed
.load(Ordering::Relaxed),
1
);
assert_eq!(
worker_client
.slots
.get("test")
.unwrap()
.completed
.load(Ordering::Relaxed),
1
);
assert_eq!(worker_client.slots.get("test").unwrap().operations.len(), 0);
let worker_request = WorkerTransferRequest {
request_id: "test".to_string(),
uuid: operation_id,
transfer_type: TransferType::Load,
request_type: RequestType::Immediate,
};
worker_client.enqueue_request(worker_request);
assert_eq!(worker_client.slots.get("test").unwrap().operations.len(), 1);
assert!(worker_client.is_complete("test"));
assert_eq!(scheduler.unprocessed_immediate_results.len(), 1);
worker_client.remove_slot(&"test".to_string());
scheduler.step().await;
assert_eq!(scheduler.unprocessed_immediate_results.len(), 0);
assert!(!scheduler.slots.contains_key("test"));
}
#[tokio::test]
async fn test_transfer_immediate_arrives_last() {
dynamo_runtime::logging::init();
let cancel_token = CancellationToken::new();
let (mut scheduler, mut worker_client, transfer_client) = Scheduler::new(cancel_token);
let operation_id = uuid::Uuid::new_v4();
let request = LeaderTransferRequest {
request_id: "test".to_string(),
uuid: operation_id,
requirement: None,
request_type: RequestType::Immediate,
};
let handle = transfer_client
.clone()
.schedule_transfer(request)
.await
.unwrap();
assert_eq!(handle.scheduler_decision(), SchedulingDecision::Execute);
worker_client.create_slot("test".to_string()).unwrap();
assert!(!scheduler.slots.contains_key("test"));
scheduler.step().await;
assert!(scheduler.slots.contains_key("test"));
assert_eq!(scheduler.unprocessed_immediate_results.len(), 0);
let request = WorkerTransferRequest {
request_id: "test".to_string(),
uuid: operation_id,
transfer_type: TransferType::Load,
request_type: RequestType::Immediate,
};
worker_client.enqueue_request(request);
let worker_slot = worker_client.slots.get("test").unwrap();
assert_eq!(worker_slot.operations.len(), 1);
assert_eq!(worker_slot.completed.load(Ordering::Relaxed), 0);
handle.mark_complete(Ok(())).await;
assert_eq!(scheduler.unprocessed_immediate_results.len(), 0);
scheduler.step().await;
assert_eq!(scheduler.unprocessed_immediate_results.len(), 0);
assert_eq!(
scheduler
.slots
.get("test")
.unwrap()
.completed
.load(Ordering::Relaxed),
1
);
assert_eq!(
worker_client
.slots
.get("test")
.unwrap()
.completed
.load(Ordering::Relaxed),
1
);
assert_eq!(worker_client.slots.get("test").unwrap().operations.len(), 1);
}
#[tokio::test]
async fn test_transfer_scheduled_arrives_first() {
dynamo_runtime::logging::init();
let cancel_token = CancellationToken::new();
let (mut scheduler, mut worker_client, transfer_client) = Scheduler::new(cancel_token);
let operation_id = uuid::Uuid::new_v4();
let request = LeaderTransferRequest {
request_id: "test".to_string(),
uuid: operation_id,
requirement: None,
request_type: RequestType::Scheduled,
};
let handle = tokio::spawn(transfer_client.schedule_transfer(request));
scheduler.step().await;
assert_eq!(scheduler.enqueued_requests.get("test").unwrap().len(), 1);
assert!(matches!(
scheduler
.enqueued_requests
.get("test")
.unwrap()
.get(&operation_id),
Some(TransferRequestSource::Transfer(_))
));
worker_client.create_slot("test".to_string()).unwrap();
assert!(!scheduler.slots.contains_key("test"));
scheduler.step().await;
assert!(scheduler.slots.contains_key("test"));
let request = WorkerTransferRequest {
request_id: "test".to_string(),
uuid: operation_id,
transfer_type: TransferType::Store,
request_type: RequestType::Scheduled,
};
worker_client.enqueue_request(request);
scheduler.step().await;
let handle = handle.await.unwrap().unwrap();
handle.mark_complete(Ok(())).await;
assert_eq!(scheduler.enqueued_requests.get("test").unwrap().len(), 0);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
assert_eq!(
worker_client
.slots
.get("test")
.unwrap()
.completed
.load(Ordering::Relaxed),
1
);
assert_eq!(
scheduler
.slots
.get("test")
.unwrap()
.completed
.load(Ordering::Relaxed),
1
);
assert!(worker_client.slots.get("test").unwrap().is_complete());
}
#[tokio::test]
async fn test_transfer_scheduled_arrives_last() {
dynamo_runtime::logging::init();
let cancel_token = CancellationToken::new();
let (mut scheduler, mut worker_client, transfer_client) = Scheduler::new(cancel_token);
let operation_id = uuid::Uuid::new_v4();
worker_client.create_slot("test".to_string()).unwrap();
assert!(!scheduler.slots.contains_key("test"));
scheduler.step().await;
assert!(scheduler.slots.contains_key("test"));
let request = WorkerTransferRequest {
request_id: "test".to_string(),
uuid: operation_id,
transfer_type: TransferType::Store,
request_type: RequestType::Scheduled,
};
worker_client.enqueue_request(request);
scheduler.step().await;
assert_eq!(scheduler.enqueued_requests.get("test").unwrap().len(), 1);
assert!(matches!(
scheduler
.enqueued_requests
.get("test")
.unwrap()
.get(&operation_id),
Some(TransferRequestSource::Worker)
));
let request = LeaderTransferRequest {
request_id: "test".to_string(),
uuid: operation_id,
requirement: None,
request_type: RequestType::Scheduled,
};
let handle = tokio::spawn(transfer_client.schedule_transfer(request));
scheduler.step().await;
let handle = handle.await.unwrap().unwrap();
assert_eq!(handle.scheduler_decision(), SchedulingDecision::Execute);
handle.mark_complete(Ok(())).await;
assert_eq!(scheduler.enqueued_requests.get("test").unwrap().len(), 0);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
assert_eq!(
worker_client
.slots
.get("test")
.unwrap()
.completed
.load(Ordering::Relaxed),
1
);
assert_eq!(
scheduler
.slots
.get("test")
.unwrap()
.completed
.load(Ordering::Relaxed),
1
);
assert!(worker_client.slots.get("test").unwrap().is_complete());
}
#[tokio::test]
async fn test_coordinate_scheduled_transfer_execution() {
dynamo_runtime::logging::init();
let cancel_token = CancellationToken::new();
let (mut scheduler, _worker_client, transfer_client) = Scheduler::new(cancel_token);
let operation_id = uuid::Uuid::new_v4();
let request = LeaderTransferRequest {
request_id: "test".to_string(),
uuid: operation_id,
requirement: None,
request_type: RequestType::Scheduled,
};
let (got_handle_tx, got_handle_rx) = oneshot::channel();
let _transfer_task = tokio::spawn(async move {
let handle = transfer_client
.clone()
.schedule_transfer(request)
.await
.unwrap();
got_handle_tx
.send(handle)
.map_err(|_| {
anyhow::anyhow!("failed to send handle back on testing oneshot channel")
})
.unwrap();
});
assert!(got_handle_rx.is_empty());
let controller = match scheduler.transfer_rx.recv().await {
Some(msg) => match msg {
TransferToSchedulerMessage::ScheduleRequest(schedule_req) => scheduler
.process_scheduled_transfer_request(schedule_req)
.ok(),
_ => {
unreachable!("unexpected message type");
}
},
None => {
unreachable!("channel closed");
}
};
let scheduler_controller = controller.expect("Expected a controller from the scheduler");
assert!(got_handle_rx.is_empty());
let completed = Arc::new(AtomicU64::new(0));
let scheduler_result = tokio::spawn(
scheduler_controller.execute(SchedulingDecision::Execute, completed.clone()),
);
let transfer_handle = got_handle_rx.await.unwrap();
assert_eq!(
transfer_handle.scheduler_decision(),
SchedulingDecision::Execute
);
transfer_handle.mark_complete(Ok(())).await;
scheduler_result.await.unwrap().unwrap();
assert_eq!(completed.load(Ordering::Relaxed), 1);
}
}