#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
use crate::network::client::agent::AlgorithmInitArgs;
use crate::network::client::agent::{ActorInferenceMode, ClientModes};
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
use crate::network::client::runtime::coordination::lifecycle_manager::SharedTransportAddresses;
use crate::network::client::runtime::coordination::state_manager::{ActorUuid, env_dtype_to_dtype};
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
use crate::network::client::runtime::data::sinks::transport_sink::TransportError;
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
use crate::network::client::runtime::data::sinks::transport_sink::transport_dispatcher::{
InferenceDispatcher, TrainingDispatcher,
};
use crate::network::client::runtime::router::{
ControlPayload, DataPayload, InferenceRequest, RoutedMessage, RoutingProtocol,
};
#[cfg(feature = "metrics")]
use crate::utilities::observability::metrics::MetricsManager;
use active_uuid_registry::registry_uuid::Uuid;
use arc_swap::ArcSwapOption;
use dashmap::DashMap;
#[cfg(feature = "tch-backend")]
use relayrl_env_trait::EnvTchDType;
use relayrl_env_trait::{EnvDType, EnvNdArrayDType};
use relayrl_types::data::action::RelayRLAction;
#[cfg(feature = "tch-backend")]
use relayrl_types::data::tensor::TchDType;
use relayrl_types::data::tensor::{
BackendMatcher, DType, DeviceType, NdArrayDType, SupportedTensorBackend, TensorData,
};
use relayrl_types::data::trajectory::RelayRLTrajectory;
use relayrl_types::model::utils::{deserialize_model_module, validate_module};
use relayrl_types::model::{HotReloadableModel, ModelError, ModelModule};
use relayrl_types::prelude::tensor::relayrl::AnyBurnTensor;
use std::any::Any;
use std::collections::HashMap;
use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
#[cfg(feature = "metrics")]
use std::time::Instant;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::oneshot;
use tokio::sync::{Mutex, RwLock};
use burn_tensor::backend::Backend;
use thiserror::Error;
pub(crate) type LocalModelHandle<B> = Arc<ArcSwapOption<HotReloadableModel<B>>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct ActorShape {
pub(crate) d_in: usize,
pub(crate) d_out: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ActorDTypes {
pub(crate) dtype_in: DType,
pub(crate) dtype_out: DType,
}
pub(crate) trait ErasedActorRuntime<B: Backend + BackendMatcher<Backend = B>>:
Send + Sync
{
fn actor_shape(&self) -> ActorShape;
fn current_model_dtypes(&self) -> Result<ActorDTypes, ActorError>;
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
fn current_algorithm_args(&self) -> Result<AlgorithmInitArgs, ActorError>;
fn request_inference_erased(
&self,
observation: Box<dyn Any + Send + Sync>,
mask: Box<dyn Any + Send + Sync>,
reward: f32,
) -> Pin<Box<dyn Future<Output = Result<RelayRLAction, ActorError>> + Send + '_>>;
fn flag_last_action_erased(
&self,
reward: f32,
env_id: Option<Uuid>,
env_label: Option<String>,
return_traj_override: bool,
) -> Pin<Box<dyn Future<Output = Result<Option<RelayRLTrajectory>, ActorError>> + Send + '_>>;
fn perform_env_byte_inference_erased<'a>(
&'a self,
model_name: &'a str,
obs_bytes: &'a [u8],
n_envs: usize,
obs_dim: usize,
obs_dtype: &'a DType,
) -> Pin<Box<dyn Future<Output = Result<TensorData, ActorError>> + Send + 'a>>;
fn perform_env_refresh_model_erased<'a>(
&'a self,
name: &'a str,
model_module: ModelModule<B>,
device: DeviceType,
) -> Pin<Box<dyn Future<Output = Result<(), ActorError>> + Send + 'a>>;
}
struct ActorTrajectoryState {
current_traj: RelayRLTrajectory,
current_episode: u64,
per_env_trajs: HashMap<Uuid, RelayRLTrajectory>,
per_env_episodes: HashMap<Uuid, u64>,
per_env_labels: HashMap<Uuid, String>,
}
impl ActorTrajectoryState {
fn new(max_traj_length: Arc<usize>) -> Self {
Self {
current_traj: RelayRLTrajectory::new(*max_traj_length),
current_episode: 0,
per_env_trajs: HashMap::new(),
per_env_episodes: HashMap::new(),
per_env_labels: HashMap::new(),
}
}
fn ensure_env_trajectory(
&mut self,
actor_id: ActorUuid,
env_id: Uuid,
env_label: impl Into<String>,
max_traj_length: Arc<usize>,
) -> &mut RelayRLTrajectory {
let label = env_label.into();
self.per_env_labels
.entry(env_id)
.or_insert_with(|| label.clone());
self.per_env_trajs.entry(env_id).or_insert_with(|| {
RelayRLTrajectory::with_metadata(
*max_traj_length,
Some(actor_id),
Some(env_id),
Some(label),
None,
None,
)
})
}
fn take_completed_env_trajectory(
&mut self,
actor_id: ActorUuid,
env_id: Uuid,
env_label: Option<String>,
reward: f32,
max_traj_length: Arc<usize>,
) -> RelayRLTrajectory {
let label = env_label.unwrap_or_else(|| {
self.per_env_labels
.get(&env_id)
.cloned()
.unwrap_or_else(|| env_id.to_string())
});
let traj =
self.ensure_env_trajectory(actor_id, env_id, label.clone(), max_traj_length.clone());
traj.add_action(RelayRLAction::new(
None,
None,
None,
reward,
true,
None,
Some(actor_id),
));
let mut traj_to_send = self
.per_env_trajs
.insert(
env_id,
RelayRLTrajectory::with_metadata(
*max_traj_length,
Some(actor_id),
Some(env_id),
Some(label),
None,
None,
),
)
.expect("env trajectory should exist before flush");
let episode = self.per_env_episodes.entry(env_id).or_insert(0);
traj_to_send.set_episode(*episode);
*episode += 1;
traj_to_send
}
fn take_completed_actor_trajectory(
&mut self,
actor_id: ActorUuid,
reward: f32,
max_traj_length: Arc<usize>,
) -> RelayRLTrajectory {
let last_action = RelayRLAction::new(None, None, None, reward, true, None, Some(actor_id));
self.current_traj.add_action(last_action);
let mut traj_to_send = std::mem::replace(
&mut self.current_traj,
RelayRLTrajectory::new(*max_traj_length),
);
traj_to_send.set_episode(self.current_episode);
self.current_episode += 1;
traj_to_send
}
fn take_shutdown_trajectories(
&mut self,
actor_id: ActorUuid,
max_traj_length: Arc<usize>,
) -> Vec<RelayRLTrajectory> {
let mut trajectories = Vec::new();
if !self.current_traj.actions.is_empty() {
trajectories.push(std::mem::replace(
&mut self.current_traj,
RelayRLTrajectory::new(*max_traj_length),
));
}
let env_ids_to_flush: Vec<_> = self
.per_env_trajs
.iter()
.filter_map(|(env_id, traj)| (!traj.actions.is_empty()).then_some(*env_id))
.collect();
for env_id in env_ids_to_flush {
let env_label = self.per_env_labels.get(&env_id).cloned();
trajectories.push(self.take_completed_env_trajectory(
actor_id,
env_id,
env_label,
0.0,
max_traj_length.clone(),
));
}
trajectories
}
}
pub(crate) struct ActorRuntime<
B: Backend + BackendMatcher<Backend = B>,
const D_IN: usize,
const D_OUT: usize,
> {
actor_id: ActorUuid,
pub(crate) reloadable_model: LocalModelHandle<B>,
temp_env_models: DashMap<String, LocalModelHandle<B>>,
max_traj_length: Arc<usize>,
shared_tx_to_buffer: Sender<RoutedMessage>,
trajectories: Mutex<ActorTrajectoryState>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
algorithm_args: AlgorithmInitArgs,
#[cfg(feature = "metrics")]
metrics: MetricsManager,
}
impl<
B: Backend + BackendMatcher<Backend = B> + Send + Sync + 'static,
const D_IN: usize,
const D_OUT: usize,
> ActorRuntime<B, D_IN, D_OUT>
{
pub(crate) async fn new(
actor_id: ActorUuid,
reloadable_model: LocalModelHandle<B>,
max_traj_length: usize,
shared_tx_to_buffer: Sender<RoutedMessage>,
#[cfg(feature = "metrics")] metrics: MetricsManager,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
algorithm_args: AlgorithmInitArgs,
) -> Self {
let max_traj_length = Arc::new(max_traj_length);
Self {
actor_id,
reloadable_model,
temp_env_models: DashMap::new(),
max_traj_length: max_traj_length.clone(),
shared_tx_to_buffer,
trajectories: Mutex::new(ActorTrajectoryState::new(max_traj_length.clone())),
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
algorithm_args,
#[cfg(feature = "metrics")]
metrics,
}
}
fn build_send_trajectory_message(
&self,
trajectory: RelayRLTrajectory,
) -> Result<RoutedMessage, ActorError> {
let now = SystemTime::now();
let duration = now
.duration_since(UNIX_EPOCH)
.map_err(|e| ActorError::SystemError(format!("Clock skew: {e}")))?;
Ok(RoutedMessage {
actor_id: self.actor_id,
protocol: RoutingProtocol::Data(DataPayload::SendTrajectory {
timestamp: (duration.as_millis(), duration.as_nanos()),
trajectory,
}),
})
}
async fn send_trajectory(&self, trajectory: RelayRLTrajectory) -> Result<(), ActorError> {
if self.shared_tx_to_buffer.is_closed() {
return Ok(());
}
let send_traj_msg = self.build_send_trajectory_message(trajectory)?;
self.shared_tx_to_buffer
.send(send_traj_msg)
.await
.map_err(|e| ActorError::TrajectorySendError(format!("{e:?}")))
}
pub(crate) async fn perform_local_inference(
&self,
observation: Arc<AnyBurnTensor<B, D_IN>>,
mask: Option<Arc<AnyBurnTensor<B, D_OUT>>>,
reward: f32,
) -> Result<RelayRLAction, ActorError> {
#[cfg(feature = "metrics")]
let start_time = Instant::now();
let result = async {
let action = {
let guard = self.reloadable_model.load();
let reloadable_model = match &*guard {
Some(m) => m,
None => {
return Err(ActorError::SystemError(
"Model not loaded/available for actor inference".to_string(),
));
}
};
reloadable_model
.forward::<D_IN, D_OUT>(observation, mask, reward, self.actor_id)
.map_err(ActorError::from)?
};
let mut trajectories = self.trajectories.lock().await;
trajectories.current_traj.add_action(action.clone());
Ok(action)
}
.await;
#[cfg(feature = "metrics")]
match &result {
Ok(_) => {
let duration = start_time.elapsed().as_secs_f64();
self.metrics
.record_histogram("actor_local_inference_latency", duration, &[])
.await;
self.metrics
.record_counter("actor_local_inferences", 1, &[])
.await;
}
Err(_) => {
self.metrics
.record_counter("actor_local_inference_failures", 1, &[])
.await;
}
}
result
}
pub(crate) async fn perform_env_byte_inference(
&self,
model_name: &str,
obs_bytes: &[u8],
n_envs: usize,
obs_dim: usize,
obs_dtype: &DType,
) -> Result<TensorData, ActorError> {
let guard = self.temp_env_models.get(model_name);
let slot = guard.as_ref().ok_or_else(|| {
ActorError::SystemError(format!(
"Model slot '{}' not loaded — call perform_env_refresh_model first",
model_name
))
})?;
let model = slot.load_full().ok_or_else(|| {
ActorError::SystemError(format!("Model slot '{}' handle is empty", model_name))
})?;
let module = model.current_module();
let input = TensorData::new(
vec![n_envs, obs_dim],
obs_dtype.clone(),
obs_bytes.to_vec(),
B::get_supported_backend(),
);
Ok(module
.flat_batch_inference(input)
.unwrap_or_else(|_| module.flat_batch_zeros(n_envs)))
}
pub(crate) async fn flag_last_action(
&self,
reward: f32,
env_id: Option<Uuid>,
env_label: Option<String>,
return_traj_override: bool,
) -> Result<Option<RelayRLTrajectory>, ActorError> {
#[cfg(feature = "metrics")]
let start_time = Instant::now();
let result = async {
let maybe_trajectory = {
let mut trajectories = self.trajectories.lock().await;
Some(match env_id {
Some(env_id) => trajectories.take_completed_env_trajectory(
self.actor_id,
env_id,
env_label,
reward,
self.max_traj_length.clone(),
),
None => trajectories.take_completed_actor_trajectory(
self.actor_id,
reward,
self.max_traj_length.clone(),
),
})
};
if let Some(trajectory) = maybe_trajectory {
if return_traj_override {
return Ok(Some(trajectory));
} else {
self.send_trajectory(trajectory).await?;
}
}
Ok(None)
}
.await;
#[cfg(feature = "metrics")]
match &result {
Ok(_) => {
let duration = start_time.elapsed().as_secs_f64();
self.metrics
.record_histogram("actor_flag_last_action_latency", duration, &[])
.await;
self.metrics
.record_counter("actor_flag_last_actions", 1, &[])
.await;
}
Err(_) => {
self.metrics
.record_counter("actor_flag_last_action_failures", 1, &[])
.await;
}
}
result
}
pub(crate) async fn perform_env_refresh_model(
&self,
name: &str,
model_module: ModelModule<B>,
device: DeviceType,
) -> Result<(), ActorError> {
if let Some(slot) = self.temp_env_models.get(name) {
match slot.load_full() {
Some(model) => {
model
.reload_from_module(model_module, model.version())
.await
.map_err(ActorError::from)?;
}
None => {
let reloadable = Arc::new(
HotReloadableModel::<B>::new_from_module(model_module, device).await?,
);
slot.store(Some(reloadable));
}
}
} else {
let reloadable =
Arc::new(HotReloadableModel::<B>::new_from_module(model_module, device).await?);
self.temp_env_models.insert(
name.to_string(),
Arc::new(arc_swap::ArcSwapOption::new(Some(reloadable))),
);
}
Ok(())
}
pub(crate) async fn flush_shutdown_trajectories(&self) -> Result<(), ActorError> {
let trajectories = {
let mut state = self.trajectories.lock().await;
state.take_shutdown_trajectories(self.actor_id, self.max_traj_length.clone())
};
for trajectory in trajectories {
self.send_trajectory(trajectory).await?;
}
Ok(())
}
}
impl<B, const D_IN: usize, const D_OUT: usize> ErasedActorRuntime<B>
for ActorRuntime<B, D_IN, D_OUT>
where
B: Backend + BackendMatcher<Backend = B> + Send + Sync + 'static,
{
fn actor_shape(&self) -> ActorShape {
ActorShape {
d_in: D_IN,
d_out: D_OUT,
}
}
fn current_model_dtypes(&self) -> Result<ActorDTypes, ActorError> {
let model = self
.reloadable_model
.load()
.clone()
.ok_or(ActorError::ModelNotLoadedError)?;
Ok(ActorDTypes {
dtype_in: model.current_module().metadata.input_dtype.clone(),
dtype_out: model.current_module().metadata.output_dtype.clone(),
})
}
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
fn current_algorithm_args(&self) -> Result<AlgorithmInitArgs, ActorError> {
Ok(self.algorithm_args.clone())
}
fn request_inference_erased(
&self,
observation: Box<dyn Any + Send + Sync>,
mask: Box<dyn Any + Send + Sync>,
reward: f32,
) -> Pin<Box<dyn Future<Output = Result<RelayRLAction, ActorError>> + Send + '_>> {
Box::pin(async move {
let observation = *observation
.downcast::<Arc<AnyBurnTensor<B, D_IN>>>()
.map_err(|_| {
ActorError::TypeConversionError(
"invalid observation shape for actor runtime".to_string(),
)
})?;
let mask = *mask
.downcast::<Option<Arc<AnyBurnTensor<B, D_OUT>>>>()
.map_err(|_| {
ActorError::TypeConversionError(
"invalid mask shape for actor runtime".to_string(),
)
})?;
self.perform_local_inference(observation, mask, reward)
.await
})
}
fn flag_last_action_erased(
&self,
reward: f32,
env_id: Option<Uuid>,
env_label: Option<String>,
return_traj_override: bool,
) -> Pin<Box<dyn Future<Output = Result<Option<RelayRLTrajectory>, ActorError>> + Send + '_>>
{
Box::pin(async move {
self.flag_last_action(reward, env_id, env_label, return_traj_override)
.await
})
}
fn perform_env_byte_inference_erased<'a>(
&'a self,
model_name: &'a str,
obs_bytes: &'a [u8],
n_envs: usize,
obs_dim: usize,
obs_dtype: &'a DType,
) -> Pin<Box<dyn Future<Output = Result<TensorData, ActorError>> + Send + 'a>> {
Box::pin(async move {
self.perform_env_byte_inference(model_name, obs_bytes, n_envs, obs_dim, obs_dtype)
.await
})
}
fn perform_env_refresh_model_erased<'a>(
&'a self,
name: &'a str,
model_module: ModelModule<B>,
device: DeviceType,
) -> Pin<Box<dyn Future<Output = Result<(), ActorError>> + Send + 'a>> {
Box::pin(async move {
self.perform_env_refresh_model(name, model_module, device)
.await
})
}
}
#[derive(Debug, Error)]
#[allow(clippy::enum_variant_names)]
pub enum ActorError {
#[error(transparent)]
ModelError(#[from] ModelError),
#[error("Model not loaded")]
ModelNotLoadedError,
#[error("Trajectory send failed: {0}")]
TrajectorySendError(String),
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
#[error("Inference request failed: {0}")]
InferenceRequestError(String),
#[error("Message handling failed: {0}")]
MessageHandlingError(String),
#[error("Type conversion failed: {0}")]
TypeConversionError(String),
#[error("System error: {0}")]
SystemError(String),
#[error(transparent)]
UuidPoolError(#[from] active_uuid_registry::UuidPoolError),
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
#[error(transparent)]
TransportError(#[from] TransportError),
}
pub trait ActorEntity<
B: Backend + BackendMatcher<Backend = B>,
const D_IN: usize,
const D_OUT: usize,
>: Send + Sync + 'static
{
#[allow(clippy::too_many_arguments)]
async fn new(
client_namespace: Arc<str>,
actor_id: ActorUuid,
device: DeviceType,
runtime: Arc<ActorRuntime<B, D_IN, D_OUT>>,
shared_local_model_path: Arc<RwLock<PathBuf>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_inference_dispatcher: Option<Arc<InferenceDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_training_dispatcher: Option<Arc<TrainingDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_transport_addresses: Option<Arc<RwLock<SharedTransportAddresses>>>,
rx_from_router: Receiver<RoutedMessage>,
shared_client_modes: Arc<ClientModes>,
#[cfg(feature = "metrics")] metrics: MetricsManager,
) -> Self
where
Self: Sized;
async fn spawn_loop(&mut self) -> Result<(), ActorError>;
async fn initial_model_handshake(&mut self, msg: RoutedMessage) -> Result<(), ActorError>;
async fn get_model_version(&self, msg: RoutedMessage) -> Result<(), ActorError>;
async fn refresh_model(&self, msg: RoutedMessage) -> Result<(), ActorError>;
async fn handle_shutdown(&mut self, _msg: RoutedMessage) -> Result<(), ActorError>;
}
pub(crate) struct Actor<
B: Backend + BackendMatcher<Backend = B>,
const D_IN: usize,
const D_OUT: usize,
> {
#[allow(dead_code)]
client_namespace: Arc<str>,
actor_id: ActorUuid,
runtime: Arc<ActorRuntime<B, D_IN, D_OUT>>,
shared_local_model_path: Arc<RwLock<PathBuf>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_inference_dispatcher: Option<Arc<InferenceDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_training_dispatcher: Option<Arc<TrainingDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_transport_addresses: Option<Arc<RwLock<SharedTransportAddresses>>>,
model_device: DeviceType,
rx_from_router: Receiver<RoutedMessage>,
shared_client_modes: Arc<ClientModes>,
#[cfg(feature = "metrics")]
metrics: MetricsManager,
}
impl<
B: Backend + BackendMatcher<Backend = B> + Send + Sync + 'static,
const D_IN: usize,
const D_OUT: usize,
> Actor<B, D_IN, D_OUT>
{
#[inline(always)]
#[allow(clippy::type_complexity)]
fn extract_inference_request(
msg: RoutedMessage,
) -> Result<
(
Arc<AnyBurnTensor<B, D_IN>>,
Option<Arc<AnyBurnTensor<B, D_OUT>>>,
f32,
oneshot::Sender<Arc<RelayRLAction>>,
),
ActorError,
> {
let RoutingProtocol::Data(DataPayload::RequestInference(req)) = msg.protocol else {
return Err(ActorError::MessageHandlingError(
"Expected RequestInference payload".to_string(),
));
};
let InferenceRequest {
observation,
mask,
reward,
reply_to,
} = *req;
let obs: Arc<AnyBurnTensor<B, D_IN>> = *observation
.downcast::<Arc<AnyBurnTensor<B, D_IN>>>()
.map_err(|_| {
ActorError::TypeConversionError("Failed to downcast observation".into())
})?;
let mask: Option<Arc<AnyBurnTensor<B, D_OUT>>> = *mask
.downcast::<Option<Arc<AnyBurnTensor<B, D_OUT>>>>()
.map_err(|_| ActorError::TypeConversionError("Failed to downcast mask".into()))?;
Ok((obs, mask, reward, reply_to))
}
#[inline(always)]
async fn handle_inference_kind(&mut self, msg: RoutedMessage) -> Result<(), ActorError> {
match self.shared_client_modes.actor_inference_mode {
ActorInferenceMode::Local(_) => self.perform_local_inference(msg).await,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
ActorInferenceMode::Server(_) => self.request_server_inference(msg).await,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
ActorInferenceMode::ServerOverflow(_, _) => {
self.request_server_overflow_inference(msg).await
}
}
}
async fn perform_local_inference(&mut self, msg: RoutedMessage) -> Result<(), ActorError> {
let (obs, mask, reward, reply_to) = Self::extract_inference_request(msg)?;
let action = self
.runtime
.perform_local_inference(obs, mask, reward)
.await?;
reply_to.send(Arc::new(action)).map_err(|e| {
ActorError::MessageHandlingError(format!("reply_to send failed: {e:?}"))
})?;
Ok(())
}
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
async fn request_server_inference(&mut self, msg: RoutedMessage) -> Result<(), ActorError> {
if let Some(inference_dispatcher) = &self.shared_inference_dispatcher {
let shared_transport_addresses = self
.shared_transport_addresses
.as_ref()
.ok_or_else(|| ActorError::SystemError("Server addresses not available".into()))?
.clone();
let (obs, _mask, _reward, reply_to) = Self::extract_inference_request(msg)?;
let obs_bytes: Vec<u8> = Vec::new();
let _ = obs;
let actor_entry = (
self.client_namespace.to_string(),
crate::network::ACTOR_CONTEXT.to_string(),
self.actor_id,
);
let rla = inference_dispatcher
.send_inference_request(actor_entry, obs_bytes, shared_transport_addresses)
.await?;
let mut trajectories = self.runtime.trajectories.lock().await;
trajectories.current_traj.add_action(rla.clone());
reply_to.send(Arc::new(rla)).map_err(|e| {
ActorError::MessageHandlingError(format!("reply_to send failed: {e:?}"))
})?;
} else {
return self.perform_local_inference(msg).await;
}
Ok(())
}
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
async fn request_server_overflow_inference(
&mut self,
msg: RoutedMessage,
) -> Result<(), ActorError> {
todo!()
}
async fn perform_flag_last_action(&mut self, msg: RoutedMessage) -> Result<(), ActorError> {
if let RoutingProtocol::Data(DataPayload::FlagLastAction {
reward,
env_id,
env_label,
}) = msg.protocol
{
let _ = self
.runtime
.flag_last_action(reward, env_id, env_label, false)
.await;
}
Ok(())
}
}
impl<
B: Backend + BackendMatcher<Backend = B> + Send + Sync + 'static,
const D_IN: usize,
const D_OUT: usize,
> ActorEntity<B, D_IN, D_OUT> for Actor<B, D_IN, D_OUT>
{
async fn new(
client_namespace: Arc<str>,
actor_id: ActorUuid,
device: DeviceType,
runtime: Arc<ActorRuntime<B, D_IN, D_OUT>>,
shared_local_model_path: Arc<RwLock<PathBuf>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_inference_dispatcher: Option<Arc<InferenceDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_training_dispatcher: Option<Arc<TrainingDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_transport_addresses: Option<Arc<RwLock<SharedTransportAddresses>>>,
rx_from_router: Receiver<RoutedMessage>,
shared_client_modes: Arc<ClientModes>,
#[cfg(feature = "metrics")] metrics: MetricsManager,
) -> Self
where
Self: Sized,
{
let model_init_flag = runtime.reloadable_model.load_full().is_none();
if model_init_flag {
log::warn!(
"[ActorEntity] Startup model is None, initial model handshake necessitated..."
);
}
let actor: Actor<B, D_IN, D_OUT> = Self {
client_namespace,
actor_id,
runtime,
shared_local_model_path,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_inference_dispatcher,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_training_dispatcher,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_transport_addresses,
model_device: device,
rx_from_router,
shared_client_modes,
#[cfg(feature = "metrics")]
metrics,
};
actor
}
async fn spawn_loop(&mut self) -> Result<(), ActorError> {
while let Some(msg) = self.rx_from_router.recv().await {
match msg.protocol {
RoutingProtocol::Control(ControlPayload::ModelHandshake) => {
self.initial_model_handshake(msg).await?;
}
RoutingProtocol::Data(DataPayload::RequestInference(_)) => {
self.handle_inference_kind(msg).await?;
}
RoutingProtocol::Data(DataPayload::FlagLastAction {
reward: _,
env_id: _,
env_label: _,
}) => {
self.perform_flag_last_action(msg).await?;
}
RoutingProtocol::Control(ControlPayload::ModelVersion { reply_to: _ }) => {
self.get_model_version(msg).await?;
}
RoutingProtocol::Control(ControlPayload::ModelUpdate {
model_bytes: _,
version: _,
}) => {
self.refresh_model(msg).await?;
}
RoutingProtocol::Control(ControlPayload::Shutdown) => {
self.handle_shutdown(msg).await?;
break;
}
_ => {}
}
}
Ok(())
}
async fn initial_model_handshake(&mut self, msg: RoutedMessage) -> Result<(), ActorError> {
if let RoutingProtocol::Control(ControlPayload::ModelHandshake) = msg.protocol {
if self.runtime.reloadable_model.load_full().is_some() {
log::warn!(
"[Actor {:?}] Model already available, handshake not needed",
self.actor_id
);
return Ok(());
}
#[cfg(feature = "metrics")]
let start_time = Instant::now();
let result: Result<bool, ActorError> = async {
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
{
if let Some(training_dispatcher) = &self.shared_training_dispatcher {
log::info!(
"[Actor {:?}] Starting training model handshake",
self.actor_id
);
let shared_transport_addresses = self
.shared_transport_addresses
.as_ref()
.ok_or_else(|| {
ActorError::SystemError("Server addresses not available".into())
})?
.clone();
let actor_entry = (
self.client_namespace.to_string(),
crate::network::ACTOR_CONTEXT.to_string(),
self.actor_id,
);
match training_dispatcher
.initial_model_handshake(actor_entry, shared_transport_addresses)
.await
{
Ok(Some(model)) => {
log::info!(
"[Actor {:?}] Model handshake successful, received model data",
self.actor_id
);
if let Err(e) =
model.save(self.shared_local_model_path.read().await.clone())
{
log::error!(
"[Actor {:?}] Failed to save model: {:?}",
self.actor_id,
e
);
}
let model_path = self.shared_local_model_path.clone();
let model_device = self.model_device.clone();
let actor_id = self.actor_id;
match self.runtime.reloadable_model.load_full() {
Some(existing_model) => {
let version = existing_model.version() + 1;
existing_model
.reload_from_path(
model_path.read().await.clone(),
version,
)
.await
.map_err(|e| {
log::error!(
"[Actor {:?}] Failed to reload model: {:?}",
actor_id,
e
);
ActorError::ModelError(e)
})?;
}
None => {
let reloadable_model = Arc::new(
HotReloadableModel::<B>::new_from_module(
model,
model_device,
)
.await
.map_err(ActorError::from)?,
);
self.runtime.reloadable_model.store(Some(reloadable_model));
}
}
Ok(true)
}
_ => {
log::error!(
"[Actor {:?}] Model handshake failed or no model update needed",
self.actor_id
);
Ok(false)
}
}
} else {
log::error!(
"[Actor {:?}] No transport dispatcher configured for model handshake",
self.actor_id
);
Ok(false)
}
}
#[cfg(not(any(feature = "nats-transport", feature = "zmq-transport")))]
{
log::error!(
"[Actor {:?}] No transport dispatcher configured for model handshake",
self.actor_id
);
Ok(false)
}
}
.await;
#[cfg(feature = "metrics")]
match &result {
Ok(true) => {
let duration = start_time.elapsed().as_secs_f64();
self.metrics
.record_histogram("actor_model_handshake_latency", duration, &[])
.await;
self.metrics
.record_counter("actor_model_handshakes", 1, &[])
.await;
}
Ok(false) | Err(_) => {
self.metrics
.record_counter("actor_model_handshake_failures", 1, &[])
.await;
}
}
return result.map(|_| ());
}
Ok(())
}
async fn get_model_version(&self, msg: RoutedMessage) -> Result<(), ActorError> {
if let RoutingProtocol::Control(ControlPayload::ModelVersion { reply_to }) = msg.protocol {
let version = self
.runtime
.reloadable_model
.load_full()
.map(|model| model.version())
.unwrap_or(-1);
reply_to
.send(version)
.map_err(|e| ActorError::MessageHandlingError(format!("{:?}", e)))?;
}
Ok(())
}
async fn refresh_model(&self, msg: RoutedMessage) -> Result<(), ActorError> {
if let RoutingProtocol::Control(ControlPayload::ModelUpdate {
model_bytes,
version,
}) = msg.protocol
{
#[cfg(feature = "metrics")]
let start_time = Instant::now();
let result: Result<bool, ActorError> = async {
let model: Result<ModelModule<B>, ModelError> =
deserialize_model_module::<B>(model_bytes, self.model_device.clone());
let model_path: PathBuf = self.shared_local_model_path.read().await.clone();
if let Ok(ok_model) = model {
if let Err(e) = validate_module::<B>(&ok_model).map_err(ActorError::from) {
log::error!(
"[ActorEntity {:?}] Failed to validate model: {:?}",
self.actor_id,
e
);
return Err(e);
}
if let Err(e) = ok_model.save(&model_path).map_err(ActorError::from) {
log::error!(
"[ActorEntity {:?}] Failed to save model: {:?}",
self.actor_id,
e
);
return Err(e);
}
let model_device = self.model_device.clone();
match self.runtime.reloadable_model.load_full() {
Some(existing_model) => {
existing_model
.reload_from_module(ok_model, version)
.await
.map_err(ActorError::from)?;
}
None => {
let reloadable_model = Arc::new(
HotReloadableModel::<B>::new_from_module(ok_model, model_device)
.await
.map_err(ActorError::from)?,
);
self.runtime.reloadable_model.store(Some(reloadable_model));
}
}
Ok(true)
} else {
Ok(false)
}
}
.await;
#[cfg(feature = "metrics")]
match &result {
Ok(true) => {
let duration = start_time.elapsed().as_secs_f64();
self.metrics
.record_histogram("actor_model_refresh_latency", duration, &[])
.await;
self.metrics
.record_counter("actor_model_refreshes", 1, &[])
.await;
}
Ok(false) | Err(_) => {
self.metrics
.record_counter("actor_model_refresh_failures", 1, &[])
.await;
}
}
return result.map(|_| ());
}
Ok(())
}
async fn handle_shutdown(&mut self, _msg: RoutedMessage) -> Result<(), ActorError> {
let _ = self.runtime.flush_shutdown_trajectories().await;
active_uuid_registry::interface::remove_id(
self.client_namespace.as_ref(),
crate::network::ACTOR_CONTEXT,
self.actor_id,
)
.map_err(ActorError::from)?;
Ok(())
}
}
#[cfg(test)]
mod unit_tests {
use super::*;
use crate::network::client::agent::{
ActorInferenceMode, ActorTrainingDataMode, ClientModes, ModelMode,
};
use crate::network::client::runtime::coordination::coordinator::CHANNEL_THROUGHPUT;
use active_uuid_registry::registry_uuid::Uuid;
use relayrl_types::data::tensor::{DType, DeviceType, NdArrayDType};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc, oneshot};
use burn_ndarray::NdArray;
use burn_tensor::{Float, Tensor, TensorData as BurnTensorData};
use relayrl_types::prelude::tensor::relayrl::FloatBurnTensor;
type NdArrayBackend = NdArray<f32>;
const D_IN: usize = 4;
const D_OUT: usize = 1;
fn disabled_data_mode() -> Arc<ClientModes> {
Arc::new(ClientModes {
actor_inference_mode: ActorInferenceMode::Local(ModelMode::Independent),
actor_training_data_mode: ActorTrainingDataMode::Disabled,
})
}
fn empty_onnx_model_handle() -> LocalModelHandle<NdArrayBackend> {
Arc::new(ArcSwapOption::new(None))
}
fn float_any_tensor(values: &[f32]) -> Arc<AnyBurnTensor<NdArrayBackend, D_IN>> {
let device = NdArrayBackend::get_device(&DeviceType::Cpu).unwrap();
let tensor = Tensor::<NdArrayBackend, D_IN, Float>::from_data(
BurnTensorData::new(values.to_vec(), [1, 1, 1, values.len()]),
&device,
);
Arc::new(AnyBurnTensor::Float(FloatBurnTensor {
tensor: Arc::new(tensor),
dtype: DType::NdArray(NdArrayDType::F32),
}))
}
async fn create_ndarray_actor(
max_traj_length: usize,
device: DeviceType,
) -> (
Actor<NdArrayBackend, D_IN, D_OUT>,
mpsc::Sender<RoutedMessage>,
mpsc::Receiver<RoutedMessage>,
) {
let actor_id = Uuid::new_v4();
let (tx_to_actor, rx_from_router) = mpsc::channel::<RoutedMessage>(CHANNEL_THROUGHPUT);
let (tx_to_buffer, rx_from_buffer) = mpsc::channel::<RoutedMessage>(CHANNEL_THROUGHPUT);
let model_handle = empty_onnx_model_handle();
let runtime = Arc::new(
ActorRuntime::new(
actor_id,
model_handle,
max_traj_length,
tx_to_buffer.clone(),
#[cfg(feature = "metrics")]
test_metrics(),
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
AlgorithmInitArgs::default(),
)
.await,
);
let actor = Actor::<NdArrayBackend, D_IN, D_OUT>::new(
Arc::from("test-actor-namespace"),
actor_id,
device,
runtime,
Arc::new(RwLock::new(PathBuf::new())),
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
None,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
None,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
None,
rx_from_router,
disabled_data_mode(),
#[cfg(feature = "metrics")]
test_metrics(),
)
.await;
(actor, tx_to_actor, rx_from_buffer)
}
fn build_msg(actor_id: ActorUuid, protocol: RoutingProtocol) -> RoutedMessage {
RoutedMessage { actor_id, protocol }
}
#[cfg(feature = "metrics")]
fn test_metrics() -> MetricsManager {
MetricsManager::new(
Arc::new(RwLock::new(("test-actor".to_string(), String::new()))),
("test-actor".to_string(), String::new()),
None,
)
}
#[tokio::test]
async fn spawn_loop_exits_on_channel_close() {
let (mut actor, tx, _rx_buf) = create_ndarray_actor(10, DeviceType::Cpu).await;
let handle = tokio::spawn(async move { actor.spawn_loop().await });
drop(tx);
let result = tokio::time::timeout(tokio::time::Duration::from_millis(300), handle)
.await
.expect("spawn_loop did not exit in time")
.expect("join error");
assert!(
result.is_ok(),
"spawn_loop did not exit in time on channel close for NdArray, CPU actor"
);
}
#[tokio::test]
async fn spawn_loop_exits_on_shutdown_message() {
let (mut actor, tx, _rx_buf) = create_ndarray_actor(10, DeviceType::Cpu).await;
let actor_id = actor.actor_id;
active_uuid_registry::interface::add_id(
"test-actor-namespace",
crate::network::ACTOR_CONTEXT,
actor_id,
)
.unwrap();
let handle = tokio::spawn(async move { actor.spawn_loop().await });
tx.send(build_msg(
actor_id,
RoutingProtocol::Control(ControlPayload::Shutdown),
))
.await
.unwrap();
let result = tokio::time::timeout(tokio::time::Duration::from_millis(300), handle)
.await
.expect("spawn_loop did not exit on Shutdown")
.expect("join error");
assert!(
result.is_ok(),
"spawn_loop did not exit on Shutdown for NdArray, CPU actor"
);
}
#[tokio::test]
async fn get_model_version_returns_minus_one_when_no_model() {
let (mut actor, tx, _rx_buf) = create_ndarray_actor(10, DeviceType::Cpu).await;
let actor_id = actor.actor_id;
let handle = tokio::spawn(async move { actor.spawn_loop().await });
let (reply_tx, reply_rx) = oneshot::channel::<i64>();
tx.send(build_msg(
actor_id,
RoutingProtocol::Control(ControlPayload::ModelVersion { reply_to: reply_tx }),
))
.await
.unwrap();
let version = tokio::time::timeout(tokio::time::Duration::from_millis(200), reply_rx)
.await
.expect("timeout waiting for model version")
.expect("oneshot cancelled");
assert_eq!(version, -1, "Unloaded model should report version -1");
tx.send(build_msg(
actor_id,
RoutingProtocol::Control(ControlPayload::Shutdown),
))
.await
.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn handle_shutdown_sends_trajectory_when_non_empty() {
let (mut actor, tx, mut rx_buf) = create_ndarray_actor(10, DeviceType::Cpu).await;
let actor_id = actor.actor_id;
let handle = tokio::spawn(async move { actor.spawn_loop().await });
tx.send(build_msg(
actor_id,
RoutingProtocol::Data(DataPayload::FlagLastAction {
reward: 1.0,
env_id: None,
env_label: None,
}),
))
.await
.unwrap();
let traj_msg = tokio::time::timeout(tokio::time::Duration::from_millis(200), rx_buf.recv())
.await
.expect("timeout waiting for trajectory after FlagLastInference")
.expect("buffer rx closed");
assert!(matches!(
traj_msg.protocol,
RoutingProtocol::Data(DataPayload::SendTrajectory { .. })
));
tx.send(build_msg(
actor_id,
RoutingProtocol::Control(ControlPayload::Shutdown),
))
.await
.unwrap();
let _ = handle.await;
}
#[tokio::test]
async fn handle_shutdown_does_not_send_when_empty_traj() {
let (mut actor, tx, mut rx_buf) = create_ndarray_actor(10, DeviceType::Cpu).await;
let actor_id = actor.actor_id;
let handle = tokio::spawn(async move { actor.spawn_loop().await });
tx.send(build_msg(
actor_id,
RoutingProtocol::Control(ControlPayload::Shutdown),
))
.await
.unwrap();
let _ = handle.await;
assert!(
rx_buf.try_recv().is_err(),
"Buffer should be empty after shutdown with no trajectory"
);
}
#[tokio::test]
async fn flag_last_action_appends_terminal_action_and_sends_traj() {
let (mut actor, tx, mut rx_buf) = create_ndarray_actor(10, DeviceType::Cpu).await;
let actor_id = actor.actor_id;
let _handle = tokio::spawn(async move { actor.spawn_loop().await });
tx.send(build_msg(
actor_id,
RoutingProtocol::Data(DataPayload::FlagLastAction {
reward: 0.0,
env_id: None,
env_label: None,
}),
))
.await
.unwrap();
let msg = tokio::time::timeout(tokio::time::Duration::from_millis(300), rx_buf.recv())
.await
.expect("timeout waiting for trajectory")
.expect("buffer rx closed");
assert!(
matches!(
msg.protocol,
RoutingProtocol::Data(DataPayload::SendTrajectory { .. })
),
"FlagLastInference should produce a SendTrajectory message"
);
}
#[tokio::test]
async fn flag_last_action_with_env_id_flushes_only_target_env_traj() {
let (mut actor, _tx, mut rx_buf) = create_ndarray_actor(10, DeviceType::Cpu).await;
let env_id_1 = Uuid::new_v4();
let env_id_2 = Uuid::new_v4();
let mut trajectories = actor.runtime.trajectories.lock().await;
trajectories
.per_env_labels
.insert(env_id_1, "env-1".to_string());
trajectories
.per_env_labels
.insert(env_id_2, "env-2".to_string());
let mut traj_1 = RelayRLTrajectory::with_metadata(
10,
Some(actor.actor_id),
Some(env_id_1),
Some("env-1".to_string()),
None,
None,
);
traj_1.add_action(RelayRLAction::minimal(0.5, false));
let mut traj_2 = RelayRLTrajectory::with_metadata(
10,
Some(actor.actor_id),
Some(env_id_2),
Some("env-2".to_string()),
None,
None,
);
traj_2.add_action(RelayRLAction::minimal(0.25, false));
trajectories.per_env_trajs.insert(env_id_1, traj_1);
trajectories.per_env_trajs.insert(env_id_2, traj_2);
drop(trajectories);
actor
.perform_flag_last_action(build_msg(
actor.actor_id,
RoutingProtocol::Data(DataPayload::FlagLastAction {
reward: 1.0,
env_id: Some(env_id_1),
env_label: Some("env-1".to_string()),
}),
))
.await
.unwrap();
let msg = rx_buf.recv().await.expect("expected env trajectory flush");
match msg.protocol {
RoutingProtocol::Data(DataPayload::SendTrajectory { trajectory, .. }) => {
assert_eq!(trajectory.get_env_id(), Some(&env_id_1));
assert_eq!(trajectory.get_env_label(), Some("env-1"));
}
other => panic!(
"expected SendTrajectory payload, got {:?}",
std::mem::discriminant(&other)
),
}
let trajectories = actor.runtime.trajectories.lock().await;
assert!(
trajectories
.per_env_trajs
.get(&env_id_1)
.is_some_and(|trajectory| trajectory.is_empty())
);
assert!(
trajectories
.per_env_trajs
.get(&env_id_2)
.is_some_and(|trajectory| !trajectory.is_empty())
);
}
#[tokio::test]
async fn spawn_loop_can_run_concurrently() {
let mut handles = Vec::new();
for _ in 0..3 {
let (mut actor, tx, _rx_buf) = create_ndarray_actor(10, DeviceType::Cpu).await;
let actor_id = actor.actor_id;
active_uuid_registry::interface::add_id(
"test-actor-namespace",
crate::network::ACTOR_CONTEXT,
actor_id,
)
.unwrap();
let h = tokio::spawn(async move { actor.spawn_loop().await });
tx.send(build_msg(
actor_id,
RoutingProtocol::Control(ControlPayload::Shutdown),
))
.await
.unwrap();
handles.push(h);
}
for h in handles {
let result = tokio::time::timeout(tokio::time::Duration::from_millis(500), h)
.await
.expect("actor did not shut down in time")
.expect("join error");
assert!(result.is_ok());
}
}
#[tokio::test]
async fn trajectory_send_failure_returns_err() {
let (mut actor, tx, rx_buf) = create_ndarray_actor(10, DeviceType::Cpu).await;
let actor_id = actor.actor_id;
drop(rx_buf);
let result = actor
.perform_flag_last_action(build_msg(
actor_id,
RoutingProtocol::Data(DataPayload::FlagLastAction {
reward: 0.0,
env_id: None,
env_label: None,
}),
))
.await;
assert!(
matches!(result, Err(ActorError::TrajectorySendError(_))),
"Expected TrajectorySendError, got {:?}",
result
);
drop(tx);
}
#[tokio::test]
async fn model_version_reply_failure_returns_err() {
let (actor, _tx, _rx_buf) = create_ndarray_actor(10, DeviceType::Cpu).await;
let actor_id = actor.actor_id;
let (reply_tx, reply_rx) = oneshot::channel::<i64>();
drop(reply_rx);
let result = actor
.get_model_version(build_msg(
actor_id,
RoutingProtocol::Control(ControlPayload::ModelVersion { reply_to: reply_tx }),
))
.await;
assert!(
matches!(result, Err(ActorError::MessageHandlingError(_))),
"Expected MessageHandlingError when oneshot rx is dropped, got {:?}",
result
);
}
}