use arrow_array::Array;
use hifitime::Epoch;
use nalgebra::Vector3;
use ordered_float::OrderedFloat;
use parquet::errors::ParquetError;
use smallvec::SmallVec;
use std::collections::hash_map::Entry;
use std::io;
use std::sync::Arc;
use crate::constants::ArcSec;
use crate::conversion::arcsec_to_rad;
use crate::observers::Observer;
use crate::outfit::Outfit;
use crate::outfit_errors::OutfitError;
use crate::TrajectorySet;
use crate::{
constants::{ObjectNumber, JDTOMJD},
observations::Observation,
};
use arrow_array::array::{Float64Array, UInt32Array};
use camino::Utf8Path;
use parquet::arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ProjectionMask};
use ahash::RandomState;
use std::collections::HashMap;
pub type FastHashMap<K, V> = HashMap<K, V, RandomState>;
pub(crate) fn parquet_to_trajset(
trajectories: &mut TrajectorySet,
env_state: &mut Outfit,
parquet: &Utf8Path,
observer: Arc<Observer>,
error_ra: ArcSec,
error_dec: ArcSec,
batch_size: Option<usize>,
) -> Result<(), OutfitError> {
let error_ra_rad = arcsec_to_rad(error_ra);
let error_dec_rad = arcsec_to_rad(error_dec);
let uint16_obs = env_state.uint16_from_observer(observer);
let file = std::fs::File::open(parquet)?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
let parquet_metadata = builder.metadata();
let schema_descr = parquet_metadata.file_metadata().schema_descr();
let all_fields = schema_descr.columns();
let column_names = ["ra", "dec", "jd", "trajectory_id"];
let projection_indices: Vec<usize> = column_names
.iter()
.map(|name| {
all_fields
.iter()
.position(|f| f.name() == *name)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!("Column '{name}' not found in schema"),
)
})
})
.collect::<Result<_, _>>()?;
let mask = ProjectionMask::leaves(schema_descr, projection_indices);
let batch_size = batch_size.unwrap_or(8192);
let reader = builder
.with_projection(mask)
.with_batch_size(batch_size)
.build()?;
let ut1 = env_state.get_ut1_provider();
let obs_ref = env_state.get_observer_from_uint16(uint16_obs);
let mut pos_cache: FastHashMap<OrderedFloat<f64>, (Vector3<f64>, Vector3<f64>)> =
FastHashMap::with_capacity_and_hasher(4096, RandomState::default());
for maybe_batch in reader {
let batch = maybe_batch.map_err(ParquetError::from)?;
let len = batch.num_rows();
let ra_arr = batch
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "ra must be Float64Array"))?;
let dec_arr = batch
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "dec must be Float64Array")
})?;
let jd_arr = batch
.column(2)
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "jd must be Float64Array"))?;
let tid_arr = batch
.column(3)
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"trajectory_id must be UInt32Array",
)
})?;
let no_nulls = ra_arr.nulls().is_none()
&& dec_arr.nulls().is_none()
&& jd_arr.nulls().is_none()
&& tid_arr.nulls().is_none();
if no_nulls {
let ra_vals: &[f64] = ra_arr.values();
let dec_vals: &[f64] = dec_arr.values();
let jd_vals: &[f64] = jd_arr.values();
let tid_vals: &[u32] = tid_arr.values();
for i in 0..len {
let ra_rad = ra_vals[i].to_radians();
let dec_rad = dec_vals[i].to_radians();
let mjd_time = jd_vals[i] - JDTOMJD;
let key = OrderedFloat(mjd_time);
let (geo_pos, helio_pos) = match pos_cache.entry(key) {
Entry::Occupied(e) => *e.get(),
Entry::Vacant(v) => {
let epoch =
Epoch::from_mjd_in_time_scale(mjd_time, hifitime::TimeScale::TT);
let (geo, _vel) = obs_ref.pvobs(&epoch, ut1)?;
let helio = obs_ref.helio_position(env_state, &epoch, &geo)?;
v.insert((geo, helio));
(geo, helio)
}
};
let obs = Observation::with_positions(
uint16_obs,
ra_rad,
error_ra_rad,
dec_rad,
error_dec_rad,
mjd_time,
geo_pos,
helio_pos,
);
let obj = ObjectNumber::Int(tid_vals[i]);
trajectories
.entry(obj)
.or_insert_with(|| SmallVec::with_capacity(32))
.push(obs);
}
} else {
for i in 0..len {
if ra_arr.is_null(i)
|| dec_arr.is_null(i)
|| jd_arr.is_null(i)
|| tid_arr.is_null(i)
{
continue; }
let ra_rad: f64 = ra_arr.value(i).to_radians();
let dec_rad = dec_arr.value(i).to_radians();
let mjd_time = jd_arr.value(i) - JDTOMJD;
let tid = tid_arr.value(i);
let key = OrderedFloat(mjd_time);
let (geo_pos, helio_pos) = match pos_cache.entry(key) {
Entry::Occupied(e) => *e.get(),
Entry::Vacant(v) => {
let epoch =
Epoch::from_mjd_in_time_scale(mjd_time, hifitime::TimeScale::TT);
let (geo, _vel) = obs_ref.pvobs(&epoch, ut1)?;
let helio = obs_ref.helio_position(env_state, &epoch, &geo)?;
v.insert((geo, helio));
(geo, helio)
}
};
let obs = Observation::with_positions(
uint16_obs,
ra_rad,
error_ra_rad,
dec_rad,
error_dec_rad,
mjd_time,
geo_pos,
helio_pos,
);
trajectories
.entry(ObjectNumber::Int(tid))
.or_insert_with(|| SmallVec::with_capacity(32))
.push(obs);
}
}
}
Ok(())
}