1#![allow(clippy::multiple_crate_versions)]
21
22use std::{borrow::Cow, num::NonZeroUsize, ops::AddAssign};
23
24use burn::{
25 backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
26 module::{Module, Param},
27 nn::loss::{MseLoss, Reduction},
28 optim::{AdamConfig, GradientsParams, Optimizer},
29 prelude::Backend,
30 record::{
31 BinBytesRecorder, DoublePrecisionSettings, FullPrecisionSettings, PrecisionSettings,
32 Record, Recorder, RecorderError,
33 },
34 tensor::{
35 backend::AutodiffBackend, Distribution, Element as BurnElement, Float, Tensor, TensorData,
36 },
37};
38use itertools::Itertools;
39use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Ix1, Order, Zip};
40use num_traits::{ConstOne, ConstZero, Float as FloatTrait, FromPrimitive};
41use numcodecs::{
42 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
43 Codec, StaticCodec, StaticCodecConfig,
44};
45use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
46use serde::{Deserialize, Deserializer, Serialize, Serializer};
47use thiserror::Error;
48
49#[cfg(test)]
50use ::serde_json as _;
51
52mod modules;
53
54use modules::{Model, ModelConfig, ModelExtra, ModelRecord};
55
56#[derive(Clone, Serialize, Deserialize, JsonSchema)]
57#[serde(deny_unknown_fields)]
58pub struct FourierNetworkCodec {
65 pub fourier_features: NonZeroUsize,
67 pub fourier_scale: Positive<f64>,
69 pub num_blocks: NonZeroUsize,
71 pub learning_rate: Positive<f64>,
73 pub num_epochs: usize,
75 #[serde(deserialize_with = "deserialize_required_option")]
81 #[schemars(required, extend("type" = ["integer", "null"]))]
82 pub mini_batch_size: Option<NonZeroUsize>,
83 pub seed: u64,
85}
86
87fn deserialize_required_option<'de, T: serde::Deserialize<'de>, D: serde::Deserializer<'de>>(
89 deserializer: D,
90) -> Result<Option<T>, D::Error> {
91 Option::<T>::deserialize(deserializer)
92}
93
94impl Codec for FourierNetworkCodec {
95 type Error = FourierNetworkCodecError;
96
97 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
98 match data {
99 AnyCowArray::F32(data) => Ok(AnyArray::U8(
100 encode::<f32, _, _, Autodiff<NdArray<f32>>>(
101 &NdArrayDevice::Cpu,
102 data,
103 self.fourier_features,
104 self.fourier_scale,
105 self.num_blocks,
106 self.learning_rate,
107 self.num_epochs,
108 self.mini_batch_size,
109 self.seed,
110 )?
111 .into_dyn(),
112 )),
113 AnyCowArray::F64(data) => Ok(AnyArray::U8(
114 encode::<f64, _, _, Autodiff<NdArray<f64>>>(
115 &NdArrayDevice::Cpu,
116 data,
117 self.fourier_features,
118 self.fourier_scale,
119 self.num_blocks,
120 self.learning_rate,
121 self.num_epochs,
122 self.mini_batch_size,
123 self.seed,
124 )?
125 .into_dyn(),
126 )),
127 encoded => Err(FourierNetworkCodecError::UnsupportedDtype(encoded.dtype())),
128 }
129 }
130
131 fn decode(&self, _encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
132 Err(FourierNetworkCodecError::MissingDecodingOutput)
133 }
134
135 fn decode_into(
136 &self,
137 encoded: AnyArrayView,
138 decoded: AnyArrayViewMut,
139 ) -> Result<(), Self::Error> {
140 let AnyArrayView::U8(encoded) = encoded else {
141 return Err(FourierNetworkCodecError::EncodedDataNotBytes {
142 dtype: encoded.dtype(),
143 });
144 };
145
146 let Ok(encoded): Result<ArrayBase<_, Ix1>, _> = encoded.view().into_dimensionality() else {
147 return Err(FourierNetworkCodecError::EncodedDataNotOneDimensional {
148 shape: encoded.shape().to_vec(),
149 });
150 };
151
152 match decoded {
153 AnyArrayViewMut::F32(decoded) => decode_into::<f32, _, _, NdArray<f32>>(
154 &NdArrayDevice::Cpu,
155 encoded,
156 decoded,
157 self.fourier_features,
158 self.num_blocks,
159 ),
160 AnyArrayViewMut::F64(decoded) => decode_into::<f64, _, _, NdArray<f64>>(
161 &NdArrayDevice::Cpu,
162 encoded,
163 decoded,
164 self.fourier_features,
165 self.num_blocks,
166 ),
167 decoded => Err(FourierNetworkCodecError::UnsupportedDtype(decoded.dtype())),
168 }
169 }
170}
171
172impl StaticCodec for FourierNetworkCodec {
173 const CODEC_ID: &'static str = "fourier-network";
174
175 type Config<'de> = Self;
176
177 fn from_config(config: Self::Config<'_>) -> Self {
178 config
179 }
180
181 fn get_config(&self) -> StaticCodecConfig<Self> {
182 StaticCodecConfig::from(self)
183 }
184}
185
186#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
188pub struct Positive<T: FloatTrait>(T);
190
191impl Serialize for Positive<f64> {
192 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
193 serializer.serialize_f64(self.0)
194 }
195}
196
197impl<'de> Deserialize<'de> for Positive<f64> {
198 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
199 let x = f64::deserialize(deserializer)?;
200
201 if x > 0.0 {
202 Ok(Self(x))
203 } else {
204 Err(serde::de::Error::invalid_value(
205 serde::de::Unexpected::Float(x),
206 &"a positive value",
207 ))
208 }
209 }
210}
211
212impl JsonSchema for Positive<f64> {
213 fn schema_name() -> Cow<'static, str> {
214 Cow::Borrowed("PositiveF64")
215 }
216
217 fn schema_id() -> Cow<'static, str> {
218 Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
219 }
220
221 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
222 json_schema!({
223 "type": "number",
224 "exclusiveMinimum": 0.0
225 })
226 }
227}
228
229#[derive(Debug, Error)]
230pub enum FourierNetworkCodecError {
232 #[error("FourierNetwork does not support the dtype {0}")]
234 UnsupportedDtype(AnyArrayDType),
235 #[error("FourierNetwork does not support non-finite (infinite or NaN) floating point data")]
238 NonFiniteData,
239 #[error("FourierNetwork failed during a neural network computation")]
241 NeuralNetworkError {
242 #[from]
244 source: NeuralNetworkError,
245 },
246 #[error("FourierNetwork must be provided the output array during decoding")]
248 MissingDecodingOutput,
249 #[error(
252 "FourierNetwork can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
253 )]
254 EncodedDataNotBytes {
255 dtype: AnyArrayDType,
257 },
258 #[error("FourierNetwork can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
261 EncodedDataNotOneDimensional {
262 shape: Vec<usize>,
264 },
265 #[error("FourierNetwork cannot decode into the provided array")]
267 MismatchedDecodeIntoArray {
268 #[from]
270 source: AnyArrayAssignError,
271 },
272}
273
274#[derive(Debug, Error)]
275#[error(transparent)]
276pub struct NeuralNetworkError(RecorderError);
278
279pub trait FloatExt:
281 AddAssign + BurnElement + ConstOne + ConstZero + FloatTrait + FromPrimitive
282{
283 type Precision: PrecisionSettings;
285
286 fn from_usize(x: usize) -> Self;
288}
289
290impl FloatExt for f32 {
291 type Precision = FullPrecisionSettings;
292
293 #[allow(clippy::cast_precision_loss)]
294 fn from_usize(x: usize) -> Self {
295 x as Self
296 }
297}
298
299impl FloatExt for f64 {
300 type Precision = DoublePrecisionSettings;
301
302 #[allow(clippy::cast_precision_loss)]
303 fn from_usize(x: usize) -> Self {
304 x as Self
305 }
306}
307
308#[allow(clippy::similar_names)] #[allow(clippy::missing_panics_doc)] #[allow(clippy::too_many_arguments)] pub fn encode<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: AutodiffBackend<FloatElem = T>>(
331 device: &B::Device,
332 data: ArrayBase<S, D>,
333 fourier_features: NonZeroUsize,
334 fourier_scale: Positive<f64>,
335 num_blocks: NonZeroUsize,
336 learning_rate: Positive<f64>,
337 num_epochs: usize,
338 mini_batch_size: Option<NonZeroUsize>,
339 seed: u64,
340) -> Result<Array<u8, Ix1>, FourierNetworkCodecError> {
341 let Some(mean) = data.mean() else {
342 return Ok(Array::from_vec(Vec::new()));
343 };
344 let stdv = data.std(T::ZERO);
345 let stdv = if stdv == T::ZERO { T::ONE } else { stdv };
346
347 if !Zip::from(&data).all(|x| x.is_finite()) {
348 return Err(FourierNetworkCodecError::NonFiniteData);
349 }
350
351 B::seed(seed);
352
353 let b_t = Tensor::<B, 2, Float>::random(
354 [data.ndim(), fourier_features.get()],
355 Distribution::Normal(0.0, fourier_scale.0),
356 device,
357 );
358
359 let train_xs = flat_grid_like(&data, device);
360 let train_xs = fourier_mapping(train_xs, b_t.clone());
361
362 let train_ys_shape = [data.len(), 1];
363 let mut train_ys = data.into_owned();
364 train_ys.mapv_inplace(|x| (x - mean) / stdv);
365 #[allow(clippy::unwrap_used)] let train_ys = train_ys
367 .into_shape_clone((train_ys_shape, Order::RowMajor))
368 .unwrap();
369 let train_ys = Tensor::from_data(
370 TensorData::new(train_ys.into_raw_vec_and_offset().0, train_ys_shape),
371 device,
372 );
373
374 let model = train(
375 device,
376 &train_xs,
377 &train_ys,
378 fourier_features,
379 num_blocks,
380 learning_rate,
381 num_epochs,
382 mini_batch_size,
383 stdv,
384 );
385
386 let extra = ModelExtra {
387 model,
388 b_t: Param::from_tensor(b_t).set_require_grad(false),
389 mean: Param::from_tensor(Tensor::from_data(
390 TensorData::new(vec![mean], vec![1]),
391 device,
392 ))
393 .set_require_grad(false),
394 stdv: Param::from_tensor(Tensor::from_data(
395 TensorData::new(vec![stdv], vec![1]),
396 device,
397 ))
398 .set_require_grad(false),
399 };
400
401 let recorder = BinBytesRecorder::<T::Precision>::new();
402 let encoded = recorder
403 .record(extra.into_record(), ())
404 .map_err(NeuralNetworkError)?;
405
406 Ok(Array::from_vec(encoded))
407}
408
409#[allow(clippy::missing_panics_doc)] pub fn decode_into<T: FloatExt, S: Data<Elem = u8>, D: Dimension, B: Backend<FloatElem = T>>(
424 device: &B::Device,
425 encoded: ArrayBase<S, Ix1>,
426 mut decoded: ArrayViewMut<T, D>,
427 fourier_features: NonZeroUsize,
428 num_blocks: NonZeroUsize,
429) -> Result<(), FourierNetworkCodecError> {
430 if encoded.is_empty() {
431 if decoded.is_empty() {
432 return Ok(());
433 }
434
435 return Err(FourierNetworkCodecError::MismatchedDecodeIntoArray {
436 source: AnyArrayAssignError::ShapeMismatch {
437 src: encoded.shape().to_vec(),
438 dst: decoded.shape().to_vec(),
439 },
440 });
441 }
442
443 let encoded = encoded.into_owned().into_raw_vec_and_offset().0;
444
445 let recorder = BinBytesRecorder::<T::Precision>::new();
446 let record = recorder.load(encoded, device).map_err(NeuralNetworkError)?;
447
448 let extra = ModelExtra::<B> {
449 model: ModelConfig::new(fourier_features, num_blocks).init(device),
450 b_t: Param::from_tensor(Tensor::zeros(
451 [decoded.ndim(), fourier_features.get()],
452 device,
453 ))
454 .set_require_grad(false),
455 mean: Param::from_tensor(Tensor::zeros([1], device)).set_require_grad(false),
456 stdv: Param::from_tensor(Tensor::ones([1], device)).set_require_grad(false),
457 }
458 .load_record(record);
459
460 let model = extra.model;
461 let b_t = extra.b_t.into_value();
462 let mean = extra.mean.into_value().into_scalar();
463 let stdv = extra.stdv.into_value().into_scalar();
464
465 let test_xs = flat_grid_like(&decoded, device);
466 let test_xs = fourier_mapping(test_xs, b_t);
467
468 let prediction = model.forward(test_xs).into_data();
469 #[allow(clippy::unwrap_used)] let prediction = prediction.as_slice().unwrap();
471
472 #[allow(clippy::unwrap_used)] decoded.assign(&ArrayView::from_shape(decoded.shape(), prediction).unwrap());
474 decoded.mapv_inplace(|x| (x * stdv) + mean);
475
476 Ok(())
477}
478
479fn flat_grid_like<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: Backend<FloatElem = T>>(
480 a: &ArrayBase<S, D>,
481 device: &B::Device,
482) -> Tensor<B, 2, Float> {
483 let grid = a
484 .shape()
485 .iter()
486 .copied()
487 .map(|s| {
488 #[allow(clippy::useless_conversion)] (0..s)
490 .into_iter()
491 .map(move |x| <T as FloatExt>::from_usize(x) / <T as FloatExt>::from_usize(s))
492 })
493 .multi_cartesian_product()
494 .flatten()
495 .collect::<Vec<_>>();
496
497 Tensor::from_data(TensorData::new(grid, [a.len(), a.ndim()]), device)
498}
499
500fn fourier_mapping<B: Backend>(
501 xs: Tensor<B, 2, Float>,
502 b_t: Tensor<B, 2, Float>,
503) -> Tensor<B, 2, Float> {
504 let xs_proj = xs.mul_scalar(core::f64::consts::TAU).matmul(b_t);
505
506 Tensor::cat(vec![xs_proj.clone().sin(), xs_proj.cos()], 1)
507}
508
509#[allow(clippy::similar_names)] #[allow(clippy::too_many_arguments)] fn train<T: FloatExt, B: AutodiffBackend<FloatElem = T>>(
512 device: &B::Device,
513 train_xs: &Tensor<B, 2, Float>,
514 train_ys: &Tensor<B, 2, Float>,
515 fourier_features: NonZeroUsize,
516 num_blocks: NonZeroUsize,
517 learning_rate: Positive<f64>,
518 num_epochs: usize,
519 mini_batch_size: Option<NonZeroUsize>,
520 stdv: T,
521) -> Model<B> {
522 let num_samples = train_ys.shape().num_elements();
523 let num_batches = mini_batch_size.map(|b| num_samples.div_ceil(b.get()));
524
525 let mut model = ModelConfig::new(fourier_features, num_blocks).init(device);
526 let mut optim = AdamConfig::new().init();
527
528 let mut best_loss = T::infinity();
529 let mut best_epoch = 0;
530 let mut best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
531
532 for epoch in 1..=num_epochs {
533 #[allow(clippy::option_if_let_else)]
534 let (train_xs_batches, train_ys_batches) = match num_batches {
535 Some(num_batches) => {
536 let shuffle = Tensor::<B, 1, Float>::random(
537 [num_samples],
538 Distribution::Uniform(0.0, 1.0),
539 device,
540 );
541 let shuffle_indices = shuffle.argsort(0);
542
543 let train_xs_shuffled = train_xs.clone().select(0, shuffle_indices.clone());
544 let train_ys_shuffled = train_ys.clone().select(0, shuffle_indices);
545
546 (
547 train_xs_shuffled.chunk(num_batches, 0),
548 train_ys_shuffled.chunk(num_batches, 0),
549 )
550 }
551 None => (vec![train_xs.clone()], vec![train_ys.clone()]),
552 };
553
554 let mut loss_sum = T::ZERO;
555
556 let mut se_sum = T::ZERO;
557 let mut ae_sum = T::ZERO;
558 let mut l_inf = T::ZERO;
559
560 for (train_xs_batch, train_ys_batch) in train_xs_batches.into_iter().zip(train_ys_batches) {
561 let prediction = model.forward(train_xs_batch);
562 let loss =
563 MseLoss::new().forward(prediction.clone(), train_ys_batch.clone(), Reduction::Mean);
564
565 let grads = GradientsParams::from_grads(loss.backward(), &model);
566 model = optim.step(learning_rate.0, model, grads);
567
568 loss_sum += loss.into_scalar();
569
570 let err = prediction - train_ys_batch;
571
572 se_sum += (err.clone() * err.clone()).sum().into_scalar();
573 ae_sum += err.clone().abs().sum().into_scalar();
574 l_inf = l_inf.max(err.abs().max().into_scalar());
575 }
576
577 let loss_mean = loss_sum / <T as FloatExt>::from_usize(num_batches.unwrap_or(1));
578
579 if loss_mean < best_loss {
580 best_loss = loss_mean;
581 best_epoch = epoch;
582 best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
583 }
584
585 let rmse = stdv * (se_sum / <T as FloatExt>::from_usize(num_samples)).sqrt();
586 let mae = stdv * ae_sum / <T as FloatExt>::from_usize(num_samples);
587 let l_inf = stdv * l_inf;
588
589 log::info!("[{epoch}/{num_epochs}]: loss={loss_mean:0.3} MAE={mae:0.3} RMSE={rmse:0.3} Linf={l_inf:0.3}");
590 }
591
592 if best_epoch != num_epochs {
593 model = model.load_record(ModelRecord::from_item(best_model_checkpoint, device));
594
595 log::info!("restored from epoch {best_epoch} with lowest loss={best_loss:0.3}");
596 }
597
598 model
599}
600
601#[cfg(test)]
602#[allow(clippy::unwrap_used)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn empty() {
608 std::mem::drop(simple_logger::init());
609
610 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
611 &NdArrayDevice::Cpu,
612 Array::<f32, _>::zeros((0,)),
613 NonZeroUsize::MIN,
614 Positive(1.0),
615 NonZeroUsize::MIN,
616 Positive(1e-4),
617 10,
618 None,
619 42,
620 )
621 .unwrap();
622 assert!(encoded.is_empty());
623 let mut decoded = Array::<f32, _>::zeros((0,));
624 decode_into::<f32, _, _, NdArray<f32>>(
625 &NdArrayDevice::Cpu,
626 encoded,
627 decoded.view_mut(),
628 NonZeroUsize::MIN,
629 NonZeroUsize::MIN,
630 )
631 .unwrap();
632 }
633
634 #[test]
635 fn ones() {
636 std::mem::drop(simple_logger::init());
637
638 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
639 &NdArrayDevice::Cpu,
640 Array::<f32, _>::zeros((1, 1, 1, 1)),
641 NonZeroUsize::MIN,
642 Positive(1.0),
643 NonZeroUsize::MIN,
644 Positive(1e-4),
645 10,
646 None,
647 42,
648 )
649 .unwrap();
650 let mut decoded = Array::<f32, _>::zeros((1, 1, 1, 1));
651 decode_into::<f32, _, _, NdArray<f32>>(
652 &NdArrayDevice::Cpu,
653 encoded,
654 decoded.view_mut(),
655 NonZeroUsize::MIN,
656 NonZeroUsize::MIN,
657 )
658 .unwrap();
659 }
660
661 #[test]
662 fn r#const() {
663 std::mem::drop(simple_logger::init());
664
665 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
666 &NdArrayDevice::Cpu,
667 Array::<f32, _>::from_elem((2, 1, 3), 42.0),
668 NonZeroUsize::MIN,
669 Positive(1.0),
670 NonZeroUsize::MIN,
671 Positive(1e-4),
672 10,
673 None,
674 42,
675 )
676 .unwrap();
677 let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
678 decode_into::<f32, _, _, NdArray<f32>>(
679 &NdArrayDevice::Cpu,
680 encoded,
681 decoded.view_mut(),
682 NonZeroUsize::MIN,
683 NonZeroUsize::MIN,
684 )
685 .unwrap();
686 }
687
688 #[test]
689 fn const_batched() {
690 std::mem::drop(simple_logger::init());
691
692 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
693 &NdArrayDevice::Cpu,
694 Array::<f32, _>::from_elem((2, 1, 3), 42.0),
695 NonZeroUsize::MIN,
696 Positive(1.0),
697 NonZeroUsize::MIN,
698 Positive(1e-4),
699 10,
700 Some(NonZeroUsize::MIN.saturating_add(1)),
701 42,
702 )
703 .unwrap();
704 let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
705 decode_into::<f32, _, _, NdArray<f32>>(
706 &NdArrayDevice::Cpu,
707 encoded,
708 decoded.view_mut(),
709 NonZeroUsize::MIN,
710 NonZeroUsize::MIN,
711 )
712 .unwrap();
713 }
714
715 #[test]
716 fn linspace() {
717 std::mem::drop(simple_logger::init());
718
719 let data = Array::linspace(0.0_f64, 100.0_f64, 100);
720
721 let fourier_features = NonZeroUsize::new(16).unwrap();
722 let fourier_scale = Positive(10.0);
723 let num_blocks = NonZeroUsize::new(2).unwrap();
724 let learning_rate = Positive(1e-4);
725 let num_epochs = 100;
726 let seed = 42;
727
728 for mini_batch_size in [
729 None, Some(NonZeroUsize::MIN), Some(NonZeroUsize::MIN.saturating_add(6)), Some(NonZeroUsize::MIN.saturating_add(9)), Some(NonZeroUsize::MIN.saturating_add(1000)), ] {
735 let mut decoded = Array::<f64, _>::zeros(data.shape());
736 let encoded = encode::<f64, _, _, Autodiff<NdArray<f64>>>(
737 &NdArrayDevice::Cpu,
738 data.view(),
739 fourier_features,
740 fourier_scale,
741 num_blocks,
742 learning_rate,
743 num_epochs,
744 mini_batch_size,
745 seed,
746 )
747 .unwrap();
748
749 decode_into::<f64, _, _, NdArray<f64>>(
750 &NdArrayDevice::Cpu,
751 encoded,
752 decoded.view_mut(),
753 fourier_features,
754 num_blocks,
755 )
756 .unwrap();
757 }
758 }
759}