use swarmkit::{Evolution, Particle};
use swarmkit_sailing::{Floats, Path, PathXY, Time, Topology};
use crate::config::topology_serde;
use crate::route::{RouteEvolution, WaypointCount};
use crate::waypoint_match;
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
pub struct SavedSolution {
pub n: usize,
pub xs: Vec<f64>,
pub ys: Vec<f64>,
pub ts: Vec<f64>,
pub best_fit: f64,
pub time_weight: f64,
pub fuel_weight: f64,
pub particles_space: usize,
pub particles_time: usize,
pub iter_space: usize,
pub iter_time: usize,
#[serde(default)]
pub seed: Option<u64>,
#[serde(default, with = "topology_serde")]
pub topology: Topology,
#[serde(default = "default_path_kick_probability")]
pub path_kick_probability: f64,
#[serde(default = "default_path_kick_gamma_0_fraction")]
pub path_kick_gamma_0_fraction: f64,
#[serde(default = "default_path_kick_gamma_min_fraction")]
pub path_kick_gamma_min_fraction: f64,
}
fn default_path_kick_probability() -> f64 {
crate::config::SearchConfig::default().path_kick_probability
}
fn default_path_kick_gamma_0_fraction() -> f64 {
crate::config::SearchConfig::default().path_kick_gamma_0_fraction
}
fn default_path_kick_gamma_min_fraction() -> f64 {
crate::config::SearchConfig::default().path_kick_gamma_min_fraction
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn old_format_round_trips_with_defaults() {
let json = r#"{
"n": 5,
"xs": [0.0, 1.0, 2.0, 3.0, 4.0],
"ys": [0.0, 0.5, 1.0, 1.5, 2.0],
"ts": [0.0, 100.0, 200.0, 300.0, 400.0],
"best_fit": -1234.5,
"time_weight": 1.0,
"fuel_weight": 10.0,
"particles_space": 40,
"particles_time": 40,
"iter_space": 40,
"iter_time": 30
}"#;
let saved: SavedSolution = serde_json::from_str(json).expect("old format must parse");
assert_eq!(saved.n, 5);
assert_eq!(saved.seed, None);
assert_eq!(saved.topology, Topology::default());
let cfg_default = crate::config::SearchConfig::default();
assert!((saved.path_kick_probability - cfg_default.path_kick_probability).abs() < 1e-12);
assert!(
(saved.path_kick_gamma_0_fraction - cfg_default.path_kick_gamma_0_fraction).abs()
< 1e-12,
);
assert!(
(saved.path_kick_gamma_min_fraction - cfg_default.path_kick_gamma_min_fraction).abs()
< 1e-12,
);
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum LoadError {
UnsupportedWaypointCount(usize),
LengthMismatch {
n: usize,
xs: usize,
ys: usize,
ts: usize,
},
PathConversion(usize),
}
impl std::fmt::Display for LoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnsupportedWaypointCount(n) => {
write!(f, "Unsupported waypoint count {n} in saved solution")
}
Self::LengthMismatch { n, xs, ys, ts } => write!(
f,
"Saved solution arrays do not match n={n}: xs={xs}, ys={ys}, ts={ts}",
),
Self::PathConversion(n) => write!(
f,
"Internal error converting saved arrays of length {n} to fixed-size path",
),
}
}
}
impl std::error::Error for LoadError {}
impl SavedSolution {
pub fn to_route_evolution(&self) -> Result<(WaypointCount, RouteEvolution), LoadError> {
let wc =
WaypointCount::from_usize(self.n).ok_or(LoadError::UnsupportedWaypointCount(self.n))?;
if self.xs.len() != self.n || self.ys.len() != self.n || self.ts.len() != self.n {
return Err(LoadError::LengthMismatch {
n: self.n,
xs: self.xs.len(),
ys: self.ys.len(),
ts: self.ts.len(),
});
}
let route_evolution = waypoint_match!(wc, N, wrap, {
let (Ok(xs_arr), Ok(ys_arr), Ok(ts_arr)) = (
<[f64; N]>::try_from(self.xs.as_slice()),
<[f64; N]>::try_from(self.ys.as_slice()),
<[f64; N]>::try_from(self.ts.as_slice()),
) else {
return Err(LoadError::PathConversion(self.n));
};
let path = Path::<N> {
xy: PathXY(Floats(xs_arr), Floats(ys_arr)),
t: Time(Floats(ts_arr)),
};
let particle = Particle::<Path<N>> {
pos: path,
vel: Path::default(),
fit: self.best_fit,
best_pos: path,
best_fit: self.best_fit,
};
wrap(Evolution::from_frames(vec![vec![particle]]))
});
Ok((wc, route_evolution))
}
}