reservoir_train/
codegen.rs

1use crate::float::RealScalar;
2use reservoir_core::types::Scalar;
3use reservoir_infer::{EchoStateNetwork, LassoReadout, RidgeReadout, SparseReservoir};
4use std::fmt::Write;
5
6/// Static model code generation utilities.
7///
8/// This module provides helpers to **export a trained model** (currently: sparse ESN)
9/// into **Rust source code** that defines the model weights as `const` arrays.
10///
11/// The generated code is intended for **embedded / `no_std` inference** use cases,
12/// where dynamic allocation and file I/O are undesirable:
13///
14/// - CSR matrices (`W_in`, `W_res`) are emitted as `u16` index arrays plus `f32` values.
15/// - The readout matrix (`W_out`) is emitted as a flattened `f32` array.
16/// - The initial reservoir state is emitted as a flattened `f32` array.
17///
18/// # Output format
19/// `generate_sparse_code` returns a single `String` containing Rust code with:
20/// - `INPUT_DIM`, `RESERVOIR_SIZE`, `OUTPUT_DIM`, `EXTENDED_SIZE`
21/// - `LEAKING_RATE`
22/// - CSR metadata (`*_NROWS`, `*_NCOLS`, `*_NNZ`)
23/// - CSR arrays: `*_ROW_PTR`, `*_COL_IDX`, `*_VALUES`
24/// - `W_OUT_DATA` and `INITIAL_STATE_DATA`
25///
26/// The generated snippet includes `use` statements for `reservoir-infer` static
27/// reservoirs/readouts and `nalgebra` fixed-size types, so it can be pasted into
28/// a target crate/module with minimal editing.
29///
30/// # Limits / invariants
31/// - CSR indices are emitted as `u16`. Therefore the number of non-zeros (NNZ)
32///   must fit into `u16::MAX` for both `W_in` and `W_res`.
33/// - All scalar values are formatted as `f32` with 8 decimal digits.
34///   This is a deliberate tradeoff for portability and code size.
35/// - The extended state layout assumed here matches `reservoir-infer` reservoirs:
36///   `[bias(1), input(input_dim), reservoir_state(reservoir_size)]`.
37///
38/// # Feature gating
39/// This code generator uses `std::fmt::Write` and returns an owned `String`,
40/// so it is typically compiled behind the `std` feature of `reservoir-train`.
41///
42/// # Example
43/// ```no_run
44/// # use reservoir_train::{ESNBuilder, ESNFitRidge};
45/// # use reservoir_train::StaticModelGenerator;
46/// // Train a sparse ESN (ridge readout).
47/// let mut esn = ESNBuilder::<f32>::new(1, 1)
48///     .units(200)
49///     .connectivity(8)
50///     .input_connectivity(1)
51///     .spectral_radius(1.2)
52///     .leaking_rate(0.8)
53///     .seed(42)
54///     .build_sparse();
55///
56/// // Fit (dummy example; supply your actual data here).
57/// // esn.fit(&inputs, &targets, 1e-6, 50);
58///
59/// // Export as Rust code (paste into your embedded inference crate).
60/// let code = StaticModelGenerator::generate_sparse_code(&esn).unwrap();
61/// println!("{}", code);
62/// ```
63pub struct StaticModelGenerator;
64
65impl StaticModelGenerator {
66    /// Generate Rust source code for a trained **sparse** Echo State Network.
67    ///
68    /// This function inspects the provided ESN and serializes its parameters into
69    /// Rust `const` definitions suitable for `no_std` inference:
70    ///
71    /// - `W_in` and `W_res` (reservoir matrices) are emitted in CSR form, using `u16`
72    ///   indices (`ROW_PTR`, `COL_IDX`) and `f32` values (`VALUES`).
73    /// - The readout weight matrix `W_out` is emitted as a flattened `f32` array.
74    /// - The initial reservoir state is emitted as a flattened `f32` array.
75    ///
76    /// The output includes enough metadata (dimensions / NNZ counts) to validate
77    /// the arrays at compile time and to reconstruct the static reservoir/readout
78    /// types in a downstream crate.
79    ///
80    /// # Type parameters
81    /// - `S`: training scalar type. Must be convertible/printable (`RealScalar + Display`)
82    ///   because values are formatted into source code.
83    /// - `O`: readout type. Must implement `reservoir_core::Readout<S>` and [`GetWeights`]
84    ///   so this generator can access the dense output weight matrix.
85    ///
86    /// # Errors
87    /// Returns `Err(String)` if the number of non-zeros (NNZ) in `W_in` or `W_res`
88    /// exceeds `u16::MAX`, because the generated CSR arrays use `u16` indices.
89    ///
90    /// # Notes
91    /// - All emitted numeric values are formatted as `f32` with 8 decimals.
92    /// - This function does not write files; it returns a `String` to give callers
93    ///   full control over where the generated code is stored.
94    pub fn generate_sparse_code<S, O>(
95        esn: &EchoStateNetwork<S, SparseReservoir<S>, O>,
96    ) -> Result<String, String>
97    where
98        S: RealScalar + std::fmt::Display,
99        O: reservoir_core::Readout<S> + GetWeights<S>,
100    {
101        let w_in = &esn.reservoir.w_in;
102        let w_res = &esn.reservoir.w;
103        let w_out = esn.readout.weights();
104
105        let input_dim = esn.reservoir.input_dim;
106        let reservoir_size = esn.reservoir.res_state.len();
107        let output_dim = esn.readout.output_dim();
108        let ext_size = 1 + input_dim + reservoir_size;
109        let leaking_rate = esn.reservoir.leaking_rate;
110        let initial_state = &esn.reservoir.res_state;
111
112        if w_in.values.len() > u16::MAX as usize {
113            return Err(format!("W_in NNZ ({}) exceeds u16::MAX", w_in.values.len()));
114        }
115        if w_res.values.len() > u16::MAX as usize {
116            return Err(format!(
117                "W_res NNZ ({}) exceeds u16::MAX",
118                w_res.values.len()
119            ));
120        }
121
122        let mut code = String::new();
123
124        writeln!(code, "// Auto-generated by reservoir-train::codegen").unwrap();
125        writeln!(code, "use reservoir_infer::reservoir::static_sparse_reservoir::{{StaticCsrMatrix, StaticSparseReservoir}};").unwrap();
126        writeln!(
127            code,
128            "use reservoir_infer::readout::static_readout::StaticReadout;"
129        )
130        .unwrap();
131        writeln!(code, "use nalgebra::{{SMatrix, SVector}};").unwrap();
132        writeln!(code).unwrap();
133
134        writeln!(code, "pub const INPUT_DIM: usize = {};", input_dim).unwrap();
135        writeln!(
136            code,
137            "pub const RESERVOIR_SIZE: usize = {};",
138            reservoir_size
139        )
140        .unwrap();
141        writeln!(code, "pub const OUTPUT_DIM: usize = {};", output_dim).unwrap();
142        writeln!(code, "pub const EXTENDED_SIZE: usize = {};", ext_size).unwrap();
143        writeln!(code, "pub const LEAKING_RATE: f32 = {:.8};", leaking_rate).unwrap();
144        writeln!(code).unwrap();
145
146        writeln!(code, "pub const W_IN_NROWS: usize = {};", w_in.nrows).unwrap();
147        writeln!(code, "pub const W_IN_NCOLS: usize = {};", w_in.ncols).unwrap();
148        writeln!(code, "pub const W_RES_NROWS: usize = {};", w_res.nrows).unwrap();
149        writeln!(code, "pub const W_RES_NCOLS: usize = {};", w_res.ncols).unwrap();
150        writeln!(code, "pub const W_IN_NNZ: usize = {};", w_in.values.len()).unwrap();
151        writeln!(code, "pub const W_RES_NNZ: usize = {};", w_res.values.len()).unwrap();
152        writeln!(code).unwrap();
153
154        let fmt_u16 = |v: &[usize]| -> String {
155            v.iter()
156                .map(|&x| format!("{}", x as u16))
157                .collect::<Vec<_>>()
158                .join(", ")
159        };
160        let fmt_scalar = |v: &[S]| -> String {
161            v.iter()
162                .map(|x| format!("{:.8}", x))
163                .collect::<Vec<_>>()
164                .join(", ")
165        };
166
167        writeln!(
168            code,
169            "pub const W_IN_ROW_PTR: [u16; {}] = [{}];",
170            w_in.row_ptr.len(),
171            fmt_u16(&w_in.row_ptr)
172        )
173        .unwrap();
174        writeln!(
175            code,
176            "pub const W_IN_COL_IDX: [u16; {}] = [{}];",
177            w_in.col_idx.len(),
178            fmt_u16(&w_in.col_idx)
179        )
180        .unwrap();
181        writeln!(
182            code,
183            "pub const W_IN_VALUES: [f32; {}] = [{}];",
184            w_in.values.len(),
185            fmt_scalar(&w_in.values)
186        )
187        .unwrap();
188        writeln!(code).unwrap();
189
190        writeln!(
191            code,
192            "pub const W_RES_ROW_PTR: [u16; {}] = [{}];",
193            w_res.row_ptr.len(),
194            fmt_u16(&w_res.row_ptr)
195        )
196        .unwrap();
197        writeln!(
198            code,
199            "pub const W_RES_COL_IDX: [u16; {}] = [{}];",
200            w_res.col_idx.len(),
201            fmt_u16(&w_res.col_idx)
202        )
203        .unwrap();
204        writeln!(
205            code,
206            "pub const W_RES_VALUES: [f32; {}] = [{}];",
207            w_res.values.len(),
208            fmt_scalar(&w_res.values)
209        )
210        .unwrap();
211        writeln!(code).unwrap();
212
213        let w_out_flat: Vec<S> = w_out.iter().cloned().collect();
214        writeln!(
215            code,
216            "pub const W_OUT_DATA: [f32; {}] = [{}];",
217            w_out_flat.len(),
218            fmt_scalar(&w_out_flat)
219        )
220        .unwrap();
221        writeln!(code).unwrap();
222
223        let state_flat: Vec<S> = initial_state.iter().cloned().collect();
224        writeln!(
225            code,
226            "pub const INITIAL_STATE_DATA: [f32; {}] = [{}];",
227            state_flat.len(),
228            fmt_scalar(&state_flat)
229        )
230        .unwrap();
231
232        Ok(code)
233    }
234}
235
236/// Trait for accessing readout weights as a dense matrix.
237///
238/// The code generator needs a uniform way to retrieve the output weight matrix
239/// (`W_out`) from different readout implementations.
240///
241/// Implementations are provided for [`RidgeReadout`] and [`LassoReadout`].
242pub trait GetWeights<S: Scalar> {
243    /// Borrow the readout weight matrix.
244    ///
245    /// The matrix shape is `(output_dim, extended_state_dim)` for the readouts in
246    /// `reservoir-infer`.
247    fn weights(&self) -> &nalgebra::DMatrix<S>;
248}
249
250impl<S: Scalar> GetWeights<S> for RidgeReadout<S> {
251    fn weights(&self) -> &nalgebra::DMatrix<S> {
252        &self.w_out
253    }
254}
255
256impl<S: Scalar> GetWeights<S> for LassoReadout<S> {
257    fn weights(&self) -> &nalgebra::DMatrix<S> {
258        &self.w_out
259    }
260}