1use crate::Bbox;
18use serde::{Deserialize, Serialize};
19use std::path::Path;
20
21use crate::trajectory::{TrajectoryRecord, diagonal_flow_pairs, load_jsonl};
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
25pub struct LinearFlowMap {
26 pub dim: usize,
27 pub velocity_weights: Vec<f64>,
28 pub velocity_bias: Vec<f64>,
29 pub topology: String,
30}
31
32impl LinearFlowMap {
33 pub fn one_step(&self, noise: &[f64]) -> Vec<f64> {
34 noise
35 .iter()
36 .enumerate()
37 .map(|(d, &a0)| {
38 a0 + self.velocity_bias.get(d).copied().unwrap_or(0.0)
39 + a0 * self.velocity_weights.get(d).copied().unwrap_or(0.0)
40 })
41 .collect()
42 }
43
44 pub fn train_diagonal(records: &[TrajectoryRecord], topology: &str) -> Option<Self> {
45 let pairs = diagonal_flow_pairs(records);
46 if pairs.is_empty() {
47 let actions: Vec<_> = records
48 .iter()
49 .filter(|r| r.topology == topology)
50 .map(|r| r.action.clone())
51 .collect();
52 if actions.is_empty() {
53 return None;
54 }
55 let dim = actions[0].len();
56 return Some(Self {
57 dim,
58 velocity_weights: vec![0.0; dim],
59 velocity_bias: mean_action(&actions, dim),
60 topology: topology.to_string(),
61 });
62 }
63 let dim = pairs[0].0.len();
64 let mut vel_sum = vec![0.0; dim];
65 let mut count = 0usize;
66 for (_, v) in &pairs {
67 if v.len() != dim {
68 continue;
69 }
70 for d in 0..dim {
71 vel_sum[d] += v[d];
72 }
73 count += 1;
74 }
75 if count == 0 {
76 return None;
77 }
78 let velocity_bias: Vec<f64> = vel_sum.iter().map(|s| s / count as f64).collect();
79 Some(Self {
80 dim,
81 velocity_weights: vec![0.0; dim],
82 velocity_bias,
83 topology: topology.to_string(),
84 })
85 }
86}
87
88fn mean_action(actions: &[Vec<f64>], dim: usize) -> Vec<f64> {
89 let mut s = vec![0.0; dim];
90 let n = actions.len().max(1) as f64;
91 for a in actions {
92 for d in 0..dim.min(a.len()) {
93 s[d] += a[d];
94 }
95 }
96 s.iter().map(|x| x / n).collect()
97}
98
99pub fn train_from_jsonl(
101 path: &Path,
102 topology: &str,
103) -> std::io::Result<Option<(LinearFlowMap, f64)>> {
104 let recs = load_jsonl(path)?;
105 let fm = LinearFlowMap::train_diagonal(&recs, topology);
106 let Some(fm) = fm else {
107 return Ok(None);
108 };
109 let pairs = diagonal_flow_pairs(&recs);
110 let mse = if pairs.is_empty() {
111 0.0
112 } else {
113 let mut err = 0.0;
114 let mut n = 0usize;
115 for (a1, v_star) in pairs {
116 if let Some(a0) = recs
117 .iter()
118 .find(|r| r.action == a1)
119 .and_then(|r| r.noise.clone())
120 {
121 let pred = fm.one_step(&a0);
122 for d in 0..v_star.len().min(pred.len()).min(a0.len()) {
123 let v_pred = pred[d] - a0[d];
124 err += (v_pred - v_star[d]).powi(2);
125 n += 1;
126 }
127 }
128 }
129 if n > 0 { err / n as f64 } else { 0.0 }
130 };
131 Ok(Some((fm, mse)))
132}
133
134pub fn save_flow_map(path: &Path, fm: &LinearFlowMap) -> std::io::Result<()> {
135 let json = serde_json::to_string_pretty(fm)
136 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
137 std::fs::write(path, json)
138}
139
140pub fn load_flow_map(path: &Path) -> std::io::Result<LinearFlowMap> {
141 let text = std::fs::read_to_string(path)?;
142 serde_json::from_str(&text).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
143}
144
145pub fn fmq_surrogate_step(
147 x: &[f64],
148 x_ref: &[f64],
149 grad_q: &[f64],
150 bbox: &Bbox,
151 eta: f64,
152 kappa: f64,
153) -> Vec<f64> {
154 let _ = x_ref;
155 crate::trust_region_q_step(x, grad_q, bbox, eta, true, kappa)
156}