use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use super::pyastrotime::ToTimeVec;
use super::pytle::PyTLE;
use crate::sgp4 as psgp4;
use numpy::PyArray1;
use numpy::PyArrayMethods;
#[allow(non_camel_case_types)]
#[pyclass(name = "sgp4_error", eq, eq_int)]
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum PySGP4Error {
success = psgp4::SGP4Error::SGP4Success as isize,
eccen = psgp4::SGP4Error::SGP4ErrorEccen as isize,
mean_motion = psgp4::SGP4Error::SGP4ErrorMeanMotion as isize,
perturb_eccen = psgp4::SGP4Error::SGP4ErrorPerturbEccen as isize,
semi_latus_rectum = psgp4::SGP4Error::SGP4ErrorSemiLatusRectum as isize,
unused = psgp4::SGP4Error::SGP4ErrorUnused as isize,
orbit_decay = psgp4::SGP4Error::SGP4ErrorOrbitDecay as isize,
}
#[allow(non_camel_case_types)]
#[pyclass(name = "sgp4_gravconst", eq, eq_int)]
#[derive(Clone, PartialEq, Eq)]
pub enum GravConst {
wgs72 = psgp4::GravConst::WGS72 as isize,
wgs72old = psgp4::GravConst::WGS72OLD as isize,
wgs84 = psgp4::GravConst::WGS84 as isize,
}
impl From<GravConst> for psgp4::GravConst {
fn from(f: GravConst) -> psgp4::GravConst {
match f {
GravConst::wgs72 => psgp4::GravConst::WGS72,
GravConst::wgs72old => psgp4::GravConst::WGS72OLD,
GravConst::wgs84 => psgp4::GravConst::WGS84,
}
}
}
#[allow(non_camel_case_types)]
#[pyclass(name = "sgp4_opsmode", eq, eq_int)]
#[derive(Clone, Eq, PartialEq)]
pub enum OpsMode {
afspc = psgp4::OpsMode::AFSPC as isize,
improved = psgp4::OpsMode::IMPROVED as isize,
}
impl From<OpsMode> for psgp4::OpsMode {
fn from(f: OpsMode) -> psgp4::OpsMode {
match f {
OpsMode::afspc => psgp4::OpsMode::AFSPC,
OpsMode::improved => psgp4::OpsMode::IMPROVED,
}
}
}
impl From<psgp4::SGP4Error> for PySGP4Error {
fn from(f: psgp4::SGP4Error) -> PySGP4Error {
match f {
psgp4::SGP4Error::SGP4Success => PySGP4Error::success,
psgp4::SGP4Error::SGP4ErrorEccen => PySGP4Error::eccen,
psgp4::SGP4Error::SGP4ErrorMeanMotion => PySGP4Error::mean_motion,
psgp4::SGP4Error::SGP4ErrorPerturbEccen => PySGP4Error::perturb_eccen,
psgp4::SGP4Error::SGP4ErrorSemiLatusRectum => PySGP4Error::semi_latus_rectum,
psgp4::SGP4Error::SGP4ErrorUnused => PySGP4Error::unused,
psgp4::SGP4Error::SGP4ErrorOrbitDecay => PySGP4Error::orbit_decay,
}
}
}
#[pyfunction]
#[pyo3(signature=(tle, time, **kwds))]
pub fn sgp4(
tle: &Bound<'_, PyAny>,
time: &Bound<'_, PyAny>,
kwds: Option<&Bound<'_, PyDict>>,
) -> PyResult<PyObject> {
let mut output_err = false;
let mut opsmode: OpsMode = OpsMode::afspc;
let mut gravconst: GravConst = GravConst::wgs72;
if kwds.is_some() {
let kw = kwds.unwrap();
match kw.get_item("errflag").unwrap() {
Some(v) => output_err = v.extract::<bool>()?,
None => {}
}
match kw.get_item("opsmode").unwrap() {
Some(v) => opsmode = v.extract::<OpsMode>()?,
None => {}
}
match kw.get_item("gravconst").unwrap() {
Some(v) => gravconst = v.extract::<GravConst>()?,
None => {}
}
}
if tle.is_instance_of::<PyTLE>() {
let mut stle: PyRefMut<PyTLE> = tle.extract()?;
let (r, v, e) = psgp4::sgp4_full(
&mut stle.inner,
time.to_time_vec()?.as_slice(),
gravconst.into(),
opsmode.into(),
);
pyo3::Python::with_gil(|py| -> PyResult<PyObject> {
let mut dims = vec![r.len()];
if r.nrows() > 1 && r.ncols() > 1 {
dims = vec![r.ncols(), r.nrows()];
}
if output_err == false {
Ok((
PyArray1::from_slice_bound(py, r.data.as_slice())
.reshape(dims.clone())?
.to_object(py),
PyArray1::from_slice_bound(py, v.data.as_slice())
.reshape(dims)?
.to_object(py),
)
.to_object(py))
} else {
let eint: Vec<i32> = e.into_iter().map(|x| x as i32).collect();
Ok((
PyArray1::from_slice_bound(py, r.data.as_slice()).reshape(dims.clone())?,
PyArray1::from_slice_bound(py, v.data.as_slice()).reshape(dims.clone())?,
PyArray1::from_slice_bound(py, eint.as_slice()),
)
.to_object(py))
}
})
} else if tle.is_instance_of::<PyList>() {
let mut tles = tle.extract::<Vec<PyRefMut<PyTLE>>>()?;
let tmarray = time.to_time_vec()?;
let results: Vec<psgp4::SGP4State> = tles
.iter_mut()
.map(|tle| psgp4::sgp4(&mut tle.inner, tmarray.as_slice()))
.collect();
pyo3::Python::with_gil(|py| -> PyResult<PyObject> {
let n = tles.len() * tmarray.len() * 3;
let parr = PyArray1::zeros_bound(py, [n], false);
let varr = PyArray1::zeros_bound(py, [n], false);
let ntimes = tmarray.len();
let mut eint = Vec::new();
eint.resize(ntimes * tle.len()?, 0);
results.iter().enumerate().for_each(|(idx, (p, v, e))| {
unsafe {
let pdata: *mut f64 = parr.data();
std::ptr::copy_nonoverlapping(
p.as_ptr(),
pdata.add(idx * ntimes * 3),
ntimes * 3,
);
let vdata: *mut f64 = varr.data();
std::ptr::copy_nonoverlapping(
v.as_ptr(),
vdata.add(idx * ntimes * 3),
ntimes * 3,
);
eint[idx] = e[idx].clone() as i32;
}
});
let dims = match (tles.len() > 1, ntimes > 1) {
(true, true) => vec![tles.len(), ntimes, 3],
(true, false) => vec![tles.len(), 3],
(false, true) => vec![ntimes, 3],
(false, false) => vec![3],
};
let edims = match (tles.len() > 1, ntimes > 1) {
(true, true) => vec![tles.len(), ntimes],
(true, false) => vec![tles.len()],
(false, true) => vec![ntimes],
(false, false) => vec![1],
};
if output_err == false {
Ok((
parr.reshape(dims.clone()).unwrap(),
varr.reshape(dims).unwrap(),
)
.to_object(py))
} else {
Ok((
parr.reshape(dims.clone()).unwrap(),
varr.reshape(dims).unwrap(),
PyArray1::from_slice_bound(py, eint.as_slice()).reshape(edims)?,
)
.to_object(py))
}
})
} else {
Err(pyo3::exceptions::PyRuntimeError::new_err(
"Invalid input type for argument 1",
))
}
}