use std::{collections::BTreeMap, mem::replace};
use reifydb_core::{
common::CommitVersion,
encoded::shape::RowShape,
interface::{catalog::flow::FlowId, change::Change},
internal,
};
use reifydb_rql::flow::flow::FlowDag;
use reifydb_runtime::{
actor::{
context::Context,
mailbox::ActorRef,
system::ActorConfig,
traits::{Actor, Directive},
},
context::clock::{Clock, Instant},
};
use reifydb_type::{util::hex::encode, value::datetime::DateTime};
use tracing::{Span, field, instrument};
use super::{
instruction::WorkerBatch,
worker::{FlowMsg, FlowResponse},
};
use crate::transaction::pending::{Pending, PendingWrite};
pub enum PoolMsg {
RegisterFlow {
flow: FlowDag,
reply: Box<dyn FnOnce(PoolResponse) + Send>,
},
Submit {
batches: BTreeMap<usize, WorkerBatch>,
reply: Box<dyn FnOnce(PoolResponse) + Send>,
},
SubmitToWorker {
worker_id: usize,
batch: WorkerBatch,
reply: Box<dyn FnOnce(PoolResponse) + Send>,
},
Tick {
ticks: BTreeMap<usize, Vec<FlowId>>,
timestamp: DateTime,
state_version: CommitVersion,
reply: Box<dyn FnOnce(PoolResponse) + Send>,
},
WorkerReply {
worker_id: usize,
response: FlowResponse,
},
}
pub enum PoolResponse {
Success {
pending: Pending,
pending_shapes: Vec<RowShape>,
view_changes: Vec<Change>,
},
RegisterSuccess,
Error(String),
}
enum Phase {
Idle,
WaitingForWorkers {
pending_count: usize,
results: Vec<Pending>,
pending_shapes: Vec<RowShape>,
view_changes: Vec<Change>,
reply: Box<dyn FnOnce(PoolResponse) + Send>,
started_at: Instant,
},
WaitingForSingleWorker {
reply: Box<dyn FnOnce(PoolResponse) + Send>,
is_register: bool,
},
}
pub struct PoolActor {
refs: Vec<ActorRef<FlowMsg>>,
clock: Clock,
}
impl PoolActor {
pub fn new(refs: Vec<ActorRef<FlowMsg>>, clock: Clock) -> Self {
Self {
refs,
clock,
}
}
}
pub struct PoolState {
phase: Phase,
}
impl Actor for PoolActor {
type State = PoolState;
type Message = PoolMsg;
fn init(&self, _ctx: &Context<Self::Message>) -> Self::State {
PoolState {
phase: Phase::Idle,
}
}
fn handle(&self, state: &mut Self::State, msg: Self::Message, ctx: &Context<Self::Message>) -> Directive {
match msg {
PoolMsg::RegisterFlow {
flow,
reply,
} => {
if !matches!(state.phase, Phase::Idle) {
(reply)(PoolResponse::Error("Pool actor is busy".to_string()));
return Directive::Continue;
}
let flow_id = flow.id;
let worker_id = (flow_id.0 as usize) % self.refs.len();
let self_ref = ctx.self_ref().clone();
let callback: Box<dyn FnOnce(FlowResponse) + Send> = Box::new(move |resp| {
let _ = self_ref.send(PoolMsg::WorkerReply {
worker_id,
response: resp,
});
});
if self.refs[worker_id]
.send(FlowMsg::Register {
flow,
reply: callback,
})
.is_err()
{
reply(PoolResponse::Error(format!("Worker {} stopped", worker_id)));
return Directive::Continue;
}
state.phase = Phase::WaitingForSingleWorker {
reply,
is_register: true,
};
}
PoolMsg::Submit {
batches,
reply,
} => {
if !matches!(state.phase, Phase::Idle) {
reply(PoolResponse::Error("Pool actor is busy".to_string()));
return Directive::Continue;
}
self.handle_submit_async(state, ctx, batches, reply);
}
PoolMsg::SubmitToWorker {
worker_id,
batch,
reply,
} => {
if !matches!(state.phase, Phase::Idle) {
(reply)(PoolResponse::Error("Pool actor is busy".to_string()));
return Directive::Continue;
}
if worker_id >= self.refs.len() {
(reply)(PoolResponse::Error(
internal!("Invalid worker_id: {}", worker_id).to_string(),
));
return Directive::Continue;
}
let self_ref = ctx.self_ref().clone();
let callback: Box<dyn FnOnce(FlowResponse) + Send> = Box::new(move |resp| {
let _ = self_ref.send(PoolMsg::WorkerReply {
worker_id,
response: resp,
});
});
if self.refs[worker_id]
.send(FlowMsg::Process {
batch,
reply: callback,
})
.is_err()
{
reply(PoolResponse::Error(format!("Worker {} stopped", worker_id)));
return Directive::Continue;
}
state.phase = Phase::WaitingForSingleWorker {
reply,
is_register: false,
};
}
PoolMsg::Tick {
ticks,
timestamp,
state_version,
reply,
} => {
if !matches!(state.phase, Phase::Idle) {
(reply)(PoolResponse::Error("Pool actor is busy".to_string()));
return Directive::Continue;
}
self.handle_tick_async(state, ctx, ticks, timestamp, state_version, reply);
}
PoolMsg::WorkerReply {
worker_id,
response,
} => {
self.handle_worker_reply(state, worker_id, response);
}
}
Directive::Continue
}
fn config(&self) -> ActorConfig {
ActorConfig::new()
}
}
impl PoolActor {
#[instrument(name = "flow::pool::submit", level = "debug", skip(self, state, ctx, batches, reply), fields(
batches = batches.len(),
instructions = field::Empty,
elapsed_us = field::Empty
))]
fn handle_submit_async(
&self,
state: &mut PoolState,
ctx: &Context<PoolMsg>,
batches: BTreeMap<usize, WorkerBatch>,
reply: Box<dyn FnOnce(PoolResponse) + Send>,
) {
let start = self.clock.instant();
let total_instructions: usize = batches.values().map(|b| b.instructions.len()).sum();
Span::current().record("instructions", total_instructions);
let batch_count = batches.len();
for (worker_id, batch) in batches {
if worker_id >= self.refs.len() {
(reply)(PoolResponse::Error(internal!("Invalid worker_id: {}", worker_id).to_string()));
return;
}
let self_ref = ctx.self_ref().clone();
let callback: Box<dyn FnOnce(FlowResponse) + Send> = Box::new(move |resp| {
let _ = self_ref.send(PoolMsg::WorkerReply {
worker_id,
response: resp,
});
});
if self.refs[worker_id]
.send(FlowMsg::Process {
batch,
reply: callback,
})
.is_err()
{
(reply)(PoolResponse::Error(format!("Worker {} stopped", worker_id)));
return;
}
}
state.phase = Phase::WaitingForWorkers {
pending_count: batch_count,
results: Vec::with_capacity(batch_count),
pending_shapes: Vec::new(),
view_changes: Vec::new(),
reply,
started_at: start,
};
}
fn handle_tick_async(
&self,
state: &mut PoolState,
ctx: &Context<PoolMsg>,
ticks: BTreeMap<usize, Vec<FlowId>>,
timestamp: DateTime,
state_version: CommitVersion,
reply: Box<dyn FnOnce(PoolResponse) + Send>,
) {
let tick_count = ticks.len();
for (worker_id, flow_ids) in ticks {
if worker_id >= self.refs.len() {
(reply)(PoolResponse::Error(internal!("Invalid worker_id: {}", worker_id).to_string()));
return;
}
let self_ref = ctx.self_ref().clone();
let callback: Box<dyn FnOnce(FlowResponse) + Send> = Box::new(move |resp| {
let _ = self_ref.send(PoolMsg::WorkerReply {
worker_id,
response: resp,
});
});
if self.refs[worker_id]
.send(FlowMsg::Tick {
flow_ids,
timestamp,
state_version,
reply: callback,
})
.is_err()
{
(reply)(PoolResponse::Error(format!("Worker {} stopped", worker_id)));
return;
}
}
state.phase = Phase::WaitingForWorkers {
pending_count: tick_count,
results: Vec::with_capacity(tick_count),
pending_shapes: Vec::new(),
view_changes: Vec::new(),
reply,
started_at: self.clock.instant(),
};
}
fn handle_worker_reply(&self, state: &mut PoolState, worker_id: usize, response: FlowResponse) {
let phase = replace(&mut state.phase, Phase::Idle);
match phase {
Phase::WaitingForSingleWorker {
reply: original_reply,
is_register,
} => {
let resp = match response {
FlowResponse::Success {
pending,
pending_shapes,
view_changes,
} => {
if is_register {
PoolResponse::RegisterSuccess
} else {
PoolResponse::Success {
pending,
pending_shapes,
view_changes,
}
}
}
FlowResponse::Error(e) => PoolResponse::Error(e),
};
(original_reply)(resp);
}
Phase::WaitingForWorkers {
mut pending_count,
mut results,
mut pending_shapes,
mut view_changes,
reply: original_reply,
started_at: start,
} => {
match response {
FlowResponse::Success {
pending,
pending_shapes: worker_pending_shapes,
view_changes: worker_view_changes,
} => {
results.push(pending);
pending_shapes.extend(worker_pending_shapes);
view_changes.extend(worker_view_changes);
pending_count -= 1;
if pending_count == 0 {
match self.aggregate_pending_writes(results) {
Ok(combined) => {
Span::current().record(
"elapsed_us",
start.elapsed().as_micros() as u64,
);
(original_reply)(PoolResponse::Success {
pending: combined,
pending_shapes,
view_changes,
});
}
Err(e) => {
(original_reply)(PoolResponse::Error(e));
}
}
} else {
state.phase = Phase::WaitingForWorkers {
pending_count,
results,
pending_shapes,
view_changes,
reply: original_reply,
started_at: start,
};
}
}
FlowResponse::Error(e) => {
(original_reply)(PoolResponse::Error(format!(
"Worker {} error: {}",
worker_id, e
)));
}
}
}
Phase::Idle => {
}
}
}
fn aggregate_pending_writes(&self, writes: Vec<Pending>) -> Result<Pending, String> {
let mut combined = Pending::new();
for pending in writes {
for (key, value) in pending.iter_sorted() {
if combined.contains_key(key) {
return Err(internal!(
"keyspace overlap detected during worker aggregation: {}",
encode(key.as_ref())
)
.to_string());
}
match value {
PendingWrite::Set(v) => {
combined.insert(key.clone(), v.clone());
}
PendingWrite::Remove => {
combined.remove(key.clone());
}
}
}
}
Ok(combined)
}
}