use crate::{CustomFamilyError, ParameterBlockSpec};
use ndarray::Array2;
use std::collections::BTreeMap;
pub const CUSTOM_FAMILY_WEIGHT_FLOOR: f64 = crate::types::MIN_WEIGHT;
pub const CUSTOM_FAMILY_RIDGE_FLOOR: f64 = 1e-12;
pub fn validate_blockspec_consistency(
specs: &[ParameterBlockSpec],
) -> Result<Vec<usize>, String> {
let mut seen_names = BTreeMap::<String, usize>::new();
for (b, spec) in specs.iter().enumerate() {
if let Some(prev) = seen_names.insert(spec.name.clone(), b) {
return Err(CustomFamilyError::ConstraintViolation {
reason: format!(
"duplicate parameter block name '{}' at indices {prev} and {b}: block names must be unique so coefficient labels resolved by name are unambiguous",
spec.name
),
}
.into());
}
}
let mut penalty_counts = Vec::with_capacity(specs.len());
for (b, spec) in specs.iter().enumerate() {
let n = spec.design.nrows();
if spec.offset.len() != n {
return Err(CustomFamilyError::DimensionMismatch {
reason: format!(
"block {b} offset length mismatch: got {}, expected {}",
spec.offset.len(),
n
),
}
.into());
}
match (&spec.stacked_design, &spec.stacked_offset) {
(Some(sd), Some(so)) => {
if sd.nrows() != so.len() {
return Err(CustomFamilyError::DimensionMismatch {
reason: format!(
"block {b} stacked_design/stacked_offset row mismatch: \
stacked_design.nrows()={}, stacked_offset.len()={}",
sd.nrows(),
so.len(),
),
}
.into());
}
if sd.ncols() != spec.design.ncols() {
return Err(CustomFamilyError::DimensionMismatch {
reason: format!(
"block {b} stacked_design column count {} disagrees with \
design column count {}",
sd.ncols(),
spec.design.ncols(),
),
}
.into());
}
}
(None, None) => {}
(Some(_), None) | (None, Some(_)) => {
return Err(CustomFamilyError::ConstraintViolation {
reason: format!(
"block {b} stacked_design and stacked_offset must be Some together \
or both None"
),
}
.into());
}
}
let p = spec.design.ncols();
if let Some(beta0) = &spec.initial_beta
&& beta0.len() != p
{
return Err(CustomFamilyError::DimensionMismatch {
reason: format!(
"block {b} initial_beta length mismatch: got {}, expected {p}",
beta0.len()
),
}
.into());
}
if spec.initial_log_lambdas.len() != spec.penalties.len() {
return Err(CustomFamilyError::DimensionMismatch {
reason: format!(
"block {b} initial_log_lambdas length {} does not match penalties {}",
spec.initial_log_lambdas.len(),
spec.penalties.len()
),
}
.into());
}
for (k, s) in spec.penalties.iter().enumerate() {
let (r, c) = s.shape();
if r != p || c != p {
return Err(CustomFamilyError::DimensionMismatch {
reason: format!("block {b} penalty {k} must be {p}x{p}, got {r}x{c}"),
}
.into());
}
}
penalty_counts.push(spec.penalties.len());
}
Ok(penalty_counts)
}
pub struct ExactNewtonOuterCurvature {
pub hessian: Array2<f64>,
pub rho_curvature_scale: f64,
pub hessian_logdet_correction: f64,
}