use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use beamr::scheduler::Scheduler;
use crate::routing::FieldValue;
use crate::routing::function::loader::{ContentHash, RoutingFunction};
mod actor;
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ConsumerId(String);
impl ConsumerId {
#[must_use]
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
#[must_use]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ConsumerStateView {
pub consumer: ConsumerId,
pub current_in_flight: u32,
pub max_in_flight: u32,
pub buffer_depth: u32,
pub affinity_tags: Vec<String>,
}
impl ConsumerStateView {
#[must_use]
pub const fn new(
consumer: ConsumerId,
current_in_flight: u32,
max_in_flight: u32,
buffer_depth: u32,
affinity_tags: Vec<String>,
) -> Self {
Self {
consumer,
current_in_flight,
max_in_flight,
buffer_depth,
affinity_tags,
}
}
#[must_use]
pub const fn available_capacity(&self) -> u32 {
self.max_in_flight.saturating_sub(self.current_in_flight)
}
#[must_use]
pub const fn has_capacity(&self) -> bool {
self.available_capacity() > 0
}
#[must_use]
pub fn has_affinity(&self, tag: &str) -> bool {
self.affinity_tags
.iter()
.any(|advertised| advertised == tag)
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct RoutingDecision {
selected: Option<ConsumerId>,
}
impl RoutingDecision {
#[must_use]
pub const fn select(consumer: ConsumerId) -> Self {
Self {
selected: Some(consumer),
}
}
#[must_use]
pub const fn none() -> Self {
Self { selected: None }
}
#[must_use]
pub const fn selected(&self) -> Option<&ConsumerId> {
self.selected.as_ref()
}
#[must_use]
pub const fn is_selected(&self) -> bool {
self.selected.is_some()
}
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct RoutingMessage {
fields: BTreeMap<String, FieldValue>,
}
impl RoutingMessage {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with(mut self, field: impl Into<String>, value: FieldValue) -> Self {
self.fields.insert(field.into(), value);
self
}
#[must_use]
pub fn get(&self, field: &str) -> Option<&FieldValue> {
self.fields.get(field)
}
pub fn fields(&self) -> impl Iterator<Item = (&str, &FieldValue)> {
self.fields
.iter()
.map(|(name, value)| (name.as_str(), value))
}
}
#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
pub enum FunctionError {
#[error("routing function '{0}' panicked during execution")]
Crashed(ContentHash),
#[error("routing function '{0}' exceeded the supervision timeout")]
TimedOut(ContentHash),
#[error("routing function execution process could not be started: {0}")]
SpawnFailed(String),
}
#[derive(Clone)]
pub struct SupervisedExecutor {
scheduler: Arc<Scheduler>,
timeout: Duration,
}
impl SupervisedExecutor {
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
#[must_use]
pub const fn new(scheduler: Arc<Scheduler>, timeout: Duration) -> Self {
Self { scheduler, timeout }
}
#[must_use]
pub const fn with_default_timeout(scheduler: Arc<Scheduler>) -> Self {
Self::new(scheduler, Self::DEFAULT_TIMEOUT)
}
#[must_use]
pub fn scheduler(&self) -> Arc<Scheduler> {
Arc::clone(&self.scheduler)
}
pub fn execute(
&self,
function: &RoutingFunction,
message: RoutingMessage,
consumers: Vec<ConsumerStateView>,
) -> Result<RoutingDecision, FunctionError> {
let invocation = actor::BeamrInvocation::new(Arc::clone(&self.scheduler), self.timeout);
let hash = function.content_hash();
match invocation.execute(function.clone(), message, consumers) {
Ok(decision) => Ok(decision),
Err(actor::InvocationError::Crashed) => Err(FunctionError::Crashed(hash)),
Err(actor::InvocationError::TimedOut(timed_out_hash)) => {
Err(FunctionError::TimedOut(timed_out_hash))
}
Err(actor::InvocationError::SpawnFailed(message)) => {
Err(FunctionError::SpawnFailed(message))
}
}
}
}
impl std::fmt::Debug for SupervisedExecutor {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("SupervisedExecutor")
.field("timeout", &self.timeout)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use std::time::Duration;
use super::{
ConsumerId, ConsumerStateView, FunctionError, RoutingDecision, RoutingMessage,
SupervisedExecutor,
};
use crate::conversation::ConversationSupervisor;
use crate::routing::FieldValue;
use crate::routing::function::loader::{ModuleLoader, RoutingModule, RoutingSlot};
fn consumer(id: &str, current: u32, max: u32, tags: &[&str]) -> ConsumerStateView {
ConsumerStateView::new(
ConsumerId::new(id),
current,
max,
0,
tags.iter().map(|tag| (*tag).to_owned()).collect(),
)
}
fn select_first_with_capacity_module(bytecode: &[u8]) -> RoutingModule {
RoutingModule::new(bytecode, |_message, consumers| {
consumers
.iter()
.find(|state| state.has_capacity())
.map_or_else(RoutingDecision::none, |state| {
RoutingDecision::select(state.consumer.clone())
})
})
}
fn selected_name(decision: &RoutingDecision) -> Option<&str> {
decision.selected().map(ConsumerId::as_str)
}
fn supervised_executor() -> Result<(ConversationSupervisor, SupervisedExecutor), Box<dyn Error>>
{
let supervisor = ConversationSupervisor::new()?;
let executor = SupervisedExecutor::with_default_timeout(supervisor.scheduler());
Ok((supervisor, executor))
}
#[test]
fn execution_returns_decision_using_consumer_state_view() -> Result<(), Box<dyn Error>> {
let loader = ModuleLoader::new();
let function = loader.load(select_first_with_capacity_module(b"v1"));
let (_supervisor, executor) = supervised_executor()?;
let consumers = vec![
consumer("saturated", 5, 5, &["fast"]),
consumer("ready", 1, 4, &["fast"]),
];
let decision = executor.execute(&function, RoutingMessage::new(), consumers);
assert!(matches!(decision, Ok(ref outcome) if selected_name(outcome) == Some("ready")));
Ok(())
}
#[test]
fn message_fields_are_visible_to_routing_function() -> Result<(), Box<dyn Error>> {
let loader = ModuleLoader::new();
let function = loader.load(RoutingModule::new(
b"amount-router",
|message, consumers| {
let high_value = matches!(
message.get("amount"),
Some(FieldValue::Integer(amount)) if *amount > 1_000
);
if high_value {
consumers
.first()
.map_or_else(RoutingDecision::none, |state| {
RoutingDecision::select(state.consumer.clone())
})
} else {
RoutingDecision::none()
}
},
));
let (_supervisor, executor) = supervised_executor()?;
let message = RoutingMessage::new().with("amount", FieldValue::Integer(5_000));
let decision = executor.execute(&function, message, vec![consumer("priority", 0, 1, &[])]);
assert!(matches!(decision, Ok(ref outcome) if selected_name(outcome) == Some("priority")));
Ok(())
}
#[test]
fn panic_in_function_is_contained_and_other_channels_proceed() -> Result<(), Box<dyn Error>> {
let loader = ModuleLoader::new();
let crashing = loader.load(RoutingModule::new(b"channel-a", |_message, _consumers| {
std::panic::resume_unwind(Box::new(
"intentional crash for fault-isolation test".to_owned(),
))
}));
let healthy = loader.load(select_first_with_capacity_module(b"channel-b"));
let (_supervisor, executor) = supervised_executor()?;
let crashed = executor.execute(&crashing, RoutingMessage::new(), Vec::new());
assert_eq!(
crashed,
Err(FunctionError::Crashed(crashing.content_hash()))
);
let recovered = executor.execute(
&healthy,
RoutingMessage::new(),
vec![consumer("ready", 0, 1, &[])],
);
assert!(matches!(recovered, Ok(ref outcome) if selected_name(outcome) == Some("ready")));
Ok(())
}
#[test]
fn repeated_panics_do_not_poison_the_shared_supervisor() -> Result<(), Box<dyn Error>> {
let loader = ModuleLoader::new();
let crashing = loader.load(RoutingModule::new(b"flaky", |_message, _consumers| {
std::panic::resume_unwind(Box::new("repeated intentional crash".to_owned()))
}));
let healthy = loader.load(select_first_with_capacity_module(b"steady"));
let (_supervisor, executor) = supervised_executor()?;
for _ in 0..16 {
let crashed = executor.execute(&crashing, RoutingMessage::new(), Vec::new());
assert_eq!(
crashed,
Err(FunctionError::Crashed(crashing.content_hash()))
);
let served = executor.execute(
&healthy,
RoutingMessage::new(),
vec![consumer("ready", 0, 1, &[])],
);
assert!(
matches!(served, Ok(ref outcome) if selected_name(outcome) == Some("ready")),
"scheduler must keep serving healthy invocations after a contained panic"
);
}
Ok(())
}
#[test]
fn function_exceeding_timeout_is_terminated_with_error() -> Result<(), Box<dyn Error>> {
let loader = ModuleLoader::new();
let slow = loader.load(RoutingModule::new(b"slow", |_message, _consumers| {
thread::sleep(Duration::from_millis(200));
RoutingDecision::none()
}));
let supervisor = ConversationSupervisor::new()?;
let executor = SupervisedExecutor::new(supervisor.scheduler(), Duration::from_millis(20));
let result = executor.execute(&slow, RoutingMessage::new(), Vec::new());
assert_eq!(result, Err(FunctionError::TimedOut(slow.content_hash())));
Ok(())
}
#[test]
fn hot_deploy_does_not_interrupt_in_flight_and_swaps_next_version() -> Result<(), Box<dyn Error>>
{
let loader = ModuleLoader::new();
let entered = Arc::new(AtomicBool::new(false));
let release = Arc::new(AtomicBool::new(false));
let entered_for_logic = Arc::clone(&entered);
let release_for_logic = Arc::clone(&release);
let old = loader.load(RoutingModule::new(b"v1", move |_message, _consumers| {
entered_for_logic.store(true, Ordering::SeqCst);
while !release_for_logic.load(Ordering::SeqCst) {
thread::sleep(Duration::from_millis(1));
}
RoutingDecision::select(ConsumerId::new("old"))
}));
let new = loader.load(RoutingModule::new(b"v2", |_message, _consumers| {
RoutingDecision::select(ConsumerId::new("new"))
}));
let old_hash = old.content_hash();
let new_hash = new.content_hash();
let slot = Arc::new(RoutingSlot::new(old));
let (_supervisor, executor) = supervised_executor()?;
let slot_for_thread = Arc::clone(&slot);
let executor_for_thread = executor.clone();
let in_flight = thread::spawn(move || {
let function = slot_for_thread.current();
executor_for_thread.execute(&function, RoutingMessage::new(), Vec::new())
});
while !entered.load(Ordering::SeqCst) {
thread::sleep(Duration::from_millis(1));
}
slot.deploy(new);
assert_eq!(slot.active_hash(), new_hash);
assert!(
loader.is_loaded(old_hash),
"old module must remain loaded while in flight"
);
assert_eq!(loader.loaded_count(), 2);
release.store(true, Ordering::SeqCst);
let in_flight_result = in_flight.join();
assert!(matches!(
in_flight_result,
Ok(Ok(ref outcome)) if selected_name(outcome) == Some("old")
));
let next = executor.execute(&slot.current(), RoutingMessage::new(), Vec::new());
assert!(matches!(next, Ok(ref outcome) if selected_name(outcome) == Some("new")));
Ok(())
}
}