use crate::bindings::python::o3_action::PyRelayRLAction;
use crate::types::action::{RelayRLAction, TensorData};
use crate::network::client::agent_grpc::RelayRLAgentGrpcTrait;
use crate::network::client::agent_wrapper::RelayRLAgent;
use crate::network::client::agent_zmq::RelayRLAgentZmqTrait;
use crate::sys_utils::tokio_utils::get_or_init_tokio_runtime;
use pyo3::types::PyAnyMethods;
use pyo3::{Bound, Py, PyAny, PyErr, PyResult, Python, pyclass, pyfunction, pymethods};
use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
use std::sync::Arc;
use tch::{CModule, Tensor};
use tokio::runtime::Runtime as TokioRuntime;
use tokio::sync::{RwLock as TokioRwLock, RwLock, RwLockWriteGuard};
#[cfg(feature = "console-subscriber")]
use crate::sys_utils::tokio::utils::get_or_init_console_subscriber;
#[pyclass(name = "RelayRLAgent")]
#[derive(Clone)]
pub struct PyRelayRLAgent {
inner: Arc<TokioRwLock<RelayRLAgent>>,
tokio_runtime: Arc<TokioRuntime>,
}
#[pymethods]
impl PyRelayRLAgent {
#[new]
#[pyo3(signature = (
model_path = None,
config_path = "./config.json",
server_type = "zmq",
training_port = None,
training_prefix = None,
training_host = None,
))]
fn new(
py: Python,
model_path: Option<String>,
config_path: Option<&str>,
server_type: Option<&str>,
training_port: Option<String>,
training_prefix: Option<String>,
training_host: Option<String>,
) -> PyResult<Self> {
#[cfg(feature = "console-subscriber")]
get_or_init_console_subscriber();
let tokio_runtime: Arc<TokioRuntime> = get_or_init_tokio_runtime();
let model: Option<CModule> = if let Some(path) = model_path {
Some(CModule::load(path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"[PyO3] Failed to load model: {}",
e
))
})?)
} else {
None
};
let config_path: Option<PathBuf> = config_path.map(|s| PathBuf::from(s));
let agent_arc: RelayRLAgent = py.allow_threads(|| {
tokio_runtime.block_on(RelayRLAgent::new(
model,
config_path,
Some(server_type.expect("server_type is required").to_string()),
training_port,
training_prefix,
training_host,
))
});
Ok(PyRelayRLAgent {
inner: Arc::new(TokioRwLock::new(agent_arc)),
tokio_runtime,
})
}
#[pyo3(signature = (obs, mask, reward))]
fn request_for_action(
&self,
py: Python,
obs: &Bound<'_, PyAny>,
mask: &Bound<'_, PyAny>,
reward: &Bound<'_, PyAny>,
) -> PyResult<Py<PyRelayRLAction>> {
let obs_tensor_data: TensorData = PyRelayRLAction::ndarray_to_tensor_data(Some(obs))
.expect("[PyO3] Failed to convert obs");
let mask_tensor_data: TensorData = PyRelayRLAction::ndarray_to_tensor_data(Some(mask))
.expect("[PyO3] Failed to convert mask");
let obs_tensor: Option<Tensor> = Tensor::try_from(obs_tensor_data).ok();
let mask_tensor: Option<Tensor> = Tensor::try_from(mask_tensor_data).ok();
let reward_value: f32 = reward.extract::<f32>()?;
let action_arc_option: Option<Arc<RelayRLAction>> = {
let mut agent_guard: RwLockWriteGuard<RelayRLAgent> = self
.tokio_runtime
.block_on(async { self.inner.write().await });
if let Some(grpc_agent) = &mut agent_guard.agent_grpc {
match self.tokio_runtime.block_on(grpc_agent.request_for_action(
obs_tensor.expect("[PyO3] Failed to convert obs"),
mask_tensor.expect("[PyO3] Failed to convert mask"),
reward_value,
)) {
Ok(action_arc) => Some(action_arc),
Err(e) => {
eprintln!("[PyO3] gRPC agent request_for_action error: {:?}", e);
None
}
}
} else if let Some(zmq_agent) = &mut agent_guard.agent_zmq {
match zmq_agent.request_for_action(
&obs_tensor.expect("[PyO3] Failed to convert obs"),
&mask_tensor.expect("[PyO3] Failed to convert mask"),
reward_value,
) {
Ok(action_arc) => Some(action_arc),
Err(e) => {
eprintln!("[PyO3] ZMQ agent request_for_action error: {:?}", e);
None
}
}
} else {
eprintln!("No agent initialized");
None
}
};
match action_arc_option {
Some(action_arc) => Ok(Py::new(
py,
PyRelayRLAction {
inner: (*action_arc).clone(),
},
)?),
None => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"[PyO3] No action returned",
)),
}
}
#[pyo3(signature = (reward = 0.0))]
fn flag_last_action(&self, py: Python, reward: f32) {
let mut agent: RwLockWriteGuard<RelayRLAgent> = self
.tokio_runtime
.block_on(async { self.inner.write().await });
py.allow_threads(|| {
if agent.agent_zmq.is_some() {
agent
.agent_zmq
.as_mut()
.expect("[PyO3] ZMQ agent not initialized")
.flag_last_action(reward);
} else if agent.agent_grpc.is_some() {
self.tokio_runtime.block_on(async {
agent
.agent_grpc
.as_mut()
.expect("[PyO3] gRPC agent not initialized")
.flag_last_action(reward)
.await
});
}
});
}
#[pyo3(signature = (training_server_address = None))]
fn restart_agent(
&mut self,
py: Python<'_>,
training_server_address: Option<String>,
) -> PyResult<bool> {
let mut agent: RwLockWriteGuard<RelayRLAgent> = self
.tokio_runtime
.block_on(async { self.inner.write().await });
let mut result: PyResult<bool> = Ok(false);
py.allow_threads(|| {
if agent.agent_zmq.is_some() {
self.tokio_runtime.block_on(async {
let result_vec = agent
.agent_zmq
.as_mut()
.expect("[PyO3] ZMQ agent not initialized")
.restart_agent(training_server_address)
.await;
let restart_result = result_vec.iter().map(|r| !r.is_err()).all(|x| x);
result = Ok(restart_result);
});
} else if agent.agent_grpc.is_some() {
self.tokio_runtime.block_on(async {
let result_vec = agent
.agent_grpc
.as_mut()
.expect("[PyO3] gRPC agent not initialized")
.restart_agent(training_server_address)
.await;
let restart_result = result_vec.iter().map(|r| !r.is_err()).all(|x| x);
result = Ok(restart_result);
});
} else {
eprintln!("[PyO3] No agent initialized");
}
});
result
}
#[pyo3(signature = ())]
fn disable_agent(&self, py: Python) -> PyResult<()> {
let mut agent: RwLockWriteGuard<RelayRLAgent> = self
.tokio_runtime
.block_on(async { self.inner.write().await });
let mut result: PyResult<()> = Ok(());
py.allow_threads(|| {
if agent.agent_zmq.is_some() {
self.tokio_runtime.block_on(async {
result = Ok(agent
.agent_zmq
.as_mut()
.expect("[PyO3] ZMQ agent not initialized")
.disable_agent()
.await
.expect("[PyO3] Failed to disable ZMQ agent"));
});
} else if agent.agent_grpc.is_some() {
self.tokio_runtime.block_on(async {
result = Ok(agent
.agent_grpc
.as_mut()
.expect("[PyO3] gRPC agent not initialized")
.disable_agent()
.await
.expect("[PyO3] Failed to disable gRPC agent"));
});
} else {
eprintln!("[PyO3] No agent initialized");
}
});
result
}
#[pyo3(signature = (training_server_address = None))]
fn enable_agent(&self, py: Python, training_server_address: Option<String>) -> PyResult<()> {
let mut agent: RwLockWriteGuard<RelayRLAgent> = self
.tokio_runtime
.block_on(async { self.inner.write().await });
let mut result: PyResult<()> = Ok(());
py.allow_threads(|| {
if agent.agent_zmq.is_some() {
self.tokio_runtime.block_on(async {
result = Ok(agent
.agent_zmq
.as_mut()
.expect("[PyO3] ZMQ agent not initialized")
.enable_agent(training_server_address)
.await
.expect("[PyO3] Failed to enable ZMQ agent"));
});
} else if agent.agent_grpc.is_some() {
self.tokio_runtime.block_on(async {
result = Ok(agent
.agent_grpc
.as_mut()
.expect("[PyO3] gRPC agent not initialized")
.enable_agent(training_server_address)
.await
.expect("[PyO3] Failed to enable gRPC agent"));
});
} else {
eprintln!("[PyO3] No agent initialized");
}
});
result
}
}