use crate::scheduler::ExecutionTask;
use crate::scheduler::ExecutionTaskCompletionNotification;
use crate::scheduler::ExecutionTaskCompletionNotifier;
use std::sync::mpsc::{Receiver, Sender};
use super::core::CoreMessage;
pub struct SerialExecutionTaskIterator {
tx: Sender<CoreMessage>,
rx: Receiver<Option<ExecutionTask>>,
is_complete: bool,
}
impl SerialExecutionTaskIterator {
pub fn new(tx: Sender<CoreMessage>, rx: Receiver<Option<ExecutionTask>>) -> Self {
SerialExecutionTaskIterator {
tx,
rx,
is_complete: false,
}
}
}
impl Iterator for SerialExecutionTaskIterator {
type Item = ExecutionTask;
fn next(&mut self) -> Option<ExecutionTask> {
if self.is_complete {
debug!(
"Execution task iterator already returned `None`; `next` should not be called again"
);
return None;
}
match self.tx.send(CoreMessage::Next) {
Ok(_) => match self.rx.recv() {
Ok(task) => {
self.is_complete = task.is_none();
task
}
Err(_) => {
error!(
"Failed to receive next execution task; scheduler shutdown unexpectedly"
);
self.is_complete = true;
None
}
},
Err(_) => {
trace!("Scheduler core message receiver dropped; checking if it shutdown properly");
match self.rx.recv() {
Ok(Some(_)) => error!(
"Scheduler sent unexpected execution task before shutting down unexpectedly"
),
Ok(None) => {}
_ => error!(
"Failed to request next execution task; scheduler shutdown unexpectedly"
),
}
self.is_complete = true;
None
}
}
}
}
#[derive(Clone)]
pub struct SerialExecutionTaskCompletionNotifier {
tx: Sender<CoreMessage>,
}
impl SerialExecutionTaskCompletionNotifier {
pub fn new(tx: Sender<CoreMessage>) -> Self {
SerialExecutionTaskCompletionNotifier { tx }
}
}
impl ExecutionTaskCompletionNotifier for SerialExecutionTaskCompletionNotifier {
fn notify(&self, notification: ExecutionTaskCompletionNotification) {
self.tx
.send(CoreMessage::ExecutionResult(notification))
.unwrap_or_else(|err| error!("failed to send notification to core: {}", err));
}
fn clone_box(&self) -> Box<dyn ExecutionTaskCompletionNotifier> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{mpsc::channel, Arc, Mutex};
use cylinder::{secp256k1::Secp256k1Context, Context, Signer};
use log::{set_boxed_logger, set_max_level, Level, LevelFilter, Log, Metadata, Record};
use rusty_fork::rusty_fork_test;
use crate::context::ContextId;
use crate::protocol::transaction::{HashMethod, TransactionBuilder};
rusty_fork_test! {
#[test]
fn task_iterator_successful() {
let logger = init_logger();
let (core_tx, core_rx) = channel();
let (task_tx, task_rx) = channel();
let join_handle = std::thread::spawn(move || {
let mut iter = SerialExecutionTaskIterator::new(core_tx, task_rx);
(iter.next(), iter.next())
});
recv_next(&core_rx);
task_tx
.send(Some(mock_execution_task()))
.expect("Failed to send execution task");
recv_next(&core_rx);
task_tx.send(None).expect("Failed to send `None`");
let (task1, task2) = join_handle.join().expect("Iterator thread panicked");
assert!(task1.is_some());
assert!(task2.is_none());
assert!(!logger.has_err());
}
#[test]
fn task_iterator_multiple_nones() {
let logger = init_logger();
let (core_tx, core_rx) = channel();
let (task_tx, task_rx) = channel();
let join_handle = std::thread::spawn(move || {
let mut iter = SerialExecutionTaskIterator::new(core_tx, task_rx);
(iter.next(), iter.next(), iter.next())
});
recv_next(&core_rx);
task_tx
.send(Some(mock_execution_task()))
.expect("Failed to send execution task");
recv_next(&core_rx);
task_tx.send(None).expect("Failed to send `None`");
core_rx.try_recv().expect_err("Got an unexpected task request");
let (task1, task2, task3) = join_handle.join().expect("Iterator thread panicked");
assert!(task1.is_some());
assert!(task2.is_none());
assert!(task3.is_none());
assert!(logger.has_debug());
}
#[test]
fn task_iterator_send_failed_but_shutdown_properly() {
let logger = init_logger();
let (core_tx, _) = channel();
let (task_tx, task_rx) = channel();
let join_handle = std::thread::spawn(move || {
SerialExecutionTaskIterator::new(core_tx, task_rx).next()
});
task_tx.send(None).expect("Failed to send `None`");
let task = join_handle.join().expect("Iterator thread panicked");
assert!(task.is_none());
assert!(!logger.has_err());
}
#[test]
fn task_iterator_send_failed_with_unexpected_task() {
let logger = init_logger();
let (core_tx, _) = channel();
let (task_tx, task_rx) = channel();
let join_handle = std::thread::spawn(move || {
SerialExecutionTaskIterator::new(core_tx, task_rx).next()
});
task_tx.send(Some(mock_execution_task())).expect("Failed to send task");
let task = join_handle.join().expect("Iterator thread panicked");
assert!(task.is_none());
assert!(logger.has_err());
}
#[test]
fn task_iterator_send_failed_no_notification() {
let logger = init_logger();
let (core_tx, _) = channel();
let (_, task_rx) = channel();
let join_handle = std::thread::spawn(move || {
SerialExecutionTaskIterator::new(core_tx, task_rx).next()
});
let task = join_handle.join().expect("Iterator thread panicked");
assert!(task.is_none());
assert!(logger.has_err());
}
#[test]
fn task_iterator_send_successful_but_receive_failed() {
let logger = init_logger();
let (core_tx, _core_rx) = channel();
let (_, task_rx) = channel();
let join_handle = std::thread::spawn(move || {
SerialExecutionTaskIterator::new(core_tx, task_rx).next()
});
let task = join_handle.join().expect("Iterator thread panicked");
assert!(task.is_none());
assert!(logger.has_err());
}
}
fn recv_next(core_rx: &Receiver<CoreMessage>) {
match core_rx.recv() {
Ok(CoreMessage::Next) => {}
res => panic!("Expected `Ok(CoreMessage::Next)`, got {:?} instead", res),
}
}
fn mock_execution_task() -> ExecutionTask {
ExecutionTask {
pair: TransactionBuilder::new()
.with_family_name("test".into())
.with_family_version("0.1".into())
.with_inputs(vec![])
.with_outputs(vec![])
.with_payload_hash_method(HashMethod::Sha512)
.with_payload(vec![])
.build_pair(&*new_signer())
.expect("Failed to build txn pair"),
context_id: ContextId::default(),
}
}
fn new_signer() -> Box<dyn Signer> {
let context = Secp256k1Context::new();
let key = context.new_random_private_key();
context.new_signer(key)
}
fn init_logger() -> MockLogger {
let logger = MockLogger::default();
set_boxed_logger(Box::new(logger.clone())).expect("Failed to set logger");
set_max_level(LevelFilter::Debug);
logger
}
#[derive(Clone, Default)]
struct MockLogger {
log_levels: Arc<Mutex<Vec<Level>>>,
}
impl MockLogger {
pub fn has_err(&self) -> bool {
self.log_levels
.lock()
.expect("Failed to get log_levels lock")
.iter()
.any(|level| level == &Level::Error)
}
pub fn has_debug(&self) -> bool {
self.log_levels
.lock()
.expect("Failed to get log_levels lock")
.iter()
.any(|level| level == &Level::Debug)
}
}
impl Log for MockLogger {
fn enabled(&self, _metadata: &Metadata) -> bool {
true
}
fn log(&self, record: &Record) {
self.log_levels
.lock()
.expect("Failed to get log_levels lock")
.push(record.level());
}
fn flush(&self) {}
}
}