use crate::warm_start::key::{Fingerprint, Fingerprinter};
use serde::{Deserialize, Serialize};
pub(crate) const FIT_ARTIFACT_SCHEMA: u32 = 1;
pub(crate) const RHO_SATURATION: f64 = 9.0;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum TermRole {
Mean,
LogSlope,
Generic,
}
impl TermRole {
fn discriminant(self) -> u8 {
match self {
TermRole::Mean => 0,
TermRole::LogSlope => 1,
TermRole::Generic => 2,
}
}
pub fn from_block_name(name: &str) -> TermRole {
let lower = name.to_ascii_lowercase();
if lower.contains("logslope")
|| lower.contains("log_slope")
|| lower.contains("scale")
|| lower.contains("sigma")
|| lower.contains("dispersion")
|| lower.contains("disp")
{
TermRole::LogSlope
} else if lower.contains("mean") || lower.contains("loc") || lower.contains("marginal") {
TermRole::Mean
} else {
TermRole::Generic
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TermIdentityKey(pub Fingerprint);
pub fn term_identity_from_block(
role: TermRole,
block_name: &str,
precision_labels: &[Option<String>],
nullspace_dims: &[usize],
reduced_width: usize,
) -> TermIdentityKey {
let mut fp = Fingerprinter::new();
fp.absorb_tag(b"fit-artifact-block-identity-v2");
fp.absorb_u64(b"role", u64::from(role.discriminant()));
fp.absorb_str(b"block_name", block_name);
fp.absorb_u64(b"n_penalties", precision_labels.len() as u64);
for label in precision_labels {
match label {
Some(l) => fp.absorb_str(b"label", l),
None => fp.absorb_tag(b"label-none"),
}
}
fp.absorb_u64(b"n_nullspace", nullspace_dims.len() as u64);
for d in nullspace_dims {
fp.absorb_u64(b"nullspace_dim", *d as u64);
}
fp.absorb_u64(b"reduced_width", reduced_width as u64);
TermIdentityKey(fp.finalize())
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ResponseSig {
pub family_kind: String,
pub n_response_channels: usize,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct RowPopulationTag {
pub n_rows: usize,
pub label: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FitDescriptor {
pub family_kind: String,
pub term_identities: Vec<TermIdentityKey>,
pub response_signature: ResponseSig,
pub row_population: Option<RowPopulationTag>,
}
impl FitDescriptor {
pub fn descriptor_key(&self) -> Fingerprint {
let mut fp = Fingerprinter::new();
fp.absorb_tag(b"fit-artifact-descriptor-v1");
fp.absorb_str(b"family_kind", &self.family_kind);
let mut keys: Vec<[u8; 32]> = self
.term_identities
.iter()
.map(|k| *k.0.as_bytes())
.collect();
keys.sort_unstable();
fp.absorb_u64(b"n_terms", keys.len() as u64);
for k in &keys {
fp.absorb_bytes(b"term", k);
}
fp.finalize()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TermArtifact {
pub identity: TermIdentityKey,
pub role: TermRole,
pub basis_meta: SerializableBasisMeta,
pub joint_null_rotation: Option<SerializableMatrix>,
pub raw_beta: Vec<f64>,
pub rho_for_term: Vec<f64>,
}
impl TermArtifact {
pub fn is_finite(&self) -> bool {
self.raw_beta.iter().all(|v| v.is_finite())
&& self.rho_for_term.iter().all(|v| v.is_finite())
&& self
.joint_null_rotation
.as_ref()
.is_none_or(|m| m.data.iter().all(|v| v.is_finite()))
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SerializableMatrix {
pub nrows: usize,
pub ncols: usize,
pub data: Vec<f64>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct SerializableBasisMeta {
pub kind: String,
pub degree: Option<u64>,
pub num_knots: Option<u64>,
pub n_centers: Option<u64>,
pub nullspace_order: Option<u64>,
pub matern_nu: Option<u64>,
pub periodic: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GlobalFitSummary {
pub outer_objective: f64,
pub converged: bool,
pub n_rows: usize,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum TransferProvenance {
Projected,
RhoOnly,
Cold,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FitArtifact {
pub schema: u32,
pub created_unix_secs: u64,
pub descriptor: FitDescriptor,
pub terms: Vec<TermArtifact>,
pub global: GlobalFitSummary,
}
impl FitArtifact {
pub fn is_usable(&self) -> bool {
self.schema == FIT_ARTIFACT_SCHEMA
&& self.global.outer_objective.is_finite()
&& self.terms.iter().all(TermArtifact::is_finite)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn block_id(role: TermRole, block_name: &str) -> TermIdentityKey {
term_identity_from_block(role, block_name, &[None], &[1], 10)
}
fn basis_meta_stub(n_centers: u64) -> SerializableBasisMeta {
SerializableBasisMeta {
kind: "block-spec".to_string(),
degree: None,
num_knots: None,
n_centers: Some(n_centers),
nullspace_order: None,
matern_nu: None,
periodic: false,
}
}
#[test]
fn block_identity_splits_on_block_name() {
let ka = block_id(TermRole::Mean, "s(x)");
let kb = block_id(TermRole::Mean, "s(z)");
assert_ne!(ka, kb, "different block name must split identity");
}
#[test]
fn block_identity_splits_on_role() {
let mean = block_id(TermRole::Mean, "s(x)");
let slope = block_id(TermRole::LogSlope, "s(x)");
assert_ne!(mean, slope, "different role must split identity");
}
#[test]
fn block_identity_splits_on_penalty_structure() {
let one = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 10);
let two = term_identity_from_block(TermRole::Mean, "s(x)", &[None, None], &[1], 10);
assert_ne!(one, two, "different #penalties must split identity");
}
#[test]
fn block_identity_splits_on_reduced_width() {
let wide = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 45);
let narrow = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 21);
assert_ne!(
wide, narrow,
"different realized reduced width must split identity"
);
}
#[test]
fn block_identity_matches_across_folds_at_equal_width() {
let fold_a = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 45);
let fold_b = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 45);
assert_eq!(
fold_a, fold_b,
"same model at equal width must share identity across folds"
);
}
#[test]
fn descriptor_key_excludes_rows_and_response() {
let id = block_id(TermRole::Mean, "s(x)");
let full = FitDescriptor {
family_kind: "gaussian".to_string(),
term_identities: vec![id],
response_signature: ResponseSig {
family_kind: "gaussian".to_string(),
n_response_channels: 1,
},
row_population: Some(RowPopulationTag {
n_rows: 1000,
label: Some("full".to_string()),
}),
};
let fold = FitDescriptor {
family_kind: "gaussian".to_string(),
term_identities: vec![id],
response_signature: ResponseSig {
family_kind: "gaussian".to_string(),
n_response_channels: 1,
},
row_population: Some(RowPopulationTag {
n_rows: 900, label: Some("fold-3".to_string()),
}),
};
assert_eq!(
full.descriptor_key(),
fold.descriptor_key(),
"LOSO fold must share its full-data parent's descriptor key"
);
}
#[test]
fn descriptor_key_invariant_to_term_order() {
let a = block_id(TermRole::Mean, "s(x)");
let b = block_id(TermRole::Mean, "s(z)");
let sig = ResponseSig {
family_kind: "gaussian".to_string(),
n_response_channels: 1,
};
let d1 = FitDescriptor {
family_kind: "gaussian".to_string(),
term_identities: vec![a, b],
response_signature: sig.clone(),
row_population: None,
};
let d2 = FitDescriptor {
family_kind: "gaussian".to_string(),
term_identities: vec![b, a],
response_signature: sig,
row_population: None,
};
assert_eq!(d1.descriptor_key(), d2.descriptor_key());
}
#[test]
fn artifact_usable_guard_rejects_nonfinite() {
let id = block_id(TermRole::Mean, "s(x)");
let mut artifact = FitArtifact {
schema: FIT_ARTIFACT_SCHEMA,
created_unix_secs: 0,
descriptor: FitDescriptor {
family_kind: "gaussian".to_string(),
term_identities: vec![id],
response_signature: ResponseSig {
family_kind: "gaussian".to_string(),
n_response_channels: 1,
},
row_population: None,
},
terms: vec![TermArtifact {
identity: id,
role: TermRole::Mean,
basis_meta: basis_meta_stub(4),
joint_null_rotation: None,
raw_beta: vec![0.1, 0.2, 0.3, 0.4],
rho_for_term: vec![1.0],
}],
global: GlobalFitSummary {
outer_objective: -123.4,
converged: true,
n_rows: 100,
},
};
assert!(artifact.is_usable());
artifact.terms[0].raw_beta[2] = f64::NAN;
assert!(
!artifact.is_usable(),
"non-finite β must fail the usable guard"
);
artifact.terms[0].raw_beta[2] = 0.3;
artifact.global.outer_objective = f64::INFINITY;
assert!(
!artifact.is_usable(),
"non-finite objective must fail the usable guard"
);
}
#[test]
fn serializable_basis_meta_roundtrips() {
let meta = basis_meta_stub(7);
let bytes = serde_json::to_vec(&meta).expect("serialize");
let back: SerializableBasisMeta = serde_json::from_slice(&bytes).expect("deserialize");
assert_eq!(meta, back);
assert_eq!(back.n_centers, Some(7));
assert_eq!(back.kind, "block-spec");
}
#[test]
fn serializable_matrix_can_carry_rotation() {
let q = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
let m = SerializableMatrix {
nrows: q.nrows(),
ncols: q.ncols(),
data: q.iter().copied().collect(),
};
let bytes = serde_json::to_vec(&m).expect("serialize");
let back: SerializableMatrix = serde_json::from_slice(&bytes).expect("deserialize");
assert_eq!(back.nrows, 2);
assert_eq!(back.data, vec![1.0, 0.0, 0.0, 1.0]);
}
}