use super::traj_it::TrajIterator;
use super::{ExportCfg, INTERPOLATION_SAMPLES, InterpolationSnafu};
use super::{Interpolatable, TrajError};
use crate::errors::{NyxError, StateError};
use crate::io::InputOutputError;
use crate::io::watermark::pq_writer;
use crate::linalg::DefaultAllocator;
use crate::linalg::allocator::Allocator;
use crate::md::prelude::{GuidanceMode, StateParameter};
use crate::md::trajectory::smooth_state_diff_in_place;
use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
use anise::analysis::AnalysisError;
use anise::analysis::specs::StateSpecTrait;
use anise::astro::orbit::Orbit;
use anise::errors::PhysicsError;
use anise::prelude::{Aberration, Almanac};
use arrow::array::{Array, Float64Builder, StringBuilder};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use hifitime::TimeScale;
use log::{info, warn};
use parquet::arrow::ArrowWriter;
use snafu::ResultExt;
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::fs::File;
use std::iter::Iterator;
use std::ops;
use std::ops::Bound::{Excluded, Included, Unbounded};
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Clone, PartialEq)]
pub struct Traj<S: Interpolatable>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
pub name: Option<String>,
pub states: Vec<S>,
}
impl<S: Interpolatable> Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
pub fn new() -> Self {
Self {
name: None,
states: Vec::new(),
}
}
pub fn finalize(&mut self) {
self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
self.states.sort_by_key(|a| a.epoch());
}
pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
return Err(TrajError::NoInterpolationData { epoch });
}
match self
.states
.binary_search_by(|state| state.epoch().cmp(&epoch))
{
Ok(idx) => {
Ok(self.states[idx])
}
Err(idx) => {
if idx == 0 || idx >= self.states.len() {
return Err(TrajError::NoInterpolationData { epoch });
}
let num_left = INTERPOLATION_SAMPLES / 2;
let mut first_idx = idx.saturating_sub(num_left);
let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
if last_idx == self.states.len() {
first_idx = last_idx.saturating_sub(2 * num_left);
}
let mut states = Vec::with_capacity(last_idx - first_idx);
for idx in first_idx..last_idx {
states.push(self.states[idx]);
}
self.states[idx]
.interpolate(epoch, &states)
.context(InterpolationSnafu)
}
}
}
pub fn first(&self) -> &S {
self.states.first().unwrap()
}
pub fn last(&self) -> &S {
self.states.last().unwrap()
}
pub fn start_epoch(&self) -> Epoch {
self.first().epoch()
}
pub fn end_epoch(&self) -> Epoch {
self.last().epoch()
}
pub fn every(&self, step: Duration) -> TrajIterator<'_, S> {
self.every_between(step, self.first().epoch(), self.last().epoch())
}
pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<'_, S> {
TrajIterator {
time_series: TimeSeries::inclusive(
start.max(self.first().epoch()),
end.min(self.last().epoch()),
step,
),
traj: self,
}
}
pub fn filter_by_epoch<R: ops::RangeBounds<Epoch>>(mut self, bound: R) -> Self {
self.states = self
.states
.iter()
.copied()
.filter(|s| bound.contains(&s.epoch()))
.collect::<Vec<_>>();
self
}
pub fn filter_by_offset<R: ops::RangeBounds<Duration>>(self, bound: R) -> Self {
if self.states.is_empty() {
return self;
}
let start = match bound.start_bound() {
Unbounded => self.states.first().unwrap().epoch(),
Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
};
let end = match bound.end_bound() {
Unbounded => self.states.last().unwrap().epoch(),
Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
};
self.filter_by_epoch(start..=end)
}
pub fn to_parquet_simple<P: AsRef<Path>>(&self, path: P) -> Result<PathBuf, Box<dyn Error>> {
self.to_parquet(path, ExportCfg::default())
}
pub fn to_parquet_with_cfg<P: AsRef<Path>>(
&self,
path: P,
cfg: ExportCfg,
) -> Result<PathBuf, Box<dyn Error>> {
self.to_parquet(path, cfg)
}
pub fn to_parquet_with_step<P: AsRef<Path>>(
&self,
path: P,
step: Duration,
) -> Result<(), Box<dyn Error>> {
self.to_parquet_with_cfg(
path,
ExportCfg {
step: Some(step),
..Default::default()
},
)?;
Ok(())
}
pub fn to_parquet<P: AsRef<Path>>(
&self,
path: P,
cfg: ExportCfg,
) -> Result<PathBuf, Box<dyn Error>> {
let tick = Epoch::now().unwrap();
info!("Exporting trajectory to parquet file...");
let path_buf = cfg.actual_path(path);
let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
let step = cfg.step.unwrap_or_else(|| 1.minutes());
self.every_between(step, start, end).collect::<Vec<S>>()
} else {
self.states.to_vec()
};
let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
let frame = self.states[0].frame();
let more_meta = Some(vec![(
"Frame".to_string(),
serde_dhall::serialize(&frame)
.static_type_annotation()
.to_string()
.map_err(|e| {
Box::new(InputOutputError::SerializeDhall {
what: format!("frame `{frame}`"),
err: e.to_string(),
})
})?,
)]);
let requested_fields = match cfg.fields {
Some(fields) => fields,
None => S::export_params(),
};
let mut fields = Vec::new();
let mut field_nullable = Vec::new();
for field in requested_fields {
let mut any_ok = false;
let mut any_err = false;
for state in &states {
if state.value(field).is_ok() {
any_ok = true;
} else {
any_err = true;
}
}
if any_ok {
fields.push(field);
field_nullable.push(any_err);
}
}
for (field, nullable) in fields.iter().zip(field_nullable.iter().copied()) {
hdrs.push(field.to_field(more_meta.clone()).with_nullable(nullable));
}
let schema = Arc::new(Schema::new(hdrs));
let mut record: Vec<Arc<dyn Array>> = Vec::new();
let mut utc_epoch = StringBuilder::new();
for s in &states {
utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
}
record.push(Arc::new(utc_epoch.finish()));
for field in fields {
if field == StateParameter::GuidanceMode() {
let mut guid_mode = StringBuilder::new();
for s in &states {
match s.value(field) {
Ok(value) => {
guid_mode.append_value(format!("{:?}", GuidanceMode::from(value)));
}
Err(_) => guid_mode.append_null(),
}
}
record.push(Arc::new(guid_mode.finish()));
} else {
let mut data = Float64Builder::new();
for s in &states {
match s.value(field) {
Ok(value) => data.append_value(value),
Err(_) => data.append_null(),
};
}
record.push(Arc::new(data.finish()));
}
}
info!(
"Serialized {} states from {} to {}",
states.len(),
states.first().unwrap().epoch(),
states.last().unwrap().epoch()
);
let mut metadata = HashMap::new();
metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
if let Some(add_meta) = cfg.metadata {
for (k, v) in add_meta {
metadata.insert(k, v);
}
}
let props = pq_writer(Some(metadata));
let file = File::create(&path_buf)?;
let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
let batch = RecordBatch::try_new(schema, record)?;
writer.write(&batch)?;
writer.close()?;
let tock_time = Epoch::now().unwrap() - tick;
info!(
"Trajectory written to {} in {tock_time}",
path_buf.display()
);
Ok(path_buf)
}
pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
if self.states.is_empty() {
return Err(NyxError::Trajectory {
source: TrajError::CreationError {
msg: "No trajectory to convert".to_string(),
},
});
}
let mut traj = Self::new();
for state in self.every(step) {
traj.states.push(state);
}
traj.finalize();
Ok(traj)
}
pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
if self.states.is_empty() {
return Err(NyxError::Trajectory {
source: TrajError::CreationError {
msg: "No trajectory to convert".to_string(),
},
});
}
let mut traj = Self::new();
for epoch in epochs {
traj.states.push(self.at(*epoch)?);
}
traj.finalize();
Ok(traj)
}
pub fn ric_diff_to_parquet<P: AsRef<Path>>(
&self,
other: &Self,
path: P,
cfg: ExportCfg,
) -> Result<PathBuf, TrajError> {
let tick = Epoch::now().unwrap();
info!("Exporting trajectory to parquet file...");
let path_buf = cfg.actual_path(path);
let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
for coord in ["X", "Y", "Z"] {
let mut meta = HashMap::new();
meta.insert("unit".to_string(), "km".to_string());
let field = Field::new(
format!("Delta {coord} (RIC) (km)"),
DataType::Float64,
false,
)
.with_metadata(meta);
hdrs.push(field);
}
for coord in ["x", "y", "z"] {
let mut meta = HashMap::new();
meta.insert("unit".to_string(), "km/s".to_string());
let field = Field::new(
format!("Delta V{coord} (RIC) (km/s)"),
DataType::Float64,
false,
)
.with_metadata(meta);
hdrs.push(field);
}
let frame = self.states[0].frame();
let more_meta = Some(vec![(
"Frame".to_string(),
serde_dhall::serialize(&frame)
.static_type_annotation()
.to_string()
.unwrap_or(frame.to_string()),
)]);
let mut cfg = cfg;
let mut fields = match cfg.fields {
Some(fields) => fields,
None => S::export_params(),
};
fields.retain(|param| {
param != &StateParameter::GuidanceMode() && self.first().value(*param).is_ok()
});
for field in &fields {
hdrs.push(field.to_field(more_meta.clone()));
}
let schema = Arc::new(Schema::new(hdrs));
let mut record: Vec<Arc<dyn Array>> = Vec::new();
cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
Some(self.first().epoch())
} else {
Some(other.first().epoch())
};
cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
Some(other.last().epoch())
} else {
Some(self.last().epoch())
};
let step = cfg.step.unwrap_or_else(|| 1.minutes());
let self_states = self
.every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
.collect::<Vec<S>>();
let other_states = other
.every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
.collect::<Vec<S>>();
let mut ric_diff = Vec::with_capacity(other_states.len());
for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
let self_orbit = self_state.orbit();
let other_orbit = other_state.orbit();
let this_ric_diff = self_orbit
.ric_difference(&other_orbit)
.map_err(|source: PhysicsError| TrajError::TrajPhysics { source })?;
ric_diff.push(this_ric_diff);
}
smooth_state_diff_in_place(&mut ric_diff, if other_states.len() > 5 { 5 } else { 1 });
let mut utc_epoch = StringBuilder::new();
for s in &self_states {
utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
}
record.push(Arc::new(utc_epoch.finish()));
for coord_no in 0..6 {
let mut data = Float64Builder::new();
for this_ric_dff in &ric_diff {
data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
}
record.push(Arc::new(data.finish()));
}
for field in fields {
let mut data = Float64Builder::new();
for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
let self_val =
self_state
.value(field)
.map_err(|err: StateError| TrajError::TrajGeneric {
err: err.to_string(),
})?;
let other_val =
other_state
.value(field)
.map_err(|err: StateError| TrajError::TrajGeneric {
err: err.to_string(),
})?;
data.append_value(self_val - other_val);
}
record.push(Arc::new(data.finish()));
}
info!("Serialized {} states differences", self_states.len());
let mut metadata = HashMap::new();
metadata.insert(
"Purpose".to_string(),
"Trajectory difference data".to_string(),
);
if let Some(add_meta) = cfg.metadata {
for (k, v) in add_meta {
metadata.insert(k, v);
}
}
let props = pq_writer(Some(metadata));
let file = File::create(&path_buf).map_err(|err| TrajError::TrajGeneric {
err: format!("{err:?}"),
})?;
let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
let batch = RecordBatch::try_new(schema, record).map_err(|err| TrajError::TrajGeneric {
err: format!("{err:?}"),
})?;
writer.write(&batch).map_err(|err| TrajError::TrajGeneric {
err: format!("{err:?}"),
})?;
writer.close().map_err(|err| TrajError::TrajGeneric {
err: format!("{err:?}"),
})?;
let tock_time = Epoch::now().unwrap() - tick;
info!(
"Trajectory written to {} in {tock_time}",
path_buf.display()
);
Ok(path_buf)
}
}
impl<S: Interpolatable> ops::Add for Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
type Output = Result<Traj<S>, NyxError>;
fn add(self, other: Traj<S>) -> Self::Output {
&self + &other
}
}
impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
type Output = Result<Traj<S>, NyxError>;
fn add(self, other: &Traj<S>) -> Self::Output {
if self.first().frame() != other.first().frame() {
Err(NyxError::Trajectory {
source: TrajError::CreationError {
msg: format!(
"Frame mismatch in add operation: {} != {}",
self.first().frame(),
other.first().frame()
),
},
})
} else {
if self.last().epoch() < other.first().epoch() {
let gap = other.first().epoch() - self.last().epoch();
warn!(
"Resulting merged trajectory will have a time-gap of {} starting at {}",
gap,
self.last().epoch()
);
}
let mut me = self.clone();
for state in &other
.states
.iter()
.copied()
.filter(|s| s.epoch() > self.last().epoch())
.collect::<Vec<S>>()
{
me.states.push(*state);
}
me.finalize();
Ok(me)
}
}
}
impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
fn add_assign(&mut self, rhs: &Self) {
*self = (self.clone() + rhs.clone()).unwrap();
}
}
impl<S: Interpolatable> fmt::Display for Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.states.is_empty() {
write!(f, "Empty Trajectory!")
} else {
let dur = self.last().epoch() - self.first().epoch();
write!(
f,
"Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
match &self.name {
Some(name) => format!("of {name} "),
None => String::new(),
},
self.first().frame(),
self.first().epoch(),
self.last().epoch(),
dur,
dur.to_seconds(),
self.states.len()
)
}
}
}
impl<S: Interpolatable> fmt::Debug for Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self}",)
}
}
impl<S: Interpolatable> Default for Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
fn default() -> Self {
Self::new()
}
}
impl<S: Interpolatable> StateSpecTrait for Traj<S>
where
DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
{
fn ab_corr(&self) -> Option<Aberration> {
None
}
fn evaluate(&self, epoch: Epoch, _almanac: &Almanac) -> Result<Orbit, AnalysisError> {
self.at(epoch)
.map(|state| state.orbit())
.map_err(|e| AnalysisError::GenericAnalysisError {
err: format!("{e}"),
})
}
}