Skip to main content

rlx_bbo/
flow_map.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15//! Simplified flow-map policy: affine one-step map + diagonal CFM training (§3.1–3.2).
16
17use crate::Bbox;
18use serde::{Deserialize, Serialize};
19use std::path::Path;
20
21use crate::trajectory::{TrajectoryRecord, diagonal_flow_pairs, load_jsonl};
22
23/// One-step flow map X_{0,1}(a₀) = a₀ + W·a₀ + b (linear MVP analogue).
24#[derive(Clone, Debug, Serialize, Deserialize)]
25pub struct LinearFlowMap {
26    pub dim: usize,
27    pub velocity_weights: Vec<f64>,
28    pub velocity_bias: Vec<f64>,
29    pub topology: String,
30}
31
32impl LinearFlowMap {
33    pub fn one_step(&self, noise: &[f64]) -> Vec<f64> {
34        noise
35            .iter()
36            .enumerate()
37            .map(|(d, &a0)| {
38                a0 + self.velocity_bias.get(d).copied().unwrap_or(0.0)
39                    + a0 * self.velocity_weights.get(d).copied().unwrap_or(0.0)
40            })
41            .collect()
42    }
43
44    pub fn train_diagonal(records: &[TrajectoryRecord], topology: &str) -> Option<Self> {
45        let pairs = diagonal_flow_pairs(records);
46        if pairs.is_empty() {
47            let actions: Vec<_> = records
48                .iter()
49                .filter(|r| r.topology == topology)
50                .map(|r| r.action.clone())
51                .collect();
52            if actions.is_empty() {
53                return None;
54            }
55            let dim = actions[0].len();
56            return Some(Self {
57                dim,
58                velocity_weights: vec![0.0; dim],
59                velocity_bias: mean_action(&actions, dim),
60                topology: topology.to_string(),
61            });
62        }
63        let dim = pairs[0].0.len();
64        let mut vel_sum = vec![0.0; dim];
65        let mut count = 0usize;
66        for (_, v) in &pairs {
67            if v.len() != dim {
68                continue;
69            }
70            for d in 0..dim {
71                vel_sum[d] += v[d];
72            }
73            count += 1;
74        }
75        if count == 0 {
76            return None;
77        }
78        let velocity_bias: Vec<f64> = vel_sum.iter().map(|s| s / count as f64).collect();
79        Some(Self {
80            dim,
81            velocity_weights: vec![0.0; dim],
82            velocity_bias,
83            topology: topology.to_string(),
84        })
85    }
86}
87
88fn mean_action(actions: &[Vec<f64>], dim: usize) -> Vec<f64> {
89    let mut s = vec![0.0; dim];
90    let n = actions.len().max(1) as f64;
91    for a in actions {
92        for d in 0..dim.min(a.len()) {
93            s[d] += a[d];
94        }
95    }
96    s.iter().map(|x| x / n).collect()
97}
98
99/// Offline train from JSONL trajectories; returns flow map + training MSE.
100pub fn train_from_jsonl(
101    path: &Path,
102    topology: &str,
103) -> std::io::Result<Option<(LinearFlowMap, f64)>> {
104    let recs = load_jsonl(path)?;
105    let fm = LinearFlowMap::train_diagonal(&recs, topology);
106    let Some(fm) = fm else {
107        return Ok(None);
108    };
109    let pairs = diagonal_flow_pairs(&recs);
110    let mse = if pairs.is_empty() {
111        0.0
112    } else {
113        let mut err = 0.0;
114        let mut n = 0usize;
115        for (a1, v_star) in pairs {
116            if let Some(a0) = recs
117                .iter()
118                .find(|r| r.action == a1)
119                .and_then(|r| r.noise.clone())
120            {
121                let pred = fm.one_step(&a0);
122                for d in 0..v_star.len().min(pred.len()).min(a0.len()) {
123                    let v_pred = pred[d] - a0[d];
124                    err += (v_pred - v_star[d]).powi(2);
125                    n += 1;
126                }
127            }
128        }
129        if n > 0 { err / n as f64 } else { 0.0 }
130    };
131    Ok(Some((fm, mse)))
132}
133
134pub fn save_flow_map(path: &Path, fm: &LinearFlowMap) -> std::io::Result<()> {
135    let json = serde_json::to_string_pretty(fm)
136        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
137    std::fs::write(path, json)
138}
139
140pub fn load_flow_map(path: &Path) -> std::io::Result<LinearFlowMap> {
141    let text = std::fs::read_to_string(path)?;
142    serde_json::from_str(&text).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
143}
144
145/// FMQ-style online step using surrogate grad: x ← x − η ∇Q / ‖∇Q‖.
146pub fn fmq_surrogate_step(
147    x: &[f64],
148    x_ref: &[f64],
149    grad_q: &[f64],
150    bbox: &Bbox,
151    eta: f64,
152    kappa: f64,
153) -> Vec<f64> {
154    let _ = x_ref;
155    crate::trust_region_q_step(x, grad_q, bbox, eta, true, kappa)
156}