use crate::network::client::agent::{ActorInferenceMode, ClientModes};
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
use crate::network::client::runtime::coordination::lifecycle_manager::SharedTransportAddresses;
use crate::network::client::runtime::coordination::state_manager::ActorUuid;
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
use crate::network::client::runtime::data::transport_sink::TransportError;
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
use crate::network::client::runtime::data::transport_sink::transport_dispatcher::{
InferenceDispatcher, TrainingDispatcher,
};
use crate::network::client::runtime::router::{
InferenceRequest, RoutedMessage, RoutedPayload, RoutingProtocol,
};
use crate::utilities::configuration::ClientConfigLoader;
use crate::utilities::tokio::get_or_init_tokio_runtime;
use relayrl_types::data::action::RelayRLAction;
use relayrl_types::data::tensor::{BackendMatcher, ConversionBurnTensor, DeviceType};
use relayrl_types::data::trajectory::RelayRLTrajectory;
use relayrl_types::model::utils::{deserialize_model_module, validate_module};
use relayrl_types::model::{HotReloadableModel, ModelError, ModelModule};
use relayrl_types::prelude::tensor::relayrl::AnyBurnTensor;
use bincode::config;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::oneshot;
use tokio::time::{Duration, timeout};
use uuid::Uuid;
use burn_tensor::{Tensor, backend::Backend};
use thiserror::Error;
pub(crate) type LocalModelHandle<B> = Arc<RwLock<Option<HotReloadableModel<B>>>>;
#[derive(Debug, Error)]
pub enum ActorError {
#[error(transparent)]
ModelError(#[from] ModelError),
#[error("Trajectory send failed: {0}")]
TrajectorySendError(String),
#[error("Inference request failed: {0}")]
InferenceRequestError(String),
#[error("Message handling failed: {0}")]
MessageHandlingError(String),
#[error("Type conversion failed: {0}")]
TypeConversionError(String),
#[error("System error: {0}")]
SystemError(String),
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
#[error(transparent)]
TransportError(#[from] TransportError),
}
pub trait ActorEntity<B: Backend + BackendMatcher<Backend = B>>: Send + Sync + 'static {
async fn new(
client_namespace: Arc<str>,
actor_id: ActorUuid,
device: DeviceType,
model_handle: LocalModelHandle<B>,
shared_local_model_path: Arc<RwLock<PathBuf>>,
shared_max_traj_length: Arc<RwLock<u128>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_inference_dispatcher: Option<Arc<InferenceDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_training_dispatcher: Option<Arc<TrainingDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_transport_addresses: Option<Arc<RwLock<SharedTransportAddresses>>>,
rx_from_router: Receiver<RoutedMessage>,
shared_tx_to_buffer: Sender<RoutedMessage>,
shared_client_modes: Arc<ClientModes>,
) -> Self
where
Self: Sized;
async fn spawn_loop(&mut self) -> Result<(), ActorError>;
async fn _initial_model_handshake(&mut self, msg: RoutedMessage) -> Result<(), ActorError>;
async fn __get_model_version(&self, msg: RoutedMessage) -> Result<(), ActorError>;
async fn _refresh_model(&self, msg: RoutedMessage) -> Result<(), ActorError>;
async fn _handle_shutdown(&self, _msg: RoutedMessage) -> Result<(), ActorError>;
}
pub(crate) struct Actor<
B: Backend + BackendMatcher<Backend = B>,
const D_IN: usize,
const D_OUT: usize,
> {
client_namespace: Arc<str>,
actor_id: ActorUuid,
reloadable_model: LocalModelHandle<B>,
shared_local_model_path: Arc<RwLock<PathBuf>>,
shared_max_traj_length: Arc<RwLock<u128>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_inference_dispatcher: Option<Arc<InferenceDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_training_dispatcher: Option<Arc<TrainingDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_transport_addresses: Option<Arc<RwLock<SharedTransportAddresses>>>,
model_device: DeviceType,
current_traj: RelayRLTrajectory,
rx_from_router: Receiver<RoutedMessage>,
shared_tx_to_buffer: Sender<RoutedMessage>,
shared_client_modes: Arc<ClientModes>,
}
impl<B: Backend + BackendMatcher<Backend = B>, const D_IN: usize, const D_OUT: usize>
Actor<B, D_IN, D_OUT>
{
#[inline(always)]
fn extract_inference_request(
msg: RoutedMessage,
) -> Result<
(
Arc<AnyBurnTensor<B, D_IN>>,
Option<Arc<AnyBurnTensor<B, D_OUT>>>,
f32,
oneshot::Sender<Arc<RelayRLAction>>,
),
ActorError,
> {
let RoutedPayload::RequestInference(req) = msg.payload else {
return Err(ActorError::MessageHandlingError(
"Expected RequestInference payload".to_string(),
));
};
let InferenceRequest {
observation,
mask,
reward,
reply_to,
} = *req;
let obs: Arc<AnyBurnTensor<B, D_IN>> = *observation
.downcast::<Arc<AnyBurnTensor<B, D_IN>>>()
.map_err(|_| {
ActorError::TypeConversionError("Failed to downcast observation".into())
})?;
let mask: Option<Arc<AnyBurnTensor<B, D_OUT>>> = *mask
.downcast::<Option<Arc<AnyBurnTensor<B, D_OUT>>>>()
.map_err(|_| ActorError::TypeConversionError("Failed to downcast mask".into()))?;
Ok((obs, mask, reward, reply_to))
}
#[inline(always)]
async fn handle_inference_kind(&mut self, msg: RoutedMessage) -> Result<(), ActorError> {
match self.shared_client_modes.actor_inference_mode {
ActorInferenceMode::Local(_) => self.perform_local_inference(msg).await,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
ActorInferenceMode::Server(_) => self.request_server_inference(msg).await,
}
}
async fn perform_local_inference(&mut self, msg: RoutedMessage) -> Result<(), ActorError> {
let model_handle = self.reloadable_model.clone();
let (obs, mask, reward, reply_to) = Self::extract_inference_request(msg)?;
let actor_id = self.actor_id;
let r4sa = tokio::task::spawn_blocking(move || -> Result<RelayRLAction, ModelError> {
let guard = model_handle.blocking_read();
let reloadable_model = guard.as_ref().ok_or_else(|| {
ModelError::IoError("Model not loaded/available for actor inference".to_string())
})?;
reloadable_model.forward::<D_IN, D_OUT>(obs, mask, reward, actor_id)
})
.await
.map_err(|e| ActorError::SystemError(format!("spawn_blocking join error: {e}")))?
.map_err(ActorError::from)?;
self.current_traj.add_action(r4sa.clone());
reply_to.send(Arc::new(r4sa)).map_err(|e| {
ActorError::MessageHandlingError(format!("reply_to send failed: {e:?}"))
})?;
Ok(())
}
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
async fn request_server_inference(&mut self, msg: RoutedMessage) -> Result<(), ActorError> {
if let Some(inference_dispatcher) = &self.shared_inference_dispatcher {
let shared_transport_addresses = self
.shared_transport_addresses
.as_ref()
.ok_or_else(|| ActorError::SystemError("Server addresses not available".into()))?
.clone();
let (obs, _mask, _reward, reply_to) = Self::extract_inference_request(msg)?;
let obs_bytes: Vec<u8> = Vec::new();
let _ = obs;
let actor_entry = (
self.client_namespace.to_string(),
crate::network::ACTOR_CONTEXT.to_string(),
self.actor_id,
);
let r4sa = inference_dispatcher
.send_inference_request(actor_entry, obs_bytes, shared_transport_addresses)
.await?;
self.current_traj.add_action(r4sa.clone());
reply_to.send(Arc::new(r4sa)).map_err(|e| {
ActorError::MessageHandlingError(format!("reply_to send failed: {e:?}"))
})?;
return Ok(());
} else {
return self.perform_local_inference(msg).await;
}
Ok(())
}
async fn perform_flag_last_action(&mut self, msg: RoutedMessage) -> Result<(), ActorError> {
if let RoutedPayload::FlagLastInference { reward } = msg.payload {
let actor_id = self.actor_id;
let mut last_action =
RelayRLAction::new(None, None, None, reward, true, None, Some(actor_id));
last_action.update_reward(reward);
self.current_traj.add_action(last_action);
let traj_clone = self.current_traj.clone();
let now = SystemTime::now();
let duration = now
.duration_since(UNIX_EPOCH)
.map_err(|e| ActorError::SystemError(format!("Clock skew: {e}")))?;
let send_traj_msg = RoutedMessage {
actor_id: self.actor_id,
protocol: RoutingProtocol::SendTrajectory,
payload: RoutedPayload::SendTrajectory {
timestamp: (duration.as_millis(), duration.as_nanos()),
trajectory: traj_clone,
},
};
self.shared_tx_to_buffer
.send(send_traj_msg)
.await
.map_err(|e| ActorError::TrajectorySendError(format!("{e:?}")))?;
}
Ok(())
}
}
impl<B: Backend + BackendMatcher<Backend = B>, const D_IN: usize, const D_OUT: usize> ActorEntity<B>
for Actor<B, D_IN, D_OUT>
{
async fn new(
client_namespace: Arc<str>,
actor_id: ActorUuid,
device: DeviceType,
model_handle: LocalModelHandle<B>,
shared_local_model_path: Arc<RwLock<PathBuf>>,
shared_max_traj_length: Arc<RwLock<u128>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_inference_dispatcher: Option<Arc<InferenceDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_training_dispatcher: Option<Arc<TrainingDispatcher<B>>>,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_transport_addresses: Option<Arc<RwLock<SharedTransportAddresses>>>,
rx_from_router: Receiver<RoutedMessage>,
shared_tx_to_buffer: Sender<RoutedMessage>,
shared_client_modes: Arc<ClientModes>,
) -> Self
where
Self: Sized,
{
let max_traj_length: u128 = shared_max_traj_length.read().await.clone();
let model_init_flag = model_handle.read().await.is_none();
if model_init_flag {
eprintln!(
"[ActorEntity] Startup model is None, initial model handshake necessitated..."
);
}
let actor: Actor<B, D_IN, D_OUT> = Self {
client_namespace,
actor_id,
reloadable_model: model_handle,
shared_local_model_path,
shared_max_traj_length,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_inference_dispatcher,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_training_dispatcher,
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
shared_transport_addresses,
model_device: device,
current_traj: RelayRLTrajectory::new(max_traj_length as usize),
rx_from_router,
shared_tx_to_buffer,
shared_client_modes,
};
actor
}
async fn spawn_loop(&mut self) -> Result<(), ActorError> {
while let Some(msg) = self.rx_from_router.recv().await {
match msg.protocol {
RoutingProtocol::ModelHandshake => {
<Actor<B, D_IN, D_OUT> as ActorEntity<B>>::_initial_model_handshake(self, msg)
.await?;
}
RoutingProtocol::RequestInference => {
self.handle_inference_kind(msg).await?;
}
RoutingProtocol::FlagLastInference => {
self.perform_flag_last_action(msg).await?;
}
RoutingProtocol::ModelVersion => {
self.__get_model_version(msg).await?;
}
RoutingProtocol::ModelUpdate => {
<Actor<B, D_IN, D_OUT> as ActorEntity<B>>::_refresh_model(self, msg).await?;
}
RoutingProtocol::Shutdown => {
<Actor<B, D_IN, D_OUT> as ActorEntity<B>>::_handle_shutdown(self, msg).await?;
break;
}
_ => {}
}
}
Ok(())
}
async fn _initial_model_handshake(&mut self, msg: RoutedMessage) -> Result<(), ActorError> {
if let RoutedPayload::ModelHandshake = msg.payload {
{
let model_guard = self.reloadable_model.read().await;
if model_guard.is_some() {
println!(
"[Actor {:?}] Model already available, handshake not needed",
self.actor_id
);
return Ok(());
}
}
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
if let Some(training_dispatcher) = &self.shared_training_dispatcher {
println!(
"[Actor {:?}] Starting training model handshake",
self.actor_id
);
let shared_transport_addresses = self
.shared_transport_addresses
.as_ref()
.ok_or_else(|| {
ActorError::SystemError("Server addresses not available".into())
})?
.clone();
let actor_entry = (
self.client_namespace.to_string(),
crate::network::ACTOR_CONTEXT.to_string(),
self.actor_id,
);
if let Ok(Some(model)) = training_dispatcher
.initial_model_handshake(actor_entry, shared_transport_addresses)
.await
{
println!(
"[Actor {:?}] Model handshake successful, received model data",
self.actor_id
);
if let Err(e) = model.save(&self.shared_local_model_path.read().await.clone()) {
eprintln!("[Actor {:?}] Failed to save model: {:?}", self.actor_id, e);
}
let model_path = self.shared_local_model_path.clone();
let model_device = self.model_device.clone();
let actor_id = self.actor_id;
let mut model_guard = self.reloadable_model.write().await;
match model_guard.as_ref() {
Some(existing_model) => {
let version = existing_model.version() + 1;
existing_model
.reload_from_path(model_path.read().await.clone(), version)
.await
.map_err(|e| {
eprintln!(
"[Actor {:?}] Failed to reload model: {:?}",
actor_id, e
);
ActorError::from(e)
})?;
}
None => {
*model_guard = Some(
HotReloadableModel::<B>::new_from_module(model, model_device)
.await
.map_err(ActorError::from)?,
);
}
}
} else {
eprintln!(
"[Actor {:?}] Model handshake failed or no model update needed",
self.actor_id
);
}
} else {
eprintln!(
"[Actor {:?}] No transport dispatcher configured for model handshake",
self.actor_id
);
}
#[cfg(not(any(feature = "nats-transport", feature = "zmq-transport")))]
{
eprintln!(
"[Actor {:?}] No transport dispatcher configured for model handshake",
self.actor_id
);
}
}
Ok(())
}
async fn __get_model_version(&self, msg: RoutedMessage) -> Result<(), ActorError> {
if let RoutedPayload::ModelVersion { reply_to } = msg.payload {
let version = {
let model_guard = self.reloadable_model.read().await;
match model_guard.as_ref() {
Some(model) => model.version(),
None => -1,
}
};
reply_to
.send(version)
.map_err(|e| ActorError::MessageHandlingError(format!("{:?}", e)))?;
}
Ok(())
}
async fn _refresh_model(&self, msg: RoutedMessage) -> Result<(), ActorError> {
if let RoutedPayload::ModelUpdate {
model_bytes,
version,
} = msg.payload
{
let model: Result<ModelModule<B>, ModelError> =
deserialize_model_module::<B>(model_bytes, self.model_device.clone());
let model_path: PathBuf = self.shared_local_model_path.read().await.clone();
if let Ok(ok_model) = model {
if let Err(e) = validate_module::<B>(&ok_model).map_err(ActorError::from) {
eprintln!(
"[ActorEntity {:?}] Failed to validate model: {:?}",
self.actor_id, e
);
return Err(e);
}
if let Err(e) = ok_model.save(&model_path).map_err(ActorError::from) {
eprintln!(
"[ActorEntity {:?}] Failed to save model: {:?}",
self.actor_id, e
);
return Err(e);
}
let model_device = self.model_device.clone();
let mut model_guard = self.reloadable_model.write().await;
match model_guard.as_ref() {
Some(existing_model) => {
existing_model
.reload_from_module(ok_model, version)
.await
.map_err(ActorError::from)?;
}
None => {
*model_guard = Some(
HotReloadableModel::<B>::new_from_module(ok_model, model_device)
.await
.map_err(ActorError::from)?,
);
}
}
}
}
Ok(())
}
async fn _handle_shutdown(&self, _msg: RoutedMessage) -> Result<(), ActorError> {
if !self.current_traj.actions.is_empty() {
let send_traj_msg = {
let traj_clone = self.current_traj.clone();
let now = SystemTime::now();
let duration_ms = now
.duration_since(UNIX_EPOCH)
.map_err(|e| ActorError::SystemError(format!("Clock skew: {}", e)))?;
let duration_ns = now
.duration_since(UNIX_EPOCH)
.map_err(|e| ActorError::SystemError(format!("Clock skew: {}", e)))?;
RoutedMessage {
actor_id: self.actor_id,
protocol: RoutingProtocol::SendTrajectory,
payload: RoutedPayload::SendTrajectory {
timestamp: (duration_ms.as_millis(), duration_ns.as_nanos()),
trajectory: traj_clone,
},
}
};
let _ = self.shared_tx_to_buffer.send(send_traj_msg).await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {}