reservoir_train/codegen.rs
1use crate::float::RealScalar;
2use reservoir_core::types::Scalar;
3use reservoir_infer::{EchoStateNetwork, LassoReadout, RidgeReadout, SparseReservoir};
4use std::fmt::Write;
5
6/// Static model code generation utilities.
7///
8/// This module provides helpers to **export a trained model** (currently: sparse ESN)
9/// into **Rust source code** that defines the model weights as `const` arrays.
10///
11/// The generated code is intended for **embedded / `no_std` inference** use cases,
12/// where dynamic allocation and file I/O are undesirable:
13///
14/// - CSR matrices (`W_in`, `W_res`) are emitted as `u16` index arrays plus `f32` values.
15/// - The readout matrix (`W_out`) is emitted as a flattened `f32` array.
16/// - The initial reservoir state is emitted as a flattened `f32` array.
17///
18/// # Output format
19/// `generate_sparse_code` returns a single `String` containing Rust code with:
20/// - `INPUT_DIM`, `RESERVOIR_SIZE`, `OUTPUT_DIM`, `EXTENDED_SIZE`
21/// - `LEAKING_RATE`
22/// - CSR metadata (`*_NROWS`, `*_NCOLS`, `*_NNZ`)
23/// - CSR arrays: `*_ROW_PTR`, `*_COL_IDX`, `*_VALUES`
24/// - `W_OUT_DATA` and `INITIAL_STATE_DATA`
25///
26/// The generated snippet includes `use` statements for `reservoir-infer` static
27/// reservoirs/readouts and `nalgebra` fixed-size types, so it can be pasted into
28/// a target crate/module with minimal editing.
29///
30/// # Limits / invariants
31/// - CSR indices are emitted as `u16`. Therefore the number of non-zeros (NNZ)
32/// must fit into `u16::MAX` for both `W_in` and `W_res`.
33/// - All scalar values are formatted as `f32` with 8 decimal digits.
34/// This is a deliberate tradeoff for portability and code size.
35/// - The extended state layout assumed here matches `reservoir-infer` reservoirs:
36/// `[bias(1), input(input_dim), reservoir_state(reservoir_size)]`.
37///
38/// # Feature gating
39/// This code generator uses `std::fmt::Write` and returns an owned `String`,
40/// so it is typically compiled behind the `std` feature of `reservoir-train`.
41///
42/// # Example
43/// ```no_run
44/// # use reservoir_train::{ESNBuilder, ESNFitRidge};
45/// # use reservoir_train::StaticModelGenerator;
46/// // Train a sparse ESN (ridge readout).
47/// let mut esn = ESNBuilder::<f32>::new(1, 1)
48/// .units(200)
49/// .connectivity(8)
50/// .input_connectivity(1)
51/// .spectral_radius(1.2)
52/// .leaking_rate(0.8)
53/// .seed(42)
54/// .build_sparse();
55///
56/// // Fit (dummy example; supply your actual data here).
57/// // esn.fit(&inputs, &targets, 1e-6, 50);
58///
59/// // Export as Rust code (paste into your embedded inference crate).
60/// let code = StaticModelGenerator::generate_sparse_code(&esn).unwrap();
61/// println!("{}", code);
62/// ```
63pub struct StaticModelGenerator;
64
65impl StaticModelGenerator {
66 /// Generate Rust source code for a trained **sparse** Echo State Network.
67 ///
68 /// This function inspects the provided ESN and serializes its parameters into
69 /// Rust `const` definitions suitable for `no_std` inference:
70 ///
71 /// - `W_in` and `W_res` (reservoir matrices) are emitted in CSR form, using `u16`
72 /// indices (`ROW_PTR`, `COL_IDX`) and `f32` values (`VALUES`).
73 /// - The readout weight matrix `W_out` is emitted as a flattened `f32` array.
74 /// - The initial reservoir state is emitted as a flattened `f32` array.
75 ///
76 /// The output includes enough metadata (dimensions / NNZ counts) to validate
77 /// the arrays at compile time and to reconstruct the static reservoir/readout
78 /// types in a downstream crate.
79 ///
80 /// # Type parameters
81 /// - `S`: training scalar type. Must be convertible/printable (`RealScalar + Display`)
82 /// because values are formatted into source code.
83 /// - `O`: readout type. Must implement `reservoir_core::Readout<S>` and [`GetWeights`]
84 /// so this generator can access the dense output weight matrix.
85 ///
86 /// # Errors
87 /// Returns `Err(String)` if the number of non-zeros (NNZ) in `W_in` or `W_res`
88 /// exceeds `u16::MAX`, because the generated CSR arrays use `u16` indices.
89 ///
90 /// # Notes
91 /// - All emitted numeric values are formatted as `f32` with 8 decimals.
92 /// - This function does not write files; it returns a `String` to give callers
93 /// full control over where the generated code is stored.
94 pub fn generate_sparse_code<S, O>(
95 esn: &EchoStateNetwork<S, SparseReservoir<S>, O>,
96 ) -> Result<String, String>
97 where
98 S: RealScalar + std::fmt::Display,
99 O: reservoir_core::Readout<S> + GetWeights<S>,
100 {
101 let w_in = &esn.reservoir.w_in;
102 let w_res = &esn.reservoir.w;
103 let w_out = esn.readout.weights();
104
105 let input_dim = esn.reservoir.input_dim;
106 let reservoir_size = esn.reservoir.res_state.len();
107 let output_dim = esn.readout.output_dim();
108 let ext_size = 1 + input_dim + reservoir_size;
109 let leaking_rate = esn.reservoir.leaking_rate;
110 let initial_state = &esn.reservoir.res_state;
111
112 if w_in.values.len() > u16::MAX as usize {
113 return Err(format!("W_in NNZ ({}) exceeds u16::MAX", w_in.values.len()));
114 }
115 if w_res.values.len() > u16::MAX as usize {
116 return Err(format!(
117 "W_res NNZ ({}) exceeds u16::MAX",
118 w_res.values.len()
119 ));
120 }
121
122 let mut code = String::new();
123
124 writeln!(code, "// Auto-generated by reservoir-train::codegen").unwrap();
125 writeln!(code, "use reservoir_infer::reservoir::static_sparse_reservoir::{{StaticCsrMatrix, StaticSparseReservoir}};").unwrap();
126 writeln!(
127 code,
128 "use reservoir_infer::readout::static_readout::StaticReadout;"
129 )
130 .unwrap();
131 writeln!(code, "use nalgebra::{{SMatrix, SVector}};").unwrap();
132 writeln!(code).unwrap();
133
134 writeln!(code, "pub const INPUT_DIM: usize = {};", input_dim).unwrap();
135 writeln!(
136 code,
137 "pub const RESERVOIR_SIZE: usize = {};",
138 reservoir_size
139 )
140 .unwrap();
141 writeln!(code, "pub const OUTPUT_DIM: usize = {};", output_dim).unwrap();
142 writeln!(code, "pub const EXTENDED_SIZE: usize = {};", ext_size).unwrap();
143 writeln!(code, "pub const LEAKING_RATE: f32 = {:.8};", leaking_rate).unwrap();
144 writeln!(code).unwrap();
145
146 writeln!(code, "pub const W_IN_NROWS: usize = {};", w_in.nrows).unwrap();
147 writeln!(code, "pub const W_IN_NCOLS: usize = {};", w_in.ncols).unwrap();
148 writeln!(code, "pub const W_RES_NROWS: usize = {};", w_res.nrows).unwrap();
149 writeln!(code, "pub const W_RES_NCOLS: usize = {};", w_res.ncols).unwrap();
150 writeln!(code, "pub const W_IN_NNZ: usize = {};", w_in.values.len()).unwrap();
151 writeln!(code, "pub const W_RES_NNZ: usize = {};", w_res.values.len()).unwrap();
152 writeln!(code).unwrap();
153
154 let fmt_u16 = |v: &[usize]| -> String {
155 v.iter()
156 .map(|&x| format!("{}", x as u16))
157 .collect::<Vec<_>>()
158 .join(", ")
159 };
160 let fmt_scalar = |v: &[S]| -> String {
161 v.iter()
162 .map(|x| format!("{:.8}", x))
163 .collect::<Vec<_>>()
164 .join(", ")
165 };
166
167 writeln!(
168 code,
169 "pub const W_IN_ROW_PTR: [u16; {}] = [{}];",
170 w_in.row_ptr.len(),
171 fmt_u16(&w_in.row_ptr)
172 )
173 .unwrap();
174 writeln!(
175 code,
176 "pub const W_IN_COL_IDX: [u16; {}] = [{}];",
177 w_in.col_idx.len(),
178 fmt_u16(&w_in.col_idx)
179 )
180 .unwrap();
181 writeln!(
182 code,
183 "pub const W_IN_VALUES: [f32; {}] = [{}];",
184 w_in.values.len(),
185 fmt_scalar(&w_in.values)
186 )
187 .unwrap();
188 writeln!(code).unwrap();
189
190 writeln!(
191 code,
192 "pub const W_RES_ROW_PTR: [u16; {}] = [{}];",
193 w_res.row_ptr.len(),
194 fmt_u16(&w_res.row_ptr)
195 )
196 .unwrap();
197 writeln!(
198 code,
199 "pub const W_RES_COL_IDX: [u16; {}] = [{}];",
200 w_res.col_idx.len(),
201 fmt_u16(&w_res.col_idx)
202 )
203 .unwrap();
204 writeln!(
205 code,
206 "pub const W_RES_VALUES: [f32; {}] = [{}];",
207 w_res.values.len(),
208 fmt_scalar(&w_res.values)
209 )
210 .unwrap();
211 writeln!(code).unwrap();
212
213 let w_out_flat: Vec<S> = w_out.iter().cloned().collect();
214 writeln!(
215 code,
216 "pub const W_OUT_DATA: [f32; {}] = [{}];",
217 w_out_flat.len(),
218 fmt_scalar(&w_out_flat)
219 )
220 .unwrap();
221 writeln!(code).unwrap();
222
223 let state_flat: Vec<S> = initial_state.iter().cloned().collect();
224 writeln!(
225 code,
226 "pub const INITIAL_STATE_DATA: [f32; {}] = [{}];",
227 state_flat.len(),
228 fmt_scalar(&state_flat)
229 )
230 .unwrap();
231
232 Ok(code)
233 }
234}
235
236/// Trait for accessing readout weights as a dense matrix.
237///
238/// The code generator needs a uniform way to retrieve the output weight matrix
239/// (`W_out`) from different readout implementations.
240///
241/// Implementations are provided for [`RidgeReadout`] and [`LassoReadout`].
242pub trait GetWeights<S: Scalar> {
243 /// Borrow the readout weight matrix.
244 ///
245 /// The matrix shape is `(output_dim, extended_state_dim)` for the readouts in
246 /// `reservoir-infer`.
247 fn weights(&self) -> &nalgebra::DMatrix<S>;
248}
249
250impl<S: Scalar> GetWeights<S> for RidgeReadout<S> {
251 fn weights(&self) -> &nalgebra::DMatrix<S> {
252 &self.w_out
253 }
254}
255
256impl<S: Scalar> GetWeights<S> for LassoReadout<S> {
257 fn weights(&self) -> &nalgebra::DMatrix<S> {
258 &self.w_out
259 }
260}