use std::collections::HashMap;
use std::fmt;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::{channel, Receiver, Sender, TryRecvError};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use std::{thread, time};
use reqwest::{blocking::Client, header, StatusCode};
use crate::protos::IntoBytes;
use super::error::WorkloadRunnerError;
use super::BatchWorkload;
use super::ExpectedBatchResult;
type ExpectedBatchResults = Arc<Mutex<HashMap<String, (String, Option<ExpectedBatchResult>)>>>;
#[derive(Default)]
pub struct WorkloadRunner {
workloads: HashMap<String, Worker>,
}
impl WorkloadRunner {
#[allow(clippy::too_many_arguments)]
pub fn add_workload(
&mut self,
id: String,
workload: Box<dyn BatchWorkload>,
targets: Vec<String>,
time_to_wait: Duration,
auth: String,
get_batch_status: bool,
duration: Option<Duration>,
request_counter: Arc<HttpRequestCounter>,
) -> Result<(), WorkloadRunnerError> {
if self.workloads.contains_key(&id) {
return Err(WorkloadRunnerError::WorkloadAddError(format!(
"Workload already running with ID: {}",
id,
)));
}
let worker = WorkerBuilder::default()
.with_id(id.to_string())
.with_workload(workload)
.with_targets(targets)
.with_time_to_wait(time_to_wait)
.with_auth(auth)
.get_batch_status(get_batch_status)
.with_duration(duration)
.with_request_counter(request_counter)
.build()?;
self.workloads.insert(id, worker);
Ok(())
}
pub fn remove_workload(&mut self, id: &str) -> Result<(), WorkloadRunnerError> {
if let Some(mut worker) = self.workloads.remove(id) {
debug!("Shutting down worker {}", worker.id);
if worker.sender.send(ShutdownMessage).is_err() {
return Err(WorkloadRunnerError::WorkloadRemoveError(format!(
"Failed to send shutdown messages to {}",
id,
)));
}
if let Some(thread) = worker.thread.take() {
if let Err(err) = thread.join() {
return Err(WorkloadRunnerError::WorkloadRemoveError(format!(
"Failed to cleanly join worker thread {}: {:?}",
id, err,
)));
}
}
if let Some(mut batch_status_checker) = worker.batch_status_checker {
batch_status_checker.remove_batch_status_checker()?;
}
} else {
return Err(WorkloadRunnerError::WorkloadRemoveError(format!(
"Workload with ID {} does not exist",
id,
)));
}
Ok(())
}
pub fn shutdown_signaler(&self) -> WorkerShutdownSignaler {
WorkerShutdownSignaler {
senders: self
.workloads
.values()
.map(|worker| (worker.id.clone(), worker.sender.clone()))
.collect(),
}
}
pub fn wait_for_shutdown(self) -> Result<(), WorkloadRunnerError> {
for (_, mut worker) in &mut self.workloads.into_iter() {
if let Some(mut batch_status_checker) = worker.batch_status_checker {
if let Some(thread) = batch_status_checker.thread.take() {
thread.join().map_err(|_| {
WorkloadRunnerError::WorkloadAddError(
"Failed to join batch status checker thread".to_string(),
)
})?
}
}
if let Some(thread) = worker.thread.take() {
thread.join().map_err(|_| {
WorkloadRunnerError::WorkloadAddError(
"Failed to join worker thread".to_string(),
)
})?
}
}
Ok(())
}
}
pub struct WorkerShutdownSignaler {
senders: Vec<(String, Sender<ShutdownMessage>)>,
}
impl WorkerShutdownSignaler {
pub fn signal_shutdown(&self) -> Result<(), WorkloadRunnerError> {
for (id, sender) in &self.senders {
debug!("Shutting down worker {}", id);
sender.send(ShutdownMessage).map_err(|_| {
WorkloadRunnerError::WorkloadRemoveError(
"Failed to send shutdown message".to_string(),
)
})?
}
Ok(())
}
}
struct ShutdownMessage;
struct Worker {
id: String,
thread: Option<thread::JoinHandle<()>>,
sender: Sender<ShutdownMessage>,
batch_status_checker: Option<BatchStatusChecker>,
}
#[derive(Default)]
struct WorkerBuilder {
id: Option<String>,
workload: Option<Box<dyn BatchWorkload>>,
targets: Option<Vec<String>>,
time_to_wait: Option<Duration>,
auth: Option<String>,
get_batch_status: Option<bool>,
duration: Option<Duration>,
request_counter: Option<Arc<HttpRequestCounter>>,
}
impl WorkerBuilder {
pub fn with_id(mut self, id: String) -> WorkerBuilder {
self.id = Some(id);
self
}
pub fn with_workload(mut self, workload: Box<dyn BatchWorkload>) -> WorkerBuilder {
self.workload = Some(workload);
self
}
pub fn with_targets(mut self, targets: Vec<String>) -> WorkerBuilder {
self.targets = Some(targets);
self
}
pub fn with_time_to_wait(mut self, time_to_wait: Duration) -> WorkerBuilder {
self.time_to_wait = Some(time_to_wait);
self
}
pub fn with_auth(mut self, auth: String) -> WorkerBuilder {
self.auth = Some(auth);
self
}
pub fn with_duration(mut self, duration: Option<Duration>) -> WorkerBuilder {
self.duration = duration;
self
}
pub fn get_batch_status(mut self, get_batch_status: bool) -> WorkerBuilder {
self.get_batch_status = Some(get_batch_status);
self
}
pub fn with_request_counter(
mut self,
request_counter: Arc<HttpRequestCounter>,
) -> WorkerBuilder {
self.request_counter = Some(request_counter);
self
}
pub fn build(self) -> Result<Worker, WorkloadRunnerError> {
let id = self.id.ok_or_else(|| {
WorkloadRunnerError::WorkloadAddError(
"unable to build, missing field: `id`".to_string(),
)
})?;
let time_to_wait = self.time_to_wait.ok_or_else(|| {
WorkloadRunnerError::WorkloadAddError(
"unable to build, missing field: `time_to_wait`".to_string(),
)
})?;
let workload = self.workload.ok_or_else(|| {
WorkloadRunnerError::WorkloadAddError(
"unable to build, missing field: `workload`".to_string(),
)
})?;
let targets = self.targets.ok_or_else(|| {
WorkloadRunnerError::WorkloadAddError(
"unable to build, missing field: `target`".to_string(),
)
})?;
let auth = self.auth.ok_or_else(|| {
WorkloadRunnerError::WorkloadAddError(
"unable to build, missing field: `auth`".to_string(),
)
})?;
let http_counter = self.request_counter.ok_or_else(|| {
WorkloadRunnerError::WorkloadAddError(
"unable to build, missing field: `request_counter`".to_string(),
)
})?;
let get_batch_status = self.get_batch_status.unwrap_or(false);
let end_time = self.duration.map(|d| time::Instant::now() + d);
let (sender, receiver) = channel();
let (batch_status_checker_sender, batch_status_checker_receiver) = channel();
let batch_status_checker_shutdown = batch_status_checker_sender.clone();
let batch_status_links: ExpectedBatchResults = Arc::new(Mutex::new(HashMap::new()));
let batch_status_checker = if get_batch_status {
Some(BatchStatusChecker::new(
batch_status_links.clone(),
auth.clone(),
id.clone(),
sender.clone(),
batch_status_checker_sender,
batch_status_checker_receiver,
)?)
} else {
None
};
let thread_id = id.to_string();
let thread = Some(
thread::Builder::new()
.name(id.to_string())
.spawn(move || {
let mut next_target = 0;
let mut workload = workload;
let mut start_time = time::Instant::now();
let mut submitted_batches = 0;
let mut submission_start = time::Instant::now();
let mut submission_avg: Option<time::Duration> = None;
loop {
if let Some(end_time) = end_time {
if time::Instant::now() > end_time {
signal_batch_status_checker_shutdown(
batch_status_checker_shutdown,
thread_id,
get_batch_status,
);
break;
}
}
match receiver.try_recv() {
Ok(_) => {
info!("Worker received shutdown");
signal_batch_status_checker_shutdown(
batch_status_checker_shutdown,
thread_id,
get_batch_status,
);
break;
}
Err(TryRecvError::Empty) => {
let target = match targets.get(next_target) {
Some(target) => target,
None => {
error!("No targets provided");
signal_batch_status_checker_shutdown(
batch_status_checker_shutdown,
thread_id,
get_batch_status,
);
break;
}
};
let (batch, expected_result) = match workload.next_batch() {
Ok((batch, expected_result)) => (batch, expected_result),
Err(_) => {
error!("Failed to get next batch");
signal_batch_status_checker_shutdown(
batch_status_checker_shutdown,
thread_id,
get_batch_status,
);
break;
}
};
let batch_bytes = match vec![batch.batch().clone()].into_bytes() {
Ok(bytes) => bytes,
Err(err) => {
error!("Unable to get batch bytes {}", err);
signal_batch_status_checker_shutdown(
batch_status_checker_shutdown,
thread_id,
get_batch_status,
);
break;
}
};
match submit_batch(target, &auth, batch_bytes.clone()) {
Ok(link) => {
if get_batch_status {
match batch_status_links.lock() {
Ok(mut l) => l.insert(
link,
(target.to_string(), expected_result),
),
Err(_) => {
error!("ExpectedBatchResults lock poisoned");
signal_batch_status_checker_shutdown(
batch_status_checker_shutdown,
thread_id,
get_batch_status,
);
break;
}
};
}
submitted_batches += 1;
http_counter.increment_sent()
}
Err(err) => {
if err == WorkloadRunnerError::TooManyRequests {
http_counter.increment_queue_full();
match slow_rate(
target,
&auth,
batch_bytes.clone(),
start_time,
submitted_batches,
&receiver,
end_time,
) {
Ok((true, _)) => {
signal_batch_status_checker_shutdown(
batch_status_checker_shutdown,
thread_id,
get_batch_status,
);
break;
}
Ok((false, Some(link))) => {
if get_batch_status {
match batch_status_links.lock() {
Ok(mut l) => l.insert(
link,
(target.to_string(), expected_result),
),
Err(_) => {
error!("ExpectedBatchResults lock poisoned");
signal_batch_status_checker_shutdown(
batch_status_checker_shutdown,
thread_id,
get_batch_status,
);
break;
}
};
}
submitted_batches = 1;
start_time = time::Instant::now();
http_counter.increment_sent()
}
Ok((false, None)) => {
if get_batch_status {
error!("Failed to get batch status link");
}
submitted_batches = 1;
start_time = time::Instant::now();
http_counter.increment_sent()
}
Err(err) => error!("{}:{}", thread_id, err),
}
} else {
error!("{}:{}", thread_id, err);
}
}
}
next_target = (next_target + 1) % targets.len();
let diff = time::Instant::now() - submission_start;
let submission_time = match submission_avg {
Some(val) => (diff + val) / 2,
None => diff,
};
submission_avg = Some(submission_time);
let wait_time = time_to_wait.saturating_sub(submission_time);
thread::sleep(wait_time);
submission_start = time::Instant::now();
}
Err(TryRecvError::Disconnected) => {
error!("Channel has disconnected");
signal_batch_status_checker_shutdown(
batch_status_checker_shutdown,
thread_id,
get_batch_status,
);
break;
}
}
}
})
.map_err(|err| {
WorkloadRunnerError::WorkloadAddError(format!(
"Unable to spawn worker thread: {}",
err
))
})?,
);
Ok(Worker {
id,
thread,
sender,
batch_status_checker,
})
}
}
#[derive(Deserialize)]
pub struct ServerError {
pub message: String,
}
fn slow_rate(
target: &str,
auth: &str,
batch_bytes: Vec<u8>,
start_time: time::Instant,
submitted_batches: u32,
receiver: &Receiver<ShutdownMessage>,
end_time: Option<time::Instant>,
) -> Result<(bool, Option<String>), WorkloadRunnerError> {
debug!("Received TooManyRequests message from target, attempting to resubmit batch");
let mut shutdown = false;
let mut link = None;
let time = (time::Instant::now() - start_time).as_secs() as u32;
let wait = match time {
0 => time::Duration::from_secs(1),
sec => match submitted_batches / sec {
0 => time::Duration::from_secs(1),
rate => time::Duration::from_secs(1) / rate,
},
};
thread::sleep(wait);
loop {
if let Some(end_time) = end_time {
if time::Instant::now() > end_time {
shutdown = true;
break;
}
}
match receiver.try_recv() {
Ok(_) => {
info!("Worker received shutdown");
shutdown = true;
break;
}
Err(TryRecvError::Empty) => {
match submit_batch(target, auth, batch_bytes.clone()) {
Ok(l) => {
link = Some(l);
break;
}
Err(WorkloadRunnerError::TooManyRequests) => thread::sleep(wait),
Err(err) => {
return Err(WorkloadRunnerError::SubmitError(format!(
"Failed to submit batch: {}",
err
)))
}
}
}
Err(TryRecvError::Disconnected) => {
error!("Channel has disconnected");
break;
}
}
}
Ok((shutdown, link))
}
fn submit_batch(
target: &str,
auth: &str,
batch_bytes: Vec<u8>,
) -> Result<String, WorkloadRunnerError> {
Client::new()
.post(format!("{}/batches", target))
.header(header::CONTENT_TYPE, "octet-stream")
.header("Authorization", auth)
.body(batch_bytes)
.send()
.map_err(|err| WorkloadRunnerError::SubmitError(format!("Failed to submit batch: {}", err)))
.and_then(|res| {
let status = res.status();
if status.is_success() {
let status_link: Link = res.json().map_err(|_| {
WorkloadRunnerError::SubmitError(
"Failed to deserialize response body".to_string(),
)
})?;
Ok(status_link.link)
} else {
if status == StatusCode::TOO_MANY_REQUESTS {
return Err(WorkloadRunnerError::TooManyRequests);
};
let message = res
.json::<ServerError>()
.map_err(|_| {
WorkloadRunnerError::SubmitError(format!(
"Batch submit request failed with status code '{}', but \
error response was not valid",
status
))
})?
.message;
Err(WorkloadRunnerError::SubmitError(format!(
"Failed to submit batch: {}",
message
)))
}
})
}
fn signal_batch_status_checker_shutdown(
batch_status_checker_sender: Sender<ShutdownMessage>,
id: String,
check_batch_status: bool,
) {
if check_batch_status {
debug!(
"Shutting down batch status checker BatchStatusChecker-{}",
id
);
if batch_status_checker_sender.send(ShutdownMessage).is_err() {
error!(
"Failed to send shutdown message to BatchStatusChecker-{}",
id,
);
}
}
}
struct BatchStatusChecker {
id: String,
sender: Sender<ShutdownMessage>,
thread: Option<thread::JoinHandle<()>>,
}
impl BatchStatusChecker {
fn new(
status_links: ExpectedBatchResults,
auth: String,
id: String,
worker_sender: Sender<ShutdownMessage>,
sender: Sender<ShutdownMessage>,
reciever: Receiver<ShutdownMessage>,
) -> Result<Self, WorkloadRunnerError> {
let id = format!("BatchStatusChecker-{}", id);
let thread = Some(
thread::Builder::new()
.name(id.clone())
.spawn(move || {
'outer: loop {
match reciever.try_recv() {
Ok(_) => {
info!("Batch status checker received shutdown");
break;
}
Err(TryRecvError::Empty) => {
let is_empty = match status_links.lock() {
Ok(l) => l.is_empty(),
Err(_) => {
error!("ExpectedBatchResults lock poisoned");
break;
}
};
if is_empty {
thread::sleep(Duration::new(2, 0))
} else {
let (status_link, (target, expected_result)) =
match status_links.lock() {
Ok(l) => match l.iter().next() {
Some((s, (t, e))) => {
(s.clone(), (t.clone(), e.clone()))
}
None => {
error!("Status links empty");
break;
}
},
Err(_) => {
error!("ExpectedBatchResults lock poisoned");
break;
}
};
let url =
get_batch_status_url(target.to_string(), &status_link);
match Client::new()
.get(url)
.header("Authorization", &auth)
.send()
{
Ok(res) => {
let status = res.status();
if status.is_success() {
let batch_info: Vec<BatchInfo> = match res.json() {
Ok(b) => b,
Err(_) => {
error!(
"Failed to deserialize response body"
);
break;
}
};
for info in batch_info {
match compare_batch_results(
status_links.clone(),
status_link.clone(),
expected_result.clone(),
info,
) {
Ok(()) => (),
Err(e) => {
error!("{}", e);
match worker_sender
.send(ShutdownMessage)
{
Ok(_) => break 'outer,
Err(_) => {
error!(
"Failed to send shutdown \
message to worker thread"
);
break 'outer;
}
}
}
}
}
} else {
let message = match res.json::<ServerError>() {
Ok(r) => r.message,
Err(_) => {
error!(
"Batch status request failed with \
status code '{}', but error response \
was not valid",
status
);
break;
}
};
error!(
"Failed to get submitted batch status: {}",
message
);
break;
}
}
Err(e) => {
error!("Failed send batch status request: {}", e);
break;
}
}
}
thread::sleep(Duration::from_millis(250));
}
Err(TryRecvError::Disconnected) => {
error!("Channel has disconnected");
break;
}
}
}
})
.map_err(|err| {
WorkloadRunnerError::WorkloadAddError(format!(
"Unable to spawn batch status checker thread: {}",
err
))
})?,
);
Ok(BatchStatusChecker { id, sender, thread })
}
fn remove_batch_status_checker(&mut self) -> Result<(), WorkloadRunnerError> {
debug!("Shutting down batch status checker {}", self.id);
if self.sender.send(ShutdownMessage).is_err() {
return Err(WorkloadRunnerError::WorkloadRemoveError(format!(
"Failed to send shutdown messages to {}",
&self.id,
)));
}
if let Some(thread) = self.thread.take() {
if let Err(err) = thread.join() {
return Err(WorkloadRunnerError::WorkloadRemoveError(format!(
"Failed to cleanly join batch status checker thread {}: {:?}",
&self.id, err,
)));
}
}
Ok(())
}
}
fn get_batch_status_url(mut target: String, status_link: &str) -> String {
let status_link_parts = status_link.splitn(5, '/');
for p in status_link_parts {
if !p.is_empty() && target.contains(format!("/{}", p).as_str()) {
target = target.replacen(format!("/{}", p).as_str(), "", 1);
}
}
format!("{}{}", target, status_link)
}
fn compare_batch_results(
status_links: ExpectedBatchResults,
current_status_link: String,
expected_result: Option<ExpectedBatchResult>,
returned_batch_info: BatchInfo,
) -> Result<(), WorkloadRunnerError> {
match returned_batch_info.status {
BatchStatus::Invalid(invalid_txns) => match expected_result {
Some(ExpectedBatchResult::Valid) => {
return Err(WorkloadRunnerError::BatchStatusError(format!(
"Expected valid result, received invalid {:?}",
invalid_txns
)))
}
Some(ExpectedBatchResult::Invalid) => {
status_links
.lock()
.map_err(|_| {
WorkloadRunnerError::BatchStatusError(
"ExpectedBatchResults lock poisoned".into(),
)
})?
.remove(¤t_status_link);
}
None => (),
},
BatchStatus::Valid(valid_txns) | BatchStatus::Committed(valid_txns) => {
match expected_result {
Some(ExpectedBatchResult::Valid) => {
status_links
.lock()
.map_err(|_| {
WorkloadRunnerError::BatchStatusError(
"ExpectedBatchResults lock poisoned".into(),
)
})?
.remove(¤t_status_link);
}
Some(ExpectedBatchResult::Invalid) => {
return Err(WorkloadRunnerError::BatchStatusError(format!(
"Expected valid result, received valid {:?}",
valid_txns
)))
}
None => (),
}
}
BatchStatus::Pending | BatchStatus::Unknown => (),
}
Ok(())
}
#[derive(Debug, Deserialize)]
pub struct Link {
link: String,
}
#[derive(Debug, Deserialize)]
struct BatchInfo {
pub status: BatchStatus,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "statusType", content = "message")]
enum BatchStatus {
Unknown,
Pending,
Invalid(Vec<InvalidTransaction>),
Valid(Vec<ValidTransaction>),
Committed(Vec<ValidTransaction>),
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ValidTransaction {
transaction_id: String,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct InvalidTransaction {
transaction_id: String,
error_message: String,
error_data: Vec<u8>,
}
pub struct HttpRequestCounter {
id: String,
sent_count: AtomicUsize,
queue_full_count: AtomicUsize,
}
impl HttpRequestCounter {
pub fn new(id: String) -> Self {
HttpRequestCounter {
id,
sent_count: AtomicUsize::new(0),
queue_full_count: AtomicUsize::new(0),
}
}
pub fn increment_sent(&self) {
self.sent_count.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_queue_full(&self) {
self.queue_full_count.fetch_add(1, Ordering::Relaxed);
}
pub fn reset_sent_count(&self) {
self.sent_count.store(0, Ordering::Relaxed);
}
pub fn reset_queue_full_count(&self) {
self.queue_full_count.store(0, Ordering::Relaxed);
}
pub fn get_batches_per_second(&self, update: f64) -> f64 {
self.sent_count.load(Ordering::Relaxed) as f64 / update
}
}
impl fmt::Display for HttpRequestCounter {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let time = chrono::Utc::now();
write!(
f,
"{0}: {1}, Sent: {2}, Queue Full {3}",
self.id,
time.format("%h-%d-%Y %H:%M:%S%.3f"),
self.sent_count.load(Ordering::Relaxed),
self.queue_full_count.load(Ordering::Relaxed)
)
}
}