use crate::estimation;
use crate::estimation::MeasurementModel;
use crate::estimation::DynamicsSource;
use crate::math::jacobian::{DifferenceMethod, PerturbationStrategy};
#[pyclass(module = "brahe._brahe", from_py_object)]
#[pyo3(name = "ProcessNoiseConfig")]
#[derive(Clone)]
pub struct PyProcessNoiseConfig {
pub(crate) config: estimation::ProcessNoiseConfig,
}
#[pymethods]
impl PyProcessNoiseConfig {
#[new]
#[pyo3(signature = (q_matrix, scale_with_dt=false))]
fn new(q_matrix: PyReadonlyArray2<f64>, scale_with_dt: bool) -> PyResult<Self> {
let shape = q_matrix.shape();
let n = shape[0];
if shape[1] != n {
return Err(exceptions::PyValueError::new_err(
"q_matrix must be a square matrix",
));
}
let data: Vec<f64> = q_matrix.as_slice()?.to_vec();
let q = DMatrix::from_row_slice(n, n, &data);
Ok(PyProcessNoiseConfig {
config: estimation::ProcessNoiseConfig {
q_matrix: q,
scale_with_dt,
},
})
}
#[getter]
fn q_matrix<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, numpy::Ix2>> {
let n = self.config.q_matrix.nrows();
let c = self.config.q_matrix.ncols();
matrix_to_numpy!(py, self.config.q_matrix, n, c, f64)
}
#[getter]
fn scale_with_dt(&self) -> bool {
self.config.scale_with_dt
}
fn __repr__(&self) -> String {
format!(
"ProcessNoiseConfig({}x{}, scale_with_dt={})",
self.config.q_matrix.nrows(),
self.config.q_matrix.ncols(),
self.config.scale_with_dt,
)
}
}
#[pyclass(module = "brahe._brahe", from_py_object)]
#[pyo3(name = "EKFConfig")]
#[derive(Clone)]
pub struct PyEKFConfig {
pub(crate) config: estimation::EKFConfig,
}
#[pymethods]
impl PyEKFConfig {
#[new]
#[pyo3(signature = (process_noise=None, store_records=true))]
fn new(process_noise: Option<PyProcessNoiseConfig>, store_records: bool) -> Self {
PyEKFConfig {
config: estimation::EKFConfig {
process_noise: process_noise.map(|pn| pn.config),
store_records,
},
}
}
#[staticmethod]
#[pyo3(name = "default")]
fn py_default() -> Self {
PyEKFConfig {
config: estimation::EKFConfig::default(),
}
}
#[getter]
fn store_records(&self) -> bool {
self.config.store_records
}
fn __repr__(&self) -> String {
format!(
"EKFConfig(process_noise={}, store_records={})",
if self.config.process_noise.is_some() { "set" } else { "None" },
self.config.store_records,
)
}
}
#[pyclass(module = "brahe._brahe", from_py_object)]
#[pyo3(name = "Observation")]
#[derive(Clone)]
pub struct PyObservation {
pub(crate) observation: estimation::Observation,
}
#[pymethods]
impl PyObservation {
#[new]
#[pyo3(signature = (epoch, measurement, model_index=0))]
fn new(
epoch: &PyEpoch,
measurement: PyReadonlyArray1<f64>,
model_index: usize,
) -> PyResult<Self> {
let meas_vec = DVector::from_column_slice(measurement.as_slice()?);
Ok(PyObservation {
observation: estimation::Observation::new(epoch.obj, meas_vec, model_index),
})
}
#[getter]
fn epoch(&self) -> PyEpoch {
PyEpoch {
obj: self.observation.epoch,
}
}
#[getter]
fn measurement<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, Ix1>> {
let n = self.observation.measurement.len();
vector_to_numpy!(py, self.observation.measurement, n, f64)
}
#[getter]
fn model_index(&self) -> usize {
self.observation.model_index
}
fn __repr__(&self) -> String {
format!(
"Observation(epoch={}, dim={}, model_index={})",
self.observation.epoch,
self.observation.measurement.len(),
self.observation.model_index,
)
}
}
#[pyclass(module = "brahe._brahe", from_py_object)]
#[pyo3(name = "FilterRecord")]
#[derive(Clone)]
pub struct PyFilterRecord {
pub(crate) record: estimation::FilterRecord,
}
#[pymethods]
impl PyFilterRecord {
#[getter]
fn epoch(&self) -> PyEpoch {
PyEpoch {
obj: self.record.epoch,
}
}
#[getter]
fn state_predicted<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, Ix1>> {
let n = self.record.state_predicted.len();
vector_to_numpy!(py, self.record.state_predicted, n, f64)
}
#[getter]
fn covariance_predicted<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, numpy::Ix2>> {
let r = self.record.covariance_predicted.nrows();
let c = self.record.covariance_predicted.ncols();
matrix_to_numpy!(py, self.record.covariance_predicted, r, c, f64)
}
#[getter]
fn state_updated<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, Ix1>> {
let n = self.record.state_updated.len();
vector_to_numpy!(py, self.record.state_updated, n, f64)
}
#[getter]
fn covariance_updated<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, numpy::Ix2>> {
let r = self.record.covariance_updated.nrows();
let c = self.record.covariance_updated.ncols();
matrix_to_numpy!(py, self.record.covariance_updated, r, c, f64)
}
#[getter]
fn prefit_residual<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, Ix1>> {
let n = self.record.prefit_residual.len();
vector_to_numpy!(py, self.record.prefit_residual, n, f64)
}
#[getter]
fn postfit_residual<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, Ix1>> {
let n = self.record.postfit_residual.len();
vector_to_numpy!(py, self.record.postfit_residual, n, f64)
}
#[getter]
fn kalman_gain<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, numpy::Ix2>> {
let r = self.record.kalman_gain.nrows();
let c = self.record.kalman_gain.ncols();
matrix_to_numpy!(py, self.record.kalman_gain, r, c, f64)
}
#[getter]
fn measurement_name(&self) -> &str {
&self.record.measurement_name
}
fn __repr__(&self) -> String {
format!(
"FilterRecord(epoch={}, model={})",
self.record.epoch, self.record.measurement_name,
)
}
}
macro_rules! impl_measurement_model_binding {
(
$py_name:ident, $rust_type:ty, $python_name:expr,
constructors: [$($ctor:tt)*],
doc: $doc:expr
) => {
#[doc = $doc]
#[pyclass(module = "brahe._brahe")]
#[pyo3(name = $python_name)]
pub struct $py_name {
pub(crate) model: $rust_type,
}
#[pymethods]
impl $py_name {
$($ctor)*
fn predict<'py>(
&self,
py: Python<'py>,
epoch: &PyEpoch,
state: PyReadonlyArray1<f64>,
) -> PyResult<Bound<'py, PyArray<f64, numpy::Ix1>>> {
let state_vec = nalgebra::DVector::from_column_slice(state.as_slice()?);
let vec = self.model.predict(&epoch.obj, &state_vec, None)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
let flat: Vec<f64> = (0..vec.len()).map(|i| vec[i]).collect();
Ok(flat.into_pyarray(py))
}
fn jacobian<'py>(
&self,
py: Python<'py>,
epoch: &PyEpoch,
state: PyReadonlyArray1<f64>,
) -> PyResult<Bound<'py, PyArray<f64, numpy::Ix2>>> {
let state_vec = nalgebra::DVector::from_column_slice(state.as_slice()?);
let mat = self.model.jacobian(&epoch.obj, &state_vec, None)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
let rows = mat.nrows();
let cols = mat.ncols();
let mut flat = Vec::with_capacity(rows * cols);
for i in 0..rows {
for j in 0..cols {
flat.push(mat[(i, j)]);
}
}
Ok(flat.into_pyarray(py).reshape([rows, cols]).unwrap())
}
fn noise_covariance<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, numpy::Ix2>> {
let mat = self.model.noise_covariance();
let rows = mat.nrows();
let cols = mat.ncols();
let mut flat = Vec::with_capacity(rows * cols);
for i in 0..rows {
for j in 0..cols {
flat.push(mat[(i, j)]);
}
}
flat.into_pyarray(py).reshape([rows, cols]).unwrap()
}
fn measurement_dim(&self) -> usize {
self.model.measurement_dim()
}
fn name(&self) -> &str {
self.model.name()
}
fn __repr__(&self) -> String {
format!("{}(dim={})", self.model.name(), self.model.measurement_dim())
}
}
};
}
impl_measurement_model_binding!(
PyInertialPositionMeasurementModel,
estimation::InertialPositionMeasurementModel,
"InertialPositionMeasurementModel",
constructors: [
#[new]
fn new(sigma: f64) -> Self {
PyInertialPositionMeasurementModel {
model: estimation::InertialPositionMeasurementModel::new(sigma),
}
}
#[staticmethod]
fn per_axis(sigma_x: f64, sigma_y: f64, sigma_z: f64) -> Self {
PyInertialPositionMeasurementModel {
model: estimation::InertialPositionMeasurementModel::new_per_axis(sigma_x, sigma_y, sigma_z),
}
}
#[staticmethod]
fn from_covariance(noise_cov: PyReadonlyArray2<f64>) -> PyResult<Self> {
let shape = noise_cov.shape();
let data = noise_cov.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let mat = nalgebra::DMatrix::from_row_slice(shape[0], shape[1], data);
let model = estimation::InertialPositionMeasurementModel::from_covariance(mat)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyInertialPositionMeasurementModel { model })
}
#[staticmethod]
fn from_upper_triangular(upper: PyReadonlyArray1<f64>) -> PyResult<Self> {
let data = upper.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let model = estimation::InertialPositionMeasurementModel::from_upper_triangular(data)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyInertialPositionMeasurementModel { model })
}
],
doc: "Inertial position measurement model (3D ECI position).\n\nDirectly observes [x, y, z] from the state vector with Gaussian noise."
);
impl_measurement_model_binding!(
PyInertialVelocityMeasurementModel,
estimation::InertialVelocityMeasurementModel,
"InertialVelocityMeasurementModel",
constructors: [
#[new]
fn new(sigma: f64) -> Self {
PyInertialVelocityMeasurementModel {
model: estimation::InertialVelocityMeasurementModel::new(sigma),
}
}
#[staticmethod]
fn per_axis(sigma_x: f64, sigma_y: f64, sigma_z: f64) -> Self {
PyInertialVelocityMeasurementModel {
model: estimation::InertialVelocityMeasurementModel::new_per_axis(sigma_x, sigma_y, sigma_z),
}
}
#[staticmethod]
fn from_covariance(noise_cov: PyReadonlyArray2<f64>) -> PyResult<Self> {
let shape = noise_cov.shape();
let data = noise_cov.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let mat = nalgebra::DMatrix::from_row_slice(shape[0], shape[1], data);
let model = estimation::InertialVelocityMeasurementModel::from_covariance(mat)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyInertialVelocityMeasurementModel { model })
}
#[staticmethod]
fn from_upper_triangular(upper: PyReadonlyArray1<f64>) -> PyResult<Self> {
let data = upper.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let model = estimation::InertialVelocityMeasurementModel::from_upper_triangular(data)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyInertialVelocityMeasurementModel { model })
}
],
doc: "Inertial velocity measurement model (3D ECI velocity).\n\nDirectly observes [vx, vy, vz] from the state vector with Gaussian noise."
);
impl_measurement_model_binding!(
PyInertialStateMeasurementModel,
estimation::InertialStateMeasurementModel,
"InertialStateMeasurementModel",
constructors: [
#[new]
fn new(pos_sigma: f64, vel_sigma: f64) -> Self {
PyInertialStateMeasurementModel {
model: estimation::InertialStateMeasurementModel::new(pos_sigma, vel_sigma),
}
}
#[staticmethod]
fn per_axis(
pos_sigma_x: f64, pos_sigma_y: f64, pos_sigma_z: f64,
vel_sigma_x: f64, vel_sigma_y: f64, vel_sigma_z: f64,
) -> Self {
PyInertialStateMeasurementModel {
model: estimation::InertialStateMeasurementModel::new_per_axis(
pos_sigma_x, pos_sigma_y, pos_sigma_z,
vel_sigma_x, vel_sigma_y, vel_sigma_z,
),
}
}
#[staticmethod]
fn from_covariance(noise_cov: PyReadonlyArray2<f64>) -> PyResult<Self> {
let shape = noise_cov.shape();
let data = noise_cov.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let mat = nalgebra::DMatrix::from_row_slice(shape[0], shape[1], data);
let model = estimation::InertialStateMeasurementModel::from_covariance(mat)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyInertialStateMeasurementModel { model })
}
#[staticmethod]
fn from_upper_triangular(upper: PyReadonlyArray1<f64>) -> PyResult<Self> {
let data = upper.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let model = estimation::InertialStateMeasurementModel::from_upper_triangular(data)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyInertialStateMeasurementModel { model })
}
],
doc: "Inertial state measurement model (6D ECI state).\n\nDirectly observes [x, y, z, vx, vy, vz] from the state vector with Gaussian noise."
);
impl_measurement_model_binding!(
PyEcefPositionMeasurementModel,
estimation::EcefPositionMeasurementModel,
"ECEFPositionMeasurementModel",
constructors: [
#[new]
fn new(sigma: f64) -> Self {
PyEcefPositionMeasurementModel {
model: estimation::EcefPositionMeasurementModel::new(sigma),
}
}
#[staticmethod]
fn per_axis(sigma_x: f64, sigma_y: f64, sigma_z: f64) -> Self {
PyEcefPositionMeasurementModel {
model: estimation::EcefPositionMeasurementModel::new_per_axis(sigma_x, sigma_y, sigma_z),
}
}
#[staticmethod]
fn from_covariance(noise_cov: PyReadonlyArray2<f64>) -> PyResult<Self> {
let shape = noise_cov.shape();
let data = noise_cov.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let mat = nalgebra::DMatrix::from_row_slice(shape[0], shape[1], data);
let model = estimation::EcefPositionMeasurementModel::from_covariance(mat)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyEcefPositionMeasurementModel { model })
}
#[staticmethod]
fn from_upper_triangular(upper: PyReadonlyArray1<f64>) -> PyResult<Self> {
let data = upper.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let model = estimation::EcefPositionMeasurementModel::from_upper_triangular(data)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyEcefPositionMeasurementModel { model })
}
],
doc: "ECEF position measurement model (3D ECEF position from GNSS).\n\nInternally converts ECI state to ECEF position."
);
impl_measurement_model_binding!(
PyEcefVelocityMeasurementModel,
estimation::EcefVelocityMeasurementModel,
"ECEFVelocityMeasurementModel",
constructors: [
#[new]
fn new(sigma: f64) -> Self {
PyEcefVelocityMeasurementModel {
model: estimation::EcefVelocityMeasurementModel::new(sigma),
}
}
#[staticmethod]
fn per_axis(sigma_x: f64, sigma_y: f64, sigma_z: f64) -> Self {
PyEcefVelocityMeasurementModel {
model: estimation::EcefVelocityMeasurementModel::new_per_axis(sigma_x, sigma_y, sigma_z),
}
}
#[staticmethod]
fn from_covariance(noise_cov: PyReadonlyArray2<f64>) -> PyResult<Self> {
let shape = noise_cov.shape();
let data = noise_cov.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let mat = nalgebra::DMatrix::from_row_slice(shape[0], shape[1], data);
let model = estimation::EcefVelocityMeasurementModel::from_covariance(mat)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyEcefVelocityMeasurementModel { model })
}
#[staticmethod]
fn from_upper_triangular(upper: PyReadonlyArray1<f64>) -> PyResult<Self> {
let data = upper.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let model = estimation::EcefVelocityMeasurementModel::from_upper_triangular(data)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyEcefVelocityMeasurementModel { model })
}
],
doc: "ECEF velocity measurement model (3D ECEF velocity from GNSS).\n\nInternally converts ECI state to ECEF velocity."
);
impl_measurement_model_binding!(
PyEcefStateMeasurementModel,
estimation::EcefStateMeasurementModel,
"ECEFStateMeasurementModel",
constructors: [
#[new]
fn new(pos_sigma: f64, vel_sigma: f64) -> Self {
PyEcefStateMeasurementModel {
model: estimation::EcefStateMeasurementModel::new(pos_sigma, vel_sigma),
}
}
#[staticmethod]
fn per_axis(
pos_sigma_x: f64, pos_sigma_y: f64, pos_sigma_z: f64,
vel_sigma_x: f64, vel_sigma_y: f64, vel_sigma_z: f64,
) -> Self {
PyEcefStateMeasurementModel {
model: estimation::EcefStateMeasurementModel::new_per_axis(
pos_sigma_x, pos_sigma_y, pos_sigma_z,
vel_sigma_x, vel_sigma_y, vel_sigma_z,
),
}
}
#[staticmethod]
fn from_covariance(noise_cov: PyReadonlyArray2<f64>) -> PyResult<Self> {
let shape = noise_cov.shape();
let data = noise_cov.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let mat = nalgebra::DMatrix::from_row_slice(shape[0], shape[1], data);
let model = estimation::EcefStateMeasurementModel::from_covariance(mat)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyEcefStateMeasurementModel { model })
}
#[staticmethod]
fn from_upper_triangular(upper: PyReadonlyArray1<f64>) -> PyResult<Self> {
let data = upper.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let model = estimation::EcefStateMeasurementModel::from_upper_triangular(data)
.map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyEcefStateMeasurementModel { model })
}
],
doc: "ECEF state measurement model (6D ECEF state from GNSS).\n\nInternally converts ECI state to ECEF state."
);
#[pyfunction]
#[pyo3(name = "isotropic_covariance")]
fn py_isotropic_covariance<'py>(
py: Python<'py>,
dim: usize,
sigma: f64,
) -> Bound<'py, PyArray<f64, numpy::Ix2>> {
let mat = crate::math::covariance::isotropic_covariance(dim, sigma);
let rows = mat.nrows();
let cols = mat.ncols();
let mut flat = Vec::with_capacity(rows * cols);
for i in 0..rows {
for j in 0..cols {
flat.push(mat[(i, j)]);
}
}
flat.into_pyarray(py).reshape([rows, cols]).unwrap()
}
#[pyfunction]
#[pyo3(name = "diagonal_covariance")]
fn py_diagonal_covariance<'py>(
py: Python<'py>,
sigmas: PyReadonlyArray1<f64>,
) -> PyResult<Bound<'py, PyArray<f64, numpy::Ix2>>> {
let data = sigmas.as_slice().map_err(|e| {
exceptions::PyValueError::new_err(format!("Failed to read array: {}", e))
})?;
let mat = crate::math::covariance::diagonal_covariance(data);
let rows = mat.nrows();
let cols = mat.ncols();
let mut flat = Vec::with_capacity(rows * cols);
for i in 0..rows {
for j in 0..cols {
flat.push(mat[(i, j)]);
}
}
Ok(flat.into_pyarray(py).reshape([rows, cols]).unwrap())
}
#[pyclass(module = "brahe._brahe", subclass)]
#[pyo3(name = "MeasurementModel")]
pub struct PyMeasurementModel {}
#[pymethods]
impl PyMeasurementModel {
#[new]
#[pyo3(signature = (*_args, **_kwargs))]
fn new(
_args: &Bound<'_, pyo3::types::PyTuple>,
_kwargs: Option<&Bound<'_, pyo3::types::PyDict>>,
) -> Self {
PyMeasurementModel {}
}
#[allow(unused_variables)]
fn predict(
&self,
epoch: &PyEpoch,
state: PyReadonlyArray1<f64>,
) -> PyResult<Py<PyAny>> {
Err(exceptions::PyNotImplementedError::new_err(
"Subclasses must implement predict()",
))
}
#[allow(unused_variables)]
fn jacobian(
&self,
epoch: &PyEpoch,
state: PyReadonlyArray1<f64>,
) -> PyResult<Py<PyAny>> {
Python::attach(|py| Ok(py.None()))
}
fn noise_covariance(&self) -> PyResult<Py<PyAny>> {
Err(exceptions::PyNotImplementedError::new_err(
"Subclasses must implement noise_covariance()",
))
}
fn measurement_dim(&self) -> PyResult<usize> {
Err(exceptions::PyNotImplementedError::new_err(
"Subclasses must implement measurement_dim()",
))
}
fn name(&self) -> PyResult<String> {
Err(exceptions::PyNotImplementedError::new_err(
"Subclasses must implement name()",
))
}
}
pub(crate) struct RustMeasurementModelWrapper {
py_model: Py<PyAny>,
cached_name: String,
cached_dim: usize,
cached_noise_cov: DMatrix<f64>,
}
unsafe impl Send for RustMeasurementModelWrapper {}
unsafe impl Sync for RustMeasurementModelWrapper {}
impl RustMeasurementModelWrapper {
pub fn new(py: Python<'_>, py_model: Py<PyAny>) -> PyResult<Self> {
let obj = py_model.bind(py);
let cached_name: String = obj
.call_method0("name")
.and_then(|r| r.extract())
.map_err(|e| {
exceptions::PyValueError::new_err(format!(
"MeasurementModel.name() failed: {}",
e
))
})?;
let cached_dim: usize = obj
.call_method0("measurement_dim")
.and_then(|r| r.extract())
.map_err(|e| {
exceptions::PyValueError::new_err(format!(
"MeasurementModel.measurement_dim() failed: {}",
e
))
})?;
let noise_cov_result = obj.call_method0("noise_covariance").map_err(|e| {
exceptions::PyValueError::new_err(format!(
"MeasurementModel.noise_covariance() failed: {}",
e
))
})?;
let noise_arr: PyReadonlyArray2<f64> = noise_cov_result.extract().map_err(|e| {
exceptions::PyValueError::new_err(format!(
"noise_covariance() must return a 2D numpy array: {}",
e
))
})?;
let shape = noise_arr.shape();
let data: Vec<f64> = noise_arr.as_slice()
.map_err(|e| exceptions::PyValueError::new_err(format!("Failed to read noise_covariance: {}", e)))?
.to_vec();
let cached_noise_cov = DMatrix::from_row_slice(shape[0], shape[1], &data);
Ok(Self {
py_model,
cached_name,
cached_dim,
cached_noise_cov,
})
}
}
impl MeasurementModel for RustMeasurementModelWrapper {
fn predict(
&self,
epoch: &crate::time::Epoch,
state: &DVector<f64>,
_params: Option<&DVector<f64>>,
) -> Result<DVector<f64>, crate::utils::errors::BraheError> {
Python::attach(|py| {
let py_epoch = Py::new(py, PyEpoch { obj: *epoch })
.map_err(|e| crate::utils::errors::BraheError::Error(e.to_string()))?;
let state_np = state.as_slice().to_pyarray(py);
let result = self
.py_model
.bind(py)
.call_method1("predict", (py_epoch, state_np))
.map_err(|e| {
crate::utils::errors::BraheError::Error(format!(
"Python predict() failed: {}",
e
))
})?;
let res_arr: PyReadonlyArray1<f64> = result.extract().map_err(|e| {
crate::utils::errors::BraheError::Error(format!(
"predict() must return a numpy array: {}",
e
))
})?;
Ok(DVector::from_column_slice(
res_arr
.as_slice()
.map_err(|e| crate::utils::errors::BraheError::Error(e.to_string()))?,
))
})
}
fn jacobian(
&self,
epoch: &crate::time::Epoch,
state: &DVector<f64>,
params: Option<&DVector<f64>>,
) -> Result<DMatrix<f64>, crate::utils::errors::BraheError> {
let py_result: Result<Option<DMatrix<f64>>, crate::utils::errors::BraheError> =
Python::attach(|py| {
let py_epoch = Py::new(py, PyEpoch { obj: *epoch }).map_err(|e| {
crate::utils::errors::BraheError::Error(format!(
"Failed to create PyEpoch: {}",
e
))
})?;
let state_np = state.as_slice().to_pyarray(py);
let result = self
.py_model
.bind(py)
.call_method1("jacobian", (py_epoch, state_np))
.map_err(|e| {
crate::utils::errors::BraheError::Error(format!(
"Python jacobian() raised an exception: {}",
e
))
})?;
if result.is_none() {
return Ok(None);
}
let arr: PyReadonlyArray2<f64> = result.extract().map_err(|e| {
crate::utils::errors::BraheError::Error(format!(
"jacobian() must return a 2D numpy array or None: {}",
e
))
})?;
let shape = arr.shape();
let data: Vec<f64> = arr
.as_slice()
.map_err(|e| {
crate::utils::errors::BraheError::Error(format!(
"Failed to read jacobian array data: {}",
e
))
})?
.to_vec();
Ok(Some(DMatrix::from_row_slice(shape[0], shape[1], &data)))
});
match py_result {
Ok(Some(jacobian)) => Ok(jacobian),
Ok(None) => {
estimation::measurement_jacobian_numerical(
self,
epoch,
state,
params,
DifferenceMethod::Central,
PerturbationStrategy::Adaptive {
scale_factor: 1.0,
min_value: 1.0,
},
)
}
Err(e) => Err(e),
}
}
fn noise_covariance(&self) -> DMatrix<f64> {
self.cached_noise_cov.clone()
}
fn measurement_dim(&self) -> usize {
self.cached_dim
}
fn name(&self) -> &str {
&self.cached_name
}
}
enum MeasurementModelHolder {
RustNative(Box<dyn MeasurementModel>),
PythonWrapper(RustMeasurementModelWrapper),
}
impl MeasurementModel for MeasurementModelHolder {
fn predict(
&self,
epoch: &crate::time::Epoch,
state: &DVector<f64>,
params: Option<&DVector<f64>>,
) -> Result<DVector<f64>, crate::utils::errors::BraheError> {
match self {
MeasurementModelHolder::RustNative(m) => m.predict(epoch, state, params),
MeasurementModelHolder::PythonWrapper(m) => m.predict(epoch, state, params),
}
}
fn jacobian(
&self,
epoch: &crate::time::Epoch,
state: &DVector<f64>,
params: Option<&DVector<f64>>,
) -> Result<DMatrix<f64>, crate::utils::errors::BraheError> {
match self {
MeasurementModelHolder::RustNative(m) => m.jacobian(epoch, state, params),
MeasurementModelHolder::PythonWrapper(m) => m.jacobian(epoch, state, params),
}
}
fn noise_covariance(&self) -> DMatrix<f64> {
match self {
MeasurementModelHolder::RustNative(m) => m.noise_covariance(),
MeasurementModelHolder::PythonWrapper(m) => m.noise_covariance(),
}
}
fn measurement_dim(&self) -> usize {
match self {
MeasurementModelHolder::RustNative(m) => m.measurement_dim(),
MeasurementModelHolder::PythonWrapper(m) => m.measurement_dim(),
}
}
fn name(&self) -> &str {
match self {
MeasurementModelHolder::RustNative(m) => m.name(),
MeasurementModelHolder::PythonWrapper(m) => m.name(),
}
}
}
fn process_measurement_models(
py: Python<'_>,
models: Vec<Py<PyAny>>,
) -> PyResult<Vec<Box<dyn MeasurementModel>>> {
let mut result: Vec<Box<dyn MeasurementModel>> = Vec::with_capacity(models.len());
for py_model in models {
let obj = py_model.bind(py);
if let Ok(m) = obj.extract::<PyRef<PyInertialPositionMeasurementModel>>() {
result.push(Box::new(MeasurementModelHolder::RustNative(Box::new(
m.model.clone(),
))));
} else if let Ok(m) = obj.extract::<PyRef<PyInertialVelocityMeasurementModel>>() {
result.push(Box::new(MeasurementModelHolder::RustNative(Box::new(
m.model.clone(),
))));
} else if let Ok(m) = obj.extract::<PyRef<PyInertialStateMeasurementModel>>() {
result.push(Box::new(MeasurementModelHolder::RustNative(Box::new(
m.model.clone(),
))));
} else if let Ok(m) = obj.extract::<PyRef<PyEcefPositionMeasurementModel>>() {
result.push(Box::new(MeasurementModelHolder::RustNative(Box::new(
m.model.clone(),
))));
} else if let Ok(m) = obj.extract::<PyRef<PyEcefVelocityMeasurementModel>>() {
result.push(Box::new(MeasurementModelHolder::RustNative(Box::new(
m.model.clone(),
))));
} else if let Ok(m) = obj.extract::<PyRef<PyEcefStateMeasurementModel>>() {
result.push(Box::new(MeasurementModelHolder::RustNative(Box::new(
m.model.clone(),
))));
} else {
let wrapper = RustMeasurementModelWrapper::new(py, py_model)?;
result.push(Box::new(MeasurementModelHolder::PythonWrapper(wrapper)));
}
}
Ok(result)
}
#[pyclass(module = "brahe._brahe")]
#[pyo3(name = "ExtendedKalmanFilter")]
pub struct PyExtendedKalmanFilter {
ekf: estimation::ExtendedKalmanFilter,
}
#[pymethods]
impl PyExtendedKalmanFilter {
#[new]
#[pyo3(signature = (
epoch, state, initial_covariance,
propagation_config, force_config,
measurement_models,
config=None, params=None, additional_dynamics=None, control_input=None
))]
#[allow(clippy::too_many_arguments)]
fn new(
py: Python<'_>,
epoch: &PyEpoch,
state: PyReadonlyArray1<f64>,
initial_covariance: PyReadonlyArray2<f64>,
propagation_config: &PyNumericalPropagationConfig,
force_config: &PyForceModelConfig,
measurement_models: Vec<Py<PyAny>>,
config: Option<&PyEKFConfig>,
params: Option<PyReadonlyArray1<f64>>,
additional_dynamics: Option<Py<PyAny>>,
control_input: Option<Py<PyAny>>,
) -> PyResult<Self> {
let state_vec = DVector::from_column_slice(state.as_slice()?);
let state_dim = state_vec.len();
let cov_shape = initial_covariance.shape();
if cov_shape[0] != state_dim || cov_shape[1] != state_dim {
return Err(exceptions::PyValueError::new_err(format!(
"initial_covariance must be {}x{}, got {}x{}",
state_dim, state_dim, cov_shape[0], cov_shape[1]
)));
}
let cov_data: Vec<f64> = initial_covariance.as_slice()?.to_vec();
let cov_matrix = DMatrix::from_row_slice(state_dim, state_dim, &cov_data);
let mut prop_config = propagation_config.config.clone();
prop_config.variational.enable_stm = true;
let params_vec =
params.map(|p| DVector::from_column_slice(p.as_slice().unwrap()));
let additional_dynamics_fn: Option<crate::integrators::traits::DStateDynamics> =
additional_dynamics.map(|dyn_py| {
let dyn_py = dyn_py.clone_ref(py);
Box::new(
move |t: f64,
x: &DVector<f64>,
p: Option<&DVector<f64>>|
-> DVector<f64> {
Python::attach(|py| {
let x_np = x.as_slice().to_pyarray(py);
let p_np: Option<Bound<'_, PyArray<f64, Ix1>>> =
p.map(|pv| pv.as_slice().to_pyarray(py).to_owned());
let result = match p_np {
Some(params_arr) => dyn_py.call1(py, (t, x_np, params_arr)),
None => dyn_py.call1(py, (t, x_np, py.None())),
};
match result {
Ok(res) => {
let res_arr: PyReadonlyArray1<f64> =
res.extract(py).unwrap();
DVector::from_column_slice(
res_arr.as_slice().unwrap(),
)
}
Err(e) => {
panic!("Error calling additional_dynamics: {e}")
}
}
})
},
) as crate::integrators::traits::DStateDynamics
});
let control_input_fn: crate::integrators::traits::DControlInput =
control_input.map(|ctrl_py| {
let ctrl_py = ctrl_py.clone_ref(py);
Box::new(
move |t: f64,
x: &DVector<f64>,
p: Option<&DVector<f64>>|
-> DVector<f64> {
Python::attach(|py| {
let x_np = x.as_slice().to_pyarray(py);
let p_np: Option<Bound<'_, PyArray<f64, Ix1>>> =
p.map(|pv| pv.as_slice().to_pyarray(py).to_owned());
let result = match p_np {
Some(params_arr) => ctrl_py.call1(py, (t, x_np, params_arr)),
None => ctrl_py.call1(py, (t, x_np, py.None())),
};
match result {
Ok(res) => {
let res_arr: PyReadonlyArray1<f64> =
res.extract(py).unwrap();
DVector::from_column_slice(
res_arr.as_slice().unwrap(),
)
}
Err(e) => {
panic!("Error calling control_input: {e}")
}
}
})
},
)
as Box<
dyn Fn(f64, &DVector<f64>, Option<&DVector<f64>>) -> DVector<f64>
+ Send
+ Sync,
>
});
let prop = propagators::DNumericalOrbitPropagator::new(
epoch.obj,
state_vec,
prop_config,
force_config.config.clone(),
params_vec,
additional_dynamics_fn,
control_input_fn,
Some(cov_matrix),
)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
let dynamics = DynamicsSource::OrbitPropagator(prop);
let models = process_measurement_models(py, measurement_models)?;
let ekf_config = config
.map(|c| c.config.clone())
.unwrap_or_default();
let ekf = estimation::ExtendedKalmanFilter::from_propagator(dynamics, models, ekf_config)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(PyExtendedKalmanFilter { ekf })
}
fn process_observation(
&mut self,
observation: &PyObservation,
) -> PyResult<PyFilterRecord> {
let record = self
.ekf
.process_observation(&observation.observation)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(PyFilterRecord { record })
}
fn process_observations(
&mut self,
observations: Vec<PyRef<PyObservation>>,
) -> PyResult<()> {
let obs_vec: Vec<estimation::Observation> = observations
.iter()
.map(|o| o.observation.clone())
.collect();
self.ekf
.process_observations(&obs_vec)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(())
}
fn current_state<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, Ix1>> {
let state = self.ekf.current_state();
let n = state.len();
vector_to_numpy!(py, state, n, f64)
}
fn current_covariance<'py>(
&self,
py: Python<'py>,
) -> Option<Bound<'py, PyArray<f64, numpy::Ix2>>> {
self.ekf.current_covariance().map(|cov| {
let r = cov.nrows();
let c = cov.ncols();
matrix_to_numpy!(py, cov, r, c, f64)
})
}
fn current_epoch(&self) -> PyEpoch {
PyEpoch {
obj: self.ekf.current_epoch(),
}
}
fn records(&self) -> Vec<PyFilterRecord> {
self.ekf
.records()
.iter()
.map(|r| PyFilterRecord { record: r.clone() })
.collect()
}
fn __repr__(&self) -> String {
format!(
"ExtendedKalmanFilter(epoch={}, state_dim={}, records={})",
self.ekf.current_epoch(),
self.ekf.current_state().len(),
self.ekf.records().len(),
)
}
}
#[pyclass(module = "brahe._brahe", from_py_object)]
#[pyo3(name = "UKFConfig")]
#[derive(Clone)]
pub struct PyUKFConfig {
pub(crate) config: estimation::UKFConfig,
}
#[pymethods]
impl PyUKFConfig {
#[new]
#[pyo3(signature = (state_dim=6, alpha=1e-3, beta=2.0, kappa=0.0, process_noise=None, store_records=true))]
fn new(
state_dim: usize,
alpha: f64,
beta: f64,
kappa: f64,
process_noise: Option<PyProcessNoiseConfig>,
store_records: bool,
) -> Self {
PyUKFConfig {
config: estimation::UKFConfig {
process_noise: process_noise.map(|pn| pn.config),
state_dim,
alpha,
beta,
kappa,
store_records,
},
}
}
#[staticmethod]
#[pyo3(name = "default")]
fn py_default() -> Self {
PyUKFConfig {
config: estimation::UKFConfig::default(),
}
}
#[getter]
fn state_dim(&self) -> usize {
self.config.state_dim
}
#[getter]
fn alpha(&self) -> f64 {
self.config.alpha
}
#[getter]
fn beta(&self) -> f64 {
self.config.beta
}
#[getter]
fn kappa(&self) -> f64 {
self.config.kappa
}
#[getter]
fn store_records(&self) -> bool {
self.config.store_records
}
fn __repr__(&self) -> String {
format!(
"UKFConfig(state_dim={}, alpha={}, beta={}, kappa={})",
self.config.state_dim,
self.config.alpha,
self.config.beta,
self.config.kappa,
)
}
}
#[pyclass(module = "brahe._brahe")]
#[pyo3(name = "UnscentedKalmanFilter")]
pub struct PyUnscentedKalmanFilter {
ukf: estimation::UnscentedKalmanFilter,
}
#[pymethods]
impl PyUnscentedKalmanFilter {
#[new]
#[pyo3(signature = (
epoch, state, initial_covariance,
propagation_config, force_config,
measurement_models,
config=None, params=None, additional_dynamics=None, control_input=None
))]
#[allow(clippy::too_many_arguments)]
fn new(
py: Python<'_>,
epoch: &PyEpoch,
state: PyReadonlyArray1<f64>,
initial_covariance: PyReadonlyArray2<f64>,
propagation_config: &PyNumericalPropagationConfig,
force_config: &PyForceModelConfig,
measurement_models: Vec<Py<PyAny>>,
config: Option<&PyUKFConfig>,
params: Option<PyReadonlyArray1<f64>>,
additional_dynamics: Option<Py<PyAny>>,
control_input: Option<Py<PyAny>>,
) -> PyResult<Self> {
let state_vec = nalgebra::DVector::from_column_slice(state.as_slice()?);
let state_dim = state_vec.len();
let cov_shape = initial_covariance.shape();
if cov_shape[0] != state_dim || cov_shape[1] != state_dim {
return Err(exceptions::PyValueError::new_err(format!(
"initial_covariance must be {}x{}, got {}x{}",
state_dim, state_dim, cov_shape[0], cov_shape[1]
)));
}
let cov_data: Vec<f64> = initial_covariance.as_slice()?.to_vec();
let cov_matrix =
nalgebra::DMatrix::from_row_slice(state_dim, state_dim, &cov_data);
let prop_config = propagation_config.config.clone();
let params_vec =
params.map(|p| nalgebra::DVector::from_column_slice(p.as_slice().unwrap()));
let additional_dynamics_fn: Option<crate::integrators::traits::DStateDynamics> =
additional_dynamics.map(|dyn_py| {
let dyn_py = dyn_py.clone_ref(py);
Box::new(
move |t: f64,
x: &nalgebra::DVector<f64>,
p: Option<&nalgebra::DVector<f64>>|
-> nalgebra::DVector<f64> {
Python::attach(|py| {
let x_np = x.as_slice().to_pyarray(py);
let p_np: Option<Bound<'_, numpy::PyArray<f64, numpy::Ix1>>> =
p.map(|pv| pv.as_slice().to_pyarray(py).to_owned());
let result = match p_np {
Some(params_arr) => dyn_py.call1(py, (t, x_np, params_arr)),
None => dyn_py.call1(py, (t, x_np, py.None())),
};
match result {
Ok(res) => {
let res_arr: PyReadonlyArray1<f64> =
res.extract(py).unwrap();
nalgebra::DVector::from_column_slice(
res_arr.as_slice().unwrap(),
)
}
Err(e) => {
panic!("Error calling additional_dynamics: {e}")
}
}
})
},
) as crate::integrators::traits::DStateDynamics
});
let control_input_fn: crate::integrators::traits::DControlInput =
control_input.map(|ctrl_py| {
let ctrl_py = ctrl_py.clone_ref(py);
Box::new(
move |t: f64,
x: &nalgebra::DVector<f64>,
p: Option<&nalgebra::DVector<f64>>|
-> nalgebra::DVector<f64> {
Python::attach(|py| {
let x_np = x.as_slice().to_pyarray(py);
let p_np: Option<Bound<'_, numpy::PyArray<f64, numpy::Ix1>>> =
p.map(|pv| pv.as_slice().to_pyarray(py).to_owned());
let result = match p_np {
Some(params_arr) => ctrl_py.call1(py, (t, x_np, params_arr)),
None => ctrl_py.call1(py, (t, x_np, py.None())),
};
match result {
Ok(res) => {
let res_arr: PyReadonlyArray1<f64> =
res.extract(py).unwrap();
nalgebra::DVector::from_column_slice(
res_arr.as_slice().unwrap(),
)
}
Err(e) => {
panic!("Error calling control_input: {e}")
}
}
})
},
)
as Box<
dyn Fn(
f64,
&nalgebra::DVector<f64>,
Option<&nalgebra::DVector<f64>>,
) -> nalgebra::DVector<f64>
+ Send
+ Sync,
>
});
let prop = propagators::DNumericalOrbitPropagator::new(
epoch.obj,
state_vec,
prop_config,
force_config.config.clone(),
params_vec,
additional_dynamics_fn,
control_input_fn,
Some(cov_matrix),
)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
let dynamics = DynamicsSource::OrbitPropagator(prop);
let models = process_measurement_models(py, measurement_models)?;
let ukf_config = config
.map(|c| {
let mut cfg = c.config.clone();
cfg.state_dim = state_dim;
cfg
})
.unwrap_or(estimation::UKFConfig {
state_dim,
..estimation::UKFConfig::default()
});
let ukf = estimation::UnscentedKalmanFilter::from_propagator(dynamics, models, ukf_config)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(PyUnscentedKalmanFilter { ukf })
}
fn process_observation(
&mut self,
observation: &PyObservation,
) -> PyResult<PyFilterRecord> {
let record = self
.ukf
.process_observation(&observation.observation)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(PyFilterRecord { record })
}
fn process_observations(
&mut self,
observations: Vec<PyRef<PyObservation>>,
) -> PyResult<()> {
let obs_vec: Vec<estimation::Observation> = observations
.iter()
.map(|o| o.observation.clone())
.collect();
self.ukf
.process_observations(&obs_vec)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(())
}
fn current_state<'py>(&self, py: Python<'py>) -> Bound<'py, numpy::PyArray<f64, numpy::Ix1>> {
let state = self.ukf.current_state();
let n = state.len();
vector_to_numpy!(py, state, n, f64)
}
fn current_covariance<'py>(
&self,
py: Python<'py>,
) -> Bound<'py, numpy::PyArray<f64, numpy::Ix2>> {
let cov = self.ukf.current_covariance();
let r = cov.nrows();
let c = cov.ncols();
matrix_to_numpy!(py, cov, r, c, f64)
}
fn current_epoch(&self) -> PyEpoch {
PyEpoch {
obj: self.ukf.current_epoch(),
}
}
fn records(&self) -> Vec<PyFilterRecord> {
self.ukf
.records()
.iter()
.map(|r| PyFilterRecord { record: r.clone() })
.collect()
}
fn __repr__(&self) -> String {
format!(
"UnscentedKalmanFilter(epoch={}, state_dim={}, records={})",
self.ukf.current_epoch(),
self.ukf.current_state().len(),
self.ukf.records().len(),
)
}
}
#[pyclass(module = "brahe._brahe", skip_from_py_object)]
#[pyo3(name = "BLSSolverMethod")]
#[derive(Clone)]
pub struct PyBLSSolverMethod {}
#[pymethods]
impl PyBLSSolverMethod {
#[classattr]
const NORMAL_EQUATIONS: u8 = 0;
#[classattr]
const STACKED_OBSERVATION_MATRIX: u8 = 1;
fn __repr__(&self) -> String {
"BLSSolverMethod".to_string()
}
}
fn map_solver_method(value: u8) -> PyResult<estimation::BLSSolverMethod> {
match value {
0 => Ok(estimation::BLSSolverMethod::NormalEquations),
1 => Ok(estimation::BLSSolverMethod::StackedObservationMatrix),
_ => Err(exceptions::PyValueError::new_err(format!(
"Invalid solver_method: {}. Use BLSSolverMethod.NORMAL_EQUATIONS (0) or \
BLSSolverMethod.STACKED_OBSERVATION_MATRIX (1)",
value
))),
}
}
#[pyclass(module = "brahe._brahe", from_py_object)]
#[pyo3(name = "ConsiderParameterConfig")]
#[derive(Clone)]
pub struct PyConsiderParameterConfig {
pub(crate) config: estimation::ConsiderParameterConfig,
}
#[pymethods]
impl PyConsiderParameterConfig {
#[new]
fn new(n_solve: usize, consider_covariance: PyReadonlyArray2<f64>) -> PyResult<Self> {
let shape = consider_covariance.shape();
if shape[0] != shape[1] {
return Err(exceptions::PyValueError::new_err(
"consider_covariance must be a square matrix",
));
}
let data: Vec<f64> = consider_covariance.as_slice()?.to_vec();
let cov = DMatrix::from_row_slice(shape[0], shape[1], &data);
Ok(PyConsiderParameterConfig {
config: estimation::ConsiderParameterConfig {
n_solve,
consider_covariance: cov,
},
})
}
#[getter]
fn n_solve(&self) -> usize {
self.config.n_solve
}
#[getter]
fn consider_covariance<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, numpy::Ix2>> {
let r = self.config.consider_covariance.nrows();
let c = self.config.consider_covariance.ncols();
matrix_to_numpy!(py, self.config.consider_covariance, r, c, f64)
}
fn __repr__(&self) -> String {
format!(
"ConsiderParameterConfig(n_solve={}, consider_dim={})",
self.config.n_solve,
self.config.consider_covariance.nrows(),
)
}
}
#[pyclass(module = "brahe._brahe", from_py_object)]
#[pyo3(name = "BLSConfig")]
#[derive(Clone)]
pub struct PyBLSConfig {
pub(crate) config: estimation::BLSConfig,
}
#[pymethods]
impl PyBLSConfig {
#[new]
#[pyo3(signature = (
solver_method=0,
max_iterations=10,
state_correction_threshold=Some(1e-8),
cost_convergence_threshold=None,
consider_params=None,
store_iteration_records=true,
store_observation_residuals=true
))]
#[allow(clippy::too_many_arguments)]
fn new(
solver_method: u8,
max_iterations: usize,
state_correction_threshold: Option<f64>,
cost_convergence_threshold: Option<f64>,
consider_params: Option<PyConsiderParameterConfig>,
store_iteration_records: bool,
store_observation_residuals: bool,
) -> PyResult<Self> {
let method = map_solver_method(solver_method)?;
Ok(PyBLSConfig {
config: estimation::BLSConfig {
solver_method: method,
max_iterations,
state_correction_threshold,
cost_convergence_threshold,
consider_params: consider_params.map(|cp| cp.config),
store_iteration_records,
store_observation_residuals,
},
})
}
#[staticmethod]
#[pyo3(name = "default")]
fn py_default() -> Self {
PyBLSConfig {
config: estimation::BLSConfig::default(),
}
}
#[getter]
fn solver_method(&self) -> u8 {
match self.config.solver_method {
estimation::BLSSolverMethod::NormalEquations => 0,
estimation::BLSSolverMethod::StackedObservationMatrix => 1,
}
}
#[getter]
fn max_iterations(&self) -> usize {
self.config.max_iterations
}
#[getter]
fn state_correction_threshold(&self) -> Option<f64> {
self.config.state_correction_threshold
}
#[getter]
fn cost_convergence_threshold(&self) -> Option<f64> {
self.config.cost_convergence_threshold
}
#[getter]
fn store_iteration_records(&self) -> bool {
self.config.store_iteration_records
}
#[getter]
fn store_observation_residuals(&self) -> bool {
self.config.store_observation_residuals
}
fn __repr__(&self) -> String {
let method_str = match self.config.solver_method {
estimation::BLSSolverMethod::NormalEquations => "NormalEquations",
estimation::BLSSolverMethod::StackedObservationMatrix => "StackedObservationMatrix",
};
format!(
"BLSConfig(solver={}, max_iter={}, state_thresh={:?}, cost_thresh={:?})",
method_str,
self.config.max_iterations,
self.config.state_correction_threshold,
self.config.cost_convergence_threshold,
)
}
}
#[pyclass(module = "brahe._brahe", from_py_object)]
#[pyo3(name = "BLSIterationRecord")]
#[derive(Clone)]
pub struct PyBLSIterationRecord {
pub(crate) record: estimation::BLSIterationRecord,
}
#[pymethods]
impl PyBLSIterationRecord {
#[getter]
fn iteration(&self) -> usize {
self.record.iteration
}
#[getter]
fn epoch(&self) -> PyEpoch {
PyEpoch {
obj: self.record.epoch,
}
}
#[getter]
fn state<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, Ix1>> {
let n = self.record.state.len();
vector_to_numpy!(py, self.record.state, n, f64)
}
#[getter]
fn covariance<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, numpy::Ix2>> {
let r = self.record.covariance.nrows();
let c = self.record.covariance.ncols();
matrix_to_numpy!(py, self.record.covariance, r, c, f64)
}
#[getter]
fn state_correction<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, Ix1>> {
let n = self.record.state_correction.len();
vector_to_numpy!(py, self.record.state_correction, n, f64)
}
#[getter]
fn state_correction_norm(&self) -> f64 {
self.record.state_correction_norm
}
#[getter]
fn cost(&self) -> f64 {
self.record.cost
}
#[getter]
fn rms_prefit_residual(&self) -> f64 {
self.record.rms_prefit_residual
}
#[getter]
fn rms_postfit_residual(&self) -> f64 {
self.record.rms_postfit_residual
}
fn __repr__(&self) -> String {
format!(
"BLSIterationRecord(iter={}, cost={:.6e}, dx_norm={:.6e}, rms_prefit={:.3e}, rms_postfit={:.3e})",
self.record.iteration,
self.record.cost,
self.record.state_correction_norm,
self.record.rms_prefit_residual,
self.record.rms_postfit_residual,
)
}
}
#[pyclass(module = "brahe._brahe", from_py_object)]
#[pyo3(name = "BLSObservationResidual")]
#[derive(Clone)]
pub struct PyBLSObservationResidual {
pub(crate) record: estimation::BLSObservationResidual,
}
#[pymethods]
impl PyBLSObservationResidual {
#[getter]
fn epoch(&self) -> PyEpoch {
PyEpoch {
obj: self.record.epoch,
}
}
#[getter]
fn model_name(&self) -> &str {
&self.record.model_name
}
#[getter]
fn prefit_residual<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, Ix1>> {
let n = self.record.prefit_residual.len();
vector_to_numpy!(py, self.record.prefit_residual, n, f64)
}
#[getter]
fn postfit_residual<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, Ix1>> {
let n = self.record.postfit_residual.len();
vector_to_numpy!(py, self.record.postfit_residual, n, f64)
}
fn __repr__(&self) -> String {
format!(
"BLSObservationResidual(epoch={}, model={})",
self.record.epoch, self.record.model_name,
)
}
}
#[pyclass(module = "brahe._brahe")]
#[pyo3(name = "BatchLeastSquares")]
pub struct PyBatchLeastSquares {
bls: estimation::BatchLeastSquares,
}
#[pymethods]
impl PyBatchLeastSquares {
#[new]
#[pyo3(signature = (
epoch, initial_state, initial_covariance,
propagation_config, force_config,
measurement_models,
config=None, params=None, additional_dynamics=None, control_input=None
))]
#[allow(clippy::too_many_arguments)]
fn new(
py: Python<'_>,
epoch: &PyEpoch,
initial_state: PyReadonlyArray1<f64>,
initial_covariance: PyReadonlyArray2<f64>,
propagation_config: &PyNumericalPropagationConfig,
force_config: &PyForceModelConfig,
measurement_models: Vec<Py<PyAny>>,
config: Option<&PyBLSConfig>,
params: Option<PyReadonlyArray1<f64>>,
additional_dynamics: Option<Py<PyAny>>,
control_input: Option<Py<PyAny>>,
) -> PyResult<Self> {
let state_vec = DVector::from_column_slice(initial_state.as_slice()?);
let state_dim = state_vec.len();
let cov_shape = initial_covariance.shape();
if cov_shape[0] != state_dim || cov_shape[1] != state_dim {
return Err(exceptions::PyValueError::new_err(format!(
"initial_covariance must be {}x{}, got {}x{}",
state_dim, state_dim, cov_shape[0], cov_shape[1]
)));
}
let cov_data: Vec<f64> = initial_covariance.as_slice()?.to_vec();
let cov_matrix = DMatrix::from_row_slice(state_dim, state_dim, &cov_data);
let params_vec =
params.map(|p| DVector::from_column_slice(p.as_slice().unwrap()));
let additional_dynamics_fn: Option<crate::integrators::traits::DStateDynamics> =
additional_dynamics.map(|dyn_py| {
let dyn_py = dyn_py.clone_ref(py);
Box::new(
move |t: f64,
x: &DVector<f64>,
p: Option<&DVector<f64>>|
-> DVector<f64> {
Python::attach(|py| {
let x_np = x.as_slice().to_pyarray(py);
let p_np: Option<Bound<'_, PyArray<f64, Ix1>>> =
p.map(|pv| pv.as_slice().to_pyarray(py).to_owned());
let result = match p_np {
Some(params_arr) => dyn_py.call1(py, (t, x_np, params_arr)),
None => dyn_py.call1(py, (t, x_np, py.None())),
};
match result {
Ok(res) => {
let res_arr: PyReadonlyArray1<f64> =
res.extract(py).unwrap();
DVector::from_column_slice(
res_arr.as_slice().unwrap(),
)
}
Err(e) => {
panic!("Error calling additional_dynamics: {e}")
}
}
})
},
) as crate::integrators::traits::DStateDynamics
});
let control_input_fn: crate::integrators::traits::DControlInput =
control_input.map(|ctrl_py| {
let ctrl_py = ctrl_py.clone_ref(py);
Box::new(
move |t: f64,
x: &DVector<f64>,
p: Option<&DVector<f64>>|
-> DVector<f64> {
Python::attach(|py| {
let x_np = x.as_slice().to_pyarray(py);
let p_np: Option<Bound<'_, PyArray<f64, Ix1>>> =
p.map(|pv| pv.as_slice().to_pyarray(py).to_owned());
let result = match p_np {
Some(params_arr) => ctrl_py.call1(py, (t, x_np, params_arr)),
None => ctrl_py.call1(py, (t, x_np, py.None())),
};
match result {
Ok(res) => {
let res_arr: PyReadonlyArray1<f64> =
res.extract(py).unwrap();
DVector::from_column_slice(
res_arr.as_slice().unwrap(),
)
}
Err(e) => {
panic!("Error calling control_input: {e}")
}
}
})
},
)
as Box<
dyn Fn(f64, &DVector<f64>, Option<&DVector<f64>>) -> DVector<f64>
+ Send
+ Sync,
>
});
let models = process_measurement_models(py, measurement_models)?;
let bls_config = config
.map(|c| c.config.clone())
.unwrap_or_default();
let bls = estimation::BatchLeastSquares::new(
epoch.obj,
state_vec,
cov_matrix,
propagation_config.config.clone(),
force_config.config.clone(),
params_vec,
additional_dynamics_fn,
control_input_fn,
models,
bls_config,
)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(PyBatchLeastSquares { bls })
}
fn solve(&mut self, observations: Vec<PyRef<PyObservation>>) -> PyResult<()> {
let obs_vec: Vec<estimation::Observation> = observations
.iter()
.map(|o| o.observation.clone())
.collect();
self.bls
.solve(&obs_vec)
.map_err(|e| exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(())
}
fn current_state<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray<f64, Ix1>> {
let state = self.bls.current_state();
let n = state.len();
vector_to_numpy!(py, state, n, f64)
}
fn current_covariance<'py>(
&self,
py: Python<'py>,
) -> Bound<'py, PyArray<f64, numpy::Ix2>> {
let cov = self.bls.total_covariance();
let r = cov.nrows();
let c = cov.ncols();
let cov_ref = &cov;
matrix_to_numpy!(py, cov_ref, r, c, f64)
}
fn current_epoch(&self) -> PyEpoch {
PyEpoch {
obj: self.bls.current_epoch(),
}
}
fn converged(&self) -> bool {
self.bls.converged()
}
fn iterations_completed(&self) -> usize {
self.bls.iterations_completed()
}
fn final_cost(&self) -> f64 {
self.bls.final_cost()
}
fn formal_covariance<'py>(
&self,
py: Python<'py>,
) -> Bound<'py, PyArray<f64, numpy::Ix2>> {
let cov = self.bls.formal_covariance();
let r = cov.nrows();
let c = cov.ncols();
matrix_to_numpy!(py, cov, r, c, f64)
}
fn consider_covariance<'py>(
&self,
py: Python<'py>,
) -> Option<Bound<'py, PyArray<f64, numpy::Ix2>>> {
self.bls.consider_covariance().map(|cov| {
let r = cov.nrows();
let c = cov.ncols();
let cov_ref = &cov;
matrix_to_numpy!(py, cov_ref, r, c, f64)
})
}
fn iteration_records(&self) -> Vec<PyBLSIterationRecord> {
self.bls
.iteration_records()
.iter()
.map(|r| PyBLSIterationRecord { record: r.clone() })
.collect()
}
fn observation_residuals(&self) -> Vec<Vec<PyBLSObservationResidual>> {
self.bls
.observation_residuals()
.iter()
.map(|iter_residuals| {
iter_residuals
.iter()
.map(|r| PyBLSObservationResidual { record: r.clone() })
.collect()
})
.collect()
}
fn __repr__(&self) -> String {
format!(
"BatchLeastSquares(epoch={}, state_dim={}, converged={}, iterations={})",
self.bls.current_epoch(),
self.bls.current_state().len(),
self.bls.converged(),
self.bls.iterations_completed(),
)
}
}