use burn::backend::NdArray;
use burn::tensor::Tensor;
use ncps::rnn::CfC;
fn main() {
println!("=== NCPS Training Example ===\n");
type Backend = NdArray<f32>;
let device = Default::default();
let model = CfC::<Backend>::new(8, 16, &device);
println!("Training setup:");
println!(" Model: CfC with 8 input, 16 hidden units");
println!(" Task: Predict next value in sinusoidal sequence");
println!();
let seq_len = 20;
let batch_size = 4;
let input_features = 8;
println!(
"Generating {} synthetic sequences of length {}",
batch_size, seq_len
);
let input = Tensor::<Backend, 3>::random(
[batch_size, seq_len, input_features],
burn::tensor::Distribution::Uniform(-1.0, 1.0),
&device,
);
let target = Tensor::<Backend, 3>::random(
[batch_size, seq_len, input_features],
burn::tensor::Distribution::Uniform(-1.0, 1.0),
&device,
);
println!(
" Input shape: {:?}",
[batch_size, seq_len, input_features]
);
println!(
" Target shape: {:?}",
[batch_size, seq_len, input_features]
);
println!();
println!("Forward pass...");
let (predictions, _) = model.forward(input, None, None);
println!(" Prediction shape: {:?}", predictions.dims());
println!();
println!("Training loop structure:");
println!(" 1. Forward: predictions = model.forward(input)");
println!(" 2. Loss: loss = mse(predictions, target)");
println!(" 3. Backward: grads = loss.backward()");
println!(" 4. Update: model.update_params(grads, optimizer)");
println!();
println!("State persistence demo:");
let initial_state = None;
let (out1, state1) = model.forward(
Tensor::<Backend, 3>::random(
[1, 5, 8],
burn::tensor::Distribution::Uniform(-1.0, 1.0),
&device,
),
initial_state,
None,
);
let (out2, state2) = model.forward(
Tensor::<Backend, 3>::random(
[1, 5, 8],
burn::tensor::Distribution::Uniform(-1.0, 1.0),
&device,
),
Some(state1),
None,
);
println!(" Batch 1 processed, state updated");
println!(" Batch 2 processed with previous state");
println!(" This is useful for continuous learning!");
println!();
println!("=== Training Example completed! ===");
println!("\nNext steps:");
println!(" - Set up an autodiff backend (WGPU or Candle)");
println!(" - Define your loss function");
println!(" - Create training data loaders");
println!(" - Implement optimization loop");
}