use crate::inference::row_measure::{RowMeasure, per_row_fisher_mass};
use crate::inference::row_metric::{MetricProvenance, RowMetric};
struct FisherTier {
rows: Vec<usize>,
inclusion: Vec<f64>,
metric: RowMetric,
}
pub struct TieredHarvest {
n_rows: usize,
fisher: Option<FisherTier>,
}
impl TieredHarvest {
pub fn activations_only(n_rows: usize) -> Self {
Self {
n_rows,
fisher: None,
}
}
pub fn with_designed_tier(
n_rows: usize,
tier_rows: Vec<usize>,
inclusion: Vec<f64>,
metric: RowMetric,
) -> Result<Self, String> {
if metric.n_rows() != tier_rows.len() {
return Err(format!(
"TieredHarvest: metric covers {} rows but the tier names {}",
metric.n_rows(),
tier_rows.len()
));
}
if inclusion.len() != tier_rows.len() {
return Err(format!(
"TieredHarvest: {} inclusion probabilities for {} tier rows",
inclusion.len(),
tier_rows.len()
));
}
for (t, &r) in tier_rows.iter().enumerate() {
if r >= n_rows {
return Err(format!(
"TieredHarvest: tier row {r} out of corpus range (n_rows = {n_rows})"
));
}
if t > 0 && tier_rows[t - 1] >= r {
return Err(
"TieredHarvest: tier rows must be strictly ascending (sorted, deduplicated)"
.to_string(),
);
}
}
for (t, &p) in inclusion.iter().enumerate() {
if !(p.is_finite() && p > 0.0 && p <= 1.0) {
return Err(format!(
"TieredHarvest: tier row {} has invalid inclusion probability {p}",
tier_rows[t]
));
}
}
Ok(Self {
n_rows,
fisher: Some(FisherTier {
rows: tier_rows,
inclusion,
metric,
}),
})
}
pub fn with_unweighted_tier(
n_rows: usize,
tier_rows: Vec<usize>,
metric: RowMetric,
) -> Result<Self, String> {
let inclusion = vec![1.0; tier_rows.len()];
Self::with_designed_tier(n_rows, tier_rows, inclusion, metric)
}
pub fn n_rows(&self) -> usize {
self.n_rows
}
pub fn has_fisher_tier(&self) -> bool {
self.fisher.is_some()
}
pub fn coverage(&self) -> f64 {
match (&self.fisher, self.n_rows) {
(Some(t), n) if n > 0 => t.rows.len() as f64 / n as f64,
_ => 0.0,
}
}
pub fn tier_rows(&self) -> &[usize] {
self.fisher.as_ref().map_or(&[], |t| &t.rows)
}
pub fn tier_metric(&self) -> Option<&RowMetric> {
self.fisher.as_ref().map(|t| &t.metric)
}
pub fn tier_provenance(&self) -> Option<MetricProvenance> {
self.fisher.as_ref().map(|t| t.metric.provenance())
}
pub fn tier_row_for(&self, corpus_row: usize) -> Option<usize> {
let tier = self.fisher.as_ref()?;
tier.rows.binary_search(&corpus_row).ok()
}
pub fn has_factors(&self, corpus_row: usize) -> bool {
self.tier_row_for(corpus_row).is_some()
}
pub fn corpus_measure(&self) -> RowMeasure {
let Some(tier) = self.fisher.as_ref() else {
return RowMeasure::uniform(self.n_rows);
};
if self.n_rows == 0 {
return RowMeasure::uniform(0);
}
let tier_mass = per_row_fisher_mass(&tier.metric);
let mut corrected = vec![0.0_f64; tier.rows.len()];
let mut total = 0.0_f64;
let mut usable = true;
for (t, &m) in tier_mass.iter().enumerate() {
if !m.is_finite() {
usable = false;
break;
}
let v = if m > 0.0 { m / tier.inclusion[t] } else { 0.0 };
corrected[t] = v;
total += v;
}
if !usable || !(total > 0.0) {
return RowMeasure::uniform(self.n_rows);
}
let mean = total / tier.rows.len() as f64;
let mut masses = vec![mean; self.n_rows];
for (t, &r) in tier.rows.iter().enumerate() {
masses[r] = corrected[t];
}
RowMeasure::from_masses(tier.metric.provenance(), masses)
}
pub fn plan_next_tier(
&self,
budget: usize,
seed: u64,
) -> crate::inference::row_measure::DesignedRowSample {
self.corpus_measure().designed_subsample(budget, seed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::inference::row_measure::MeasureProvenance;
use ndarray::Array2;
use std::sync::Arc;
fn tier_metric(masses: &[f64]) -> RowMetric {
let n = masses.len();
let mut u = Array2::<f64>::zeros((n, 1));
for (i, &m) in masses.iter().enumerate() {
u[[i, 0]] = m.sqrt();
}
RowMetric::output_fisher(Arc::new(u), 1, 1).expect("tier metric")
}
#[test]
fn activations_only_degrades_everywhere() {
let h = TieredHarvest::activations_only(10);
assert!(!h.has_fisher_tier());
assert_eq!(h.coverage(), 0.0);
assert!(h.tier_metric().is_none());
assert!(h.tier_provenance().is_none());
assert!(!h.has_factors(3));
let m = h.corpus_measure();
assert_eq!(m.provenance(), MeasureProvenance::Uniform);
assert_eq!(m.n_rows(), 10);
}
#[test]
fn tier_mapping_and_coverage() {
let metric = tier_metric(&[1.0, 4.0, 1.0]);
let h = TieredHarvest::with_unweighted_tier(10, vec![2, 5, 9], metric).expect("harvest");
assert!(h.has_fisher_tier());
assert!((h.coverage() - 0.3).abs() < 1e-12);
assert_eq!(h.tier_row_for(5), Some(1));
assert_eq!(h.tier_row_for(4), None);
assert!(h.has_factors(9));
assert!(!h.has_factors(0));
assert_eq!(
h.tier_provenance(),
Some(h.tier_metric().unwrap().provenance())
);
}
#[test]
fn lifted_measure_imputes_mean_mass_off_tier() {
let metric = tier_metric(&[1.0, 9.0]);
let h = TieredHarvest::with_unweighted_tier(4, vec![2, 3], metric).expect("harvest");
let m = h.corpus_measure();
assert!(m.is_enriched());
let w = m.weights();
assert!((w[0] - 0.25).abs() < 1e-12);
assert!((w[2] - 0.05).abs() < 1e-12);
assert!((w[3] - 0.45).abs() < 1e-12);
}
#[test]
fn inclusion_correction_undoes_design_bias() {
let metric = tier_metric(&[4.0, 4.0]);
let h = TieredHarvest::with_designed_tier(2, vec![0, 1], vec![0.5, 1.0], metric)
.expect("harvest");
let m = h.corpus_measure();
let w = m.weights();
assert!(
(w[0] - 2.0 * w[1]).abs() < 1e-12,
"HT lift must double the half-inclusion row: {w:?}"
);
}
#[test]
fn flat_tier_collapses_to_uniform_attention() {
let metric = tier_metric(&[2.0, 2.0]);
let h = TieredHarvest::with_unweighted_tier(6, vec![1, 4], metric).expect("harvest");
let m = h.corpus_measure();
let w = m.weights();
for &x in w {
assert!((x - 1.0 / 6.0).abs() < 1e-12, "flat tier must lift uniform");
}
}
#[test]
fn validation_rejects_malformed_tiers() {
let metric = tier_metric(&[1.0, 2.0]);
assert!(
TieredHarvest::with_unweighted_tier(5, vec![3, 1], tier_metric(&[1.0, 2.0])).is_err()
);
assert!(
TieredHarvest::with_unweighted_tier(3, vec![1, 3], tier_metric(&[1.0, 2.0])).is_err()
);
assert!(TieredHarvest::with_unweighted_tier(5, vec![0, 1, 2], metric).is_err());
assert!(
TieredHarvest::with_designed_tier(
5,
vec![0, 1],
vec![0.0, 1.0],
tier_metric(&[1.0, 2.0])
)
.is_err()
);
}
#[test]
fn plan_next_tier_cold_start_is_uniform_design() {
let h = TieredHarvest::activations_only(50);
let plan = h.plan_next_tier(10, 7);
assert_eq!(plan.provenance, MeasureProvenance::Uniform);
assert_eq!(plan.len(), 10);
let metric = tier_metric(&[1.0, 100.0]);
let h2 = TieredHarvest::with_unweighted_tier(50, vec![10, 20], metric).expect("harvest");
let plan2 = h2.plan_next_tier(10, 7);
assert!(
plan2.rows.contains(&20),
"the loud previously-harvested row must be re-designed in"
);
}
}