use std::cell::RefCell;
use std::rc::Rc;
use nalgebra::DVector;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Stage {
Init,
Mds,
Final,
}
impl Stage {
pub fn as_str(self) -> &'static str {
match self {
Stage::Init => "init",
Stage::Mds => "mds",
Stage::Final => "final",
}
}
}
#[derive(Debug, Clone)]
pub struct RecordedFrame<S> {
pub stage: Stage,
pub iteration: u64,
pub cost: f64,
pub shapes: Vec<S>,
}
#[derive(Debug, Clone)]
pub struct FitRecording<S> {
pub frames: Vec<RecordedFrame<S>>,
pub set_names: Vec<String>,
}
pub(crate) struct RawFrame {
pub iteration: u64,
pub params: Vec<f64>,
}
pub(crate) struct FrameRecorder {
sink: Rc<RefCell<Vec<RawFrame>>>,
max_frames: usize,
}
impl FrameRecorder {
pub(crate) fn new(sink: Rc<RefCell<Vec<RawFrame>>>, max_frames: usize) -> Self {
Self { sink, max_frames }
}
fn push<S>(&self, state: &S)
where
S: basin::State<Param = DVector<f64>, Float = f64>,
{
let mut sink = self.sink.borrow_mut();
if sink.len() >= self.max_frames {
return;
}
sink.push(RawFrame {
iteration: state.iter(),
params: state.param().as_slice().to_vec(),
});
}
}
impl<S> basin::Observe<S> for FrameRecorder
where
S: basin::State<Param = DVector<f64>, Float = f64>,
{
fn observe_init(&mut self, state: &S) {
self.push(state);
}
fn observe_iter(&mut self, state: &S) {
self.push(state);
}
}
#[cfg(test)]
mod tests {
use super::Stage;
use crate::DiagramSpecBuilder;
use crate::Fitter;
use crate::geometry::shapes::{Circle, Ellipse};
fn three_set_spec() -> crate::spec::DiagramSpec {
DiagramSpecBuilder::new()
.set("A", 10.0)
.set("B", 9.0)
.set("C", 8.0)
.intersection(&["A", "B"], 3.0)
.intersection(&["A", "C"], 2.5)
.intersection(&["B", "C"], 2.0)
.intersection(&["A", "B", "C"], 1.0)
.build()
.unwrap()
}
#[test]
fn recording_walks_init_mds_final_in_order() {
let spec = three_set_spec();
let rec = Fitter::<Circle>::new(&spec)
.seed(42)
.fit_recording()
.unwrap();
assert!(!rec.frames.is_empty(), "expected recorded frames");
assert_eq!(rec.frames[0].stage, Stage::Init);
assert_eq!(
rec.frames.iter().filter(|f| f.stage == Stage::Init).count(),
1
);
assert!(rec.frames.iter().any(|f| f.stage == Stage::Mds));
assert!(rec.frames.iter().any(|f| f.stage == Stage::Final));
let rank = |s: Stage| match s {
Stage::Init => 0,
Stage::Mds => 1,
Stage::Final => 2,
};
let mut prev = 0;
for f in &rec.frames {
let r = rank(f.stage);
assert!(r >= prev, "stages out of order");
prev = r;
assert_eq!(f.shapes.len(), 3);
}
let final_costs: Vec<f64> = rec
.frames
.iter()
.filter(|f| f.stage == Stage::Final)
.map(|f| f.cost)
.collect();
assert!(final_costs.len() >= 2);
assert!(
*final_costs.last().unwrap() <= final_costs[0] + 1e-12,
"final stage cost should not increase overall"
);
}
#[test]
fn recording_is_deterministic_for_a_seed() {
let spec = three_set_spec();
let a = Fitter::<Circle>::new(&spec)
.seed(7)
.fit_recording()
.unwrap();
let b = Fitter::<Circle>::new(&spec)
.seed(7)
.fit_recording()
.unwrap();
assert_eq!(a.frames.len(), b.frames.len());
let fa = a.frames.last().unwrap();
let fb = b.frames.last().unwrap();
for (ca, cb) in fa.shapes.iter().zip(fb.shapes.iter()) {
assert_eq!(ca.center().x(), cb.center().x());
assert_eq!(ca.center().y(), cb.center().y());
assert_eq!(ca.radius(), cb.radius());
}
}
#[test]
fn recording_works_for_ellipses() {
let spec = three_set_spec();
let rec = Fitter::<Ellipse>::new(&spec)
.seed(1)
.fit_recording()
.unwrap();
assert!(rec.frames.iter().any(|f| f.stage == Stage::Final));
let init = &rec.frames[0];
for e in &init.shapes {
assert!((e.semi_major() - e.semi_minor()).abs() < 1e-9);
}
}
}