#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
use crate::network::HyperparameterArgs;
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
use crate::network::TransportType;
use crate::network::client::runtime::coordination::coordinator::{
ClientCoordinator, ClientInterface, CoordinatorError,
};
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
use crate::network::client::runtime::coordination::scale_manager::AlgorithmArgs;
use crate::network::client::runtime::coordination::state_manager::ActorUuid;
use crate::prelude::config::ClientConfigLoader;
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
use crate::utilities::configuration::Algorithm;
use active_uuid_registry::UuidPoolError;
use active_uuid_registry::interface::get;
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
use relayrl_types::types::data::action::CodecConfig;
use relayrl_types::types::data::action::RelayRLAction;
use relayrl_types::types::data::tensor::{
AnyBurnTensor, BackendMatcher, BoolBurnTensor, DType, DeviceType, FloatBurnTensor,
IntBurnTensor, SupportedTensorBackend,
};
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
use relayrl_types::types::data::tensor::{NdArrayDType, TchDType};
use relayrl_types::types::model::ModelModule;
use burn_tensor::{Bool, Float, Int, Tensor, TensorKind, backend::Backend};
use serde::{Deserialize, Serialize};
#[cfg(any(feature = "metrics", feature = "logging"))]
use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use thiserror::Error;
use tokio::task::JoinHandle;
use uuid::Uuid;
#[non_exhaustive]
#[derive(Debug, Error)]
pub enum ClientError {
#[error(transparent)]
UuidPoolError(#[from] UuidPoolError),
#[error("Inference server mode disabled: {0}")]
InferenceServerModeDisabled(String),
#[error("Inference server mode enabled: {0}")]
InferenceServerModeEnabled(String),
#[error(transparent)]
CoordinatorError(#[from] CoordinatorError),
#[error("Backend mismatch: {0}")]
BackendMismatchError(String),
#[error("No input or output dtype set")]
NoInputOrOutputDtypeSet(String),
#[error("Noop router scale: {0}")]
NoopRouterScale(String),
#[error("Noop actor count: {0}")]
NoopActorCount(String),
#[error("Invalid inference mode: {0}")]
InvalidInferenceMode(String),
#[error("Invalid trajectory file directory: {0}")]
InvalidTrajectoryFileDirectory(String),
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
#[cfg(any(feature = "metrics", feature = "logging"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "metrics", feature = "logging"))))]
pub enum RuntimeStatisticsReturnType {
JsonFile(PathBuf),
JsonString(String),
Hashmap(HashMap<String, String>),
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum ActorInferenceMode {
Local,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "async_transport", feature = "sync_transport")))
)]
Server,
}
impl Default for ActorInferenceMode {
fn default() -> Self {
Self::Local
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TrajectoryFileParams {
pub enabled: bool,
pub encode: bool,
pub directory: PathBuf,
}
impl TrajectoryFileParams {
pub fn new(enabled: bool, encode: bool, directory: PathBuf) -> Result<Self, ClientError> {
if directory.extension().is_some() {
return Err(ClientError::InvalidTrajectoryFileDirectory(format!(
"Path '{}' appears to be a file, not a directory",
directory.display()
)));
}
Ok(Self {
enabled,
encode,
directory,
})
}
}
impl Default for TrajectoryFileParams {
fn default() -> Self {
Self::new(false, false, PathBuf::from(".")).unwrap() }
}
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
#[derive(Debug, Clone, PartialEq)]
pub enum DatabaseTypeParams {
Sqlite(SqliteParams),
PostgreSQL(PostgreSQLParams),
}
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
#[derive(Debug, Clone, PartialEq)]
pub struct SqliteParams {
pub path: PathBuf,
}
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
#[derive(Debug, Clone, PartialEq)]
pub struct PostgreSQLParams {
pub connection: String,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum TrajectoryPersistenceMode {
Local(Option<TrajectoryFileParams>),
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "postgres_db", feature = "sqlite_db"))))]
Database(Option<DatabaseTypeParams>),
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "postgres_db", feature = "sqlite_db"))))]
Hybrid(Option<DatabaseTypeParams>, Option<TrajectoryFileParams>),
#[cfg(any(
feature = "postgres_db",
feature = "sqlite_db",
feature = "async_transport",
feature = "sync_transport"
))]
#[cfg_attr(
docsrs,
doc(cfg(any(
feature = "postgres_db",
feature = "sqlite_db",
feature = "async_transport",
feature = "sync_transport"
)))
)]
Disabled,
}
impl Default for TrajectoryPersistenceMode {
fn default() -> Self {
#[cfg(any(
feature = "postgres_db",
feature = "sqlite_db",
feature = "async_transport",
feature = "sync_transport"
))]
return Self::Disabled;
#[cfg(not(any(
feature = "postgres_db",
feature = "sqlite_db",
feature = "async_transport",
feature = "sync_transport"
)))]
return Self::Local(Some(TrajectoryFileParams::default()));
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum ActorServerModelMode {
Independent,
Shared,
Disabled,
}
impl Default for ActorServerModelMode {
fn default() -> Self {
Self::Independent
}
}
#[derive(Debug, Clone)]
pub struct ClientCapabilities {
pub local_inference: bool,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "async_transport", feature = "sync_transport")))
)]
pub server_inference: bool,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "async_transport", feature = "sync_transport")))
)]
pub inference_server_mode: ActorServerModelMode,
pub local_trajectory_persistence: bool,
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "postgres_db", feature = "sqlite_db"))))]
pub database_trajectory_persistence: bool,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "async_transport", feature = "sync_transport")))
)]
pub training_server_mode: ActorServerModelMode,
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "postgres_db", feature = "sqlite_db"))))]
pub db_params: Option<DatabaseTypeParams>,
}
impl ClientCapabilities {
pub fn trajectory_recording_enabled(&self) -> bool {
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
let database_trajectory_persistence = self.database_trajectory_persistence;
#[cfg(not(any(feature = "postgres_db", feature = "sqlite_db")))]
let database_trajectory_persistence = false;
self.local_trajectory_persistence || database_trajectory_persistence
}
}
pub struct ClientModes {
pub actor_inference_mode: ActorInferenceMode,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "async_transport", feature = "sync_transport")))
)]
pub inference_server_mode: ActorServerModelMode,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "async_transport", feature = "sync_transport")))
)]
pub training_server_mode: ActorServerModelMode,
pub trajectory_persistence_mode: TrajectoryPersistenceMode,
}
impl Default for ClientModes {
fn default() -> Self {
Self {
actor_inference_mode: ActorInferenceMode::Local,
#[cfg(any(
feature = "postgres_db",
feature = "sqlite_db",
feature = "async_transport",
feature = "sync_transport"
))]
trajectory_persistence_mode: TrajectoryPersistenceMode::Disabled,
#[cfg(not(any(
feature = "postgres_db",
feature = "sqlite_db",
feature = "async_transport",
feature = "sync_transport"
)))]
trajectory_persistence_mode: TrajectoryPersistenceMode::Local(Some(TrajectoryFileParams::default())),
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
inference_server_mode: ActorServerModelMode::Disabled,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
training_server_mode: ActorServerModelMode::Independent,
}
}
}
impl ClientModes {
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
pub fn validate_modes(&self) -> Result<(), ClientError> {
if self.actor_inference_mode == ActorInferenceMode::Server
&& self.inference_server_mode == ActorServerModelMode::Disabled
{
return Err(ClientError::InvalidInferenceMode(
"Inference server mode disabled for server-side inference: {:?}".to_string(),
));
}
if self.actor_inference_mode == ActorInferenceMode::Local
&& self.inference_server_mode != ActorServerModelMode::Disabled
{
return Err(ClientError::InvalidInferenceMode(
"Inference server mode enabled for client-side inference: {:?}".to_string(),
));
}
Ok(())
}
pub fn capabilities(&self) -> Result<ClientCapabilities, ClientError> {
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
self.validate_modes()?;
let (local_inference, _server_inference) = match self.actor_inference_mode {
ActorInferenceMode::Local => (true, false),
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
ActorInferenceMode::Server => (false, true),
};
let (local_trajectory_persistence, _database_trajectory_persistence) =
match &self.trajectory_persistence_mode {
TrajectoryPersistenceMode::Local(_) => (true, false),
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
TrajectoryPersistenceMode::Database(_) => (false, true),
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
TrajectoryPersistenceMode::Hybrid(_, _) => (true, true),
#[cfg(any(
feature = "postgres_db",
feature = "sqlite_db",
feature = "async_transport",
feature = "sync_transport"
))]
TrajectoryPersistenceMode::Disabled => (false, false),
};
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
let db_params: Option<DatabaseTypeParams> = match &self.trajectory_persistence_mode {
TrajectoryPersistenceMode::Database(params) => Some(params.clone()),
TrajectoryPersistenceMode::Hybrid(params, _) => Some(params.clone()),
_ => None,
};
Ok(ClientCapabilities {
local_inference,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
server_inference: _server_inference,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
inference_server_mode: self.inference_server_mode.clone(),
local_trajectory_persistence,
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
database_trajectory_persistence: _database_trajectory_persistence,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
training_server_mode: self.training_server_mode.clone(),
#[cfg(any(feature = "postgres_db", feature = "sqlite_db"))]
db_params,
})
}
}
pub struct AgentStartParameters<B: Backend + BackendMatcher<Backend = B>> {
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "async_transport", feature = "sync_transport")))
)]
pub algorithm_args: AlgorithmArgs,
pub actor_count: u32,
pub router_scale: u32,
pub default_device: DeviceType,
#[cfg(any(feature = "tch-model", feature = "onnx-model"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "tch-model", feature = "onnx-model"))))]
pub default_model: Option<ModelModule<B>>,
#[cfg(not(any(feature = "tch-model", feature = "onnx-model")))]
#[cfg_attr(
docsrs,
doc(cfg(not(any(feature = "tch-model", feature = "onnx-model"))))
)]
pub default_model: ModelModule<B>,
pub config_path: Option<PathBuf>,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
pub codec: CodecConfig,
}
impl<B: Backend + BackendMatcher<Backend = B>> std::fmt::Debug for AgentStartParameters<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RLAgentStartParameters")
}
}
#[must_use]
pub struct AgentBuilder<
B: Backend + BackendMatcher<Backend = B>,
const D_IN: usize,
const D_OUT: usize,
KindIn: TensorKind<B> + Send + Sync,
KindOut: TensorKind<B> + Send + Sync,
> {
pub client_modes: Option<ClientModes>,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
pub transport_type: Option<TransportType>,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
pub algorithm_args: Option<AlgorithmArgs>,
pub actor_count: Option<u32>,
pub router_scale: Option<u32>,
pub default_device: Option<DeviceType>,
pub default_model: Option<ModelModule<B>>,
pub config_path: Option<PathBuf>,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
pub codec: Option<CodecConfig>,
_phantom: PhantomData<(KindIn, KindOut)>,
}
impl<
B: Backend + BackendMatcher<Backend = B>,
const D_IN: usize,
const D_OUT: usize,
KindIn: TensorKind<B> + Send + Sync,
KindOut: TensorKind<B> + Send + Sync,
> AgentBuilder<B, D_IN, D_OUT, KindIn, KindOut>
{
#[must_use]
pub fn builder() -> Self {
Self {
client_modes: Some(ClientModes::default()),
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
transport_type: Some(TransportType::default()),
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
algorithm_args: Some(AlgorithmArgs::default()),
actor_count: None,
router_scale: None,
default_device: None,
default_model: None,
config_path: None,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
codec: None,
_phantom: PhantomData,
}
}
#[must_use]
pub fn actor_inference_mode(mut self, actor_inference_mode: ActorInferenceMode) -> Self {
if let Some(ref mut modes) = self.client_modes {
modes.actor_inference_mode = actor_inference_mode;
}
self
}
#[must_use]
pub fn trajectory_persistence_mode(
mut self,
trajectory_persistence_mode: TrajectoryPersistenceMode,
) -> Self {
if let Some(ref mut modes) = self.client_modes {
modes.trajectory_persistence_mode = trajectory_persistence_mode;
}
self
}
#[must_use]
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
pub fn inference_server_mode(mut self, inference_server_mode: ActorServerModelMode) -> Self {
if let Some(ref mut modes) = self.client_modes {
modes.inference_server_mode = inference_server_mode;
}
self
}
#[must_use]
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
pub fn training_server_mode(mut self, training_server_mode: ActorServerModelMode) -> Self {
if let Some(ref mut modes) = self.client_modes {
modes.training_server_mode = training_server_mode;
}
self
}
#[must_use]
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
pub fn transport_type(mut self, transport_type: TransportType) -> Self {
self.transport_type = Some(transport_type);
self
}
#[must_use]
pub fn actor_count(mut self, count: u32) -> Self {
self.actor_count = Some(count);
self
}
#[must_use]
pub fn router_scale(mut self, count: u32) -> Self {
self.router_scale = Some(count);
self
}
#[must_use]
pub fn default_device(mut self, device: DeviceType) -> Self {
self.default_device = Some(device);
self
}
#[must_use]
pub fn default_model(mut self, model: ModelModule<B>) -> Self {
self.default_model = Some(model);
self
}
#[must_use]
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
pub fn algorithm(mut self, algorithm: Algorithm) -> Self {
let hyperparams = match self.algorithm_args {
Some(args) => args.hyperparams,
None => None,
};
self.algorithm_args = Some(AlgorithmArgs {
algorithm,
hyperparams,
});
self
}
#[must_use]
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
pub fn hyperparams(mut self, hyperparams: HyperparameterArgs) -> Self {
let algorithm = match self.algorithm_args {
Some(args) => args.algorithm,
None => Algorithm::ConfigInit,
};
self.algorithm_args = Some(AlgorithmArgs {
algorithm,
hyperparams: Some(hyperparams),
});
self
}
#[must_use]
pub fn config_path(mut self, path: PathBuf) -> Self {
self.config_path = Some(path.into());
self
}
#[must_use]
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
pub fn codec(mut self, codec: CodecConfig) -> Self {
self.codec = Some(codec);
self
}
pub async fn build(
self,
) -> Result<
(
RelayRLAgent<B, D_IN, D_OUT, KindIn, KindOut>,
AgentStartParameters<B>,
),
ClientError,
> {
let agent: RelayRLAgent<B, D_IN, D_OUT, KindIn, KindOut> = RelayRLAgent::new(
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
self.transport_type.unwrap_or(TransportType::ZMQ),
self.client_modes.unwrap_or(ClientModes::default()),
)?;
let startup_params: AgentStartParameters<B> = AgentStartParameters::<B> {
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
algorithm_args: self.algorithm_args.unwrap_or(AlgorithmArgs::default()),
actor_count: self.actor_count.unwrap_or(1),
router_scale: self.router_scale.unwrap_or(1),
default_device: self.default_device.unwrap_or_default(),
#[cfg(any(feature = "tch-model", feature = "onnx-model"))]
default_model: self.default_model,
#[cfg(not(any(feature = "tch-model", feature = "onnx-model")))]
default_model: self.default_model.unwrap(), config_path: self.config_path,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
codec: self.codec.unwrap_or_default(),
};
Ok((agent, startup_params))
}
}
pub trait ToAnyBurnTensor<B: Backend + BackendMatcher<Backend = B>, const D: usize> {
fn to_any_burn_tensor(self, dtype: DType) -> AnyBurnTensor<B, D>;
}
impl<B: Backend + BackendMatcher<Backend = B>, const D: usize> ToAnyBurnTensor<B, D>
for Tensor<B, D, Float>
{
fn to_any_burn_tensor(self, dtype: DType) -> AnyBurnTensor<B, D> {
AnyBurnTensor::Float(FloatBurnTensor {
tensor: Arc::new(self),
dtype,
})
}
}
impl<B: Backend + BackendMatcher<Backend = B>, const D: usize> ToAnyBurnTensor<B, D>
for Tensor<B, D, Int>
{
fn to_any_burn_tensor(self, dtype: DType) -> AnyBurnTensor<B, D> {
AnyBurnTensor::Int(IntBurnTensor {
tensor: Arc::new(self),
dtype,
})
}
}
impl<B: Backend + BackendMatcher<Backend = B>, const D: usize> ToAnyBurnTensor<B, D>
for Tensor<B, D, Bool>
{
fn to_any_burn_tensor(self, dtype: DType) -> AnyBurnTensor<B, D> {
AnyBurnTensor::Bool(BoolBurnTensor {
tensor: Arc::new(self),
dtype,
})
}
}
pub struct RelayRLAgent<
B: Backend + BackendMatcher<Backend = B>,
const D_IN: usize,
const D_OUT: usize,
KindIn: TensorKind<B>,
KindOut: TensorKind<B>,
> {
coordinator: ClientCoordinator<B, D_IN, D_OUT>,
supported_backend: SupportedTensorBackend,
input_dtype: Option<DType>,
output_dtype: Option<DType>,
_phantom: PhantomData<(KindIn, KindOut)>,
}
impl<
B: Backend + BackendMatcher<Backend = B>,
const D_IN: usize,
const D_OUT: usize,
KindIn: TensorKind<B> + Send + Sync,
KindOut: TensorKind<B> + Send + Sync,
> std::fmt::Debug for RelayRLAgent<B, D_IN, D_OUT, KindIn, KindOut>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RLAgent")
}
}
impl<
B: Backend + BackendMatcher<Backend = B>,
const D_IN: usize,
const D_OUT: usize,
KindIn: TensorKind<B> + Send + Sync,
KindOut: TensorKind<B> + Send + Sync,
> RelayRLAgent<B, D_IN, D_OUT, KindIn, KindOut>
{
pub fn new(
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
transport_type: TransportType,
client_modes: ClientModes,
) -> Result<Self, ClientError> {
let supported_backend = if B::matches_backend(&SupportedTensorBackend::NdArray) {
SupportedTensorBackend::NdArray
} else if B::matches_backend(&SupportedTensorBackend::Tch) {
SupportedTensorBackend::Tch
} else {
SupportedTensorBackend::None
};
Ok(Self {
coordinator: ClientCoordinator::<B, D_IN, D_OUT>::new(
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
transport_type,
client_modes,
)?,
supported_backend,
input_dtype: None,
output_dtype: None,
_phantom: PhantomData,
})
}
pub async fn start(
&mut self,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
algorithm_args: AlgorithmArgs,
actor_count: u32,
router_scale: u32,
default_device: DeviceType,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))] default_model: Option<
ModelModule<B>,
>,
#[cfg(not(any(feature = "async_transport", feature = "sync_transport")))]
default_model: ModelModule<B>,
config_path: Option<PathBuf>,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))] codec: Option<
CodecConfig,
>,
) -> Result<(), ClientError> {
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
let (input_dtype, output_dtype) = if let Some(ref model_module) = default_model {
(
Some(model_module.metadata.input_dtype.clone()),
Some(model_module.metadata.output_dtype.clone()),
)
} else {
(None, None)
};
#[cfg(not(any(feature = "async_transport", feature = "sync_transport")))]
let (input_dtype, output_dtype) = (
Some(default_model.metadata.input_dtype.clone()),
Some(default_model.metadata.output_dtype.clone()),
);
self.input_dtype = input_dtype;
self.output_dtype = output_dtype;
self.coordinator
._start(
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
algorithm_args,
actor_count,
router_scale,
default_device,
default_model,
config_path,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
codec,
)
.await
.map_err(Into::<ClientError>::into)?;
Ok(())
}
pub async fn restart(
&mut self,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
algorithm_args: AlgorithmArgs,
actor_count: u32,
router_scale: u32,
default_device: DeviceType,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))] default_model: Option<
ModelModule<B>,
>,
#[cfg(not(any(feature = "async_transport", feature = "sync_transport")))]
default_model: ModelModule<B>,
config_path: Option<PathBuf>,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))] codec: Option<
CodecConfig,
>,
) -> Result<(), ClientError> {
self.coordinator
._restart(
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
algorithm_args,
actor_count,
router_scale,
default_device,
default_model,
config_path,
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
codec,
)
.await?;
Ok(())
}
pub async fn shutdown(&mut self) -> Result<(), ClientError> {
self.coordinator._shutdown().await?;
Ok(())
}
pub async fn scale_throughput(&mut self, router_scale: i32) -> Result<(), ClientError> {
match router_scale {
add if router_scale > 0 => {
self.coordinator._scale_out(add as u32).await?;
Ok(())
}
remove if router_scale < 0 => {
self.coordinator._scale_in(remove.unsigned_abs()).await?;
Ok(())
}
_ => Err(ClientError::NoopRouterScale(
"Noop router scale: `router_scale` set to zero in `scale_throughput()`".to_string(),
)),
}
}
pub async fn request_action(
&self,
ids: Vec<Uuid>,
observation: Tensor<B, D_IN, KindIn>,
mask: Option<Tensor<B, D_OUT, KindOut>>,
reward: f32,
) -> Result<Vec<(ActorUuid, Arc<RelayRLAction>)>, ClientError>
where
Tensor<B, D_IN, KindIn>: ToAnyBurnTensor<B, D_IN>,
Tensor<B, D_OUT, KindOut>: ToAnyBurnTensor<B, D_OUT>,
{
match B::matches_backend(&self.supported_backend) {
true => {
if let (Some(input_dtype), Some(output_dtype)) =
(self.input_dtype.clone(), self.output_dtype.clone())
{
let obs_tensor: Arc<AnyBurnTensor<B, D_IN>> =
Arc::new(observation.to_any_burn_tensor(input_dtype));
let mask_tensor: Option<Arc<AnyBurnTensor<B, D_OUT>>> =
mask.map(|tensor| Arc::new(tensor.to_any_burn_tensor(output_dtype)));
let result = self
.coordinator
._request_action(ids, obs_tensor, mask_tensor, reward)
.await?;
Ok(result)
} else {
Err(ClientError::NoInputOrOutputDtypeSet(
"No input or output dtype set in agent".to_string(),
))
}
}
false => Err(ClientError::BackendMismatchError(
"Backend mismatch".to_string(),
)),
}
}
pub async fn flag_last_action(
&self,
ids: Vec<Uuid>,
reward: Option<f32>,
) -> Result<(), ClientError> {
self.coordinator._flag_last_action(ids, reward).await?;
Ok(())
}
pub async fn get_model_version(&self, ids: Vec<Uuid>) -> Result<Vec<(Uuid, i64)>, ClientError> {
Ok(self.coordinator._get_model_version(ids).await?)
}
#[deprecated(note = "Not implemented")]
#[cfg(any(feature = "metrics", feature = "logging"))]
pub fn runtime_statistics(
&self,
return_type: RuntimeStatisticsReturnType,
) -> Result<RuntimeStatisticsReturnType, ClientError> {
Ok(RuntimeStatisticsReturnType::Hashmap(HashMap::new()))
}
pub async fn get_config(&self) -> Result<ClientConfigLoader, ClientError> {
Ok(self.coordinator._get_config().await?)
}
pub async fn set_config_path(&self, config_path: PathBuf) -> Result<(), ClientError> {
self.coordinator._set_config_path(config_path).await?;
Ok(())
}
}
pub trait RelayRLAgentActors<
B: Backend + BackendMatcher<Backend = B>,
const D_IN: usize,
const D_OUT: usize,
KindIn: TensorKind<B>,
KindOut: TensorKind<B>,
>
{
fn new_actor(
&mut self,
device: DeviceType,
default_model: Option<ModelModule<B>>,
) -> Pin<Box<dyn Future<Output = Result<(), ClientError>> + Send + '_>>;
fn new_actors(
&mut self,
count: u32,
device: DeviceType,
default_model: Option<ModelModule<B>>,
) -> Pin<Box<dyn Future<Output = Result<(), ClientError>> + Send + '_>>;
fn remove_actor(
&mut self,
id: Uuid,
) -> Pin<Box<dyn Future<Output = Result<(), ClientError>> + Send + '_>>;
fn get_actor_ids(&mut self) -> Result<Vec<ActorUuid>, ClientError>;
fn set_actor_id(
&mut self,
current_id: Uuid,
new_id: Uuid,
) -> Pin<Box<dyn Future<Output = Result<(), ClientError>> + Send + '_>>;
}
impl<
B: Backend + BackendMatcher<Backend = B>,
const D_IN: usize,
const D_OUT: usize,
KindIn: TensorKind<B> + Send + Sync,
KindOut: TensorKind<B> + Send + Sync,
> RelayRLAgentActors<B, D_IN, D_OUT, KindIn, KindOut>
for RelayRLAgent<B, D_IN, D_OUT, KindIn, KindOut>
{
fn new_actor(
&mut self,
device: DeviceType,
default_model: Option<ModelModule<B>>,
) -> Pin<Box<dyn Future<Output = Result<(), ClientError>> + Send + '_>> {
Box::pin(async move {
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
self.coordinator
._new_actor(device, default_model, true)
.await?;
#[cfg(not(any(feature = "async_transport", feature = "sync_transport")))]
self.coordinator._new_actor(device, default_model).await?;
Ok(())
})
}
fn new_actors(
&mut self,
count: u32,
device: DeviceType,
default_model: Option<ModelModule<B>>,
) -> Pin<Box<dyn Future<Output = Result<(), ClientError>> + Send + '_>> {
if count == 0 {
return Box::pin(async move {
Err(ClientError::NoopActorCount(
"Noop actor count: `count` set to zero".to_string(),
))
});
}
Box::pin(async move {
for _ in 0..count {
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
self.coordinator
._new_actor(device.clone(), default_model.clone(), false)
.await?;
#[cfg(not(any(feature = "async_transport", feature = "sync_transport")))]
self.coordinator
._new_actor(device.clone(), default_model.clone())
.await?;
}
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
let actor_ids = get("actor").map_err(ClientError::from)?;
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
self.coordinator
._send_client_ids_to_server(actor_ids)
.await?;
Ok(())
})
}
fn remove_actor(
&mut self,
id: ActorUuid,
) -> Pin<Box<dyn Future<Output = Result<(), ClientError>> + Send + '_>> {
Box::pin(async move {
self.coordinator._remove_actor(id).await?;
Ok(())
})
}
fn get_actor_ids(&mut self) -> Result<Vec<ActorUuid>, ClientError> {
let ids = get("actor").map_err(ClientError::from)?;
Ok(ids.iter().map(|(_, id)| id.clone()).collect())
}
fn set_actor_id(
&mut self,
current_id: ActorUuid,
new_id: ActorUuid,
) -> Pin<Box<dyn Future<Output = Result<(), ClientError>> + Send + '_>> {
Box::pin(async move {
self.coordinator._set_actor_id(current_id, new_id).await?;
Ok(())
})
}
}