use std::sync::Arc;
use ahash::AHashMap;
use arrow_array::{
Array, RecordBatch, StringArray, StringViewArray, UInt32Array,
cast::AsArray,
types::{Float64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type},
};
use datafusion::{
datasource::{
file_format::parquet::ParquetFormat,
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
},
error::DataFusionError,
prelude::*,
};
use object_store::{Error as ObjStoreError, ObjectStore};
use tokio::runtime::Runtime;
use url::Url;
use crate::{
NightId, TrajId,
coordinates::equatorial::EquCoord,
io::datafusion::{
input_uri::InputUri,
storage::{UriStoreError, resolve_input_uri},
},
observation_dataset::{
ObsDataset,
index::{NightIndexMap, ObsMapIndex, TrajIndexMap},
observation::ObservationInput,
},
observer::error_model::ObsErrorModel,
observer::{Observer, dataset::ObserverId, mpc::MpcCode},
photometry::{Filter, Photometry},
};
#[derive(Debug)]
pub enum LoadObsError {
NotFound(String),
Resolve(String),
DataFusion(DataFusionError),
Arrow(String),
}
impl std::fmt::Display for LoadObsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LoadObsError::NotFound(s) => write!(f, "resource not found: {s}"),
LoadObsError::Resolve(s) => write!(f, "URI resolution error: {s}"),
LoadObsError::DataFusion(e) => write!(f, "DataFusion error: {e}"),
LoadObsError::Arrow(s) => write!(f, "Arrow conversion error: {s}"),
}
}
}
impl std::error::Error for LoadObsError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
LoadObsError::DataFusion(e) => Some(e),
_ => None,
}
}
}
impl From<DataFusionError> for LoadObsError {
fn from(e: DataFusionError) -> Self {
LoadObsError::DataFusion(e)
}
}
impl From<UriStoreError> for LoadObsError {
fn from(e: UriStoreError) -> Self {
LoadObsError::Resolve(e.to_string())
}
}
pub enum ContiguousChoice {
ContiguousNight,
ContiguousTraj,
}
pub struct LoadObsArgs {
pub error_model: Option<ObsErrorModel>,
pub contiguous_choice: Option<ContiguousChoice>,
}
impl Default for LoadObsArgs {
fn default() -> Self {
Self {
error_model: None,
contiguous_choice: Some(ContiguousChoice::ContiguousNight),
}
}
}
pub fn load_obs_sync(input: &InputUri, args: LoadObsArgs) -> Result<ObsDataset, LoadObsError> {
let rt = Runtime::new().expect("failed to build tokio runtime");
rt.block_on(load_obs_from_parquet_uri(input, args))
}
pub async fn load_obs_from_parquet_uri(
input: &InputUri,
args: LoadObsArgs,
) -> Result<ObsDataset, LoadObsError> {
let url = input
.parse()
.map_err(|e| LoadObsError::Resolve(format!("invalid URI: {e}")))?;
let resolved = resolve_input_uri(input)?;
let scheme = url.scheme();
if scheme == "file" || scheme == "hdfs" {
match resolved.store.head(&resolved.path).await {
Ok(_) => {}
Err(ObjStoreError::NotFound { .. }) => {
return Err(LoadObsError::NotFound(input.0.clone()));
}
Err(e) => {
return Err(LoadObsError::Resolve(format!(
"store head() error for '{}': {e:?}",
input.0
)));
}
}
}
let ctx = build_session_context_with_store(&url, resolved.store)?;
let listing_url = ListingTableUrl::parse(input.0.as_str())
.map_err(|e| LoadObsError::Resolve(e.to_string()))?;
let listing_opts =
ListingOptions::new(Arc::new(ParquetFormat::default())).with_file_extension(".parquet");
let listing_cfg = ListingTableConfig::new(listing_url)
.with_listing_options(listing_opts)
.infer_schema(&ctx.state())
.await
.map_err(LoadObsError::DataFusion)?;
let table = ListingTable::try_new(listing_cfg).map_err(LoadObsError::DataFusion)?;
let df: DataFrame = ctx.read_table(Arc::new(table))?;
let sort_col: Option<&str> = match &args.contiguous_choice {
Some(ContiguousChoice::ContiguousNight) => Some("night_id"),
Some(ContiguousChoice::ContiguousTraj) => Some("traj_id"),
None => None,
};
let df = if let Some(col_name) = sort_col {
if df.schema().field_with_name(None, col_name).is_ok() {
df.sort(vec![col(col_name).sort(true, true)])?
} else {
df
}
} else {
df
};
let batches: Vec<RecordBatch> = df.collect().await?;
build_obs_dataset_from_batches(&batches, args)
}
fn build_session_context_with_store(
url: &Url,
store: Arc<dyn ObjectStore>,
) -> Result<SessionContext, LoadObsError> {
let ctx = SessionContext::new();
ctx.runtime_env().register_object_store(url, store);
Ok(ctx)
}
struct ContiguousGroupTracker<K, I> {
current: Option<(K, usize)>,
make_entry: fn(usize, usize) -> I,
}
impl<K: Clone + Eq, I> ContiguousGroupTracker<K, I> {
fn new(make_entry: fn(usize, usize) -> I) -> Self {
Self {
current: None,
make_entry,
}
}
fn on_row(&mut self, row_idx: usize, key: Option<K>) -> Option<(K, I)> {
match key {
Some(k) => match &self.current {
Some((ck, _)) if *ck == k => None,
Some((prev_key, start)) => {
let entry = (self.make_entry)(*start, row_idx);
let finished = (prev_key.clone(), entry);
self.current = Some((k, row_idx));
Some(finished)
}
None => {
self.current = Some((k, row_idx));
None
}
},
None => self.current.take().map(|(key, start)| {
let entry = (self.make_entry)(start, row_idx);
(key, entry)
}),
}
}
fn finalize(mut self, n: usize) -> Option<(K, I)> {
self.current.take().map(|(key, start)| {
let entry = (self.make_entry)(start, n);
(key, entry)
})
}
}
fn build_obs_dataset_from_batches(
batches: &[RecordBatch],
args: LoadObsArgs,
) -> Result<ObsDataset, LoadObsError> {
let mut observations: Vec<ObservationInput> = Vec::new();
let mut custom_observers: Vec<Observer> = Vec::with_capacity(16);
let mut observer_lookup: AHashMap<Observer, usize> = AHashMap::with_capacity(16);
let mut night_map: Option<NightIndexMap> = None;
let mut traj_map: Option<TrajIndexMap> = None;
let mut schema_checked = false;
let night_is_contiguous = matches!(
args.contiguous_choice,
Some(ContiguousChoice::ContiguousNight)
);
let traj_is_contiguous = matches!(
args.contiguous_choice,
Some(ContiguousChoice::ContiguousTraj)
);
let mut night_tracker: ContiguousGroupTracker<NightId, ObsMapIndex> =
ContiguousGroupTracker::new(|start, end| ObsMapIndex::Contiguous { start, end });
let mut traj_tracker: ContiguousGroupTracker<TrajId, ObsMapIndex> =
ContiguousGroupTracker::new(|start, end| ObsMapIndex::Contiguous { start, end });
let mut global_row = 0usize;
for batch in batches {
if !schema_checked {
schema_checked = true;
if batch.schema().index_of("night_id").is_ok() {
night_map = Some(NightIndexMap::new());
}
if batch.schema().index_of("traj_id").is_ok() {
traj_map = Some(TrajIndexMap::new());
}
}
process_batch(
batch,
&mut observations,
&mut custom_observers,
&mut observer_lookup,
&mut night_map,
&mut traj_map,
&mut night_tracker,
&mut traj_tracker,
night_is_contiguous,
traj_is_contiguous,
&mut global_row,
)?;
}
let total = observations.len();
if night_is_contiguous
&& let (Some(map), Some((key, entry))) = (&mut night_map, night_tracker.finalize(total))
{
map.insert(key, entry);
}
if traj_is_contiguous
&& let (Some(map), Some((key, entry))) = (&mut traj_map, traj_tracker.finalize(total))
{
map.insert(key, entry);
}
Ok(ObsDataset::new(
observations,
custom_observers,
args.error_model,
night_map,
traj_map,
))
}
#[allow(clippy::too_many_arguments)]
fn process_batch(
batch: &RecordBatch,
observations: &mut Vec<ObservationInput>,
custom_observers: &mut Vec<Observer>,
observer_lookup: &mut AHashMap<Observer, usize>,
night_map: &mut Option<NightIndexMap>,
traj_map: &mut Option<TrajIndexMap>,
night_tracker: &mut ContiguousGroupTracker<NightId, ObsMapIndex>,
traj_tracker: &mut ContiguousGroupTracker<TrajId, ObsMapIndex>,
night_is_contiguous: bool,
traj_is_contiguous: bool,
global_row: &mut usize,
) -> Result<(), LoadObsError> {
let n = batch.num_rows();
let ids = col_u64(batch, "id")?;
let ra = col_f64(batch, "ra")?;
let ra_err = col_f64(batch, "ra_err")?;
let dec = col_f64(batch, "dec")?;
let dec_err = col_f64(batch, "dec_err")?;
let magnitude = col_f64(batch, "magnitude")?;
let mag_err = col_f64(batch, "mag_err")?;
let mjd_tt = col_f64(batch, "mjd_tt")?;
let filter_col = col_filter(batch, "filter")?;
let obs_lon = opt_col_f64(batch, "obs_lon");
let obs_lat = opt_col_f64(batch, "obs_lat");
let obs_alt = opt_col_f64(batch, "obs_alt");
let obs_ra_acc = opt_col_f64(batch, "obs_ra_acc");
let obs_dec_acc = opt_col_f64(batch, "obs_dec_acc");
let mpc_code_col = opt_col_string(batch, "mpc_code_obs");
let night_id_col = opt_col_u32(batch, "night_id");
let traj_id_col = TrajIdCol::from_batch(batch, "traj_id");
for i in 0..n {
let row_idx = observations.len();
macro_rules! require_non_null {
($arr:expr, $name:literal) => {
if $arr.is_null(i) {
return Err(LoadObsError::Arrow(format!(
"null in required column '{}' at global row {}",
$name, *global_row
)));
}
};
}
require_non_null!(ids, "id");
require_non_null!(ra, "ra");
require_non_null!(ra_err, "ra_err");
require_non_null!(dec, "dec");
require_non_null!(dec_err, "dec_err");
require_non_null!(magnitude, "magnitude");
require_non_null!(mag_err, "mag_err");
require_non_null!(mjd_tt, "mjd_tt");
filter_col.require_non_null(i, *global_row)?;
let observer_id = resolve_and_intern_observer(
i,
*global_row,
obs_lon.as_ref(),
obs_lat.as_ref(),
obs_alt.as_ref(),
obs_ra_acc.as_ref(),
obs_dec_acc.as_ref(),
mpc_code_col.as_ref(),
custom_observers,
observer_lookup,
)?;
if let Some(map) = night_map.as_mut() {
let night_id = night_id_col
.as_ref()
.and_then(|c| c.value_at(i))
.map(NightId);
if night_is_contiguous {
if let Some((key, entry)) = night_tracker.on_row(row_idx, night_id) {
map.insert(key, entry);
}
} else if let Some(nid) = night_id {
map.entry(nid)
.or_insert_with(|| ObsMapIndex::Split(Vec::new()))
.push_split(row_idx);
}
}
if let Some(map) = traj_map.as_mut() {
let traj_id = traj_id_col.value_at(i);
if traj_is_contiguous {
if let Some((key, entry)) = traj_tracker.on_row(row_idx, traj_id) {
map.insert(key, entry);
}
} else if let Some(tid) = traj_id {
map.entry(tid)
.or_insert_with(|| ObsMapIndex::Split(Vec::new()))
.push_split(row_idx);
}
}
observations.push(ObservationInput {
id: ids.value(i),
equ_coord: EquCoord::new(ra.value(i), ra_err.value(i), dec.value(i), dec_err.value(i)),
photometry: Photometry {
magnitude: magnitude.value(i),
error: mag_err.value(i),
filter: filter_col.value_at(i),
},
mjd_tt: mjd_tt.value(i),
observer: observer_id,
});
*global_row += 1;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn resolve_and_intern_observer(
i: usize,
global_row: usize,
obs_lon: Option<&OptF64Col<'_>>,
obs_lat: Option<&OptF64Col<'_>>,
obs_alt: Option<&OptF64Col<'_>>,
obs_ra_acc: Option<&OptF64Col<'_>>,
obs_dec_acc: Option<&OptF64Col<'_>>,
mpc_code: Option<&StringCol<'_>>,
custom_observers: &mut Vec<Observer>,
observer_lookup: &mut AHashMap<Observer, usize>,
) -> Result<Option<ObserverId>, LoadObsError> {
if let Some(col) = mpc_code
&& let Some(code_str) = col.value_at(i)
{
let bytes: MpcCode = code_str.as_bytes().try_into().map_err(|_| {
LoadObsError::Arrow(format!(
"invalid MPC code '{code_str}' at global row {global_row}: must be exactly 3 ASCII bytes"
))
})?;
return Ok(Some(ObserverId::MpcCode(bytes)));
}
let lon = obs_lon.and_then(|c| c.value_at(i));
let lat = obs_lat.and_then(|c| c.value_at(i));
let alt = obs_alt.and_then(|c| c.value_at(i));
match (lon, lat, alt) {
(Some(lon), Some(lat), Some(alt)) => {
let ra_acc = obs_ra_acc
.and_then(|c| c.value_at(i))
.ok_or_else(|| {
LoadObsError::Arrow(format!(
"obs_ra_acc is null at global row {global_row} but geodetic triplet is fully set"
))
})?;
let dec_acc = obs_dec_acc
.and_then(|c| c.value_at(i))
.ok_or_else(|| {
LoadObsError::Arrow(format!(
"obs_dec_acc is null at global row {global_row} but geodetic triplet is fully set"
))
})?;
let observer = Observer::new(lon, lat, alt, None, Some(ra_acc), Some(dec_acc))
.map_err(|e| {
LoadObsError::Arrow(format!("invalid observer at global row {global_row}: {e}"))
})?;
let idx = match observer_lookup.get(&observer) {
Some(&idx) => idx,
None => {
let idx = custom_observers.len();
custom_observers.push(observer.clone());
observer_lookup.insert(observer, idx);
idx
}
};
Ok(Some(ObserverId::IntId(idx)))
}
(None, None, None) => Ok(None),
_ => Err(LoadObsError::Arrow(format!(
"partial geodetic triplet (obs_lon/obs_lat/obs_alt) at global row {global_row}: \
all three must be either all non-null or all null"
))),
}
}
fn col_index(batch: &RecordBatch, name: &str) -> Result<usize, LoadObsError> {
batch
.schema()
.index_of(name)
.map_err(|_| LoadObsError::Arrow(format!("missing required column '{name}'")))
}
fn col_u64<'a>(
batch: &'a RecordBatch,
name: &str,
) -> Result<&'a arrow_array::PrimitiveArray<UInt64Type>, LoadObsError> {
let idx = col_index(batch, name)?;
batch
.column(idx)
.as_primitive_opt::<UInt64Type>()
.ok_or_else(|| LoadObsError::Arrow(format!("column '{name}' is not UInt64")))
}
fn col_f64<'a>(
batch: &'a RecordBatch,
name: &str,
) -> Result<&'a arrow_array::PrimitiveArray<Float64Type>, LoadObsError> {
let idx = col_index(batch, name)?;
batch
.column(idx)
.as_primitive_opt::<Float64Type>()
.ok_or_else(|| LoadObsError::Arrow(format!("column '{name}' is not Float64")))
}
struct OptF64Col<'a>(&'a arrow_array::PrimitiveArray<Float64Type>);
impl OptF64Col<'_> {
fn value_at(&self, i: usize) -> Option<f64> {
if self.0.is_null(i) {
None
} else {
Some(self.0.value(i))
}
}
}
fn opt_col_f64<'a>(batch: &'a RecordBatch, name: &str) -> Option<OptF64Col<'a>> {
let idx = batch.schema().index_of(name).ok()?;
let arr = batch.column(idx).as_primitive_opt::<Float64Type>()?;
Some(OptF64Col(arr))
}
struct OptU32Col<'a>(&'a arrow_array::PrimitiveArray<UInt32Type>);
impl OptU32Col<'_> {
fn value_at(&self, i: usize) -> Option<u32> {
if self.0.is_null(i) {
None
} else {
Some(self.0.value(i))
}
}
}
fn opt_col_u32<'a>(batch: &'a RecordBatch, name: &str) -> Option<OptU32Col<'a>> {
let idx = batch.schema().index_of(name).ok()?;
let arr = batch.column(idx).as_primitive_opt::<UInt32Type>()?;
Some(OptU32Col(arr))
}
enum StringCol<'a> {
Utf8(&'a StringArray),
View(&'a StringViewArray),
}
impl StringCol<'_> {
fn value_at(&self, i: usize) -> Option<&str> {
match self {
StringCol::Utf8(a) => {
if a.is_null(i) {
None
} else {
Some(a.value(i))
}
}
StringCol::View(a) => {
if a.is_null(i) {
None
} else {
Some(a.value(i))
}
}
}
}
}
fn opt_col_string<'a>(batch: &'a RecordBatch, name: &str) -> Option<StringCol<'a>> {
let idx = batch.schema().index_of(name).ok()?;
let col = batch.column(idx);
if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
return Some(StringCol::Utf8(arr));
}
if let Some(arr) = col.as_any().downcast_ref::<StringViewArray>() {
return Some(StringCol::View(arr));
}
None
}
enum FilterCol<'a> {
Str(StringCol<'a>),
U8(&'a arrow_array::PrimitiveArray<UInt8Type>),
U16(&'a arrow_array::PrimitiveArray<UInt16Type>),
U32(&'a arrow_array::PrimitiveArray<UInt32Type>),
}
impl FilterCol<'_> {
fn value_at(&self, i: usize) -> Filter {
match self {
FilterCol::Str(sc) => Filter::String(sc.value_at(i).unwrap_or("").to_owned()),
FilterCol::U8(arr) => Filter::Int(u32::from(arr.value(i))),
FilterCol::U16(arr) => Filter::Int(u32::from(arr.value(i))),
FilterCol::U32(arr) => Filter::Int(arr.value(i)),
}
}
fn require_non_null(&self, i: usize, global_row: usize) -> Result<(), LoadObsError> {
let is_null = match self {
FilterCol::Str(StringCol::Utf8(a)) => a.is_null(i),
FilterCol::Str(StringCol::View(a)) => a.is_null(i),
FilterCol::U8(a) => a.is_null(i),
FilterCol::U16(a) => a.is_null(i),
FilterCol::U32(a) => a.is_null(i),
};
if is_null {
Err(LoadObsError::Arrow(format!(
"null in required column 'filter' at global row {global_row}"
)))
} else {
Ok(())
}
}
}
fn col_filter<'a>(batch: &'a RecordBatch, name: &str) -> Result<FilterCol<'a>, LoadObsError> {
let idx = col_index(batch, name)?;
let col = batch.column(idx);
if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
return Ok(FilterCol::Str(StringCol::Utf8(arr)));
}
if let Some(arr) = col.as_any().downcast_ref::<StringViewArray>() {
return Ok(FilterCol::Str(StringCol::View(arr)));
}
if let Some(arr) = col.as_primitive_opt::<UInt8Type>() {
return Ok(FilterCol::U8(arr));
}
if let Some(arr) = col.as_primitive_opt::<UInt16Type>() {
return Ok(FilterCol::U16(arr));
}
if let Some(arr) = col.as_primitive_opt::<UInt32Type>() {
return Ok(FilterCol::U32(arr));
}
Err(LoadObsError::Arrow(format!(
"column '{name}' has an unsupported type for filter \
(expected Utf8, Utf8View, UInt8, UInt16, or UInt32)"
)))
}
enum TrajIdCol<'a> {
Int(&'a UInt32Array),
Str(StringCol<'a>),
Absent,
}
impl TrajIdCol<'_> {
fn from_batch<'a>(batch: &'a RecordBatch, name: &str) -> TrajIdCol<'a> {
let Some(idx) = batch.schema().index_of(name).ok() else {
return TrajIdCol::Absent;
};
let col = batch.column(idx);
if let Some(arr) = col.as_any().downcast_ref::<UInt32Array>() {
return TrajIdCol::Int(arr);
}
if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
return TrajIdCol::Str(StringCol::Utf8(arr));
}
if let Some(arr) = col.as_any().downcast_ref::<StringViewArray>() {
return TrajIdCol::Str(StringCol::View(arr));
}
TrajIdCol::Absent
}
fn value_at(&self, i: usize) -> Option<TrajId> {
match self {
TrajIdCol::Int(arr) => {
if arr.is_null(i) {
None
} else {
Some(TrajId::Int(arr.value(i)))
}
}
TrajIdCol::Str(sc) => sc.value_at(i).map(|s| TrajId::Str(s.to_owned())),
TrajIdCol::Absent => None,
}
}
}
#[cfg(test)]
mod datafusion_loader_tests {
use super::*;
use arrow_array::{ArrayRef, Float64Array, StringArray, UInt32Array, UInt64Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
fn base_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt64, false),
Field::new("ra", DataType::Float64, false),
Field::new("ra_err", DataType::Float64, false),
Field::new("dec", DataType::Float64, false),
Field::new("dec_err", DataType::Float64, false),
Field::new("magnitude", DataType::Float64, false),
Field::new("mag_err", DataType::Float64, false),
Field::new("filter", DataType::Utf8, false),
Field::new("mjd_tt", DataType::Float64, false),
]))
}
fn make_base_batch(n_rows: usize) -> RecordBatch {
let schema = base_schema();
let ids: Vec<u64> = (0..n_rows as u64).collect();
let vals: Vec<f64> = (0..n_rows).map(|i| i as f64).collect();
let strs: Vec<&str> = (0..n_rows).map(|_| "G").collect();
RecordBatch::try_new(
schema,
vec![
Arc::new(UInt64Array::from(ids)) as ArrayRef,
Arc::new(Float64Array::from(vals.clone())) as ArrayRef,
Arc::new(Float64Array::from(
vals.iter().map(|_| 0.001).collect::<Vec<f64>>(),
)) as ArrayRef,
Arc::new(Float64Array::from(vals.clone())) as ArrayRef,
Arc::new(Float64Array::from(
vals.iter().map(|_| 0.001).collect::<Vec<f64>>(),
)) as ArrayRef,
Arc::new(Float64Array::from(
vals.iter().map(|_| 15.0).collect::<Vec<f64>>(),
)) as ArrayRef,
Arc::new(Float64Array::from(
vals.iter().map(|_| 0.05).collect::<Vec<f64>>(),
)) as ArrayRef,
Arc::new(StringArray::from(strs)) as ArrayRef,
Arc::new(Float64Array::from(
vals.iter().map(|_| 60000.0).collect::<Vec<f64>>(),
)) as ArrayRef,
],
)
.unwrap()
}
#[test]
fn base_columns_only_builds_dataset_with_no_observer() {
let batch = make_base_batch(3);
let ds = build_obs_dataset_from_batches(&[batch], LoadObsArgs::default()).unwrap();
assert_eq!(ds.observation_count(), 3);
for obs in ds.iter_observations() {
assert!(obs.observer.is_none());
}
}
#[test]
fn mpc_code_obs_column_sets_mpc_observer() {
let mut schema_fields = base_schema().fields().to_vec();
schema_fields.push(Arc::new(Field::new("mpc_code_obs", DataType::Utf8, true)));
let schema = Arc::new(Schema::new(schema_fields));
let base = make_base_batch(1);
let mpc: ArrayRef = Arc::new(StringArray::from(vec![Some("I41")]));
let mut cols = base.columns().to_vec();
cols.push(mpc);
let batch = RecordBatch::try_new(schema, cols).unwrap();
let ds = build_obs_dataset_from_batches(&[batch], LoadObsArgs::default()).unwrap();
let obs: Vec<_> = ds.iter_observations().collect();
assert_eq!(obs.len(), 1);
assert!(
obs[0].observer == Some(ObserverId::MpcCode(*b"I41")),
"expected MpcCode(b\"I41\"), got {:?}",
obs[0].observer
);
}
#[test]
fn missing_required_column_returns_arrow_error() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt64, false),
Field::new("ra", DataType::Float64, false),
Field::new("ra_err", DataType::Float64, false),
Field::new("dec", DataType::Float64, false),
Field::new("dec_err", DataType::Float64, false),
Field::new("mag_err", DataType::Float64, false),
Field::new("filter", DataType::Utf8, false),
Field::new("mjd_tt", DataType::Float64, false),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(UInt64Array::from(vec![1u64])) as ArrayRef,
Arc::new(Float64Array::from(vec![1.0f64])) as ArrayRef,
Arc::new(Float64Array::from(vec![0.001f64])) as ArrayRef,
Arc::new(Float64Array::from(vec![1.0f64])) as ArrayRef,
Arc::new(Float64Array::from(vec![0.001f64])) as ArrayRef,
Arc::new(Float64Array::from(vec![0.05f64])) as ArrayRef,
Arc::new(StringArray::from(vec!["G"])) as ArrayRef,
Arc::new(Float64Array::from(vec![60000.0f64])) as ArrayRef,
],
)
.unwrap();
let err = build_obs_dataset_from_batches(&[batch], LoadObsArgs::default()).unwrap_err();
match err {
LoadObsError::Arrow(msg) => {
assert!(msg.contains("magnitude"), "msg={msg}");
}
other => panic!("expected Arrow error, got: {other:?}"),
}
}
#[test]
fn partial_geodetic_triplet_returns_arrow_error() {
let mut schema_fields = base_schema().fields().to_vec();
schema_fields.push(Arc::new(Field::new("obs_lon", DataType::Float64, true)));
schema_fields.push(Arc::new(Field::new("obs_lat", DataType::Float64, true)));
let schema = Arc::new(Schema::new(schema_fields));
let base = make_base_batch(1);
let lon: ArrayRef = Arc::new(Float64Array::from(vec![Some(0.1f64)]));
let lat: ArrayRef = Arc::new(Float64Array::from(vec![Some(0.2f64)]));
let mut cols = base.columns().to_vec();
cols.push(lon);
cols.push(lat);
let batch = RecordBatch::try_new(schema, cols).unwrap();
let err = build_obs_dataset_from_batches(&[batch], LoadObsArgs::default()).unwrap_err();
match err {
LoadObsError::Arrow(msg) => {
assert!(msg.contains("partial geodetic"), "msg={msg}");
}
other => panic!("expected Arrow error, got: {other:?}"),
}
}
#[test]
fn night_id_column_builds_night_index() {
let mut schema_fields = base_schema().fields().to_vec();
schema_fields.push(Arc::new(Field::new("night_id", DataType::UInt32, true)));
let schema = Arc::new(Schema::new(schema_fields));
let base = make_base_batch(3);
let nights: ArrayRef = Arc::new(UInt32Array::from(vec![1u32, 1u32, 2u32]));
let mut cols = base.columns().to_vec();
cols.push(nights);
let batch = RecordBatch::try_new(schema, cols).unwrap();
let ds = build_obs_dataset_from_batches(&[batch], LoadObsArgs::default()).unwrap();
assert_eq!(ds.observation_count(), 3);
let index = ds.index_ref();
assert!(index.obs_index_by_night.is_some());
let night_map = index.obs_index_by_night.as_ref().unwrap();
assert!(night_map.contains_key(&NightId(1)));
assert!(night_map.contains_key(&NightId(2)));
}
#[test]
fn filter_column_uint8_is_accepted() {
use arrow_array::UInt8Array;
use datafusion::arrow::datatypes::DataType;
let mut fields = base_schema().fields().to_vec();
let filter_pos = fields.iter().position(|f| f.name() == "filter").unwrap();
fields[filter_pos] = Arc::new(Field::new("filter", DataType::UInt8, false));
let schema = Arc::new(Schema::new(fields));
let base = make_base_batch(2);
let mut cols = base.columns().to_vec();
cols[7] = Arc::new(UInt8Array::from(vec![1u8, 2u8])) as ArrayRef;
let batch = RecordBatch::try_new(schema, cols).unwrap();
let ds = build_obs_dataset_from_batches(&[batch], LoadObsArgs::default()).unwrap();
let obs: Vec<_> = ds.iter_observations().collect();
assert_eq!(obs.len(), 2);
assert!(matches!(obs[0].photometry.filter, Filter::Int(1)));
assert!(matches!(obs[1].photometry.filter, Filter::Int(2)));
}
#[test]
fn filter_column_uint16_is_accepted() {
use arrow_array::UInt16Array;
use datafusion::arrow::datatypes::DataType;
let mut fields = base_schema().fields().to_vec();
let filter_pos = fields.iter().position(|f| f.name() == "filter").unwrap();
fields[filter_pos] = Arc::new(Field::new("filter", DataType::UInt16, false));
let schema = Arc::new(Schema::new(fields));
let base = make_base_batch(2);
let mut cols = base.columns().to_vec();
cols[7] = Arc::new(UInt16Array::from(vec![10u16, 20u16])) as ArrayRef;
let batch = RecordBatch::try_new(schema, cols).unwrap();
let ds = build_obs_dataset_from_batches(&[batch], LoadObsArgs::default()).unwrap();
let obs: Vec<_> = ds.iter_observations().collect();
assert_eq!(obs.len(), 2);
assert!(matches!(obs[0].photometry.filter, Filter::Int(10)));
assert!(matches!(obs[1].photometry.filter, Filter::Int(20)));
}
#[test]
fn traj_id_contiguous_choice_builds_contiguous_index() {
let mut schema_fields = base_schema().fields().to_vec();
schema_fields.push(Arc::new(Field::new("traj_id", DataType::UInt32, true)));
let schema = Arc::new(Schema::new(schema_fields));
let base = make_base_batch(4);
let trajs: ArrayRef = Arc::new(UInt32Array::from(vec![
Some(7u32),
Some(7u32),
Some(9u32),
Some(9u32),
]));
let mut cols = base.columns().to_vec();
cols.push(trajs);
let batch = RecordBatch::try_new(schema, cols).unwrap();
let ds = build_obs_dataset_from_batches(
&[batch],
LoadObsArgs {
contiguous_choice: Some(ContiguousChoice::ContiguousTraj),
..Default::default()
},
)
.unwrap();
let index = ds.index_ref();
let traj_map = index.obs_index_by_trajectory.as_ref().unwrap();
match traj_map.get(&TrajId::Int(7)).unwrap() {
ObsMapIndex::Contiguous { start, end } => {
assert_eq!(*start, 0);
assert_eq!(*end, 2);
}
ObsMapIndex::Split(_) => panic!("expected Contiguous for traj 7"),
}
match traj_map.get(&TrajId::Int(9)).unwrap() {
ObsMapIndex::Contiguous { start, end } => {
assert_eq!(*start, 2);
assert_eq!(*end, 4);
}
ObsMapIndex::Split(_) => panic!("expected Contiguous for traj 9"),
}
}
#[test]
fn night_id_contiguous_choice_builds_contiguous_index() {
let mut schema_fields = base_schema().fields().to_vec();
schema_fields.push(Arc::new(Field::new("night_id", DataType::UInt32, true)));
let schema = Arc::new(Schema::new(schema_fields));
let base = make_base_batch(3);
let nights: ArrayRef = Arc::new(UInt32Array::from(vec![1u32, 1u32, 2u32]));
let mut cols = base.columns().to_vec();
cols.push(nights);
let batch = RecordBatch::try_new(schema, cols).unwrap();
let ds = build_obs_dataset_from_batches(
&[batch],
LoadObsArgs {
contiguous_choice: Some(ContiguousChoice::ContiguousNight),
..Default::default()
},
)
.unwrap();
let index = ds.index_ref();
let night_map = index.obs_index_by_night.as_ref().unwrap();
match night_map.get(&NightId(1)).unwrap() {
ObsMapIndex::Contiguous { start, end } => {
assert_eq!(*start, 0);
assert_eq!(*end, 2);
}
ObsMapIndex::Split(_) => panic!("expected Contiguous for night 1"),
}
match night_map.get(&NightId(2)).unwrap() {
ObsMapIndex::Contiguous { start, end } => {
assert_eq!(*start, 2);
assert_eq!(*end, 3);
}
ObsMapIndex::Split(_) => panic!("expected Contiguous for night 2"),
}
}
#[test]
fn night_id_no_contiguous_choice_builds_split_index() {
let mut schema_fields = base_schema().fields().to_vec();
schema_fields.push(Arc::new(Field::new("night_id", DataType::UInt32, true)));
let schema = Arc::new(Schema::new(schema_fields));
let base = make_base_batch(3);
let nights: ArrayRef = Arc::new(UInt32Array::from(vec![1u32, 1u32, 2u32]));
let mut cols = base.columns().to_vec();
cols.push(nights);
let batch = RecordBatch::try_new(schema, cols).unwrap();
let ds = build_obs_dataset_from_batches(
&[batch],
LoadObsArgs {
contiguous_choice: None,
..Default::default()
},
)
.unwrap();
let index = ds.index_ref();
let night_map = index.obs_index_by_night.as_ref().unwrap();
match night_map.get(&NightId(1)).unwrap() {
ObsMapIndex::Split(v) => assert_eq!(v, &[0, 1]),
ObsMapIndex::Contiguous { .. } => panic!("expected Split without contiguous choice"),
}
match night_map.get(&NightId(2)).unwrap() {
ObsMapIndex::Split(v) => assert_eq!(v, &[2]),
ObsMapIndex::Contiguous { .. } => panic!("expected Split without contiguous choice"),
}
}
#[test]
fn load_obs_sync_reads_local_parquet() {
use crate::io::datafusion::loader::load_obs_sync;
use arrow_array::RecordBatch;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use parquet::arrow::ArrowWriter;
use std::fs::File;
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("obs.parquet");
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt64, false),
Field::new("ra", DataType::Float64, false),
Field::new("ra_err", DataType::Float64, false),
Field::new("dec", DataType::Float64, false),
Field::new("dec_err", DataType::Float64, false),
Field::new("magnitude", DataType::Float64, false),
Field::new("mag_err", DataType::Float64, false),
Field::new("filter", DataType::Utf8, false),
Field::new("mjd_tt", DataType::Float64, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt64Array::from(vec![42u64])) as ArrayRef,
Arc::new(Float64Array::from(vec![1.0f64])) as ArrayRef,
Arc::new(Float64Array::from(vec![0.001f64])) as ArrayRef,
Arc::new(Float64Array::from(vec![0.5f64])) as ArrayRef,
Arc::new(Float64Array::from(vec![0.001f64])) as ArrayRef,
Arc::new(Float64Array::from(vec![15.5f64])) as ArrayRef,
Arc::new(Float64Array::from(vec![0.02f64])) as ArrayRef,
Arc::new(StringArray::from(vec!["G"])) as ArrayRef,
Arc::new(Float64Array::from(vec![60000.0f64])) as ArrayRef,
],
)
.unwrap();
let file = File::create(&path).unwrap();
let mut writer = ArrowWriter::try_new(file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
let uri = InputUri(format!("file://{}", path.display()));
let ds =
load_obs_sync(&uri, LoadObsArgs::default()).expect("should load from local parquet");
assert_eq!(ds.observation_count(), 1);
let obs: Vec<_> = ds.iter_observations().collect();
assert_eq!(*obs[0].id(), 42u64);
}
}