impl ThreadedProtocolMachine {
#[must_use]
pub fn auto(config: ProtocolMachineConfig) -> Self {
let workers = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
Self::with_workers(config, workers)
}
pub fn try_with_workers(config: ProtocolMachineConfig, workers: usize) -> Result<Self, ProtocolMachineError> {
config
.validate_invariants()
.map_err(|reason| ProtocolMachineError::InvalidConfig { reason })?;
let worker_count = workers.max(1);
let pool = ThreadPoolBuilder::new()
.num_threads(worker_count)
.build()
.map_err(|e| ProtocolMachineError::ThreadPoolBuild {
message: e.to_string(),
})?;
let tick_duration = config.tick_duration;
let communication_replay_mode = config.communication_replay_mode;
let scheduler = Scheduler::new(config.sched_policy.clone());
let mut guard_resources = BTreeMap::new();
for layer in &config.guard_layers {
guard_resources.insert(layer.id.clone(), Value::Unit);
}
Ok(Self {
config,
programs: ProgramStore::new(),
coroutines: Vec::new(),
sessions: ThreadedSessionStore::new(),
scheduler,
trace: Vec::new(),
role_symbols: SymbolTable::new(),
label_symbols: SymbolTable::new(),
handler_symbols: SymbolTable::new(),
edge_symbols: EdgeSymbolTable::new(),
clock: SimClock::new(tick_duration),
next_coro_id: 0,
non_terminal_coroutines: 0,
pool,
workers: worker_count,
lane_count: worker_count,
guard_resources: Arc::new(Mutex::new(guard_resources)),
resource_states: Arc::new(Mutex::new(BTreeMap::new())),
communication_consumption: Arc::new(Mutex::new(DefaultCommunicationConsumption::new(
communication_replay_mode,
))),
communication_consumption_artifacts: Arc::new(Mutex::new(Vec::new())),
effect_trace: Vec::new(),
effect_exchanges: Vec::new(),
operation_instances: Vec::new(),
outstanding_effects: Vec::new(),
progress_contracts: Vec::new(),
progress_transitions: Vec::new(),
next_effect_id: 0,
output_condition_checks: Vec::new(),
crashed_sites: BTreeSet::new(),
paused_roles: BTreeSet::new(),
partitioned_edges: BTreeSet::new(),
corrupted_edges: BTreeMap::new(),
timed_out_sites: BTreeMap::new(),
lane_trace: Vec::new(),
pending_handoffs: VecDeque::new(),
handoff_trace_log: Vec::new(),
next_handoff_id: 0,
delegation_audit_log: Vec::new(),
next_delegation_receipt_id: 0,
contention_metrics: ContentionMetrics::default(),
force_invalid_wave_certificate_once: false,
handler_identity_anchor: None,
})
}
#[must_use]
pub fn with_workers(config: ProtocolMachineConfig, workers: usize) -> Self {
Self::try_with_workers(config, workers)
.unwrap_or_else(|e| panic!("threaded ProtocolMachine initialization failed: {e}"))
}
fn ensure_session_capacity(&self) -> Result<(), ProtocolMachineError> {
if self.sessions.active_count() >= self.config.max_sessions {
return Err(ProtocolMachineError::TooManySessions {
max: self.config.max_sessions,
});
}
Ok(())
}
fn ensure_coroutine_capacity(&self) -> Result<(), ProtocolMachineError> {
if self.coroutines.len() >= self.config.max_coroutines {
return Err(ProtocolMachineError::TooManyCoroutines {
max: self.config.max_coroutines,
});
}
Ok(())
}
fn bind_default_handlers_for_session(&mut self, sid: SessionId) {
if let Some(session) = self.sessions.get(sid) {
let mut session_guard = session.lock().expect("threaded ProtocolMachine lock poisoned");
session_guard.default_handler = crate::session::DEFAULT_HANDLER_ID.to_string();
}
let _: StringId = self
.handler_symbols
.intern(crate::session::DEFAULT_HANDLER_ID);
}
fn spawn_role_coroutine(
&mut self,
sid: SessionId,
role: &str,
image: &CodeImage,
) -> Result<(), ProtocolMachineError> {
self.ensure_coroutine_capacity()?;
let role_key = role.to_string();
let program_id = self
.programs
.intern(image.programs.get(&role_key).cloned().unwrap_or_default());
let coro_id = self.next_coro_id;
self.next_coro_id += 1;
let ep = Endpoint {
sid,
role: role_key.clone(),
};
let mut coro = Coroutine::new(
coro_id,
program_id,
sid,
role_key,
self.config.num_registers,
self.config.initial_cost_budget,
);
coro.owned_endpoints.push(ep.clone());
if !coro.regs.is_empty() {
coro.regs[0] = Value::Endpoint(ep);
}
self.scheduler.add_ready(coro_id);
self.coroutines.push(Arc::new(Mutex::new(coro)));
self.non_terminal_coroutines = self.non_terminal_coroutines.saturating_add(1);
Ok(())
}
#[doc(hidden)]
pub fn load_choreography(&mut self, image: &CodeImage) -> Result<SessionId, ProtocolMachineError> {
self.ensure_session_capacity()?;
image
.validate_runtime_shape()
.map_err(|reason| ProtocolMachineError::InvalidCodeImage { reason })?;
let roles = image.roles();
self.programs.reserve(image.programs.len());
self.coroutines.reserve(roles.len());
let sid = self.sessions.open(
roles.clone(),
&self.config.buffer_config,
&image.local_types,
);
self.bind_default_handlers_for_session(sid);
self.intern_session_runtime_symbols(sid);
self.resource_states
.lock()
.expect("threaded ProtocolMachine lock poisoned")
.entry(sid)
.or_default();
self.trace.push(ObsEvent::Opened {
tick: self.clock.tick,
session: sid,
roles: roles.clone(),
});
for role in &roles {
self.spawn_role_coroutine(sid, role, image)?;
}
Ok(sid)
}
pub fn load_choreography_owned(
&mut self,
image: &CodeImage,
owner_id: impl Into<String>,
) -> Result<OwnedSession, ProtocolMachineError> {
let sid = self.load_choreography(image)?;
let capability = self
.sessions
.claim_ownership(sid, owner_id, OwnershipScope::Session)
.map_err(|err| ProtocolMachineError::OwnershipContract(format!("{err:?}")))?;
Ok(OwnedSession::new(sid, capability))
}
pub(crate) fn kernel_step_round(
&mut self,
handler: &dyn EffectHandler,
n: usize,
) -> Result<StepResult, ProtocolMachineError> {
if n == 0 {
return Err(ProtocolMachineError::InvalidConcurrency { n });
}
self.clock.advance();
let tick = self.clock.tick;
if self.all_done() {
return Ok(StepResult::AllDone);
}
self.ingest_topology_events(handler)?;
self.prune_expired_timeouts();
self.try_unblock_senders();
self.try_unblock_receivers();
self.evaluate_progress_contracts()?;
let mut progressed = false;
let mut remaining = match self.config.threaded_round_semantics {
ThreadedRoundSemantics::CanonicalOneStep => 1,
ThreadedRoundSemantics::WaveParallelExtension => n,
};
let mut wave = 0_u64;
let enforce_certified_wave = remaining > 1;
while remaining > 0 {
let Some(planned_wave) = self.plan_next_wave(remaining, enforce_certified_wave)? else {
break;
};
let stop_after_wave = planned_wave.stop_after_wave;
let picks = planned_wave.picks;
self.record_lane_trace(&picks, tick, wave);
let picks_len = picks.len();
progressed |= self.execute_picks(picks, handler, tick)?;
remaining = remaining.saturating_sub(picks_len);
wave = wave.saturating_add(1);
if stop_after_wave {
break;
}
}
self.evaluate_progress_contracts()?;
if self.all_done() {
Ok(StepResult::AllDone)
} else if progressed {
Ok(StepResult::Continue)
} else {
Ok(StepResult::Stuck)
}
}
fn plan_next_wave(
&mut self,
budget: usize,
enforce_certified_wave: bool,
) -> Result<Option<PlannedWave>, ProtocolMachineError> {
self.contention_metrics
.observe_ready_depth(self.scheduler.ready_count());
let ready_before_pick = self.scheduler.ready_set_snapshot();
let picks = self.pick_ready(budget)?;
if picks.is_empty() {
return Ok(None);
}
if enforce_certified_wave {
let cert = WaveCertificate {
waves: vec![picks.iter().map(|pick| pick.coro_id).collect()],
planner_step: self.scheduler.step_count(),
};
if !self.check_wave_certificate(&cert, &ready_before_pick) {
self.restore_picks_to_ready(&picks);
let fallback = self.pick_ready(1)?;
if fallback.is_empty() {
return Ok(None);
}
return Ok(Some(PlannedWave {
picks: fallback,
stop_after_wave: true,
}));
}
}
Ok(Some(PlannedWave {
picks,
stop_after_wave: false,
}))
}
fn execute_picks(
&mut self,
picks: Vec<Picked>,
handler: &dyn EffectHandler,
tick: u64,
) -> Result<bool, ProtocolMachineError> {
let handler_identity = handler.handler_identity();
self.enforce_handler_identity_contract(&handler_identity)?;
self.contention_metrics.observe_wave_width(picks.len());
let step_ctx = ThreadedStepCtx {
config: &self.config,
guard_resources: &self.guard_resources,
resource_states: &self.resource_states,
communication_consumption: &self.communication_consumption,
communication_consumption_artifacts: &self.communication_consumption_artifacts,
crashed_sites: &self.crashed_sites,
partitioned_edges: &self.partitioned_edges,
corrupted_edges: &self.corrupted_edges,
timed_out_sites: &self.timed_out_sites,
handler,
tick,
};
let exec_ctx = ThreadedExecCtx {
store: &self.sessions,
programs: &self.programs,
step: step_ctx,
};
if picks.len() == 1 {
let Some(pick) = picks.into_iter().next() else {
return Ok(false);
};
let result = exec_instr(&pick.coro, &pick.session, &exec_ctx);
return self.commit_pick_result(&pick, result, tick, handler, &handler_identity);
}
let results: Vec<Result<ThreadedExecSuccess, ThreadedExecFault>> =
self.pool.install(|| {
picks
.par_iter()
.map(|pick| exec_instr(&pick.coro, &pick.session, &exec_ctx))
.collect()
});
let mut progressed = false;
for (pick, result) in picks.into_iter().zip(results) {
progressed |= self.commit_pick_result(&pick, result, tick, handler, &handler_identity)?;
}
Ok(progressed)
}
fn commit_pick_result(
&mut self,
pick: &Picked,
result: Result<ThreadedExecSuccess, ThreadedExecFault>,
tick: u64,
handler: &dyn EffectHandler,
handler_identity: &str,
) -> Result<bool, ProtocolMachineError> {
match result {
Ok(ThreadedExecSuccess {
pack,
effect_observations,
output_observation,
}) => {
match self.commit_pack(
&pick.coro,
&pick.session,
pack,
effect_observations,
output_observation,
handler,
handler_identity,
) {
Ok(outcome) => match outcome {
ExecOutcome::Continue => {
self.scheduler.reschedule(pick.coro_id);
}
ExecOutcome::Blocked(reason) => {
self.scheduler.mark_blocked(pick.coro_id, reason);
}
ExecOutcome::Halted => {
self.scheduler.mark_done(pick.coro_id);
self.trace.push(ObsEvent::Halted {
tick,
coro_id: pick.coro_id,
});
}
},
Err(fault) => {
self.trace.push(ObsEvent::Faulted {
tick,
coro_id: pick.coro_id,
fault: fault.clone(),
});
let mut coro = pick.coro.lock().expect("threaded ProtocolMachine lock poisoned");
let was_terminal = coro.is_terminal();
coro.status = CoroStatus::Faulted(fault.clone());
self.note_status_transition(was_terminal, coro.is_terminal());
self.scheduler.mark_done(pick.coro_id);
return Err(ProtocolMachineError::Fault {
coro_id: pick.coro_id,
fault,
});
}
}
Ok(true)
}
Err(ThreadedExecFault {
fault,
effect_observations,
}) => {
self.record_effect_observations(effect_observations, handler_identity);
self.trace.push(ObsEvent::Faulted {
tick,
coro_id: pick.coro_id,
fault: fault.clone(),
});
let mut coro = pick.coro.lock().expect("threaded ProtocolMachine lock poisoned");
let was_terminal = coro.is_terminal();
coro.status = CoroStatus::Faulted(fault.clone());
self.note_status_transition(was_terminal, coro.is_terminal());
self.scheduler.mark_done(pick.coro_id);
Err(ProtocolMachineError::Fault {
coro_id: pick.coro_id,
fault,
})
}
}
}
fn record_lane_trace(&mut self, picks: &[Picked], tick: u64, wave: u64) {
for pick in picks {
self.lane_trace.push(LaneSelection {
tick,
wave,
coro_id: pick.coro_id,
session: pick.sid,
lane: pick.lane,
});
}
}
fn restore_picks_to_ready(&mut self, picks: &[Picked]) {
for pick in picks {
self.scheduler.reschedule(pick.coro_id);
}
}
fn check_wave_certificate(
&mut self,
cert: &WaveCertificate,
ready_before_pick: &BTreeSet<usize>,
) -> bool {
if self.force_invalid_wave_certificate_once {
self.force_invalid_wave_certificate_once = false;
return false;
}
if cert.planner_step != self.scheduler.step_count() {
return false;
}
let mut seen = BTreeSet::new();
for wave in &cert.waves {
for coro_id in wave {
if !seen.insert(*coro_id) {
return false;
}
}
if !self.check_wave_admissible(wave, ready_before_pick) {
return false;
}
}
true
}
fn check_wave_admissible(&self, wave: &[usize], ready_before_pick: &BTreeSet<usize>) -> bool {
let mut seen = BTreeSet::new();
let mut lanes = BTreeSet::new();
let mut footprint = BTreeSet::new();
for coro_id in wave {
if !seen.insert(*coro_id) || !ready_before_pick.contains(coro_id) {
return false;
}
let Some(coro) = self.coroutines.get(*coro_id) else {
return false;
};
let guard = coro.lock().expect("threaded ProtocolMachine lock poisoned");
if !matches!(guard.status, CoroStatus::Ready | CoroStatus::Speculating) {
return false;
}
let lane = *coro_id % self.lane_count.max(1);
if !lanes.insert(lane) {
return false;
}
for endpoint in &guard.owned_endpoints {
let key = (endpoint.sid, role_fingerprint(&endpoint.role));
if !footprint.insert(key) {
return false;
}
}
}
true
}
pub fn step_round(
&mut self,
handler: &dyn EffectHandler,
n: usize,
) -> Result<StepResult, ProtocolMachineError> {
ProtocolMachineKernel::step_round(self, handler, n)
}
pub fn run(
&mut self,
handler: &dyn EffectHandler,
max_rounds: usize,
) -> Result<RunStatus, ProtocolMachineError> {
self.run_concurrent(handler, max_rounds, self.workers.max(1))
}
pub fn run_replay(
&mut self,
fallback: &dyn EffectHandler,
replay_trace: &[EffectTraceEntry],
max_rounds: usize,
) -> Result<RunStatus, ProtocolMachineError> {
self.run_replay_shared(
fallback,
Arc::<[EffectTraceEntry]>::from(replay_trace),
max_rounds,
)
}
pub fn run_replay_shared(
&mut self,
fallback: &dyn EffectHandler,
replay_trace: Arc<[EffectTraceEntry]>,
max_rounds: usize,
) -> Result<RunStatus, ProtocolMachineError> {
let replay = ReplayEffectHandler::with_fallback(replay_trace, fallback);
self.run(&replay, max_rounds)
}
pub fn run_concurrent(
&mut self,
handler: &dyn EffectHandler,
max_rounds: usize,
concurrency: usize,
) -> Result<RunStatus, ProtocolMachineError> {
ProtocolMachineKernel::run_concurrent(self, handler, max_rounds, concurrency)
}
pub fn force_invalid_wave_certificate_once(&mut self) {
self.force_invalid_wave_certificate_once = true;
}
pub fn run_concurrent_replay(
&mut self,
fallback: &dyn EffectHandler,
replay_trace: &[EffectTraceEntry],
max_rounds: usize,
concurrency: usize,
) -> Result<RunStatus, ProtocolMachineError> {
self.run_concurrent_replay_shared(
fallback,
Arc::<[EffectTraceEntry]>::from(replay_trace),
max_rounds,
concurrency,
)
}
pub fn run_concurrent_replay_shared(
&mut self,
fallback: &dyn EffectHandler,
replay_trace: Arc<[EffectTraceEntry]>,
max_rounds: usize,
concurrency: usize,
) -> Result<RunStatus, ProtocolMachineError> {
let replay = ReplayEffectHandler::with_fallback(replay_trace, fallback);
self.run_concurrent(&replay, max_rounds, concurrency)
}
}