1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
use crate::message_queue::MessageQueue;
use crate::task::TaskExecutionData;
use log::{debug, error, info};
use std::{borrow::BorrowMut, collections::HashMap, thread, time::Duration};
use tokio::sync::mpsc::{Receiver, Sender};
use ulid::Ulid;
use super::{
errors::DispatchError,
worker::{Worker, WorkerState, WorkerStateMachine, WorkerStatusReport},
};
pub struct TaskDispatcher {
worker_status_rx: Receiver<WorkerStatusReport>,
worker_channels: HashMap<Ulid, Sender<TaskExecutionData>>,
worker_states: HashMap<Ulid, WorkerStateMachine>,
task_request_rx: Receiver<TaskExecutionData>,
}
// Responsible for dispatching tasks to workers
impl TaskDispatcher {
/// Starts the dispatcher with the specified number of workers.
/// As the consumers pop tasks from the message broker, they will place them into the task request channel.
/// The dispatcher will then pick up the tasks from the task request channel and dispatch them to the workers.
pub async fn start(
n_workers: u32,
task_request_rx: Receiver<TaskExecutionData>,
message_queue: MessageQueue,
) {
let mut worker_pool = Vec::new();
let mut worker_channels = HashMap::new();
// These channels are used by the workers to report their status back to the dispatcher and deliver updates to the message queue
let (worker_status_tx, worker_status_rx) = tokio::sync::mpsc::channel(1);
let (update_tx, mut update_rx) = tokio::sync::mpsc::channel(1);
// Create the workers and assign them a task input channel and a task output channel
for i in 0..n_workers {
let (task_tx, task_rx) = tokio::sync::mpsc::channel(1);
let worker = Worker::new(
format!("worker-{}", i),
task_rx,
worker_status_tx.clone(),
update_tx.clone(),
);
info!("Spawned worker {}", worker.name);
let worker_id = worker.id();
// Store the channel used to send tasks to the worker and add the worker to the worker pool
worker_channels.insert(worker_id, task_tx);
worker_pool.push(worker);
}
// Worker state machine contains the worker state and the channel used to send tasks to it
let mut worker_states = HashMap::new();
for (id, tx) in worker_channels.borrow_mut() {
worker_states.insert(
id.to_owned(),
WorkerStateMachine {
task_tx: tx.to_owned(),
worker_state: WorkerState::Idle,
},
);
}
// Start the workers in separate threads
for mut worker in worker_pool {
tokio::spawn(async move {
worker.work().await;
});
}
// Start a thread that monitors the task_update_rx channel and updates the message queue
// TODO: Make sure this drains even if the dispatcher thread dies
// TODO: Batching?
tokio::spawn(async move {
loop {
match update_rx.try_recv() {
Ok(update) => {
debug!("Got update: {:?}", update);
match message_queue.push_task_execution_update(update).await {
Ok(_) => {}
Err(e) => {
error!("Error updating execution record: {:?}", e);
}
}
}
Err(e) => match e {
tokio::sync::mpsc::error::TryRecvError::Empty => {
continue;
}
tokio::sync::mpsc::error::TryRecvError::Disconnected => {
error!("Task update channel closed!");
break;
}
},
}
}
});
// Start the dispatcher
TaskDispatcher {
worker_status_rx,
worker_channels,
worker_states,
task_request_rx,
}
.dispatch()
.await;
}
/// Poll the worker status channel for updates and update the worker states accordingly
async fn update_worker_states(&mut self) -> Result<(), DispatchError> {
// Each worker can only be IDLE or RUNNING, so we allow for 2 updates per worker
for _ in 0..self.worker_channels.len() * 2 {
match self.worker_status_rx.try_recv() {
// If there was an update in the channel, update the worker state
Ok(worker_status_update) => {
if let Some(worker_state) = self.worker_states.get_mut(&worker_status_update.id)
{
worker_state.worker_state = worker_status_update.worker_state;
// This should never happen, but we check just in case
} else {
return Err(DispatchError::WorkerStatusReceiverError(String::from(
"Worker state not found in worker state map!",
)));
}
}
// If there was an error encountered, determine what kind of error it was and act accordingly
Err(e) => match e {
tokio::sync::mpsc::error::TryRecvError::Empty => {
continue;
}
tokio::sync::mpsc::error::TryRecvError::Disconnected => {
error!("Worker status channel closed!");
break;
}
},
}
}
Ok(())
}
/// Get the first idle worker transmitter
async fn get_idle_job_transmitter(
&mut self,
) -> Result<Option<Sender<TaskExecutionData>>, DispatchError> {
// Update the worker states first
self.update_worker_states().await?;
// Filter the worker states to get the ones that are idle
let workers_in_idle_state: Vec<(&Ulid, &WorkerStateMachine)> = self
.worker_states
.iter()
.filter(|(_, state)| state.worker_state == WorkerState::Idle)
.collect();
// If there are idle workers, return the first one
if let Some((worker_id, worker_state)) = workers_in_idle_state.first() {
debug!("Found idle worker {}", worker_id);
Ok(Some(worker_state.task_tx.clone()))
} else {
Ok(None)
}
}
/// Start a thread that monitors the input queue and dispatches tasks to the workers as they become available
async fn dispatch(mut self) {
loop {
// Get an idle worker transmitter and send a task to it.
// Breaks out of the loop if an IO channel is closed, or errors occurred when getting an idle worker transmitter
match self.get_idle_job_transmitter().await {
Ok(Some(transmitter)) => match self.task_request_rx.try_recv() {
Ok(task) => {
if let Err(e) = transmitter.send(task).await {
error!(
"Error sending task to worker! {:?} stopping dispatcher...",
e
);
break;
}
}
Err(e) => match e {
tokio::sync::mpsc::error::TryRecvError::Empty => {
continue;
}
tokio::sync::mpsc::error::TryRecvError::Disconnected => {
error!("Task request channel closed!");
break;
}
},
},
Ok(None) => {
if let Err(e) = self.update_worker_states().await {
error!(
"Error updating worker states! {:?} stopping dispatcher...",
e
);
break;
}
}
Err(e) => {
error!("Error getting idle worker! {:?} stopping dispatcher...", e);
break;
}
}
thread::sleep(Duration::from_millis(100));
}
}
}