use core::marker::PhantomData;
use core::ops::{Add, Div, Mul, Neg, Sub};
use nalgebra::{UnitQuaternion, Vector3};
use serde::{Deserialize, Serialize};
use crate::epoch::{Epoch, Ut1, Utc};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum FrameCategory {
Eci,
Ecef,
LocalOrbital,
Body,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum FrameDescriptor {
SimpleEci,
SimpleEcef,
Gcrs,
Cirs,
Tirs,
Itrs,
Rsw,
Body,
}
impl FrameDescriptor {
pub const fn name(self) -> &'static str {
match self {
FrameDescriptor::SimpleEci => "SimpleEci",
FrameDescriptor::SimpleEcef => "SimpleEcef",
FrameDescriptor::Gcrs => "Gcrs",
FrameDescriptor::Cirs => "Cirs",
FrameDescriptor::Tirs => "Tirs",
FrameDescriptor::Itrs => "Itrs",
FrameDescriptor::Rsw => "Rsw",
FrameDescriptor::Body => "Body",
}
}
pub const fn category(self) -> FrameCategory {
match self {
FrameDescriptor::SimpleEci | FrameDescriptor::Gcrs | FrameDescriptor::Cirs => {
FrameCategory::Eci
}
FrameDescriptor::SimpleEcef | FrameDescriptor::Tirs | FrameDescriptor::Itrs => {
FrameCategory::Ecef
}
FrameDescriptor::Rsw => FrameCategory::LocalOrbital,
FrameDescriptor::Body => FrameCategory::Body,
}
}
}
mod sealed {
pub trait Sealed {}
}
pub trait Frame: sealed::Sealed {
const NAME: &'static str;
const DESCRIPTOR: FrameDescriptor;
}
pub trait Eci: Frame {}
pub trait Ecef: Frame {}
pub trait LocalOrbital: Frame {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SimpleEci;
impl sealed::Sealed for SimpleEci {}
impl Frame for SimpleEci {
const NAME: &'static str = "SimpleEci";
const DESCRIPTOR: FrameDescriptor = FrameDescriptor::SimpleEci;
}
impl Eci for SimpleEci {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SimpleEcef;
impl sealed::Sealed for SimpleEcef {}
impl Frame for SimpleEcef {
const NAME: &'static str = "SimpleEcef";
const DESCRIPTOR: FrameDescriptor = FrameDescriptor::SimpleEcef;
}
impl Ecef for SimpleEcef {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Gcrs;
impl sealed::Sealed for Gcrs {}
impl Frame for Gcrs {
const NAME: &'static str = "Gcrs";
const DESCRIPTOR: FrameDescriptor = FrameDescriptor::Gcrs;
}
impl Eci for Gcrs {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Cirs;
impl sealed::Sealed for Cirs {}
impl Frame for Cirs {
const NAME: &'static str = "Cirs";
const DESCRIPTOR: FrameDescriptor = FrameDescriptor::Cirs;
}
impl Eci for Cirs {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Tirs;
impl sealed::Sealed for Tirs {}
impl Frame for Tirs {
const NAME: &'static str = "Tirs";
const DESCRIPTOR: FrameDescriptor = FrameDescriptor::Tirs;
}
impl Ecef for Tirs {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Itrs;
impl sealed::Sealed for Itrs {}
impl Frame for Itrs {
const NAME: &'static str = "Itrs";
const DESCRIPTOR: FrameDescriptor = FrameDescriptor::Itrs;
}
impl Ecef for Itrs {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Rsw;
impl sealed::Sealed for Rsw {}
impl Frame for Rsw {
const NAME: &'static str = "Rsw";
const DESCRIPTOR: FrameDescriptor = FrameDescriptor::Rsw;
}
impl LocalOrbital for Rsw {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Body;
impl sealed::Sealed for Body {}
impl Frame for Body {
const NAME: &'static str = "Body";
const DESCRIPTOR: FrameDescriptor = FrameDescriptor::Body;
}
#[derive(Clone, Copy, PartialEq)]
pub struct Vec3<F>(Vector3<f64>, PhantomData<F>);
impl<F> Vec3<F> {
pub fn new(x: f64, y: f64, z: f64) -> Self {
Self(Vector3::new(x, y, z), PhantomData)
}
pub fn from_raw(v: Vector3<f64>) -> Self {
Self(v, PhantomData)
}
pub fn zeros() -> Self {
Self(Vector3::zeros(), PhantomData)
}
pub fn inner(&self) -> &Vector3<f64> {
&self.0
}
pub fn into_inner(self) -> Vector3<f64> {
self.0
}
pub fn x(&self) -> f64 {
self.0.x
}
pub fn y(&self) -> f64 {
self.0.y
}
pub fn z(&self) -> f64 {
self.0.z
}
pub fn magnitude(&self) -> f64 {
self.0.magnitude()
}
pub fn magnitude_squared(&self) -> f64 {
self.0.magnitude_squared()
}
pub fn normalize(&self) -> Self {
Self(self.0.normalize(), PhantomData)
}
pub fn dot(&self, other: &Self) -> f64 {
self.0.dot(&other.0)
}
pub fn cross(&self, other: &Self) -> Self {
Self(self.0.cross(&other.0), PhantomData)
}
pub fn is_finite(&self) -> bool {
self.0.iter().all(|x| x.is_finite())
}
}
impl<F: Frame> Vec3<F> {
pub const fn frame_descriptor() -> FrameDescriptor {
F::DESCRIPTOR
}
}
impl<F> core::fmt::Debug for Vec3<F> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Vec3<{}>({}, {}, {})",
core::any::type_name::<F>()
.rsplit("::")
.next()
.unwrap_or("?"),
self.0.x,
self.0.y,
self.0.z
)
}
}
impl<F> Add for Vec3<F> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0, PhantomData)
}
}
impl<F> Sub for Vec3<F> {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self(self.0 - rhs.0, PhantomData)
}
}
impl<F> Neg for Vec3<F> {
type Output = Self;
fn neg(self) -> Self {
Self(-self.0, PhantomData)
}
}
impl<F> Mul<f64> for Vec3<F> {
type Output = Self;
fn mul(self, rhs: f64) -> Self {
Self(self.0 * rhs, PhantomData)
}
}
impl<F> Mul<Vec3<F>> for f64 {
type Output = Vec3<F>;
fn mul(self, rhs: Vec3<F>) -> Vec3<F> {
Vec3(self * rhs.0, PhantomData)
}
}
impl<F> Div<f64> for Vec3<F> {
type Output = Self;
fn div(self, rhs: f64) -> Self {
Self(self.0 / rhs, PhantomData)
}
}
impl<F> core::ops::AddAssign for Vec3<F> {
fn add_assign(&mut self, rhs: Self) {
self.0 += rhs.0;
}
}
impl<F> core::ops::SubAssign for Vec3<F> {
fn sub_assign(&mut self, rhs: Self) {
self.0 -= rhs.0;
}
}
#[derive(Clone, Copy, PartialEq)]
pub struct Rotation<From, To>(UnitQuaternion<f64>, PhantomData<(From, To)>);
impl<From, To> Rotation<From, To> {
pub fn from_raw(q: UnitQuaternion<f64>) -> Self {
Self(q, PhantomData)
}
pub fn inner(&self) -> &UnitQuaternion<f64> {
&self.0
}
pub fn into_inner(self) -> UnitQuaternion<f64> {
self.0
}
pub fn transform(&self, v: &Vec3<From>) -> Vec3<To> {
Vec3(self.0.transform_vector(&v.0), PhantomData)
}
pub fn inverse(&self) -> Rotation<To, From> {
Rotation(self.0.inverse(), PhantomData)
}
pub fn then<C>(&self, other: &Rotation<To, C>) -> Rotation<From, C> {
Rotation(other.0 * self.0, PhantomData)
}
}
impl Rotation<SimpleEci, SimpleEcef> {
pub fn from_ut1(epoch: &Epoch<Ut1>) -> Self {
Self::from_era(epoch.era())
}
pub fn from_utc_assuming_ut1_eq_utc(epoch: &Epoch<Utc>) -> Self {
Self::from_ut1(&epoch.to_ut1_naive())
}
pub fn from_era(era: f64) -> Self {
let axis = nalgebra::Unit::new_normalize(Vector3::z());
Self::from_raw(UnitQuaternion::from_axis_angle(&axis, -era))
}
}
impl Rotation<SimpleEcef, SimpleEci> {
pub fn from_ut1(epoch: &Epoch<Ut1>) -> Self {
Self::from_era(epoch.era())
}
pub fn from_utc_assuming_ut1_eq_utc(epoch: &Epoch<Utc>) -> Self {
Self::from_ut1(&epoch.to_ut1_naive())
}
pub fn from_era(era: f64) -> Self {
let axis = nalgebra::Unit::new_normalize(Vector3::z());
Self::from_raw(UnitQuaternion::from_axis_angle(&axis, era))
}
}
impl<From, To> core::fmt::Debug for Rotation<From, To> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let from = core::any::type_name::<From>()
.rsplit("::")
.next()
.unwrap_or("?");
let to = core::any::type_name::<To>()
.rsplit("::")
.next()
.unwrap_or("?");
write!(f, "Rotation<{from}, {to}>({:?})", self.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn vec3_basic_ops() {
let a = Vec3::<Gcrs>::new(1.0, 2.0, 3.0);
let b = Vec3::<Gcrs>::new(4.0, 5.0, 6.0);
let sum = a + b;
assert_eq!(sum.x(), 5.0);
assert_eq!(sum.y(), 7.0);
assert_eq!(sum.z(), 9.0);
let diff = b - a;
assert_eq!(diff.x(), 3.0);
let neg = -a;
assert_eq!(neg.x(), -1.0);
let scaled = a * 2.0;
assert_eq!(scaled.x(), 2.0);
let scaled2 = 3.0 * a;
assert_eq!(scaled2.x(), 3.0);
let div = a / 2.0;
assert_eq!(div.x(), 0.5);
}
#[test]
fn vec3_magnitude_and_normalize() {
let v = Vec3::<Body>::new(3.0, 4.0, 0.0);
assert!((v.magnitude() - 5.0).abs() < 1e-15);
assert!((v.magnitude_squared() - 25.0).abs() < 1e-15);
let n = v.normalize();
assert!((n.magnitude() - 1.0).abs() < 1e-15);
assert!((n.x() - 0.6).abs() < 1e-15);
}
#[test]
fn vec3_dot_and_cross() {
let a = Vec3::<Gcrs>::new(1.0, 0.0, 0.0);
let b = Vec3::<Gcrs>::new(0.0, 1.0, 0.0);
assert!((a.dot(&b)).abs() < 1e-15);
let c = a.cross(&b);
assert!((c.z() - 1.0).abs() < 1e-15);
}
#[test]
fn vec3_add_assign() {
let mut a = Vec3::<Gcrs>::new(1.0, 2.0, 3.0);
a += Vec3::new(10.0, 20.0, 30.0);
assert_eq!(a.x(), 11.0);
}
#[test]
fn vec3_is_finite() {
assert!(Vec3::<Gcrs>::new(1.0, 2.0, 3.0).is_finite());
assert!(!Vec3::<Gcrs>::new(f64::NAN, 0.0, 0.0).is_finite());
assert!(!Vec3::<Gcrs>::new(0.0, f64::INFINITY, 0.0).is_finite());
}
#[test]
fn rotation_identity_is_noop() {
let r = Rotation::<Gcrs, Body>::from_raw(UnitQuaternion::identity());
let v = Vec3::<Gcrs>::new(1.0, 2.0, 3.0);
let result = r.transform(&v);
assert!((result.x() - 1.0).abs() < 1e-15);
assert!((result.y() - 2.0).abs() < 1e-15);
assert!((result.z() - 3.0).abs() < 1e-15);
}
#[test]
fn rotation_90deg_about_z() {
let axis = nalgebra::Unit::new_normalize(Vector3::z());
let q = UnitQuaternion::from_axis_angle(&axis, PI / 2.0);
let r = Rotation::<Gcrs, Body>::from_raw(q);
let v = Vec3::<Gcrs>::new(1.0, 0.0, 0.0);
let result = r.transform(&v);
assert!((result.x()).abs() < 1e-15);
assert!((result.y() - 1.0).abs() < 1e-15);
assert!((result.z()).abs() < 1e-15);
}
#[test]
fn rotation_inverse() {
let axis = nalgebra::Unit::new_normalize(Vector3::z());
let q = UnitQuaternion::from_axis_angle(&axis, PI / 4.0);
let r = Rotation::<Gcrs, Body>::from_raw(q);
let v = Vec3::<Gcrs>::new(1.0, 0.0, 0.0);
let body = r.transform(&v);
let back = r.inverse().transform(&body);
assert!((back.x() - 1.0).abs() < 1e-14);
assert!((back.y()).abs() < 1e-14);
}
#[test]
fn rotation_compose() {
let axis = nalgebra::Unit::new_normalize(Vector3::z());
let r_ab =
Rotation::<Gcrs, Body>::from_raw(UnitQuaternion::from_axis_angle(&axis, PI / 4.0));
let r_bc =
Rotation::<Body, Rsw>::from_raw(UnitQuaternion::from_axis_angle(&axis, PI / 4.0));
let r_ac: Rotation<Gcrs, Rsw> = r_ab.then(&r_bc);
let v = Vec3::<Gcrs>::new(1.0, 0.0, 0.0);
let result = r_ac.transform(&v);
assert!((result.x()).abs() < 1e-14);
assert!((result.y() - 1.0).abs() < 1e-14);
}
#[test]
fn frame_descriptor_name() {
assert_eq!(FrameDescriptor::SimpleEci.name(), "SimpleEci");
assert_eq!(FrameDescriptor::SimpleEcef.name(), "SimpleEcef");
assert_eq!(FrameDescriptor::Gcrs.name(), "Gcrs");
assert_eq!(FrameDescriptor::Cirs.name(), "Cirs");
assert_eq!(FrameDescriptor::Tirs.name(), "Tirs");
assert_eq!(FrameDescriptor::Itrs.name(), "Itrs");
assert_eq!(FrameDescriptor::Rsw.name(), "Rsw");
assert_eq!(FrameDescriptor::Body.name(), "Body");
}
#[test]
fn frame_descriptor_category() {
assert_eq!(FrameDescriptor::SimpleEci.category(), FrameCategory::Eci);
assert_eq!(FrameDescriptor::Gcrs.category(), FrameCategory::Eci);
assert_eq!(FrameDescriptor::Cirs.category(), FrameCategory::Eci);
assert_eq!(FrameDescriptor::SimpleEcef.category(), FrameCategory::Ecef);
assert_eq!(FrameDescriptor::Tirs.category(), FrameCategory::Ecef);
assert_eq!(FrameDescriptor::Itrs.category(), FrameCategory::Ecef);
assert_eq!(FrameDescriptor::Rsw.category(), FrameCategory::LocalOrbital);
assert_eq!(FrameDescriptor::Body.category(), FrameCategory::Body);
}
#[test]
fn frame_descriptor_via_trait() {
assert_eq!(<SimpleEci as Frame>::DESCRIPTOR, FrameDescriptor::SimpleEci);
assert_eq!(<Gcrs as Frame>::DESCRIPTOR, FrameDescriptor::Gcrs);
assert_eq!(<Cirs as Frame>::DESCRIPTOR, FrameDescriptor::Cirs);
assert_eq!(<Tirs as Frame>::DESCRIPTOR, FrameDescriptor::Tirs);
assert_eq!(<Itrs as Frame>::DESCRIPTOR, FrameDescriptor::Itrs);
assert_eq!(
<SimpleEcef as Frame>::DESCRIPTOR,
FrameDescriptor::SimpleEcef
);
assert_eq!(
Vec3::<SimpleEci>::frame_descriptor(),
FrameDescriptor::SimpleEci
);
}
#[test]
fn category_trait_bounds_gate_generic_api() {
fn magnitude_eci<F: Eci>(v: Vec3<F>) -> f64 {
v.magnitude()
}
assert_eq!(magnitude_eci(Vec3::<SimpleEci>::new(3.0, 4.0, 0.0)), 5.0);
assert_eq!(magnitude_eci(Vec3::<Gcrs>::new(0.0, 0.0, 7.0)), 7.0);
assert_eq!(magnitude_eci(Vec3::<Cirs>::new(5.0, 0.0, 12.0)), 13.0);
fn magnitude_ecef<F: Ecef>(v: Vec3<F>) -> f64 {
v.magnitude()
}
assert_eq!(magnitude_ecef(Vec3::<SimpleEcef>::new(3.0, 4.0, 0.0)), 5.0);
assert_eq!(magnitude_ecef(Vec3::<Tirs>::new(0.0, 0.0, 7.0)), 7.0);
assert_eq!(magnitude_ecef(Vec3::<Itrs>::new(5.0, 0.0, 12.0)), 13.0);
}
#[test]
fn from_era_zero_is_identity() {
let r = Rotation::<SimpleEci, SimpleEcef>::from_era(0.0);
let v = Vec3::<SimpleEci>::new(1.0, 2.0, 3.0);
let result = r.transform(&v);
assert!((result.x() - 1.0).abs() < 1e-14);
assert!((result.y() - 2.0).abs() < 1e-14);
assert!((result.z() - 3.0).abs() < 1e-14);
}
#[test]
fn from_era_90deg() {
let r = Rotation::<SimpleEci, SimpleEcef>::from_era(PI / 2.0);
let v = Vec3::<SimpleEci>::new(1.0, 0.0, 0.0);
let result = r.transform(&v);
assert!(result.x().abs() < 1e-14);
assert!((result.y() + 1.0).abs() < 1e-14);
assert!(result.z().abs() < 1e-14);
}
#[test]
fn from_era_roundtrip() {
let era = 1.234;
let r_ei = Rotation::<SimpleEci, SimpleEcef>::from_era(era);
let r_ie = Rotation::<SimpleEcef, SimpleEci>::from_era(era);
let v = Vec3::<SimpleEci>::new(100.0, 200.0, 300.0);
let ecef = r_ei.transform(&v);
let back = r_ie.transform(&ecef);
assert!((back.x() - v.x()).abs() < 1e-10);
assert!((back.y() - v.y()).abs() < 1e-10);
assert!((back.z() - v.z()).abs() < 1e-10);
}
#[test]
fn from_ut1_matches_from_era() {
use crate::epoch::Epoch;
let ut1 = Epoch::<Ut1>::from_jd_ut1(2460390.5);
let era = ut1.era();
let r_direct = Rotation::<SimpleEci, SimpleEcef>::from_era(era);
let r_via_ut1 = Rotation::<SimpleEci, SimpleEcef>::from_ut1(&ut1);
let v = Vec3::<SimpleEci>::new(6778.0, 0.0, 0.0);
let a = r_direct.transform(&v);
let b = r_via_ut1.transform(&v);
assert!((a.x() - b.x()).abs() < 1e-14);
assert!((a.y() - b.y()).abs() < 1e-14);
assert!((a.z() - b.z()).abs() < 1e-14);
}
#[test]
fn from_utc_assuming_ut1_eq_utc_matches_legacy_gmst() {
use crate::epoch::Epoch;
let utc = Epoch::from_gregorian(2024, 3, 20, 12, 0, 0.0);
let legacy_gmst = utc.gmst();
let r_new = Rotation::<SimpleEci, SimpleEcef>::from_utc_assuming_ut1_eq_utc(&utc);
let r_legacy = Rotation::<SimpleEci, SimpleEcef>::from_era(legacy_gmst);
let v = Vec3::<SimpleEci>::new(7000.0, 1000.0, 500.0);
let a = r_new.transform(&v);
let b = r_legacy.transform(&v);
assert!((a.x() - b.x()).abs() < 1e-14);
assert!((a.y() - b.y()).abs() < 1e-14);
assert!((a.z() - b.z()).abs() < 1e-14);
}
}