use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::unit_ref::EtlUnitRef;
use super::bindings::CodomainBinding;
use super::source_context::{SourceContext, SourceKey};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JoinColumn {
pub unit: EtlUnitRef,
pub binding: CodomainBinding,
}
impl JoinColumn {
pub fn new(unit: EtlUnitRef, binding: CodomainBinding) -> Self {
Self { unit, binding }
}
pub fn has_join_fill(&self) -> bool {
self.binding.join_null_fill.is_some()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum JoinKeys {
SubjectTime,
Subject,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct GroupSignalConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ttl_ms: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub sample_rate_ms: Option<i64>,
}
impl GroupSignalConfig {
pub fn new(ttl_ms: Option<i64>, sample_rate_ms: Option<i64>) -> Self {
Self {
ttl_ms,
sample_rate_ms,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SourceJoin {
pub right_source: Arc<SourceContext>,
pub keys: JoinKeys,
pub signal_config: GroupSignalConfig,
pub columns: Vec<JoinColumn>,
}
impl SourceJoin {
pub fn new(
right_source: Arc<SourceContext>,
keys: JoinKeys,
signal_config: GroupSignalConfig,
columns: Vec<JoinColumn>,
) -> Self {
Self {
right_source,
keys,
signal_config,
columns,
}
}
pub fn right_source_key(&self) -> SourceKey {
self.right_source.source_key
}
pub fn columns_with_join_fills(&self) -> impl Iterator<Item = &JoinColumn> {
self.columns.iter().filter(|c| c.has_join_fill())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct JoinPlan {
pub joins: Vec<SourceJoin>,
}
impl JoinPlan {
pub fn empty() -> Self {
Self::default()
}
pub fn op_count(&self) -> usize {
self.joins.len()
}
pub fn column_count(&self) -> usize {
self.joins.iter().map(|j| j.columns.len()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::plan::bindings::{CodomainBinding, ColumnBinding};
use crate::plan::source_context::SourceMember;
use crate::unit::NullValue;
use crate::universe::measurement_storage::DataSourceName;
fn scada_source() -> Arc<SourceContext> {
Arc::new(SourceContext {
source_name: DataSourceName::new("scada"),
source_key: SourceKey::from_raw(0xA1),
subject: ColumnBinding::new("station_id", "station_name"),
time: ColumnBinding::new("obs_time", "timestamp"),
components: vec![],
members: vec![
SourceMember::new(
EtlUnitRef::measurement("sump"),
CodomainBinding::new("sump_reading", "sump"),
),
SourceMember::new(
EtlUnitRef::measurement("discharge"),
CodomainBinding::new("discharge_reading", "discharge"),
),
],
})
}
fn mrms_source() -> Arc<SourceContext> {
Arc::new(SourceContext {
source_name: DataSourceName::new("mrms"),
source_key: SourceKey::from_raw(0xB2),
subject: ColumnBinding::identity("station_name"),
time: ColumnBinding::identity("timestamp"),
components: vec![],
members: vec![SourceMember::new(
EtlUnitRef::measurement("historical_precip"),
CodomainBinding::new("value_mm", "historical_precip")
.with_join_null_fill(NullValue::Float(0.0)),
)],
})
}
fn quality_source() -> Arc<SourceContext> {
Arc::new(SourceContext {
source_name: DataSourceName::new("scada_qualities"),
source_key: SourceKey::from_raw(0xC3),
subject: ColumnBinding::identity("station_name"),
time: ColumnBinding::identity("timestamp"),
components: vec![],
members: vec![SourceMember::new(
EtlUnitRef::quality("station_label"),
CodomainBinding::new("display_name", "station_label"),
)],
})
}
#[test]
fn join_column_has_join_fill_reflects_binding() {
let no_fill = JoinColumn::new(
EtlUnitRef::measurement("sump"),
CodomainBinding::new("sump_reading", "sump"),
);
assert!(!no_fill.has_join_fill());
let with_fill = JoinColumn::new(
EtlUnitRef::measurement("engines_on_count"),
CodomainBinding::new("engine_on", "engines_on_count")
.with_join_null_fill(NullValue::Integer(0)),
);
assert!(with_fill.has_join_fill());
}
#[test]
fn one_source_join_carries_all_its_columns() {
let scada = scada_source();
let key = scada.source_key;
let join = SourceJoin::new(
scada,
JoinKeys::SubjectTime,
GroupSignalConfig::new(None, None),
vec![
JoinColumn::new(
EtlUnitRef::measurement("sump"),
CodomainBinding::new("sump_reading", "sump"),
),
JoinColumn::new(
EtlUnitRef::measurement("discharge"),
CodomainBinding::new("discharge_reading", "discharge"),
),
],
);
assert_eq!(join.right_source_key(), key);
assert_eq!(join.columns.len(), 2);
assert_eq!(join.keys, JoinKeys::SubjectTime);
}
#[test]
fn columns_with_join_fills_filters_correctly() {
let join = SourceJoin::new(
mrms_source(),
JoinKeys::SubjectTime,
GroupSignalConfig::new(None, None),
vec![
JoinColumn::new(
EtlUnitRef::measurement("historical_precip"),
CodomainBinding::new("value_mm", "historical_precip")
.with_join_null_fill(NullValue::Float(0.0)),
),
JoinColumn::new(
EtlUnitRef::measurement("dummy"),
CodomainBinding::new("d", "dummy"),
),
],
);
let with_fills: Vec<&JoinColumn> = join.columns_with_join_fills().collect();
assert_eq!(with_fills.len(), 1);
assert_eq!(with_fills[0].unit.as_str(), "historical_precip");
}
#[test]
fn quality_join_uses_subject_only_keys() {
let join = SourceJoin::new(
quality_source(),
JoinKeys::Subject,
GroupSignalConfig::new(None, None),
vec![JoinColumn::new(
EtlUnitRef::quality("station_label"),
CodomainBinding::new("display_name", "station_label"),
)],
);
assert_eq!(join.keys, JoinKeys::Subject);
}
#[test]
fn empty_plan_counts() {
let plan = JoinPlan::empty();
assert_eq!(plan.op_count(), 0);
assert_eq!(plan.column_count(), 0);
}
#[test]
fn op_count_vs_column_count_is_the_optimization_metric() {
let plan = JoinPlan {
joins: vec![
SourceJoin::new(
scada_source(),
JoinKeys::SubjectTime,
GroupSignalConfig::new(None, None),
vec![
JoinColumn::new(
EtlUnitRef::measurement("sump"),
CodomainBinding::new("sump_reading", "sump"),
),
JoinColumn::new(
EtlUnitRef::measurement("discharge"),
CodomainBinding::new("discharge_reading", "discharge"),
),
],
),
SourceJoin::new(
mrms_source(),
JoinKeys::SubjectTime,
GroupSignalConfig::new(None, None),
vec![JoinColumn::new(
EtlUnitRef::measurement("historical_precip"),
CodomainBinding::new("value_mm", "historical_precip"),
)],
),
],
};
assert_eq!(plan.op_count(), 2);
assert_eq!(plan.column_count(), 3);
}
#[test]
fn serde_roundtrip_join_plan() {
let plan = JoinPlan {
joins: vec![SourceJoin::new(
scada_source(),
JoinKeys::SubjectTime,
GroupSignalConfig::new(None, None),
vec![JoinColumn::new(
EtlUnitRef::measurement("sump"),
CodomainBinding::new("sump_reading", "sump"),
)],
)],
};
let json = serde_json::to_string(&plan).unwrap();
let back: JoinPlan = serde_json::from_str(&json).unwrap();
assert_eq!(back.op_count(), 1);
assert_eq!(back.column_count(), 1);
}
}