use crate::bindings::python::o3_action::PyRelayRLAction;
use crate::types::action::RelayRLAction;
use crate::types::trajectory;
use crate::types::trajectory::RelayRLTrajectoryTrait;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString};
use pyo3::{Bound, PyAny, pyclass, pymethods};
use serde::{Deserialize, Serialize};
#[pyclass(name = "RelayRLTrajectory")]
#[derive(Serialize, Deserialize, Debug)]
pub struct PyRelayRLTrajectory {
pub inner: trajectory::RelayRLTrajectory,
}
#[pymethods]
impl PyRelayRLTrajectory {
#[new]
#[pyo3(signature = (max_length = 1000, trajectory_server = "tcp://127.0.0.1:5556"))]
fn new(max_length: Option<u32>, trajectory_server: Option<&str>) -> Self {
let traj_server = trajectory_server.map(|server| server.to_string());
PyRelayRLTrajectory {
inner: trajectory::RelayRLTrajectory::new(max_length, traj_server),
}
}
#[pyo3(signature = ())]
fn get_actions(&self) -> Vec<PyRelayRLAction> {
let mut py_actions: Vec<PyRelayRLAction> = Vec::new();
let actions: &Vec<RelayRLAction> = &self.inner.actions;
for action in actions {
py_actions.push((*action).clone().into_py());
}
py_actions
}
#[pyo3(signature = (action))]
fn add_action(&mut self, action: &PyRelayRLAction) {
self.inner.add_action(&action.inner, true);
}
#[pyo3(signature = ())]
fn to_json(&self) -> String {
serde_json::to_string(&self).expect("Failed to serialize trajectory to JSON")
}
#[staticmethod]
#[pyo3(signature = (trajectory_dict))]
fn traj_from_json(trajectory_dict: &Bound<'_, PyDict>) -> PyRelayRLTrajectory {
let inner_any: Bound<PyAny> = trajectory_dict
.get_item("inner")
.expect("Missing 'inner' field")
.expect("Missing 'inner' field");
let inner_dict: &Bound<PyDict> = inner_any
.downcast::<PyDict>()
.expect("Expected 'inner' to be a dictionary");
let max_length: u32 = match inner_dict.get_item("max_length") {
Ok(Some(val)) => val.extract::<u32>().expect("Failed to extract max_length"),
_ => 1000,
};
let trajectory_binding: Bound<PyAny> = inner_dict
.get_item("trajectory_server")
.ok()
.flatten()
.expect("Missing 'trajectory_server' field");
let trajectory_server_result: Option<&str> = trajectory_binding
.downcast::<PyString>()
.ok()
.and_then(|py_str| py_str.to_str().ok());
let trajectory_server: Option<&str> = trajectory_server_result
.map(|_server| trajectory_server_result.expect("Failed to extract trajectory_server"));
let mut py_trajectory: PyRelayRLTrajectory =
PyRelayRLTrajectory::new(Some(max_length), trajectory_server);
if let Ok(Some(actions_obj)) = inner_dict.get_item("actions") {
let actions_list: &Bound<PyList> = actions_obj
.downcast::<PyList>()
.expect("'actions' must be a list");
for action_item in actions_list.iter() {
let action_dict: &Bound<PyDict> = action_item
.downcast::<PyDict>()
.expect("Action must be a dictionary");
let action = PyRelayRLAction::action_from_json(action_dict);
py_trajectory.inner.add_action(&action.inner, false);
}
}
py_trajectory
}
}