Skip to main content

Variable

Struct Variable 

Source
pub struct Variable { /* private fields */ }
Expand description

A tensor with automatic differentiation support.

Variable wraps a Tensor and tracks operations performed on it to enable automatic gradient computation. When requires_grad is true, all operations are recorded in a computational graph.

Implementations§

Source§

impl Variable

Source

pub fn new(data: Tensor<f32>, requires_grad: bool) -> Variable

Creates a new variable from a tensor.

§Arguments
  • data - The tensor data
  • requires_grad - Whether to track gradients for this variable
Examples found in repository?
examples/hvac_model.rs (line 190)
165    fn mean_pool(&self, x: &Variable) -> Variable {
166        let data = x.data();
167        let shape = data.shape();
168        let batch_size = shape[0];
169        let seq_len = shape[1];
170        let hidden = shape[2];
171
172        // Reshape to [batch * seq, hidden] then back
173        let values = data.to_vec();
174
175        // Calculate mean over sequence dimension
176        let mut pooled = vec![0.0f32; batch_size * hidden];
177        for b in 0..batch_size {
178            for h in 0..hidden {
179                let mut sum = 0.0;
180                for s in 0..seq_len {
181                    let idx = b * seq_len * hidden + s * hidden + h;
182                    sum += values[idx];
183                }
184                pooled[b * hidden + h] = sum / seq_len as f32;
185            }
186        }
187
188        let pooled_tensor = Tensor::from_vec(pooled, &[batch_size, hidden])
189            .expect("Failed to create pooled tensor");
190        Variable::new(pooled_tensor, x.requires_grad())
191    }
192
193    /// Forward pass returning logits for all 3 horizons
194    pub fn forward_multi(&self, x: &Variable) -> HvacOutput {
195        let x_data = x.data();
196        let shape = x_data.shape();
197        let batch_size = shape[0];
198        let seq_len = shape[1];
199        drop(x_data); // Release borrow
200
201        // Input projection: [batch, seq, features] -> [batch, seq, hidden]
202        // Reshape for linear: [batch * seq, features]
203        let x_flat = x.reshape(&[batch_size * seq_len, self.config.num_features]);
204        let proj = self.input_proj.forward(&x_flat);
205        let proj = self.input_norm.forward(&proj);
206        let proj = self.input_relu.forward(&proj);
207        let proj = proj.reshape(&[batch_size, seq_len, self.config.hidden_size]);
208
209        // GRU encoding: [batch, seq, hidden] -> [batch, seq, hidden]
210        let encoded = self.gru.forward(&proj);
211
212        // Mean pooling: [batch, seq, hidden] -> [batch, hidden]
213        let pooled = self.mean_pool(&encoded);
214
215        // Prediction heads
216        let imminent_logits = self.head_imminent.forward(&pooled);
217        let warning_logits = self.head_warning.forward(&pooled);
218        let early_logits = self.head_early.forward(&pooled);
219
220        HvacOutput {
221            imminent_logits,
222            warning_logits,
223            early_logits,
224        }
225    }
226
227    /// Get predicted classes (argmax of logits)
228    pub fn predict(&self, x: &Variable) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
229        let output = self.forward_multi(x);
230
231        let imminent_probs = self.softmax.forward(&output.imminent_logits);
232        let warning_probs = self.softmax.forward(&output.warning_logits);
233        let early_probs = self.softmax.forward(&output.early_logits);
234
235        (
236            argmax_batch(&imminent_probs),
237            argmax_batch(&warning_probs),
238            argmax_batch(&early_probs),
239        )
240    }
241
242    /// Returns the model configuration
243    pub fn config(&self) -> &HvacConfig {
244        &self.config
245    }
246
247    /// Returns the number of trainable parameters
248    pub fn num_parameters(&self) -> usize {
249        self.parameters()
250            .iter()
251            .map(|p| p.variable().data().numel())
252            .sum()
253    }
254}
255
256impl Module for HvacPredictor {
257    fn forward(&self, x: &Variable) -> Variable {
258        // Return concatenated logits for all horizons
259        let output = self.forward_multi(x);
260        // For single output, return imminent predictions
261        output.imminent_logits
262    }
263
264    fn parameters(&self) -> Vec<Parameter> {
265        let mut params = self.input_proj.parameters();
266        params.extend(self.input_norm.parameters());
267        params.extend(self.gru.parameters());
268        params.extend(self.head_imminent.parameters());
269        params.extend(self.head_warning.parameters());
270        params.extend(self.head_early.parameters());
271        params
272    }
273}
274
275// =============================================================================
276// Helper Functions
277// =============================================================================
278
279/// Get argmax for each sample in batch
280fn argmax_batch(x: &Variable) -> Vec<usize> {
281    let data = x.data();
282    let shape = data.shape();
283    let batch_size = shape[0];
284    let num_classes = shape[1];
285    let values = data.to_vec();
286
287    let mut results = Vec::with_capacity(batch_size);
288    for b in 0..batch_size {
289        let start = b * num_classes;
290        let end = start + num_classes;
291        let slice = &values[start..end];
292
293        let mut max_idx = 0;
294        let mut max_val = slice[0];
295        for (i, &v) in slice.iter().enumerate() {
296            if v > max_val {
297                max_val = v;
298                max_idx = i;
299            }
300        }
301        results.push(max_idx);
302    }
303    results
304}
305
306/// Failure type names
307pub const FAILURE_TYPES: [&str; 20] = [
308    "normal",
309    "pump_failure_hw_5",
310    "pump_failure_hw_6",
311    "pump_failure_cw_3",
312    "pump_failure_cw_4",
313    "pump_failure_2pipe_a",
314    "pump_failure_2pipe_b",
315    "pressure_low_hw",
316    "pressure_high_hw",
317    "pressure_low_cw",
318    "pressure_high_cw",
319    "temp_anomaly_hw_supply",
320    "temp_anomaly_cw_supply",
321    "temp_anomaly_space",
322    "valve_stuck_1_3",
323    "valve_stuck_2_3",
324    "vfd_fault",
325    "sensor_drift",
326    "chiller_fault",
327    "interlock_violation",
328];
329
330/// Feature names for the 28 sensor inputs
331pub const FEATURE_NAMES: [&str; 28] = [
332    "hw_pump_5_current",
333    "hw_pump_6_current",
334    "cw_pump_3_current",
335    "cw_pump_4_current",
336    "2pipe_pump_a_current",
337    "2pipe_pump_b_current",
338    "hw_supply_4pipe_temp",
339    "cw_supply_4pipe_temp",
340    "hw_supply_2pipe_temp",
341    "cw_return_2pipe_temp",
342    "outdoor_air_temp",
343    "mech_room_temp",
344    "space_sensor_1_temp",
345    "space_sensor_2_temp",
346    "hw_pressure_4pipe",
347    "cw_pressure_4pipe",
348    "hw_pump_5_vfd_speed",
349    "hw_pump_6_vfd_speed",
350    "cw_pump_3_vfd_speed",
351    "cw_pump_4_vfd_speed",
352    "2pipe_pump_a_vfd_speed",
353    "2pipe_pump_b_vfd_speed",
354    "steam_valve_1_3_pos",
355    "steam_valve_2_3_pos",
356    "summer_winter_mode",
357    "hw_lead_pump_id",
358    "cw_lead_pump_id",
359    "2pipe_lead_pump_id",
360];
361
362// =============================================================================
363// Main
364// =============================================================================
365
366fn main() {
367    println!("╔════════════════════════════════════════════════════════════╗");
368    println!("║     HVAC Multi-Horizon Predictor - AxonML Native           ║");
369    println!("╚════════════════════════════════════════════════════════════╝");
370    println!();
371
372    // Create model with default config
373    let config = HvacConfig::default();
374    println!("Model Configuration:");
375    println!("  Input features: {}", config.num_features);
376    println!("  Sequence length: {}", config.seq_len);
377    println!("  Hidden size: {}", config.hidden_size);
378    println!("  GRU layers: {}", config.num_layers);
379    println!("  Output classes: {}", config.num_classes);
380    println!("  Dropout: {}", config.dropout);
381    println!();
382
383    let model = HvacPredictor::new(config.clone());
384    println!("Model created!");
385    println!("  Total parameters: {}", model.num_parameters());
386    println!();
387
388    // Create sample input
389    let batch_size = 2;
390    let mut input_data = vec![0.5f32; batch_size * config.seq_len * config.num_features];
391
392    // Simulate normal HVAC readings
393    for b in 0..batch_size {
394        for t in 0..config.seq_len {
395            let base = (b * config.seq_len + t) * config.num_features;
396            // Pump currents ~25A (normalized)
397            for i in 0..6 {
398                input_data[base + i] = 0.5;
399            }
400            // Temperatures (normalized)
401            input_data[base + 6] = 0.83; // HW supply ~180F
402            input_data[base + 7] = 0.375; // CW supply ~55F
403            // VFD speeds ~60%
404            for i in 16..22 {
405                input_data[base + i] = 0.6;
406            }
407        }
408    }
409
410    let input = Tensor::from_vec(
411        input_data,
412        &[batch_size, config.seq_len, config.num_features],
413    )
414    .expect("Failed to create input tensor");
415
416    let input_var = Variable::new(input, false);
417    println!("Input shape: {:?}", input_var.data().shape());
418
419    // Run inference
420    println!();
421    println!("Running inference...");
422    let (imminent, warning, early) = model.predict(&input_var);
423
424    println!();
425    println!("Predictions:");
426    println!("────────────────────────────────────────────────────────────");
427    for b in 0..batch_size {
428        println!("Sample {}:", b);
429        println!(
430            "  5 min (Imminent): {} - {}",
431            imminent[b], FAILURE_TYPES[imminent[b]]
432        );
433        println!(
434            "  15 min (Warning): {} - {}",
435            warning[b], FAILURE_TYPES[warning[b]]
436        );
437        println!(
438            "  30 min (Early):   {} - {}",
439            early[b], FAILURE_TYPES[early[b]]
440        );
441    }
442    println!("────────────────────────────────────────────────────────────");
443    println!();
444    println!("Model ready for training with your HVAC sensor data!");
445}
More examples
Hide additional examples
examples/hvac_training.rs (line 779)
737fn train_epoch(
738    model: &HvacPredictor,
739    optimizer: &mut Adam,
740    loss_fn: &CrossEntropyLoss,
741    x_data: &[f32],
742    y_imminent: &[i64],
743    y_warning: &[i64],
744    y_early: &[i64],
745    batch_size: usize,
746    seq_len: usize,
747    num_features: usize,
748) -> (f32, f32, f32, f32) {
749    let n_sequences = y_imminent.len();
750    let n_batches = n_sequences / batch_size;
751
752    let mut total_loss = 0.0f32;
753    let mut total_acc_imm = 0.0f32;
754    let mut total_acc_warn = 0.0f32;
755    let mut total_acc_early = 0.0f32;
756
757    for batch_idx in 0..n_batches {
758        let start = batch_idx * batch_size;
759
760        // Prepare batch data
761        let mut batch_x = vec![0.0f32; batch_size * seq_len * num_features];
762        let mut batch_y_imm = vec![0i64; batch_size];
763        let mut batch_y_warn = vec![0i64; batch_size];
764        let mut batch_y_early = vec![0i64; batch_size];
765
766        for b in 0..batch_size {
767            let seq_start = (start + b) * seq_len * num_features;
768            for i in 0..(seq_len * num_features) {
769                batch_x[b * seq_len * num_features + i] = x_data[seq_start + i];
770            }
771            batch_y_imm[b] = y_imminent[start + b];
772            batch_y_warn[b] = y_warning[start + b];
773            batch_y_early[b] = y_early[start + b];
774        }
775
776        // Create tensors
777        let x_tensor = Tensor::from_vec(batch_x, &[batch_size, seq_len, num_features])
778            .expect("Failed to create input tensor");
779        let x_var = Variable::new(x_tensor, true);
780
781        // Forward pass
782        let (logits_imm, logits_warn, logits_early) = model.forward_multi(&x_var);
783
784        // Calculate losses (simplified - just use imminent for now)
785        let y_imm_tensor = Tensor::from_vec(
786            batch_y_imm.iter().map(|&y| y as f32).collect(),
787            &[batch_size],
788        )
789        .expect("Failed to create label tensor");
790        let y_imm_var = Variable::new(y_imm_tensor, false);
791
792        let loss = loss_fn.compute(&logits_imm, &y_imm_var);
793
794        // Backward pass
795        optimizer.zero_grad();
796        loss.backward();
797        optimizer.step();
798
799        // Metrics
800        total_loss += loss.data().to_vec()[0];
801        total_acc_imm += calculate_accuracy(&logits_imm, &batch_y_imm);
802        total_acc_warn += calculate_accuracy(&logits_warn, &batch_y_warn);
803        total_acc_early += calculate_accuracy(&logits_early, &batch_y_early);
804    }
805
806    let n = n_batches as f32;
807    (
808        total_loss / n,
809        total_acc_imm / n,
810        total_acc_warn / n,
811        total_acc_early / n,
812    )
813}
examples/simple_training.rs (line 62)
19fn main() {
20    println!("=== Axonml ML Framework - Simple Training Example ===\n");
21
22    // Print version and features
23    println!("Version: {}", axonml::version());
24    println!("Features: {}\n", axonml::features());
25
26    // 1. Create a simple dataset (XOR problem)
27    println!("1. Creating XOR dataset...");
28    let inputs = vec![
29        vec![0.0, 0.0],
30        vec![0.0, 1.0],
31        vec![1.0, 0.0],
32        vec![1.0, 1.0],
33    ];
34    let targets = vec![0.0, 1.0, 1.0, 0.0]; // XOR outputs
35
36    println!("   Inputs: {inputs:?}");
37    println!("   Targets: {targets:?}\n");
38
39    // 2. Create a simple MLP model
40    println!("2. Creating MLP model (2 -> 4 -> 1)...");
41    let linear1 = Linear::new(2, 4);
42    let linear2 = Linear::new(4, 1);
43
44    println!("   Layer 1: Linear(2, 4)");
45    println!("   Layer 2: Linear(4, 1)\n");
46
47    // 3. Create optimizer
48    println!("3. Creating Adam optimizer (lr=0.1)...");
49    let params = [linear1.parameters(), linear2.parameters()].concat();
50    let mut optimizer = Adam::new(params, 0.1);
51    println!("   Optimizer created!\n");
52
53    // 4. Training loop
54    println!("4. Training for 1000 epochs...");
55    let epochs = 1000;
56
57    for epoch in 0..epochs {
58        let mut total_loss = 0.0;
59
60        for (input, &target) in inputs.iter().zip(targets.iter()) {
61            // Create input tensor
62            let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), true);
63
64            // Forward pass
65            let h = linear1.forward(&x);
66            let h = h.sigmoid();
67            let output = linear2.forward(&h);
68            let output = output.sigmoid();
69
70            // Create target tensor
71            let y = Variable::new(Tensor::from_vec(vec![target], &[1, 1]).unwrap(), false);
72
73            // Compute MSE loss manually: (output - target)^2
74            let diff = output.sub_var(&y);
75            let loss = diff.mul_var(&diff);
76
77            total_loss += loss.data().to_vec()[0];
78
79            // Backward pass
80            loss.backward();
81
82            // Update weights
83            optimizer.step();
84            optimizer.zero_grad();
85        }
86
87        if epoch % 200 == 0 || epoch == epochs - 1 {
88            println!("   Epoch {}: Loss = {:.6}", epoch, total_loss / 4.0);
89        }
90    }
91
92    // 5. Test the trained model
93    println!("\n5. Testing trained model...");
94    for (input, &expected) in inputs.iter().zip(targets.iter()) {
95        let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), false);
96
97        let h = linear1.forward(&x);
98        let h = h.sigmoid();
99        let output = linear2.forward(&h);
100        let output = output.sigmoid();
101
102        let pred = output.data().to_vec()[0];
103        let rounded = if pred > 0.5 { 1.0 } else { 0.0 };
104
105        println!(
106            "   Input: {input:?} -> Predicted: {pred:.4} (rounded: {rounded}) | Expected: {expected}"
107        );
108    }
109
110    println!("\n=== Training Complete! ===");
111}
examples/mnist_training.rs (lines 95-102)
20fn main() {
21    println!("=== AxonML - MNIST Training (LeNet) ===\n");
22
23    // Detect device
24    #[cfg(feature = "cuda")]
25    let device = {
26        let cuda = Device::Cuda(0);
27        if cuda.is_available() {
28            println!("GPU detected: using CUDA device 0");
29            cuda
30        } else {
31            println!("CUDA feature enabled but no GPU available, using CPU");
32            Device::Cpu
33        }
34    };
35    #[cfg(not(feature = "cuda"))]
36    let device = {
37        println!("Using CPU (compile with --features cuda for GPU)");
38        Device::Cpu
39    };
40
41    // 1. Create dataset
42    let num_train = 2000;
43    let num_test = 400;
44    println!("\n1. Creating SyntheticMNIST dataset ({num_train} train, {num_test} test)...");
45    let train_dataset = SyntheticMNIST::new(num_train);
46    let test_dataset = SyntheticMNIST::new(num_test);
47
48    // 2. Create DataLoader
49    let batch_size = 64;
50    println!("2. Creating DataLoader (batch_size={batch_size})...");
51    let train_loader = DataLoader::new(train_dataset, batch_size);
52    let test_loader = DataLoader::new(test_dataset, batch_size);
53    println!("   Training batches: {}", train_loader.len());
54
55    // 3. Create LeNet model and move to device
56    println!("3. Creating LeNet model...");
57    let model = LeNet::new();
58    model.to_device(device);
59    let params = model.parameters();
60    let total_params: usize = params
61        .iter()
62        .map(|p| p.variable().data().to_vec().len())
63        .sum();
64    println!(
65        "   Parameters: {} ({} total weights)",
66        params.len(),
67        total_params
68    );
69    println!("   Device: {:?}", device);
70
71    // 4. Create optimizer and loss
72    println!("4. Creating Adam optimizer (lr=0.001) + CrossEntropyLoss...");
73    let mut optimizer = Adam::new(params, 0.001);
74    let criterion = CrossEntropyLoss::new();
75
76    // 5. Training loop
77    let epochs = 10;
78    println!("5. Training for {epochs} epochs...\n");
79
80    let train_start = Instant::now();
81
82    for epoch in 0..epochs {
83        let epoch_start = Instant::now();
84        let mut total_loss = 0.0;
85        let mut correct = 0usize;
86        let mut total = 0usize;
87        let mut batch_count = 0;
88
89        for batch in train_loader.iter() {
90            let bs = batch.data.shape()[0];
91
92            // Reshape to [N, 1, 28, 28] and create Variable
93            let input_data = batch.data.to_vec();
94            let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
95            let input = Variable::new(
96                if device.is_gpu() {
97                    input_tensor.to_device(device).unwrap()
98                } else {
99                    input_tensor
100                },
101                true,
102            );
103
104            // Target: convert one-hot [N, 10] to class indices [N]
105            let target_onehot = batch.targets.to_vec();
106            let mut target_indices = vec![0.0f32; bs];
107            for i in 0..bs {
108                let offset = i * 10;
109                let mut max_idx = 0;
110                let mut max_val = f32::NEG_INFINITY;
111                for c in 0..10 {
112                    if target_onehot[offset + c] > max_val {
113                        max_val = target_onehot[offset + c];
114                        max_idx = c;
115                    }
116                }
117                target_indices[i] = max_idx as f32;
118            }
119            let target_tensor = Tensor::from_vec(target_indices.clone(), &[bs]).unwrap();
120            let target = Variable::new(
121                if device.is_gpu() {
122                    target_tensor.to_device(device).unwrap()
123                } else {
124                    target_tensor
125                },
126                false,
127            );
128
129            // Forward pass
130            let output = model.forward(&input);
131
132            // Cross-entropy loss
133            let loss = criterion.compute(&output, &target);
134
135            let loss_val = loss.data().to_vec()[0];
136            total_loss += loss_val;
137            batch_count += 1;
138
139            // Compute training accuracy
140            let out_data = output.data().to_vec();
141            for i in 0..bs {
142                let offset = i * 10;
143                let mut pred = 0;
144                let mut pred_val = f32::NEG_INFINITY;
145                for c in 0..10 {
146                    if out_data[offset + c] > pred_val {
147                        pred_val = out_data[offset + c];
148                        pred = c;
149                    }
150                }
151                if pred == target_indices[i] as usize {
152                    correct += 1;
153                }
154                total += 1;
155            }
156
157            // Backward pass
158            loss.backward();
159
160            // Update weights
161            optimizer.step();
162            optimizer.zero_grad();
163        }
164
165        let epoch_time = epoch_start.elapsed();
166        let avg_loss = total_loss / batch_count as f32;
167        let accuracy = 100.0 * correct as f32 / total as f32;
168        let samples_per_sec = total as f64 / epoch_time.as_secs_f64();
169
170        println!(
171            "   Epoch {:2}/{}: Loss={:.4}  Acc={:.1}%  ({:.0} samples/s, {:.2}s)",
172            epoch + 1,
173            epochs,
174            avg_loss,
175            accuracy,
176            samples_per_sec,
177            epoch_time.as_secs_f64(),
178        );
179    }
180
181    let train_time = train_start.elapsed();
182    println!("\n   Total training time: {:.2}s", train_time.as_secs_f64());
183
184    // 6. Test evaluation
185    println!("\n6. Evaluating on test set...");
186
187    // Disable gradient computation for evaluation
188    let (correct, total) = no_grad(|| {
189        let mut correct = 0usize;
190        let mut total = 0usize;
191
192        for batch in test_loader.iter() {
193            let bs = batch.data.shape()[0];
194
195            let input_data = batch.data.to_vec();
196            let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
197            let input = Variable::new(
198                if device.is_gpu() {
199                    input_tensor.to_device(device).unwrap()
200                } else {
201                    input_tensor
202                },
203                false,
204            );
205
206            let target_onehot = batch.targets.to_vec();
207            let output = model.forward(&input);
208            let out_data = output.data().to_vec();
209
210            for i in 0..bs {
211                // Prediction: argmax of output
212                let offset = i * 10;
213                let mut pred = 0;
214                let mut pred_val = f32::NEG_INFINITY;
215                for c in 0..10 {
216                    if out_data[offset + c] > pred_val {
217                        pred_val = out_data[offset + c];
218                        pred = c;
219                    }
220                }
221
222                // True label: argmax of one-hot target
223                let mut true_label = 0;
224                let mut true_val = f32::NEG_INFINITY;
225                for c in 0..10 {
226                    if target_onehot[i * 10 + c] > true_val {
227                        true_val = target_onehot[i * 10 + c];
228                        true_label = c;
229                    }
230                }
231
232                if pred == true_label {
233                    correct += 1;
234                }
235                total += 1;
236            }
237        }
238
239        (correct, total)
240    });
241
242    let test_accuracy = 100.0 * correct as f32 / total as f32;
243    println!(
244        "   Test Accuracy: {}/{} ({:.2}%)",
245        correct, total, test_accuracy
246    );
247
248    println!("\n=== Training Complete! ===");
249    println!("   Device: {:?}", device);
250    println!("   Final test accuracy: {:.2}%", test_accuracy);
251}
examples/train_panoptes.rs (lines 98-101)
56fn main() {
57    println!("╔══════════════════════════════════════════════════════════════╗");
58    println!("║     PANOPTES — Facility-Wide Anomaly Detection Training     ║");
59    println!("║     Heritage Pointe of Warren (59 equipment)                ║");
60    println!("╚══════════════════════════════════════════════════════════════╝");
61    println!();
62
63    // =========================================================================
64    // Generate training data
65    // =========================================================================
66    println!("[data] Generating physics-informed training data...");
67    let t0 = Instant::now();
68
69    let sim = WarrenSimulator::new(SEED);
70    let normal_train = sim.generate_normal(NORMAL_SAMPLES);
71    let fault_data = sim.generate_with_faults(FAULT_SAMPLES, 1.0);
72
73    // Validation set (different seed)
74    let val_sim = WarrenSimulator::new(SEED + 999);
75    let normal_val = val_sim.generate_normal(200);
76    let fault_val = val_sim.generate_with_faults(100, 1.0);
77
78    println!("  Normal train: {} samples", normal_train.len());
79    println!("  Fault train:  {} samples", fault_data.len());
80    println!("  Normal val:   {} samples", normal_val.len());
81    println!("  Fault val:    {} samples", fault_val.len());
82    println!("  Generated in {:.1}s", t0.elapsed().as_secs_f32());
83    println!();
84
85    // =========================================================================
86    // Create model
87    // =========================================================================
88    let model = Panoptes::new(NUM_EQUIPMENT);
89    println!("[model] Panoptes created");
90    println!("  Equipment slots: {NUM_EQUIPMENT}");
91    println!("  Parameters: {}", model.num_parameters());
92    println!("  Embed dim: {EMBED_DIM}");
93    println!();
94
95    let mse = MSELoss::new();
96
97    // Zero target for normal operation
98    let zero_target = Variable::new(
99        Tensor::from_vec(vec![0.0; NUM_EQUIPMENT], &[1, NUM_EQUIPMENT]).unwrap(),
100        false,
101    );
102
103    // =========================================================================
104    // Phase 1: Learn normal operation
105    // =========================================================================
106    println!("═══════════════════════════════════════════════════════════════");
107    println!(" PHASE 1: Learning Normal Operation ({PHASE1_EPOCHS} epochs)");
108    println!("═══════════════════════════════════════════════════════════════");
109    println!(
110        "  {:>5}  {:>12}  {:>12}  {:>8}",
111        "Epoch", "Train Loss", "Val Loss", "Time"
112    );
113    println!("  {:-<5}  {:-<12}  {:-<12}  {:-<8}", "", "", "", "");
114
115    let params = model.parameters();
116    let mut optimizer = Adam::new(params, LR);
117
118    for epoch in 1..=PHASE1_EPOCHS {
119        let epoch_start = Instant::now();
120        let mut epoch_loss = 0.0f32;
121        let mut batch_count = 0;
122
123        // Train on normal data: target = all zeros
124        for batch_start in (0..normal_train.len()).step_by(BATCH_SIZE) {
125            let batch_end = (batch_start + BATCH_SIZE).min(normal_train.len());
126
127            for i in batch_start..batch_end {
128                optimizer.zero_grad();
129
130                let (equip_scores, _) = model.forward_snapshot(&normal_train[i]);
131                let loss = mse.compute(&equip_scores, &zero_target);
132                let loss_val = loss.data().to_vec()[0];
133                epoch_loss += loss_val;
134                batch_count += 1;
135
136                if loss.requires_grad() {
137                    loss.backward();
138                    optimizer.step();
139                }
140            }
141        }
142
143        // Validation
144        let val_loss = evaluate_normal(&model, &normal_val, &mse, &zero_target);
145
146        let avg_loss = epoch_loss / batch_count as f32;
147        let elapsed = epoch_start.elapsed().as_secs_f32();
148
149        println!(
150            "  {:>5}  {:>12.6}  {:>12.6}  {:>6.1}s",
151            epoch, avg_loss, val_loss, elapsed
152        );
153    }
154
155    println!();
156
157    // =========================================================================
158    // Phase 2: Learn fault signatures
159    // =========================================================================
160    println!("═══════════════════════════════════════════════════════════════");
161    println!(" PHASE 2: Learning Fault Signatures ({PHASE2_EPOCHS} epochs)");
162    println!("═══════════════════════════════════════════════════════════════");
163    println!(
164        "  {:>5}  {:>12}  {:>12}  {:>12}  {:>8}",
165        "Epoch", "Normal Loss", "Fault Loss", "Val Loss", "Time"
166    );
167    println!(
168        "  {:-<5}  {:-<12}  {:-<12}  {:-<12}  {:-<8}",
169        "", "", "", "", ""
170    );
171
172    // Reset optimizer with lower LR for phase 2
173    let params = model.parameters();
174    let mut optimizer = Adam::new(params, LR * 0.5);
175
176    for epoch in 1..=PHASE2_EPOCHS {
177        let epoch_start = Instant::now();
178        let mut normal_loss_sum = 0.0f32;
179        let mut fault_loss_sum = 0.0f32;
180        let mut normal_count = 0;
181        let mut fault_count = 0;
182
183        // Interleave normal + fault samples
184        let normal_per_epoch = NORMAL_SAMPLES / 2; // Use half of normal data
185        let fault_per_epoch = fault_data.len();
186
187        // Normal samples: target = zeros
188        for i in 0..normal_per_epoch.min(normal_train.len()) {
189            optimizer.zero_grad();
190            let (equip_scores, _) = model.forward_snapshot(&normal_train[i]);
191            let loss = mse.compute(&equip_scores, &zero_target);
192            normal_loss_sum += loss.data().to_vec()[0];
193            normal_count += 1;
194
195            if loss.requires_grad() {
196                loss.backward();
197                optimizer.step();
198            }
199        }
200
201        // Fault samples: target = 1.0 for affected equipment
202        for i in 0..fault_per_epoch {
203            let (ref snap, ref _fault, ref affected) = fault_data[i];
204
205            let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
206            let fault_target = Variable::new(
207                Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
208                false,
209            );
210
211            optimizer.zero_grad();
212            let (equip_scores, _) = model.forward_snapshot(snap);
213            let loss = mse.compute(&equip_scores, &fault_target);
214            fault_loss_sum += loss.data().to_vec()[0];
215            fault_count += 1;
216
217            if loss.requires_grad() {
218                loss.backward();
219                optimizer.step();
220            }
221        }
222
223        // Validation
224        let val_loss = evaluate_mixed(&model, &normal_val, &fault_val, &mse, &zero_target);
225
226        let avg_normal = normal_loss_sum / normal_count.max(1) as f32;
227        let avg_fault = fault_loss_sum / fault_count.max(1) as f32;
228        let elapsed = epoch_start.elapsed().as_secs_f32();
229
230        println!(
231            "  {:>5}  {:>12.6}  {:>12.6}  {:>12.6}  {:>6.1}s",
232            epoch, avg_normal, avg_fault, val_loss, elapsed
233        );
234    }
235
236    println!();
237
238    // =========================================================================
239    // Phase 3: Temporal training
240    // =========================================================================
241    println!("═══════════════════════════════════════════════════════════════");
242    println!(" PHASE 3: Temporal Training ({PHASE3_EPOCHS} epochs, window={TEMPORAL_WINDOW})");
243    println!("═══════════════════════════════════════════════════════════════");
244
245    // Generate temporal sequences
246    println!("[data] Generating temporal sequences...");
247    let t0 = Instant::now();
248
249    // Normal temporal sequences: varied starting OAT, slow drift
250    let mut normal_seqs: Vec<Vec<FacilitySnapshot>> = Vec::new();
251    for i in 0..TEMPORAL_NORMAL_SEQS {
252        let start_oat = -5.0 + (i as f32 / TEMPORAL_NORMAL_SEQS as f32) * 100.0;
253        let drift = if start_oat < 50.0 { 0.2 } else { -0.1 }; // warming up or cooling down
254        let seq_sim = WarrenSimulator::new(SEED + 5000 + i as u64);
255        let seq = seq_sim.generate_temporal_sequence(TEMPORAL_WINDOW, start_oat, drift);
256        normal_seqs.push(seq);
257    }
258
259    // Fault temporal sequences: fault injected mid-sequence
260    let mut fault_seqs: Vec<(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)> = Vec::new();
261    for i in 0..TEMPORAL_FAULT_SEQS {
262        let start_oat = -5.0 + (i as f32 / TEMPORAL_FAULT_SEQS as f32) * 100.0;
263        let drift = 0.1;
264        let seq_sim = WarrenSimulator::new(SEED + 8000 + i as u64);
265        let seq_data =
266            seq_sim.generate_temporal_with_fault(TEMPORAL_WINDOW, start_oat, drift, i as u64);
267        fault_seqs.push(seq_data);
268    }
269
270    // Validation temporal sequences
271    let mut val_normal_seqs: Vec<Vec<FacilitySnapshot>> = Vec::new();
272    for i in 0..20 {
273        let start_oat = 10.0 + (i as f32 / 20.0) * 80.0;
274        let seq_sim = WarrenSimulator::new(SEED + 9000 + i as u64);
275        let seq = seq_sim.generate_temporal_sequence(TEMPORAL_WINDOW, start_oat, 0.15);
276        val_normal_seqs.push(seq);
277    }
278
279    let mut val_fault_seqs: Vec<(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)> = Vec::new();
280    for i in 0..20 {
281        let start_oat = 10.0 + (i as f32 / 20.0) * 80.0;
282        let seq_sim = WarrenSimulator::new(SEED + 9500 + i as u64);
283        let seq_data =
284            seq_sim.generate_temporal_with_fault(TEMPORAL_WINDOW, start_oat, 0.1, i as u64);
285        val_fault_seqs.push(seq_data);
286    }
287
288    println!("  Normal temporal seqs: {}", normal_seqs.len());
289    println!("  Fault temporal seqs:  {}", fault_seqs.len());
290    println!("  Val normal seqs:      {}", val_normal_seqs.len());
291    println!("  Val fault seqs:       {}", val_fault_seqs.len());
292    println!("  Window size: {TEMPORAL_WINDOW} snapshots (1 hour)");
293    println!("  Generated in {:.1}s", t0.elapsed().as_secs_f32());
294    println!();
295
296    println!(
297        "  {:>5}  {:>12}  {:>12}  {:>12}  {:>8}",
298        "Epoch", "Normal Loss", "Fault Loss", "Val Loss", "Time"
299    );
300    println!(
301        "  {:-<5}  {:-<12}  {:-<12}  {:-<12}  {:-<8}",
302        "", "", "", "", ""
303    );
304
305    // Lower LR for temporal fine-tuning
306    let params = model.parameters();
307    let mut optimizer = Adam::new(params, LR * 0.3);
308
309    for epoch in 1..=PHASE3_EPOCHS {
310        let epoch_start = Instant::now();
311        let mut normal_loss_sum = 0.0f32;
312        let mut fault_loss_sum = 0.0f32;
313        let mut normal_count = 0;
314        let mut fault_count = 0;
315
316        // Normal temporal sequences: target = all zeros
317        for seq in &normal_seqs {
318            optimizer.zero_grad();
319            let (equip_scores, _) = model.forward_temporal(seq);
320            let loss = mse.compute(&equip_scores, &zero_target);
321            normal_loss_sum += loss.data().to_vec()[0];
322            normal_count += 1;
323
324            if loss.requires_grad() {
325                loss.backward();
326                optimizer.step();
327            }
328        }
329
330        // Fault temporal sequences: target = 1.0 for affected equipment
331        for (seq, _onset, _fault, affected) in &fault_seqs {
332            let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
333            let fault_target = Variable::new(
334                Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
335                false,
336            );
337
338            optimizer.zero_grad();
339            let (equip_scores, _) = model.forward_temporal(seq);
340            let loss = mse.compute(&equip_scores, &fault_target);
341            fault_loss_sum += loss.data().to_vec()[0];
342            fault_count += 1;
343
344            if loss.requires_grad() {
345                loss.backward();
346                optimizer.step();
347            }
348        }
349
350        // Validation
351        let val_loss = evaluate_temporal_mixed(
352            &model,
353            &val_normal_seqs,
354            &val_fault_seqs,
355            &mse,
356            &zero_target,
357        );
358
359        let avg_normal = normal_loss_sum / normal_count.max(1) as f32;
360        let avg_fault = fault_loss_sum / fault_count.max(1) as f32;
361        let elapsed = epoch_start.elapsed().as_secs_f32();
362
363        println!(
364            "  {:>5}  {:>12.6}  {:>12.6}  {:>12.6}  {:>6.1}s",
365            epoch, avg_normal, avg_fault, val_loss, elapsed
366        );
367    }
368
369    println!();
370
371    // =========================================================================
372    // Final evaluation
373    // =========================================================================
374    println!("═══════════════════════════════════════════════════════════════");
375    println!(" FINAL EVALUATION");
376    println!("═══════════════════════════════════════════════════════════════");
377
378    // Test on normal data — scores should be near zero
379    let config = FacilityConfig::warren();
380    println!("\n  Normal operation (should be low scores):");
381    for i in [0, 50, 100, 150] {
382        if i >= normal_val.len() {
383            break;
384        }
385        let (equip_scores, fac_score) = model.forward_snapshot(&normal_val[i]);
386        let scores = equip_scores.data().to_vec();
387        let fac = fac_score.data().to_vec()[0];
388        let max_score = scores.iter().cloned().fold(0.0f32, f32::max);
389        let avg_score: f32 = scores.iter().sum::<f32>() / scores.len() as f32;
390        println!(
391            "    Sample {i:>3}: facility={fac:.4}, avg_equip={avg_score:.4}, max_equip={max_score:.4}"
392        );
393    }
394
395    // Test on fault data — affected equipment should have higher scores
396    println!("\n  Fault samples (affected equipment should score higher):");
397    for i in 0..5.min(fault_val.len()) {
398        let (ref snap, ref fault, ref affected) = fault_val[i];
399        let (equip_scores, fac_score) = model.forward_snapshot(snap);
400        let scores = equip_scores.data().to_vec();
401        let fac = fac_score.data().to_vec()[0];
402
403        let output = PanoptesOutput::from_scores(&scores, fac, &config, 0.3);
404
405        // Get scores for affected vs unaffected
406        let affected_avg: f32 = if !affected.is_empty() {
407            affected
408                .iter()
409                .filter(|&&s| s < scores.len())
410                .map(|&s| scores[s])
411                .sum::<f32>()
412                / affected.len() as f32
413        } else {
414            0.0
415        };
416
417        println!("    Fault {:?}:", fault);
418        println!(
419            "      facility={fac:.4}, affected_avg={affected_avg:.4}, alerts={}",
420            output.alerts.len()
421        );
422    }
423
424    // Temporal evaluation
425    println!("\n  Temporal normal (should be low scores):");
426    for i in 0..3.min(val_normal_seqs.len()) {
427        let (equip_scores, fac_score) = model.forward_temporal(&val_normal_seqs[i]);
428        let scores = equip_scores.data().to_vec();
429        let fac = fac_score.data().to_vec()[0];
430        let max_score = scores.iter().cloned().fold(0.0f32, f32::max);
431        let avg_score: f32 = scores.iter().sum::<f32>() / scores.len() as f32;
432        println!(
433            "    Seq {i:>3}: facility={fac:.4}, avg_equip={avg_score:.4}, max_equip={max_score:.4}"
434        );
435    }
436
437    println!("\n  Temporal fault (fault injected mid-sequence):");
438    for i in 0..5.min(val_fault_seqs.len()) {
439        let (ref seq, onset, ref fault, ref affected) = val_fault_seqs[i];
440        let (equip_scores, fac_score) = model.forward_temporal(seq);
441        let scores = equip_scores.data().to_vec();
442        let fac = fac_score.data().to_vec()[0];
443
444        let affected_avg: f32 = if !affected.is_empty() {
445            affected
446                .iter()
447                .filter(|&&s| s < scores.len())
448                .map(|&s| scores[s])
449                .sum::<f32>()
450                / affected.len() as f32
451        } else {
452            0.0
453        };
454        let unaffected_avg: f32 = {
455            let unaffected: Vec<f32> = scores
456                .iter()
457                .enumerate()
458                .filter(|(idx, _)| !affected.contains(idx))
459                .map(|(_, &s)| s)
460                .collect();
461            if unaffected.is_empty() {
462                0.0
463            } else {
464                unaffected.iter().sum::<f32>() / unaffected.len() as f32
465            }
466        };
467
468        let output = PanoptesOutput::from_scores(&scores, fac, &config, 0.3);
469        println!(
470            "    Fault {:?} (onset step {onset}/{TEMPORAL_WINDOW}):",
471            fault
472        );
473        println!(
474            "      facility={fac:.4}, affected={affected_avg:.4}, unaffected={unaffected_avg:.4}, alerts={}",
475            output.alerts.len()
476        );
477    }
478
479    println!();
480    println!("Training complete.");
481}
482
483// =============================================================================
484// Evaluation helpers
485// =============================================================================
486
487fn evaluate_normal(
488    model: &Panoptes,
489    val_data: &[FacilitySnapshot],
490    mse: &MSELoss,
491    zero_target: &Variable,
492) -> f32 {
493    let mut total_loss = 0.0f32;
494    for snap in val_data {
495        let (equip_scores, _) = model.forward_snapshot(snap);
496        let loss = mse.compute(&equip_scores, zero_target);
497        total_loss += loss.data().to_vec()[0];
498    }
499    total_loss / val_data.len() as f32
500}
501
502fn evaluate_mixed(
503    model: &Panoptes,
504    normal_val: &[FacilitySnapshot],
505    fault_val: &[(FacilitySnapshot, FaultType, Vec<usize>)],
506    mse: &MSELoss,
507    zero_target: &Variable,
508) -> f32 {
509    let mut total_loss = 0.0f32;
510    let mut count = 0;
511
512    for snap in normal_val {
513        let (equip_scores, _) = model.forward_snapshot(snap);
514        let loss = mse.compute(&equip_scores, zero_target);
515        total_loss += loss.data().to_vec()[0];
516        count += 1;
517    }
518
519    for (snap, _, affected) in fault_val {
520        let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
521        let fault_target = Variable::new(
522            Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
523            false,
524        );
525        let (equip_scores, _) = model.forward_snapshot(snap);
526        let loss = mse.compute(&equip_scores, &fault_target);
527        total_loss += loss.data().to_vec()[0];
528        count += 1;
529    }
530
531    total_loss / count as f32
532}
533
534fn evaluate_temporal_mixed(
535    model: &Panoptes,
536    normal_seqs: &[Vec<FacilitySnapshot>],
537    fault_seqs: &[(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)],
538    mse: &MSELoss,
539    zero_target: &Variable,
540) -> f32 {
541    let mut total_loss = 0.0f32;
542    let mut count = 0;
543
544    for seq in normal_seqs {
545        let (equip_scores, _) = model.forward_temporal(seq);
546        let loss = mse.compute(&equip_scores, zero_target);
547        total_loss += loss.data().to_vec()[0];
548        count += 1;
549    }
550
551    for (seq, _, _, affected) in fault_seqs {
552        let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
553        let fault_target = Variable::new(
554            Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
555            false,
556        );
557        let (equip_scores, _) = model.forward_temporal(seq);
558        let loss = mse.compute(&equip_scores, &fault_target);
559        total_loss += loss.data().to_vec()[0];
560        count += 1;
561    }
562
563    total_loss / count as f32
564}
Source

pub fn from_tensor(data: Tensor<f32>) -> Variable

Creates a variable that doesn’t require gradients.

Source

pub fn from_operation( data: Tensor<f32>, grad_fn: GradFn, requires_grad: bool, ) -> Variable

Creates a new variable from an operation result with an attached gradient function.

This connects the variable to the computational graph, allowing gradients to flow backward through the operation that produced this variable.

Source

pub fn data(&self) -> Tensor<f32>

Returns a reference to the underlying tensor data.

Examples found in repository?
examples/hvac_training.rs (line 617)
616    pub fn forward_multi(&self, x: &Variable) -> (Variable, Variable, Variable) {
617        let x_data = x.data();
618        let shape = x_data.shape();
619        let batch_size = shape[0];
620        let seq_len = shape[1];
621        drop(x_data);
622
623        let x_flat = x.reshape(&[batch_size * seq_len, self.config.num_features]);
624        let proj = self.input_proj.forward(&x_flat);
625        let proj = self.input_norm.forward(&proj);
626        let proj = self.input_relu.forward(&proj);
627        let proj = proj.reshape(&[batch_size, seq_len, self.config.hidden_size]);
628
629        // Use forward_mean for proper gradient flow (equivalent to forward + mean_pool)
630        let pooled = self.gru.forward_mean(&proj);
631
632        let imminent = self.head_imminent.forward(&pooled);
633        let warning = self.head_warning.forward(&pooled);
634        let early = self.head_early.forward(&pooled);
635
636        (imminent, warning, early)
637    }
638
639    pub fn num_parameters(&self) -> usize {
640        self.parameters()
641            .iter()
642            .map(|p| p.variable().data().numel())
643            .sum()
644    }
645}
646
647impl Module for HvacPredictor {
648    fn forward(&self, x: &Variable) -> Variable {
649        let (imminent, _, _) = self.forward_multi(x);
650        imminent
651    }
652
653    fn parameters(&self) -> Vec<Parameter> {
654        let mut params = self.input_proj.parameters();
655        params.extend(self.input_norm.parameters());
656        params.extend(self.gru.parameters());
657        params.extend(self.head_imminent.parameters());
658        params.extend(self.head_warning.parameters());
659        params.extend(self.head_early.parameters());
660        params
661    }
662}
663
664// =============================================================================
665// Training Loop
666// =============================================================================
667
668/// Normalize data to [0, 1] range based on sensor ranges
669fn normalize_data(data: &mut [f32], n_samples: usize) {
670    let sensor_ranges: [(f32, f32); 28] = [
671        (0.0, 50.0),    // 0: hw_pump_5_current
672        (0.0, 50.0),    // 1: hw_pump_6_current
673        (0.0, 50.0),    // 2: cw_pump_3_current
674        (0.0, 50.0),    // 3: cw_pump_4_current
675        (0.0, 50.0),    // 4: 2pipe_pump_a_current
676        (0.0, 50.0),    // 5: 2pipe_pump_b_current
677        (80.0, 200.0),  // 6: hw_supply_4pipe_temp
678        (40.0, 80.0),   // 7: cw_supply_4pipe_temp
679        (115.0, 155.0), // 8: hw_supply_2pipe_temp
680        (70.0, 120.0),  // 9: cw_return_2pipe_temp
681        (-20.0, 120.0), // 10: outdoor_air_temp
682        (50.0, 90.0),   // 11: mech_room_temp
683        (65.0, 85.0),   // 12: space_sensor_1_temp
684        (65.0, 85.0),   // 13: space_sensor_2_temp
685        (0.0, 200.0),   // 14: hw_pressure_4pipe
686        (0.0, 200.0),   // 15: cw_pressure_4pipe
687        (0.0, 100.0),   // 16-21: VFD speeds
688        (0.0, 100.0),
689        (0.0, 100.0),
690        (0.0, 100.0),
691        (0.0, 100.0),
692        (0.0, 100.0),
693        (0.0, 100.0), // 22: steam_valve_1_3_pos
694        (0.0, 100.0), // 23: steam_valve_2_3_pos
695        (0.0, 1.0),   // 24: summer_winter_mode
696        (0.0, 1.0),   // 25: hw_lead_pump_id
697        (0.0, 1.0),   // 26: cw_lead_pump_id
698        (0.0, 1.0),   // 27: 2pipe_lead_pump_id
699    ];
700
701    for i in 0..n_samples {
702        for f in 0..28 {
703            let (min_val, max_val) = sensor_ranges[f];
704            let idx = i * 28 + f;
705            data[idx] = ((data[idx] - min_val) / (max_val - min_val)).clamp(0.0, 1.0);
706        }
707    }
708}
709
710/// Calculate accuracy
711fn calculate_accuracy(logits: &Variable, labels: &[i64]) -> f32 {
712    let data = logits.data();
713    let shape = data.shape();
714    let batch_size = shape[0];
715    let num_classes = shape[1];
716    let values = data.to_vec();
717
718    let mut correct = 0;
719    for b in 0..batch_size {
720        let start = b * num_classes;
721        let mut max_idx = 0;
722        let mut max_val = values[start];
723        for c in 1..num_classes {
724            if values[start + c] > max_val {
725                max_val = values[start + c];
726                max_idx = c;
727            }
728        }
729        if max_idx == labels[b] as usize {
730            correct += 1;
731        }
732    }
733    correct as f32 / batch_size as f32
734}
735
736/// Training function
737fn train_epoch(
738    model: &HvacPredictor,
739    optimizer: &mut Adam,
740    loss_fn: &CrossEntropyLoss,
741    x_data: &[f32],
742    y_imminent: &[i64],
743    y_warning: &[i64],
744    y_early: &[i64],
745    batch_size: usize,
746    seq_len: usize,
747    num_features: usize,
748) -> (f32, f32, f32, f32) {
749    let n_sequences = y_imminent.len();
750    let n_batches = n_sequences / batch_size;
751
752    let mut total_loss = 0.0f32;
753    let mut total_acc_imm = 0.0f32;
754    let mut total_acc_warn = 0.0f32;
755    let mut total_acc_early = 0.0f32;
756
757    for batch_idx in 0..n_batches {
758        let start = batch_idx * batch_size;
759
760        // Prepare batch data
761        let mut batch_x = vec![0.0f32; batch_size * seq_len * num_features];
762        let mut batch_y_imm = vec![0i64; batch_size];
763        let mut batch_y_warn = vec![0i64; batch_size];
764        let mut batch_y_early = vec![0i64; batch_size];
765
766        for b in 0..batch_size {
767            let seq_start = (start + b) * seq_len * num_features;
768            for i in 0..(seq_len * num_features) {
769                batch_x[b * seq_len * num_features + i] = x_data[seq_start + i];
770            }
771            batch_y_imm[b] = y_imminent[start + b];
772            batch_y_warn[b] = y_warning[start + b];
773            batch_y_early[b] = y_early[start + b];
774        }
775
776        // Create tensors
777        let x_tensor = Tensor::from_vec(batch_x, &[batch_size, seq_len, num_features])
778            .expect("Failed to create input tensor");
779        let x_var = Variable::new(x_tensor, true);
780
781        // Forward pass
782        let (logits_imm, logits_warn, logits_early) = model.forward_multi(&x_var);
783
784        // Calculate losses (simplified - just use imminent for now)
785        let y_imm_tensor = Tensor::from_vec(
786            batch_y_imm.iter().map(|&y| y as f32).collect(),
787            &[batch_size],
788        )
789        .expect("Failed to create label tensor");
790        let y_imm_var = Variable::new(y_imm_tensor, false);
791
792        let loss = loss_fn.compute(&logits_imm, &y_imm_var);
793
794        // Backward pass
795        optimizer.zero_grad();
796        loss.backward();
797        optimizer.step();
798
799        // Metrics
800        total_loss += loss.data().to_vec()[0];
801        total_acc_imm += calculate_accuracy(&logits_imm, &batch_y_imm);
802        total_acc_warn += calculate_accuracy(&logits_warn, &batch_y_warn);
803        total_acc_early += calculate_accuracy(&logits_early, &batch_y_early);
804    }
805
806    let n = n_batches as f32;
807    (
808        total_loss / n,
809        total_acc_imm / n,
810        total_acc_warn / n,
811        total_acc_early / n,
812    )
813}
More examples
Hide additional examples
examples/hvac_model.rs (line 166)
165    fn mean_pool(&self, x: &Variable) -> Variable {
166        let data = x.data();
167        let shape = data.shape();
168        let batch_size = shape[0];
169        let seq_len = shape[1];
170        let hidden = shape[2];
171
172        // Reshape to [batch * seq, hidden] then back
173        let values = data.to_vec();
174
175        // Calculate mean over sequence dimension
176        let mut pooled = vec![0.0f32; batch_size * hidden];
177        for b in 0..batch_size {
178            for h in 0..hidden {
179                let mut sum = 0.0;
180                for s in 0..seq_len {
181                    let idx = b * seq_len * hidden + s * hidden + h;
182                    sum += values[idx];
183                }
184                pooled[b * hidden + h] = sum / seq_len as f32;
185            }
186        }
187
188        let pooled_tensor = Tensor::from_vec(pooled, &[batch_size, hidden])
189            .expect("Failed to create pooled tensor");
190        Variable::new(pooled_tensor, x.requires_grad())
191    }
192
193    /// Forward pass returning logits for all 3 horizons
194    pub fn forward_multi(&self, x: &Variable) -> HvacOutput {
195        let x_data = x.data();
196        let shape = x_data.shape();
197        let batch_size = shape[0];
198        let seq_len = shape[1];
199        drop(x_data); // Release borrow
200
201        // Input projection: [batch, seq, features] -> [batch, seq, hidden]
202        // Reshape for linear: [batch * seq, features]
203        let x_flat = x.reshape(&[batch_size * seq_len, self.config.num_features]);
204        let proj = self.input_proj.forward(&x_flat);
205        let proj = self.input_norm.forward(&proj);
206        let proj = self.input_relu.forward(&proj);
207        let proj = proj.reshape(&[batch_size, seq_len, self.config.hidden_size]);
208
209        // GRU encoding: [batch, seq, hidden] -> [batch, seq, hidden]
210        let encoded = self.gru.forward(&proj);
211
212        // Mean pooling: [batch, seq, hidden] -> [batch, hidden]
213        let pooled = self.mean_pool(&encoded);
214
215        // Prediction heads
216        let imminent_logits = self.head_imminent.forward(&pooled);
217        let warning_logits = self.head_warning.forward(&pooled);
218        let early_logits = self.head_early.forward(&pooled);
219
220        HvacOutput {
221            imminent_logits,
222            warning_logits,
223            early_logits,
224        }
225    }
226
227    /// Get predicted classes (argmax of logits)
228    pub fn predict(&self, x: &Variable) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
229        let output = self.forward_multi(x);
230
231        let imminent_probs = self.softmax.forward(&output.imminent_logits);
232        let warning_probs = self.softmax.forward(&output.warning_logits);
233        let early_probs = self.softmax.forward(&output.early_logits);
234
235        (
236            argmax_batch(&imminent_probs),
237            argmax_batch(&warning_probs),
238            argmax_batch(&early_probs),
239        )
240    }
241
242    /// Returns the model configuration
243    pub fn config(&self) -> &HvacConfig {
244        &self.config
245    }
246
247    /// Returns the number of trainable parameters
248    pub fn num_parameters(&self) -> usize {
249        self.parameters()
250            .iter()
251            .map(|p| p.variable().data().numel())
252            .sum()
253    }
254}
255
256impl Module for HvacPredictor {
257    fn forward(&self, x: &Variable) -> Variable {
258        // Return concatenated logits for all horizons
259        let output = self.forward_multi(x);
260        // For single output, return imminent predictions
261        output.imminent_logits
262    }
263
264    fn parameters(&self) -> Vec<Parameter> {
265        let mut params = self.input_proj.parameters();
266        params.extend(self.input_norm.parameters());
267        params.extend(self.gru.parameters());
268        params.extend(self.head_imminent.parameters());
269        params.extend(self.head_warning.parameters());
270        params.extend(self.head_early.parameters());
271        params
272    }
273}
274
275// =============================================================================
276// Helper Functions
277// =============================================================================
278
279/// Get argmax for each sample in batch
280fn argmax_batch(x: &Variable) -> Vec<usize> {
281    let data = x.data();
282    let shape = data.shape();
283    let batch_size = shape[0];
284    let num_classes = shape[1];
285    let values = data.to_vec();
286
287    let mut results = Vec::with_capacity(batch_size);
288    for b in 0..batch_size {
289        let start = b * num_classes;
290        let end = start + num_classes;
291        let slice = &values[start..end];
292
293        let mut max_idx = 0;
294        let mut max_val = slice[0];
295        for (i, &v) in slice.iter().enumerate() {
296            if v > max_val {
297                max_val = v;
298                max_idx = i;
299            }
300        }
301        results.push(max_idx);
302    }
303    results
304}
305
306/// Failure type names
307pub const FAILURE_TYPES: [&str; 20] = [
308    "normal",
309    "pump_failure_hw_5",
310    "pump_failure_hw_6",
311    "pump_failure_cw_3",
312    "pump_failure_cw_4",
313    "pump_failure_2pipe_a",
314    "pump_failure_2pipe_b",
315    "pressure_low_hw",
316    "pressure_high_hw",
317    "pressure_low_cw",
318    "pressure_high_cw",
319    "temp_anomaly_hw_supply",
320    "temp_anomaly_cw_supply",
321    "temp_anomaly_space",
322    "valve_stuck_1_3",
323    "valve_stuck_2_3",
324    "vfd_fault",
325    "sensor_drift",
326    "chiller_fault",
327    "interlock_violation",
328];
329
330/// Feature names for the 28 sensor inputs
331pub const FEATURE_NAMES: [&str; 28] = [
332    "hw_pump_5_current",
333    "hw_pump_6_current",
334    "cw_pump_3_current",
335    "cw_pump_4_current",
336    "2pipe_pump_a_current",
337    "2pipe_pump_b_current",
338    "hw_supply_4pipe_temp",
339    "cw_supply_4pipe_temp",
340    "hw_supply_2pipe_temp",
341    "cw_return_2pipe_temp",
342    "outdoor_air_temp",
343    "mech_room_temp",
344    "space_sensor_1_temp",
345    "space_sensor_2_temp",
346    "hw_pressure_4pipe",
347    "cw_pressure_4pipe",
348    "hw_pump_5_vfd_speed",
349    "hw_pump_6_vfd_speed",
350    "cw_pump_3_vfd_speed",
351    "cw_pump_4_vfd_speed",
352    "2pipe_pump_a_vfd_speed",
353    "2pipe_pump_b_vfd_speed",
354    "steam_valve_1_3_pos",
355    "steam_valve_2_3_pos",
356    "summer_winter_mode",
357    "hw_lead_pump_id",
358    "cw_lead_pump_id",
359    "2pipe_lead_pump_id",
360];
361
362// =============================================================================
363// Main
364// =============================================================================
365
366fn main() {
367    println!("╔════════════════════════════════════════════════════════════╗");
368    println!("║     HVAC Multi-Horizon Predictor - AxonML Native           ║");
369    println!("╚════════════════════════════════════════════════════════════╝");
370    println!();
371
372    // Create model with default config
373    let config = HvacConfig::default();
374    println!("Model Configuration:");
375    println!("  Input features: {}", config.num_features);
376    println!("  Sequence length: {}", config.seq_len);
377    println!("  Hidden size: {}", config.hidden_size);
378    println!("  GRU layers: {}", config.num_layers);
379    println!("  Output classes: {}", config.num_classes);
380    println!("  Dropout: {}", config.dropout);
381    println!();
382
383    let model = HvacPredictor::new(config.clone());
384    println!("Model created!");
385    println!("  Total parameters: {}", model.num_parameters());
386    println!();
387
388    // Create sample input
389    let batch_size = 2;
390    let mut input_data = vec![0.5f32; batch_size * config.seq_len * config.num_features];
391
392    // Simulate normal HVAC readings
393    for b in 0..batch_size {
394        for t in 0..config.seq_len {
395            let base = (b * config.seq_len + t) * config.num_features;
396            // Pump currents ~25A (normalized)
397            for i in 0..6 {
398                input_data[base + i] = 0.5;
399            }
400            // Temperatures (normalized)
401            input_data[base + 6] = 0.83; // HW supply ~180F
402            input_data[base + 7] = 0.375; // CW supply ~55F
403            // VFD speeds ~60%
404            for i in 16..22 {
405                input_data[base + i] = 0.6;
406            }
407        }
408    }
409
410    let input = Tensor::from_vec(
411        input_data,
412        &[batch_size, config.seq_len, config.num_features],
413    )
414    .expect("Failed to create input tensor");
415
416    let input_var = Variable::new(input, false);
417    println!("Input shape: {:?}", input_var.data().shape());
418
419    // Run inference
420    println!();
421    println!("Running inference...");
422    let (imminent, warning, early) = model.predict(&input_var);
423
424    println!();
425    println!("Predictions:");
426    println!("────────────────────────────────────────────────────────────");
427    for b in 0..batch_size {
428        println!("Sample {}:", b);
429        println!(
430            "  5 min (Imminent): {} - {}",
431            imminent[b], FAILURE_TYPES[imminent[b]]
432        );
433        println!(
434            "  15 min (Warning): {} - {}",
435            warning[b], FAILURE_TYPES[warning[b]]
436        );
437        println!(
438            "  30 min (Early):   {} - {}",
439            early[b], FAILURE_TYPES[early[b]]
440        );
441    }
442    println!("────────────────────────────────────────────────────────────");
443    println!();
444    println!("Model ready for training with your HVAC sensor data!");
445}
examples/simple_training.rs (line 77)
19fn main() {
20    println!("=== Axonml ML Framework - Simple Training Example ===\n");
21
22    // Print version and features
23    println!("Version: {}", axonml::version());
24    println!("Features: {}\n", axonml::features());
25
26    // 1. Create a simple dataset (XOR problem)
27    println!("1. Creating XOR dataset...");
28    let inputs = vec![
29        vec![0.0, 0.0],
30        vec![0.0, 1.0],
31        vec![1.0, 0.0],
32        vec![1.0, 1.0],
33    ];
34    let targets = vec![0.0, 1.0, 1.0, 0.0]; // XOR outputs
35
36    println!("   Inputs: {inputs:?}");
37    println!("   Targets: {targets:?}\n");
38
39    // 2. Create a simple MLP model
40    println!("2. Creating MLP model (2 -> 4 -> 1)...");
41    let linear1 = Linear::new(2, 4);
42    let linear2 = Linear::new(4, 1);
43
44    println!("   Layer 1: Linear(2, 4)");
45    println!("   Layer 2: Linear(4, 1)\n");
46
47    // 3. Create optimizer
48    println!("3. Creating Adam optimizer (lr=0.1)...");
49    let params = [linear1.parameters(), linear2.parameters()].concat();
50    let mut optimizer = Adam::new(params, 0.1);
51    println!("   Optimizer created!\n");
52
53    // 4. Training loop
54    println!("4. Training for 1000 epochs...");
55    let epochs = 1000;
56
57    for epoch in 0..epochs {
58        let mut total_loss = 0.0;
59
60        for (input, &target) in inputs.iter().zip(targets.iter()) {
61            // Create input tensor
62            let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), true);
63
64            // Forward pass
65            let h = linear1.forward(&x);
66            let h = h.sigmoid();
67            let output = linear2.forward(&h);
68            let output = output.sigmoid();
69
70            // Create target tensor
71            let y = Variable::new(Tensor::from_vec(vec![target], &[1, 1]).unwrap(), false);
72
73            // Compute MSE loss manually: (output - target)^2
74            let diff = output.sub_var(&y);
75            let loss = diff.mul_var(&diff);
76
77            total_loss += loss.data().to_vec()[0];
78
79            // Backward pass
80            loss.backward();
81
82            // Update weights
83            optimizer.step();
84            optimizer.zero_grad();
85        }
86
87        if epoch % 200 == 0 || epoch == epochs - 1 {
88            println!("   Epoch {}: Loss = {:.6}", epoch, total_loss / 4.0);
89        }
90    }
91
92    // 5. Test the trained model
93    println!("\n5. Testing trained model...");
94    for (input, &expected) in inputs.iter().zip(targets.iter()) {
95        let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), false);
96
97        let h = linear1.forward(&x);
98        let h = h.sigmoid();
99        let output = linear2.forward(&h);
100        let output = output.sigmoid();
101
102        let pred = output.data().to_vec()[0];
103        let rounded = if pred > 0.5 { 1.0 } else { 0.0 };
104
105        println!(
106            "   Input: {input:?} -> Predicted: {pred:.4} (rounded: {rounded}) | Expected: {expected}"
107        );
108    }
109
110    println!("\n=== Training Complete! ===");
111}
examples/mnist_training.rs (line 62)
20fn main() {
21    println!("=== AxonML - MNIST Training (LeNet) ===\n");
22
23    // Detect device
24    #[cfg(feature = "cuda")]
25    let device = {
26        let cuda = Device::Cuda(0);
27        if cuda.is_available() {
28            println!("GPU detected: using CUDA device 0");
29            cuda
30        } else {
31            println!("CUDA feature enabled but no GPU available, using CPU");
32            Device::Cpu
33        }
34    };
35    #[cfg(not(feature = "cuda"))]
36    let device = {
37        println!("Using CPU (compile with --features cuda for GPU)");
38        Device::Cpu
39    };
40
41    // 1. Create dataset
42    let num_train = 2000;
43    let num_test = 400;
44    println!("\n1. Creating SyntheticMNIST dataset ({num_train} train, {num_test} test)...");
45    let train_dataset = SyntheticMNIST::new(num_train);
46    let test_dataset = SyntheticMNIST::new(num_test);
47
48    // 2. Create DataLoader
49    let batch_size = 64;
50    println!("2. Creating DataLoader (batch_size={batch_size})...");
51    let train_loader = DataLoader::new(train_dataset, batch_size);
52    let test_loader = DataLoader::new(test_dataset, batch_size);
53    println!("   Training batches: {}", train_loader.len());
54
55    // 3. Create LeNet model and move to device
56    println!("3. Creating LeNet model...");
57    let model = LeNet::new();
58    model.to_device(device);
59    let params = model.parameters();
60    let total_params: usize = params
61        .iter()
62        .map(|p| p.variable().data().to_vec().len())
63        .sum();
64    println!(
65        "   Parameters: {} ({} total weights)",
66        params.len(),
67        total_params
68    );
69    println!("   Device: {:?}", device);
70
71    // 4. Create optimizer and loss
72    println!("4. Creating Adam optimizer (lr=0.001) + CrossEntropyLoss...");
73    let mut optimizer = Adam::new(params, 0.001);
74    let criterion = CrossEntropyLoss::new();
75
76    // 5. Training loop
77    let epochs = 10;
78    println!("5. Training for {epochs} epochs...\n");
79
80    let train_start = Instant::now();
81
82    for epoch in 0..epochs {
83        let epoch_start = Instant::now();
84        let mut total_loss = 0.0;
85        let mut correct = 0usize;
86        let mut total = 0usize;
87        let mut batch_count = 0;
88
89        for batch in train_loader.iter() {
90            let bs = batch.data.shape()[0];
91
92            // Reshape to [N, 1, 28, 28] and create Variable
93            let input_data = batch.data.to_vec();
94            let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
95            let input = Variable::new(
96                if device.is_gpu() {
97                    input_tensor.to_device(device).unwrap()
98                } else {
99                    input_tensor
100                },
101                true,
102            );
103
104            // Target: convert one-hot [N, 10] to class indices [N]
105            let target_onehot = batch.targets.to_vec();
106            let mut target_indices = vec![0.0f32; bs];
107            for i in 0..bs {
108                let offset = i * 10;
109                let mut max_idx = 0;
110                let mut max_val = f32::NEG_INFINITY;
111                for c in 0..10 {
112                    if target_onehot[offset + c] > max_val {
113                        max_val = target_onehot[offset + c];
114                        max_idx = c;
115                    }
116                }
117                target_indices[i] = max_idx as f32;
118            }
119            let target_tensor = Tensor::from_vec(target_indices.clone(), &[bs]).unwrap();
120            let target = Variable::new(
121                if device.is_gpu() {
122                    target_tensor.to_device(device).unwrap()
123                } else {
124                    target_tensor
125                },
126                false,
127            );
128
129            // Forward pass
130            let output = model.forward(&input);
131
132            // Cross-entropy loss
133            let loss = criterion.compute(&output, &target);
134
135            let loss_val = loss.data().to_vec()[0];
136            total_loss += loss_val;
137            batch_count += 1;
138
139            // Compute training accuracy
140            let out_data = output.data().to_vec();
141            for i in 0..bs {
142                let offset = i * 10;
143                let mut pred = 0;
144                let mut pred_val = f32::NEG_INFINITY;
145                for c in 0..10 {
146                    if out_data[offset + c] > pred_val {
147                        pred_val = out_data[offset + c];
148                        pred = c;
149                    }
150                }
151                if pred == target_indices[i] as usize {
152                    correct += 1;
153                }
154                total += 1;
155            }
156
157            // Backward pass
158            loss.backward();
159
160            // Update weights
161            optimizer.step();
162            optimizer.zero_grad();
163        }
164
165        let epoch_time = epoch_start.elapsed();
166        let avg_loss = total_loss / batch_count as f32;
167        let accuracy = 100.0 * correct as f32 / total as f32;
168        let samples_per_sec = total as f64 / epoch_time.as_secs_f64();
169
170        println!(
171            "   Epoch {:2}/{}: Loss={:.4}  Acc={:.1}%  ({:.0} samples/s, {:.2}s)",
172            epoch + 1,
173            epochs,
174            avg_loss,
175            accuracy,
176            samples_per_sec,
177            epoch_time.as_secs_f64(),
178        );
179    }
180
181    let train_time = train_start.elapsed();
182    println!("\n   Total training time: {:.2}s", train_time.as_secs_f64());
183
184    // 6. Test evaluation
185    println!("\n6. Evaluating on test set...");
186
187    // Disable gradient computation for evaluation
188    let (correct, total) = no_grad(|| {
189        let mut correct = 0usize;
190        let mut total = 0usize;
191
192        for batch in test_loader.iter() {
193            let bs = batch.data.shape()[0];
194
195            let input_data = batch.data.to_vec();
196            let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
197            let input = Variable::new(
198                if device.is_gpu() {
199                    input_tensor.to_device(device).unwrap()
200                } else {
201                    input_tensor
202                },
203                false,
204            );
205
206            let target_onehot = batch.targets.to_vec();
207            let output = model.forward(&input);
208            let out_data = output.data().to_vec();
209
210            for i in 0..bs {
211                // Prediction: argmax of output
212                let offset = i * 10;
213                let mut pred = 0;
214                let mut pred_val = f32::NEG_INFINITY;
215                for c in 0..10 {
216                    if out_data[offset + c] > pred_val {
217                        pred_val = out_data[offset + c];
218                        pred = c;
219                    }
220                }
221
222                // True label: argmax of one-hot target
223                let mut true_label = 0;
224                let mut true_val = f32::NEG_INFINITY;
225                for c in 0..10 {
226                    if target_onehot[i * 10 + c] > true_val {
227                        true_val = target_onehot[i * 10 + c];
228                        true_label = c;
229                    }
230                }
231
232                if pred == true_label {
233                    correct += 1;
234                }
235                total += 1;
236            }
237        }
238
239        (correct, total)
240    });
241
242    let test_accuracy = 100.0 * correct as f32 / total as f32;
243    println!(
244        "   Test Accuracy: {}/{} ({:.2}%)",
245        correct, total, test_accuracy
246    );
247
248    println!("\n=== Training Complete! ===");
249    println!("   Device: {:?}", device);
250    println!("   Final test accuracy: {:.2}%", test_accuracy);
251}
examples/train_panoptes.rs (line 132)
56fn main() {
57    println!("╔══════════════════════════════════════════════════════════════╗");
58    println!("║     PANOPTES — Facility-Wide Anomaly Detection Training     ║");
59    println!("║     Heritage Pointe of Warren (59 equipment)                ║");
60    println!("╚══════════════════════════════════════════════════════════════╝");
61    println!();
62
63    // =========================================================================
64    // Generate training data
65    // =========================================================================
66    println!("[data] Generating physics-informed training data...");
67    let t0 = Instant::now();
68
69    let sim = WarrenSimulator::new(SEED);
70    let normal_train = sim.generate_normal(NORMAL_SAMPLES);
71    let fault_data = sim.generate_with_faults(FAULT_SAMPLES, 1.0);
72
73    // Validation set (different seed)
74    let val_sim = WarrenSimulator::new(SEED + 999);
75    let normal_val = val_sim.generate_normal(200);
76    let fault_val = val_sim.generate_with_faults(100, 1.0);
77
78    println!("  Normal train: {} samples", normal_train.len());
79    println!("  Fault train:  {} samples", fault_data.len());
80    println!("  Normal val:   {} samples", normal_val.len());
81    println!("  Fault val:    {} samples", fault_val.len());
82    println!("  Generated in {:.1}s", t0.elapsed().as_secs_f32());
83    println!();
84
85    // =========================================================================
86    // Create model
87    // =========================================================================
88    let model = Panoptes::new(NUM_EQUIPMENT);
89    println!("[model] Panoptes created");
90    println!("  Equipment slots: {NUM_EQUIPMENT}");
91    println!("  Parameters: {}", model.num_parameters());
92    println!("  Embed dim: {EMBED_DIM}");
93    println!();
94
95    let mse = MSELoss::new();
96
97    // Zero target for normal operation
98    let zero_target = Variable::new(
99        Tensor::from_vec(vec![0.0; NUM_EQUIPMENT], &[1, NUM_EQUIPMENT]).unwrap(),
100        false,
101    );
102
103    // =========================================================================
104    // Phase 1: Learn normal operation
105    // =========================================================================
106    println!("═══════════════════════════════════════════════════════════════");
107    println!(" PHASE 1: Learning Normal Operation ({PHASE1_EPOCHS} epochs)");
108    println!("═══════════════════════════════════════════════════════════════");
109    println!(
110        "  {:>5}  {:>12}  {:>12}  {:>8}",
111        "Epoch", "Train Loss", "Val Loss", "Time"
112    );
113    println!("  {:-<5}  {:-<12}  {:-<12}  {:-<8}", "", "", "", "");
114
115    let params = model.parameters();
116    let mut optimizer = Adam::new(params, LR);
117
118    for epoch in 1..=PHASE1_EPOCHS {
119        let epoch_start = Instant::now();
120        let mut epoch_loss = 0.0f32;
121        let mut batch_count = 0;
122
123        // Train on normal data: target = all zeros
124        for batch_start in (0..normal_train.len()).step_by(BATCH_SIZE) {
125            let batch_end = (batch_start + BATCH_SIZE).min(normal_train.len());
126
127            for i in batch_start..batch_end {
128                optimizer.zero_grad();
129
130                let (equip_scores, _) = model.forward_snapshot(&normal_train[i]);
131                let loss = mse.compute(&equip_scores, &zero_target);
132                let loss_val = loss.data().to_vec()[0];
133                epoch_loss += loss_val;
134                batch_count += 1;
135
136                if loss.requires_grad() {
137                    loss.backward();
138                    optimizer.step();
139                }
140            }
141        }
142
143        // Validation
144        let val_loss = evaluate_normal(&model, &normal_val, &mse, &zero_target);
145
146        let avg_loss = epoch_loss / batch_count as f32;
147        let elapsed = epoch_start.elapsed().as_secs_f32();
148
149        println!(
150            "  {:>5}  {:>12.6}  {:>12.6}  {:>6.1}s",
151            epoch, avg_loss, val_loss, elapsed
152        );
153    }
154
155    println!();
156
157    // =========================================================================
158    // Phase 2: Learn fault signatures
159    // =========================================================================
160    println!("═══════════════════════════════════════════════════════════════");
161    println!(" PHASE 2: Learning Fault Signatures ({PHASE2_EPOCHS} epochs)");
162    println!("═══════════════════════════════════════════════════════════════");
163    println!(
164        "  {:>5}  {:>12}  {:>12}  {:>12}  {:>8}",
165        "Epoch", "Normal Loss", "Fault Loss", "Val Loss", "Time"
166    );
167    println!(
168        "  {:-<5}  {:-<12}  {:-<12}  {:-<12}  {:-<8}",
169        "", "", "", "", ""
170    );
171
172    // Reset optimizer with lower LR for phase 2
173    let params = model.parameters();
174    let mut optimizer = Adam::new(params, LR * 0.5);
175
176    for epoch in 1..=PHASE2_EPOCHS {
177        let epoch_start = Instant::now();
178        let mut normal_loss_sum = 0.0f32;
179        let mut fault_loss_sum = 0.0f32;
180        let mut normal_count = 0;
181        let mut fault_count = 0;
182
183        // Interleave normal + fault samples
184        let normal_per_epoch = NORMAL_SAMPLES / 2; // Use half of normal data
185        let fault_per_epoch = fault_data.len();
186
187        // Normal samples: target = zeros
188        for i in 0..normal_per_epoch.min(normal_train.len()) {
189            optimizer.zero_grad();
190            let (equip_scores, _) = model.forward_snapshot(&normal_train[i]);
191            let loss = mse.compute(&equip_scores, &zero_target);
192            normal_loss_sum += loss.data().to_vec()[0];
193            normal_count += 1;
194
195            if loss.requires_grad() {
196                loss.backward();
197                optimizer.step();
198            }
199        }
200
201        // Fault samples: target = 1.0 for affected equipment
202        for i in 0..fault_per_epoch {
203            let (ref snap, ref _fault, ref affected) = fault_data[i];
204
205            let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
206            let fault_target = Variable::new(
207                Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
208                false,
209            );
210
211            optimizer.zero_grad();
212            let (equip_scores, _) = model.forward_snapshot(snap);
213            let loss = mse.compute(&equip_scores, &fault_target);
214            fault_loss_sum += loss.data().to_vec()[0];
215            fault_count += 1;
216
217            if loss.requires_grad() {
218                loss.backward();
219                optimizer.step();
220            }
221        }
222
223        // Validation
224        let val_loss = evaluate_mixed(&model, &normal_val, &fault_val, &mse, &zero_target);
225
226        let avg_normal = normal_loss_sum / normal_count.max(1) as f32;
227        let avg_fault = fault_loss_sum / fault_count.max(1) as f32;
228        let elapsed = epoch_start.elapsed().as_secs_f32();
229
230        println!(
231            "  {:>5}  {:>12.6}  {:>12.6}  {:>12.6}  {:>6.1}s",
232            epoch, avg_normal, avg_fault, val_loss, elapsed
233        );
234    }
235
236    println!();
237
238    // =========================================================================
239    // Phase 3: Temporal training
240    // =========================================================================
241    println!("═══════════════════════════════════════════════════════════════");
242    println!(" PHASE 3: Temporal Training ({PHASE3_EPOCHS} epochs, window={TEMPORAL_WINDOW})");
243    println!("═══════════════════════════════════════════════════════════════");
244
245    // Generate temporal sequences
246    println!("[data] Generating temporal sequences...");
247    let t0 = Instant::now();
248
249    // Normal temporal sequences: varied starting OAT, slow drift
250    let mut normal_seqs: Vec<Vec<FacilitySnapshot>> = Vec::new();
251    for i in 0..TEMPORAL_NORMAL_SEQS {
252        let start_oat = -5.0 + (i as f32 / TEMPORAL_NORMAL_SEQS as f32) * 100.0;
253        let drift = if start_oat < 50.0 { 0.2 } else { -0.1 }; // warming up or cooling down
254        let seq_sim = WarrenSimulator::new(SEED + 5000 + i as u64);
255        let seq = seq_sim.generate_temporal_sequence(TEMPORAL_WINDOW, start_oat, drift);
256        normal_seqs.push(seq);
257    }
258
259    // Fault temporal sequences: fault injected mid-sequence
260    let mut fault_seqs: Vec<(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)> = Vec::new();
261    for i in 0..TEMPORAL_FAULT_SEQS {
262        let start_oat = -5.0 + (i as f32 / TEMPORAL_FAULT_SEQS as f32) * 100.0;
263        let drift = 0.1;
264        let seq_sim = WarrenSimulator::new(SEED + 8000 + i as u64);
265        let seq_data =
266            seq_sim.generate_temporal_with_fault(TEMPORAL_WINDOW, start_oat, drift, i as u64);
267        fault_seqs.push(seq_data);
268    }
269
270    // Validation temporal sequences
271    let mut val_normal_seqs: Vec<Vec<FacilitySnapshot>> = Vec::new();
272    for i in 0..20 {
273        let start_oat = 10.0 + (i as f32 / 20.0) * 80.0;
274        let seq_sim = WarrenSimulator::new(SEED + 9000 + i as u64);
275        let seq = seq_sim.generate_temporal_sequence(TEMPORAL_WINDOW, start_oat, 0.15);
276        val_normal_seqs.push(seq);
277    }
278
279    let mut val_fault_seqs: Vec<(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)> = Vec::new();
280    for i in 0..20 {
281        let start_oat = 10.0 + (i as f32 / 20.0) * 80.0;
282        let seq_sim = WarrenSimulator::new(SEED + 9500 + i as u64);
283        let seq_data =
284            seq_sim.generate_temporal_with_fault(TEMPORAL_WINDOW, start_oat, 0.1, i as u64);
285        val_fault_seqs.push(seq_data);
286    }
287
288    println!("  Normal temporal seqs: {}", normal_seqs.len());
289    println!("  Fault temporal seqs:  {}", fault_seqs.len());
290    println!("  Val normal seqs:      {}", val_normal_seqs.len());
291    println!("  Val fault seqs:       {}", val_fault_seqs.len());
292    println!("  Window size: {TEMPORAL_WINDOW} snapshots (1 hour)");
293    println!("  Generated in {:.1}s", t0.elapsed().as_secs_f32());
294    println!();
295
296    println!(
297        "  {:>5}  {:>12}  {:>12}  {:>12}  {:>8}",
298        "Epoch", "Normal Loss", "Fault Loss", "Val Loss", "Time"
299    );
300    println!(
301        "  {:-<5}  {:-<12}  {:-<12}  {:-<12}  {:-<8}",
302        "", "", "", "", ""
303    );
304
305    // Lower LR for temporal fine-tuning
306    let params = model.parameters();
307    let mut optimizer = Adam::new(params, LR * 0.3);
308
309    for epoch in 1..=PHASE3_EPOCHS {
310        let epoch_start = Instant::now();
311        let mut normal_loss_sum = 0.0f32;
312        let mut fault_loss_sum = 0.0f32;
313        let mut normal_count = 0;
314        let mut fault_count = 0;
315
316        // Normal temporal sequences: target = all zeros
317        for seq in &normal_seqs {
318            optimizer.zero_grad();
319            let (equip_scores, _) = model.forward_temporal(seq);
320            let loss = mse.compute(&equip_scores, &zero_target);
321            normal_loss_sum += loss.data().to_vec()[0];
322            normal_count += 1;
323
324            if loss.requires_grad() {
325                loss.backward();
326                optimizer.step();
327            }
328        }
329
330        // Fault temporal sequences: target = 1.0 for affected equipment
331        for (seq, _onset, _fault, affected) in &fault_seqs {
332            let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
333            let fault_target = Variable::new(
334                Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
335                false,
336            );
337
338            optimizer.zero_grad();
339            let (equip_scores, _) = model.forward_temporal(seq);
340            let loss = mse.compute(&equip_scores, &fault_target);
341            fault_loss_sum += loss.data().to_vec()[0];
342            fault_count += 1;
343
344            if loss.requires_grad() {
345                loss.backward();
346                optimizer.step();
347            }
348        }
349
350        // Validation
351        let val_loss = evaluate_temporal_mixed(
352            &model,
353            &val_normal_seqs,
354            &val_fault_seqs,
355            &mse,
356            &zero_target,
357        );
358
359        let avg_normal = normal_loss_sum / normal_count.max(1) as f32;
360        let avg_fault = fault_loss_sum / fault_count.max(1) as f32;
361        let elapsed = epoch_start.elapsed().as_secs_f32();
362
363        println!(
364            "  {:>5}  {:>12.6}  {:>12.6}  {:>12.6}  {:>6.1}s",
365            epoch, avg_normal, avg_fault, val_loss, elapsed
366        );
367    }
368
369    println!();
370
371    // =========================================================================
372    // Final evaluation
373    // =========================================================================
374    println!("═══════════════════════════════════════════════════════════════");
375    println!(" FINAL EVALUATION");
376    println!("═══════════════════════════════════════════════════════════════");
377
378    // Test on normal data — scores should be near zero
379    let config = FacilityConfig::warren();
380    println!("\n  Normal operation (should be low scores):");
381    for i in [0, 50, 100, 150] {
382        if i >= normal_val.len() {
383            break;
384        }
385        let (equip_scores, fac_score) = model.forward_snapshot(&normal_val[i]);
386        let scores = equip_scores.data().to_vec();
387        let fac = fac_score.data().to_vec()[0];
388        let max_score = scores.iter().cloned().fold(0.0f32, f32::max);
389        let avg_score: f32 = scores.iter().sum::<f32>() / scores.len() as f32;
390        println!(
391            "    Sample {i:>3}: facility={fac:.4}, avg_equip={avg_score:.4}, max_equip={max_score:.4}"
392        );
393    }
394
395    // Test on fault data — affected equipment should have higher scores
396    println!("\n  Fault samples (affected equipment should score higher):");
397    for i in 0..5.min(fault_val.len()) {
398        let (ref snap, ref fault, ref affected) = fault_val[i];
399        let (equip_scores, fac_score) = model.forward_snapshot(snap);
400        let scores = equip_scores.data().to_vec();
401        let fac = fac_score.data().to_vec()[0];
402
403        let output = PanoptesOutput::from_scores(&scores, fac, &config, 0.3);
404
405        // Get scores for affected vs unaffected
406        let affected_avg: f32 = if !affected.is_empty() {
407            affected
408                .iter()
409                .filter(|&&s| s < scores.len())
410                .map(|&s| scores[s])
411                .sum::<f32>()
412                / affected.len() as f32
413        } else {
414            0.0
415        };
416
417        println!("    Fault {:?}:", fault);
418        println!(
419            "      facility={fac:.4}, affected_avg={affected_avg:.4}, alerts={}",
420            output.alerts.len()
421        );
422    }
423
424    // Temporal evaluation
425    println!("\n  Temporal normal (should be low scores):");
426    for i in 0..3.min(val_normal_seqs.len()) {
427        let (equip_scores, fac_score) = model.forward_temporal(&val_normal_seqs[i]);
428        let scores = equip_scores.data().to_vec();
429        let fac = fac_score.data().to_vec()[0];
430        let max_score = scores.iter().cloned().fold(0.0f32, f32::max);
431        let avg_score: f32 = scores.iter().sum::<f32>() / scores.len() as f32;
432        println!(
433            "    Seq {i:>3}: facility={fac:.4}, avg_equip={avg_score:.4}, max_equip={max_score:.4}"
434        );
435    }
436
437    println!("\n  Temporal fault (fault injected mid-sequence):");
438    for i in 0..5.min(val_fault_seqs.len()) {
439        let (ref seq, onset, ref fault, ref affected) = val_fault_seqs[i];
440        let (equip_scores, fac_score) = model.forward_temporal(seq);
441        let scores = equip_scores.data().to_vec();
442        let fac = fac_score.data().to_vec()[0];
443
444        let affected_avg: f32 = if !affected.is_empty() {
445            affected
446                .iter()
447                .filter(|&&s| s < scores.len())
448                .map(|&s| scores[s])
449                .sum::<f32>()
450                / affected.len() as f32
451        } else {
452            0.0
453        };
454        let unaffected_avg: f32 = {
455            let unaffected: Vec<f32> = scores
456                .iter()
457                .enumerate()
458                .filter(|(idx, _)| !affected.contains(idx))
459                .map(|(_, &s)| s)
460                .collect();
461            if unaffected.is_empty() {
462                0.0
463            } else {
464                unaffected.iter().sum::<f32>() / unaffected.len() as f32
465            }
466        };
467
468        let output = PanoptesOutput::from_scores(&scores, fac, &config, 0.3);
469        println!(
470            "    Fault {:?} (onset step {onset}/{TEMPORAL_WINDOW}):",
471            fault
472        );
473        println!(
474            "      facility={fac:.4}, affected={affected_avg:.4}, unaffected={unaffected_avg:.4}, alerts={}",
475            output.alerts.len()
476        );
477    }
478
479    println!();
480    println!("Training complete.");
481}
482
483// =============================================================================
484// Evaluation helpers
485// =============================================================================
486
487fn evaluate_normal(
488    model: &Panoptes,
489    val_data: &[FacilitySnapshot],
490    mse: &MSELoss,
491    zero_target: &Variable,
492) -> f32 {
493    let mut total_loss = 0.0f32;
494    for snap in val_data {
495        let (equip_scores, _) = model.forward_snapshot(snap);
496        let loss = mse.compute(&equip_scores, zero_target);
497        total_loss += loss.data().to_vec()[0];
498    }
499    total_loss / val_data.len() as f32
500}
501
502fn evaluate_mixed(
503    model: &Panoptes,
504    normal_val: &[FacilitySnapshot],
505    fault_val: &[(FacilitySnapshot, FaultType, Vec<usize>)],
506    mse: &MSELoss,
507    zero_target: &Variable,
508) -> f32 {
509    let mut total_loss = 0.0f32;
510    let mut count = 0;
511
512    for snap in normal_val {
513        let (equip_scores, _) = model.forward_snapshot(snap);
514        let loss = mse.compute(&equip_scores, zero_target);
515        total_loss += loss.data().to_vec()[0];
516        count += 1;
517    }
518
519    for (snap, _, affected) in fault_val {
520        let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
521        let fault_target = Variable::new(
522            Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
523            false,
524        );
525        let (equip_scores, _) = model.forward_snapshot(snap);
526        let loss = mse.compute(&equip_scores, &fault_target);
527        total_loss += loss.data().to_vec()[0];
528        count += 1;
529    }
530
531    total_loss / count as f32
532}
533
534fn evaluate_temporal_mixed(
535    model: &Panoptes,
536    normal_seqs: &[Vec<FacilitySnapshot>],
537    fault_seqs: &[(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)],
538    mse: &MSELoss,
539    zero_target: &Variable,
540) -> f32 {
541    let mut total_loss = 0.0f32;
542    let mut count = 0;
543
544    for seq in normal_seqs {
545        let (equip_scores, _) = model.forward_temporal(seq);
546        let loss = mse.compute(&equip_scores, zero_target);
547        total_loss += loss.data().to_vec()[0];
548        count += 1;
549    }
550
551    for (seq, _, _, affected) in fault_seqs {
552        let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
553        let fault_target = Variable::new(
554            Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
555            false,
556        );
557        let (equip_scores, _) = model.forward_temporal(seq);
558        let loss = mse.compute(&equip_scores, &fault_target);
559        total_loss += loss.data().to_vec()[0];
560        count += 1;
561    }
562
563    total_loss / count as f32
564}
Source

pub fn shape(&self) -> Vec<usize>

Returns the shape of the tensor.

Source

pub fn ndim(&self) -> usize

Returns the number of dimensions.

Source

pub fn numel(&self) -> usize

Returns the total number of elements.

Source

pub fn device(&self) -> Device

Returns the device this variable’s data is on.

Source

pub fn to_device(&self, device: Device) -> Variable

Moves this variable’s data to the specified device.

Creates a new leaf Variable on the target device. Used for moving inputs to GPU before forward pass.

Source

pub fn requires_grad(&self) -> bool

Returns whether this variable requires gradients.

Examples found in repository?
examples/hvac_model.rs (line 190)
165    fn mean_pool(&self, x: &Variable) -> Variable {
166        let data = x.data();
167        let shape = data.shape();
168        let batch_size = shape[0];
169        let seq_len = shape[1];
170        let hidden = shape[2];
171
172        // Reshape to [batch * seq, hidden] then back
173        let values = data.to_vec();
174
175        // Calculate mean over sequence dimension
176        let mut pooled = vec![0.0f32; batch_size * hidden];
177        for b in 0..batch_size {
178            for h in 0..hidden {
179                let mut sum = 0.0;
180                for s in 0..seq_len {
181                    let idx = b * seq_len * hidden + s * hidden + h;
182                    sum += values[idx];
183                }
184                pooled[b * hidden + h] = sum / seq_len as f32;
185            }
186        }
187
188        let pooled_tensor = Tensor::from_vec(pooled, &[batch_size, hidden])
189            .expect("Failed to create pooled tensor");
190        Variable::new(pooled_tensor, x.requires_grad())
191    }
More examples
Hide additional examples
examples/train_panoptes.rs (line 136)
56fn main() {
57    println!("╔══════════════════════════════════════════════════════════════╗");
58    println!("║     PANOPTES — Facility-Wide Anomaly Detection Training     ║");
59    println!("║     Heritage Pointe of Warren (59 equipment)                ║");
60    println!("╚══════════════════════════════════════════════════════════════╝");
61    println!();
62
63    // =========================================================================
64    // Generate training data
65    // =========================================================================
66    println!("[data] Generating physics-informed training data...");
67    let t0 = Instant::now();
68
69    let sim = WarrenSimulator::new(SEED);
70    let normal_train = sim.generate_normal(NORMAL_SAMPLES);
71    let fault_data = sim.generate_with_faults(FAULT_SAMPLES, 1.0);
72
73    // Validation set (different seed)
74    let val_sim = WarrenSimulator::new(SEED + 999);
75    let normal_val = val_sim.generate_normal(200);
76    let fault_val = val_sim.generate_with_faults(100, 1.0);
77
78    println!("  Normal train: {} samples", normal_train.len());
79    println!("  Fault train:  {} samples", fault_data.len());
80    println!("  Normal val:   {} samples", normal_val.len());
81    println!("  Fault val:    {} samples", fault_val.len());
82    println!("  Generated in {:.1}s", t0.elapsed().as_secs_f32());
83    println!();
84
85    // =========================================================================
86    // Create model
87    // =========================================================================
88    let model = Panoptes::new(NUM_EQUIPMENT);
89    println!("[model] Panoptes created");
90    println!("  Equipment slots: {NUM_EQUIPMENT}");
91    println!("  Parameters: {}", model.num_parameters());
92    println!("  Embed dim: {EMBED_DIM}");
93    println!();
94
95    let mse = MSELoss::new();
96
97    // Zero target for normal operation
98    let zero_target = Variable::new(
99        Tensor::from_vec(vec![0.0; NUM_EQUIPMENT], &[1, NUM_EQUIPMENT]).unwrap(),
100        false,
101    );
102
103    // =========================================================================
104    // Phase 1: Learn normal operation
105    // =========================================================================
106    println!("═══════════════════════════════════════════════════════════════");
107    println!(" PHASE 1: Learning Normal Operation ({PHASE1_EPOCHS} epochs)");
108    println!("═══════════════════════════════════════════════════════════════");
109    println!(
110        "  {:>5}  {:>12}  {:>12}  {:>8}",
111        "Epoch", "Train Loss", "Val Loss", "Time"
112    );
113    println!("  {:-<5}  {:-<12}  {:-<12}  {:-<8}", "", "", "", "");
114
115    let params = model.parameters();
116    let mut optimizer = Adam::new(params, LR);
117
118    for epoch in 1..=PHASE1_EPOCHS {
119        let epoch_start = Instant::now();
120        let mut epoch_loss = 0.0f32;
121        let mut batch_count = 0;
122
123        // Train on normal data: target = all zeros
124        for batch_start in (0..normal_train.len()).step_by(BATCH_SIZE) {
125            let batch_end = (batch_start + BATCH_SIZE).min(normal_train.len());
126
127            for i in batch_start..batch_end {
128                optimizer.zero_grad();
129
130                let (equip_scores, _) = model.forward_snapshot(&normal_train[i]);
131                let loss = mse.compute(&equip_scores, &zero_target);
132                let loss_val = loss.data().to_vec()[0];
133                epoch_loss += loss_val;
134                batch_count += 1;
135
136                if loss.requires_grad() {
137                    loss.backward();
138                    optimizer.step();
139                }
140            }
141        }
142
143        // Validation
144        let val_loss = evaluate_normal(&model, &normal_val, &mse, &zero_target);
145
146        let avg_loss = epoch_loss / batch_count as f32;
147        let elapsed = epoch_start.elapsed().as_secs_f32();
148
149        println!(
150            "  {:>5}  {:>12.6}  {:>12.6}  {:>6.1}s",
151            epoch, avg_loss, val_loss, elapsed
152        );
153    }
154
155    println!();
156
157    // =========================================================================
158    // Phase 2: Learn fault signatures
159    // =========================================================================
160    println!("═══════════════════════════════════════════════════════════════");
161    println!(" PHASE 2: Learning Fault Signatures ({PHASE2_EPOCHS} epochs)");
162    println!("═══════════════════════════════════════════════════════════════");
163    println!(
164        "  {:>5}  {:>12}  {:>12}  {:>12}  {:>8}",
165        "Epoch", "Normal Loss", "Fault Loss", "Val Loss", "Time"
166    );
167    println!(
168        "  {:-<5}  {:-<12}  {:-<12}  {:-<12}  {:-<8}",
169        "", "", "", "", ""
170    );
171
172    // Reset optimizer with lower LR for phase 2
173    let params = model.parameters();
174    let mut optimizer = Adam::new(params, LR * 0.5);
175
176    for epoch in 1..=PHASE2_EPOCHS {
177        let epoch_start = Instant::now();
178        let mut normal_loss_sum = 0.0f32;
179        let mut fault_loss_sum = 0.0f32;
180        let mut normal_count = 0;
181        let mut fault_count = 0;
182
183        // Interleave normal + fault samples
184        let normal_per_epoch = NORMAL_SAMPLES / 2; // Use half of normal data
185        let fault_per_epoch = fault_data.len();
186
187        // Normal samples: target = zeros
188        for i in 0..normal_per_epoch.min(normal_train.len()) {
189            optimizer.zero_grad();
190            let (equip_scores, _) = model.forward_snapshot(&normal_train[i]);
191            let loss = mse.compute(&equip_scores, &zero_target);
192            normal_loss_sum += loss.data().to_vec()[0];
193            normal_count += 1;
194
195            if loss.requires_grad() {
196                loss.backward();
197                optimizer.step();
198            }
199        }
200
201        // Fault samples: target = 1.0 for affected equipment
202        for i in 0..fault_per_epoch {
203            let (ref snap, ref _fault, ref affected) = fault_data[i];
204
205            let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
206            let fault_target = Variable::new(
207                Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
208                false,
209            );
210
211            optimizer.zero_grad();
212            let (equip_scores, _) = model.forward_snapshot(snap);
213            let loss = mse.compute(&equip_scores, &fault_target);
214            fault_loss_sum += loss.data().to_vec()[0];
215            fault_count += 1;
216
217            if loss.requires_grad() {
218                loss.backward();
219                optimizer.step();
220            }
221        }
222
223        // Validation
224        let val_loss = evaluate_mixed(&model, &normal_val, &fault_val, &mse, &zero_target);
225
226        let avg_normal = normal_loss_sum / normal_count.max(1) as f32;
227        let avg_fault = fault_loss_sum / fault_count.max(1) as f32;
228        let elapsed = epoch_start.elapsed().as_secs_f32();
229
230        println!(
231            "  {:>5}  {:>12.6}  {:>12.6}  {:>12.6}  {:>6.1}s",
232            epoch, avg_normal, avg_fault, val_loss, elapsed
233        );
234    }
235
236    println!();
237
238    // =========================================================================
239    // Phase 3: Temporal training
240    // =========================================================================
241    println!("═══════════════════════════════════════════════════════════════");
242    println!(" PHASE 3: Temporal Training ({PHASE3_EPOCHS} epochs, window={TEMPORAL_WINDOW})");
243    println!("═══════════════════════════════════════════════════════════════");
244
245    // Generate temporal sequences
246    println!("[data] Generating temporal sequences...");
247    let t0 = Instant::now();
248
249    // Normal temporal sequences: varied starting OAT, slow drift
250    let mut normal_seqs: Vec<Vec<FacilitySnapshot>> = Vec::new();
251    for i in 0..TEMPORAL_NORMAL_SEQS {
252        let start_oat = -5.0 + (i as f32 / TEMPORAL_NORMAL_SEQS as f32) * 100.0;
253        let drift = if start_oat < 50.0 { 0.2 } else { -0.1 }; // warming up or cooling down
254        let seq_sim = WarrenSimulator::new(SEED + 5000 + i as u64);
255        let seq = seq_sim.generate_temporal_sequence(TEMPORAL_WINDOW, start_oat, drift);
256        normal_seqs.push(seq);
257    }
258
259    // Fault temporal sequences: fault injected mid-sequence
260    let mut fault_seqs: Vec<(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)> = Vec::new();
261    for i in 0..TEMPORAL_FAULT_SEQS {
262        let start_oat = -5.0 + (i as f32 / TEMPORAL_FAULT_SEQS as f32) * 100.0;
263        let drift = 0.1;
264        let seq_sim = WarrenSimulator::new(SEED + 8000 + i as u64);
265        let seq_data =
266            seq_sim.generate_temporal_with_fault(TEMPORAL_WINDOW, start_oat, drift, i as u64);
267        fault_seqs.push(seq_data);
268    }
269
270    // Validation temporal sequences
271    let mut val_normal_seqs: Vec<Vec<FacilitySnapshot>> = Vec::new();
272    for i in 0..20 {
273        let start_oat = 10.0 + (i as f32 / 20.0) * 80.0;
274        let seq_sim = WarrenSimulator::new(SEED + 9000 + i as u64);
275        let seq = seq_sim.generate_temporal_sequence(TEMPORAL_WINDOW, start_oat, 0.15);
276        val_normal_seqs.push(seq);
277    }
278
279    let mut val_fault_seqs: Vec<(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)> = Vec::new();
280    for i in 0..20 {
281        let start_oat = 10.0 + (i as f32 / 20.0) * 80.0;
282        let seq_sim = WarrenSimulator::new(SEED + 9500 + i as u64);
283        let seq_data =
284            seq_sim.generate_temporal_with_fault(TEMPORAL_WINDOW, start_oat, 0.1, i as u64);
285        val_fault_seqs.push(seq_data);
286    }
287
288    println!("  Normal temporal seqs: {}", normal_seqs.len());
289    println!("  Fault temporal seqs:  {}", fault_seqs.len());
290    println!("  Val normal seqs:      {}", val_normal_seqs.len());
291    println!("  Val fault seqs:       {}", val_fault_seqs.len());
292    println!("  Window size: {TEMPORAL_WINDOW} snapshots (1 hour)");
293    println!("  Generated in {:.1}s", t0.elapsed().as_secs_f32());
294    println!();
295
296    println!(
297        "  {:>5}  {:>12}  {:>12}  {:>12}  {:>8}",
298        "Epoch", "Normal Loss", "Fault Loss", "Val Loss", "Time"
299    );
300    println!(
301        "  {:-<5}  {:-<12}  {:-<12}  {:-<12}  {:-<8}",
302        "", "", "", "", ""
303    );
304
305    // Lower LR for temporal fine-tuning
306    let params = model.parameters();
307    let mut optimizer = Adam::new(params, LR * 0.3);
308
309    for epoch in 1..=PHASE3_EPOCHS {
310        let epoch_start = Instant::now();
311        let mut normal_loss_sum = 0.0f32;
312        let mut fault_loss_sum = 0.0f32;
313        let mut normal_count = 0;
314        let mut fault_count = 0;
315
316        // Normal temporal sequences: target = all zeros
317        for seq in &normal_seqs {
318            optimizer.zero_grad();
319            let (equip_scores, _) = model.forward_temporal(seq);
320            let loss = mse.compute(&equip_scores, &zero_target);
321            normal_loss_sum += loss.data().to_vec()[0];
322            normal_count += 1;
323
324            if loss.requires_grad() {
325                loss.backward();
326                optimizer.step();
327            }
328        }
329
330        // Fault temporal sequences: target = 1.0 for affected equipment
331        for (seq, _onset, _fault, affected) in &fault_seqs {
332            let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
333            let fault_target = Variable::new(
334                Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
335                false,
336            );
337
338            optimizer.zero_grad();
339            let (equip_scores, _) = model.forward_temporal(seq);
340            let loss = mse.compute(&equip_scores, &fault_target);
341            fault_loss_sum += loss.data().to_vec()[0];
342            fault_count += 1;
343
344            if loss.requires_grad() {
345                loss.backward();
346                optimizer.step();
347            }
348        }
349
350        // Validation
351        let val_loss = evaluate_temporal_mixed(
352            &model,
353            &val_normal_seqs,
354            &val_fault_seqs,
355            &mse,
356            &zero_target,
357        );
358
359        let avg_normal = normal_loss_sum / normal_count.max(1) as f32;
360        let avg_fault = fault_loss_sum / fault_count.max(1) as f32;
361        let elapsed = epoch_start.elapsed().as_secs_f32();
362
363        println!(
364            "  {:>5}  {:>12.6}  {:>12.6}  {:>12.6}  {:>6.1}s",
365            epoch, avg_normal, avg_fault, val_loss, elapsed
366        );
367    }
368
369    println!();
370
371    // =========================================================================
372    // Final evaluation
373    // =========================================================================
374    println!("═══════════════════════════════════════════════════════════════");
375    println!(" FINAL EVALUATION");
376    println!("═══════════════════════════════════════════════════════════════");
377
378    // Test on normal data — scores should be near zero
379    let config = FacilityConfig::warren();
380    println!("\n  Normal operation (should be low scores):");
381    for i in [0, 50, 100, 150] {
382        if i >= normal_val.len() {
383            break;
384        }
385        let (equip_scores, fac_score) = model.forward_snapshot(&normal_val[i]);
386        let scores = equip_scores.data().to_vec();
387        let fac = fac_score.data().to_vec()[0];
388        let max_score = scores.iter().cloned().fold(0.0f32, f32::max);
389        let avg_score: f32 = scores.iter().sum::<f32>() / scores.len() as f32;
390        println!(
391            "    Sample {i:>3}: facility={fac:.4}, avg_equip={avg_score:.4}, max_equip={max_score:.4}"
392        );
393    }
394
395    // Test on fault data — affected equipment should have higher scores
396    println!("\n  Fault samples (affected equipment should score higher):");
397    for i in 0..5.min(fault_val.len()) {
398        let (ref snap, ref fault, ref affected) = fault_val[i];
399        let (equip_scores, fac_score) = model.forward_snapshot(snap);
400        let scores = equip_scores.data().to_vec();
401        let fac = fac_score.data().to_vec()[0];
402
403        let output = PanoptesOutput::from_scores(&scores, fac, &config, 0.3);
404
405        // Get scores for affected vs unaffected
406        let affected_avg: f32 = if !affected.is_empty() {
407            affected
408                .iter()
409                .filter(|&&s| s < scores.len())
410                .map(|&s| scores[s])
411                .sum::<f32>()
412                / affected.len() as f32
413        } else {
414            0.0
415        };
416
417        println!("    Fault {:?}:", fault);
418        println!(
419            "      facility={fac:.4}, affected_avg={affected_avg:.4}, alerts={}",
420            output.alerts.len()
421        );
422    }
423
424    // Temporal evaluation
425    println!("\n  Temporal normal (should be low scores):");
426    for i in 0..3.min(val_normal_seqs.len()) {
427        let (equip_scores, fac_score) = model.forward_temporal(&val_normal_seqs[i]);
428        let scores = equip_scores.data().to_vec();
429        let fac = fac_score.data().to_vec()[0];
430        let max_score = scores.iter().cloned().fold(0.0f32, f32::max);
431        let avg_score: f32 = scores.iter().sum::<f32>() / scores.len() as f32;
432        println!(
433            "    Seq {i:>3}: facility={fac:.4}, avg_equip={avg_score:.4}, max_equip={max_score:.4}"
434        );
435    }
436
437    println!("\n  Temporal fault (fault injected mid-sequence):");
438    for i in 0..5.min(val_fault_seqs.len()) {
439        let (ref seq, onset, ref fault, ref affected) = val_fault_seqs[i];
440        let (equip_scores, fac_score) = model.forward_temporal(seq);
441        let scores = equip_scores.data().to_vec();
442        let fac = fac_score.data().to_vec()[0];
443
444        let affected_avg: f32 = if !affected.is_empty() {
445            affected
446                .iter()
447                .filter(|&&s| s < scores.len())
448                .map(|&s| scores[s])
449                .sum::<f32>()
450                / affected.len() as f32
451        } else {
452            0.0
453        };
454        let unaffected_avg: f32 = {
455            let unaffected: Vec<f32> = scores
456                .iter()
457                .enumerate()
458                .filter(|(idx, _)| !affected.contains(idx))
459                .map(|(_, &s)| s)
460                .collect();
461            if unaffected.is_empty() {
462                0.0
463            } else {
464                unaffected.iter().sum::<f32>() / unaffected.len() as f32
465            }
466        };
467
468        let output = PanoptesOutput::from_scores(&scores, fac, &config, 0.3);
469        println!(
470            "    Fault {:?} (onset step {onset}/{TEMPORAL_WINDOW}):",
471            fault
472        );
473        println!(
474            "      facility={fac:.4}, affected={affected_avg:.4}, unaffected={unaffected_avg:.4}, alerts={}",
475            output.alerts.len()
476        );
477    }
478
479    println!();
480    println!("Training complete.");
481}
Source

pub fn is_leaf(&self) -> bool

Returns whether this is a leaf variable.

Source

pub fn grad(&self) -> Option<Tensor<f32>>

Returns the gradient of this variable.

Only available for leaf variables after backward() has been called.

Source

pub fn grad_fn(&self) -> Option<&GradFn>

Returns the gradient function.

Source

pub fn set_grad(&self, grad: Tensor<f32>)

Sets the gradient (used during backward pass).

Source

pub fn accumulate_grad(&self, grad: &Tensor<f32>)

Accumulates gradient (adds to existing gradient).

Source

pub fn zero_grad(&self)

Clears the gradient.

Source

pub fn detach(&self) -> Variable

Detaches this variable from the computation graph.

Returns a new variable with the same data but no gradient history.

Source

pub fn requires_grad_(self, requires_grad: bool) -> Variable

Returns a new variable with requires_grad set.

Source

pub fn backward(&self)

Computes gradients via backpropagation.

This should only be called on scalar (single-element) tensors, typically the loss value.

Examples found in repository?
examples/hvac_training.rs (line 796)
737fn train_epoch(
738    model: &HvacPredictor,
739    optimizer: &mut Adam,
740    loss_fn: &CrossEntropyLoss,
741    x_data: &[f32],
742    y_imminent: &[i64],
743    y_warning: &[i64],
744    y_early: &[i64],
745    batch_size: usize,
746    seq_len: usize,
747    num_features: usize,
748) -> (f32, f32, f32, f32) {
749    let n_sequences = y_imminent.len();
750    let n_batches = n_sequences / batch_size;
751
752    let mut total_loss = 0.0f32;
753    let mut total_acc_imm = 0.0f32;
754    let mut total_acc_warn = 0.0f32;
755    let mut total_acc_early = 0.0f32;
756
757    for batch_idx in 0..n_batches {
758        let start = batch_idx * batch_size;
759
760        // Prepare batch data
761        let mut batch_x = vec![0.0f32; batch_size * seq_len * num_features];
762        let mut batch_y_imm = vec![0i64; batch_size];
763        let mut batch_y_warn = vec![0i64; batch_size];
764        let mut batch_y_early = vec![0i64; batch_size];
765
766        for b in 0..batch_size {
767            let seq_start = (start + b) * seq_len * num_features;
768            for i in 0..(seq_len * num_features) {
769                batch_x[b * seq_len * num_features + i] = x_data[seq_start + i];
770            }
771            batch_y_imm[b] = y_imminent[start + b];
772            batch_y_warn[b] = y_warning[start + b];
773            batch_y_early[b] = y_early[start + b];
774        }
775
776        // Create tensors
777        let x_tensor = Tensor::from_vec(batch_x, &[batch_size, seq_len, num_features])
778            .expect("Failed to create input tensor");
779        let x_var = Variable::new(x_tensor, true);
780
781        // Forward pass
782        let (logits_imm, logits_warn, logits_early) = model.forward_multi(&x_var);
783
784        // Calculate losses (simplified - just use imminent for now)
785        let y_imm_tensor = Tensor::from_vec(
786            batch_y_imm.iter().map(|&y| y as f32).collect(),
787            &[batch_size],
788        )
789        .expect("Failed to create label tensor");
790        let y_imm_var = Variable::new(y_imm_tensor, false);
791
792        let loss = loss_fn.compute(&logits_imm, &y_imm_var);
793
794        // Backward pass
795        optimizer.zero_grad();
796        loss.backward();
797        optimizer.step();
798
799        // Metrics
800        total_loss += loss.data().to_vec()[0];
801        total_acc_imm += calculate_accuracy(&logits_imm, &batch_y_imm);
802        total_acc_warn += calculate_accuracy(&logits_warn, &batch_y_warn);
803        total_acc_early += calculate_accuracy(&logits_early, &batch_y_early);
804    }
805
806    let n = n_batches as f32;
807    (
808        total_loss / n,
809        total_acc_imm / n,
810        total_acc_warn / n,
811        total_acc_early / n,
812    )
813}
More examples
Hide additional examples
examples/simple_training.rs (line 80)
19fn main() {
20    println!("=== Axonml ML Framework - Simple Training Example ===\n");
21
22    // Print version and features
23    println!("Version: {}", axonml::version());
24    println!("Features: {}\n", axonml::features());
25
26    // 1. Create a simple dataset (XOR problem)
27    println!("1. Creating XOR dataset...");
28    let inputs = vec![
29        vec![0.0, 0.0],
30        vec![0.0, 1.0],
31        vec![1.0, 0.0],
32        vec![1.0, 1.0],
33    ];
34    let targets = vec![0.0, 1.0, 1.0, 0.0]; // XOR outputs
35
36    println!("   Inputs: {inputs:?}");
37    println!("   Targets: {targets:?}\n");
38
39    // 2. Create a simple MLP model
40    println!("2. Creating MLP model (2 -> 4 -> 1)...");
41    let linear1 = Linear::new(2, 4);
42    let linear2 = Linear::new(4, 1);
43
44    println!("   Layer 1: Linear(2, 4)");
45    println!("   Layer 2: Linear(4, 1)\n");
46
47    // 3. Create optimizer
48    println!("3. Creating Adam optimizer (lr=0.1)...");
49    let params = [linear1.parameters(), linear2.parameters()].concat();
50    let mut optimizer = Adam::new(params, 0.1);
51    println!("   Optimizer created!\n");
52
53    // 4. Training loop
54    println!("4. Training for 1000 epochs...");
55    let epochs = 1000;
56
57    for epoch in 0..epochs {
58        let mut total_loss = 0.0;
59
60        for (input, &target) in inputs.iter().zip(targets.iter()) {
61            // Create input tensor
62            let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), true);
63
64            // Forward pass
65            let h = linear1.forward(&x);
66            let h = h.sigmoid();
67            let output = linear2.forward(&h);
68            let output = output.sigmoid();
69
70            // Create target tensor
71            let y = Variable::new(Tensor::from_vec(vec![target], &[1, 1]).unwrap(), false);
72
73            // Compute MSE loss manually: (output - target)^2
74            let diff = output.sub_var(&y);
75            let loss = diff.mul_var(&diff);
76
77            total_loss += loss.data().to_vec()[0];
78
79            // Backward pass
80            loss.backward();
81
82            // Update weights
83            optimizer.step();
84            optimizer.zero_grad();
85        }
86
87        if epoch % 200 == 0 || epoch == epochs - 1 {
88            println!("   Epoch {}: Loss = {:.6}", epoch, total_loss / 4.0);
89        }
90    }
91
92    // 5. Test the trained model
93    println!("\n5. Testing trained model...");
94    for (input, &expected) in inputs.iter().zip(targets.iter()) {
95        let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), false);
96
97        let h = linear1.forward(&x);
98        let h = h.sigmoid();
99        let output = linear2.forward(&h);
100        let output = output.sigmoid();
101
102        let pred = output.data().to_vec()[0];
103        let rounded = if pred > 0.5 { 1.0 } else { 0.0 };
104
105        println!(
106            "   Input: {input:?} -> Predicted: {pred:.4} (rounded: {rounded}) | Expected: {expected}"
107        );
108    }
109
110    println!("\n=== Training Complete! ===");
111}
examples/mnist_training.rs (line 158)
20fn main() {
21    println!("=== AxonML - MNIST Training (LeNet) ===\n");
22
23    // Detect device
24    #[cfg(feature = "cuda")]
25    let device = {
26        let cuda = Device::Cuda(0);
27        if cuda.is_available() {
28            println!("GPU detected: using CUDA device 0");
29            cuda
30        } else {
31            println!("CUDA feature enabled but no GPU available, using CPU");
32            Device::Cpu
33        }
34    };
35    #[cfg(not(feature = "cuda"))]
36    let device = {
37        println!("Using CPU (compile with --features cuda for GPU)");
38        Device::Cpu
39    };
40
41    // 1. Create dataset
42    let num_train = 2000;
43    let num_test = 400;
44    println!("\n1. Creating SyntheticMNIST dataset ({num_train} train, {num_test} test)...");
45    let train_dataset = SyntheticMNIST::new(num_train);
46    let test_dataset = SyntheticMNIST::new(num_test);
47
48    // 2. Create DataLoader
49    let batch_size = 64;
50    println!("2. Creating DataLoader (batch_size={batch_size})...");
51    let train_loader = DataLoader::new(train_dataset, batch_size);
52    let test_loader = DataLoader::new(test_dataset, batch_size);
53    println!("   Training batches: {}", train_loader.len());
54
55    // 3. Create LeNet model and move to device
56    println!("3. Creating LeNet model...");
57    let model = LeNet::new();
58    model.to_device(device);
59    let params = model.parameters();
60    let total_params: usize = params
61        .iter()
62        .map(|p| p.variable().data().to_vec().len())
63        .sum();
64    println!(
65        "   Parameters: {} ({} total weights)",
66        params.len(),
67        total_params
68    );
69    println!("   Device: {:?}", device);
70
71    // 4. Create optimizer and loss
72    println!("4. Creating Adam optimizer (lr=0.001) + CrossEntropyLoss...");
73    let mut optimizer = Adam::new(params, 0.001);
74    let criterion = CrossEntropyLoss::new();
75
76    // 5. Training loop
77    let epochs = 10;
78    println!("5. Training for {epochs} epochs...\n");
79
80    let train_start = Instant::now();
81
82    for epoch in 0..epochs {
83        let epoch_start = Instant::now();
84        let mut total_loss = 0.0;
85        let mut correct = 0usize;
86        let mut total = 0usize;
87        let mut batch_count = 0;
88
89        for batch in train_loader.iter() {
90            let bs = batch.data.shape()[0];
91
92            // Reshape to [N, 1, 28, 28] and create Variable
93            let input_data = batch.data.to_vec();
94            let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
95            let input = Variable::new(
96                if device.is_gpu() {
97                    input_tensor.to_device(device).unwrap()
98                } else {
99                    input_tensor
100                },
101                true,
102            );
103
104            // Target: convert one-hot [N, 10] to class indices [N]
105            let target_onehot = batch.targets.to_vec();
106            let mut target_indices = vec![0.0f32; bs];
107            for i in 0..bs {
108                let offset = i * 10;
109                let mut max_idx = 0;
110                let mut max_val = f32::NEG_INFINITY;
111                for c in 0..10 {
112                    if target_onehot[offset + c] > max_val {
113                        max_val = target_onehot[offset + c];
114                        max_idx = c;
115                    }
116                }
117                target_indices[i] = max_idx as f32;
118            }
119            let target_tensor = Tensor::from_vec(target_indices.clone(), &[bs]).unwrap();
120            let target = Variable::new(
121                if device.is_gpu() {
122                    target_tensor.to_device(device).unwrap()
123                } else {
124                    target_tensor
125                },
126                false,
127            );
128
129            // Forward pass
130            let output = model.forward(&input);
131
132            // Cross-entropy loss
133            let loss = criterion.compute(&output, &target);
134
135            let loss_val = loss.data().to_vec()[0];
136            total_loss += loss_val;
137            batch_count += 1;
138
139            // Compute training accuracy
140            let out_data = output.data().to_vec();
141            for i in 0..bs {
142                let offset = i * 10;
143                let mut pred = 0;
144                let mut pred_val = f32::NEG_INFINITY;
145                for c in 0..10 {
146                    if out_data[offset + c] > pred_val {
147                        pred_val = out_data[offset + c];
148                        pred = c;
149                    }
150                }
151                if pred == target_indices[i] as usize {
152                    correct += 1;
153                }
154                total += 1;
155            }
156
157            // Backward pass
158            loss.backward();
159
160            // Update weights
161            optimizer.step();
162            optimizer.zero_grad();
163        }
164
165        let epoch_time = epoch_start.elapsed();
166        let avg_loss = total_loss / batch_count as f32;
167        let accuracy = 100.0 * correct as f32 / total as f32;
168        let samples_per_sec = total as f64 / epoch_time.as_secs_f64();
169
170        println!(
171            "   Epoch {:2}/{}: Loss={:.4}  Acc={:.1}%  ({:.0} samples/s, {:.2}s)",
172            epoch + 1,
173            epochs,
174            avg_loss,
175            accuracy,
176            samples_per_sec,
177            epoch_time.as_secs_f64(),
178        );
179    }
180
181    let train_time = train_start.elapsed();
182    println!("\n   Total training time: {:.2}s", train_time.as_secs_f64());
183
184    // 6. Test evaluation
185    println!("\n6. Evaluating on test set...");
186
187    // Disable gradient computation for evaluation
188    let (correct, total) = no_grad(|| {
189        let mut correct = 0usize;
190        let mut total = 0usize;
191
192        for batch in test_loader.iter() {
193            let bs = batch.data.shape()[0];
194
195            let input_data = batch.data.to_vec();
196            let input_tensor = Tensor::from_vec(input_data, &[bs, 1, 28, 28]).unwrap();
197            let input = Variable::new(
198                if device.is_gpu() {
199                    input_tensor.to_device(device).unwrap()
200                } else {
201                    input_tensor
202                },
203                false,
204            );
205
206            let target_onehot = batch.targets.to_vec();
207            let output = model.forward(&input);
208            let out_data = output.data().to_vec();
209
210            for i in 0..bs {
211                // Prediction: argmax of output
212                let offset = i * 10;
213                let mut pred = 0;
214                let mut pred_val = f32::NEG_INFINITY;
215                for c in 0..10 {
216                    if out_data[offset + c] > pred_val {
217                        pred_val = out_data[offset + c];
218                        pred = c;
219                    }
220                }
221
222                // True label: argmax of one-hot target
223                let mut true_label = 0;
224                let mut true_val = f32::NEG_INFINITY;
225                for c in 0..10 {
226                    if target_onehot[i * 10 + c] > true_val {
227                        true_val = target_onehot[i * 10 + c];
228                        true_label = c;
229                    }
230                }
231
232                if pred == true_label {
233                    correct += 1;
234                }
235                total += 1;
236            }
237        }
238
239        (correct, total)
240    });
241
242    let test_accuracy = 100.0 * correct as f32 / total as f32;
243    println!(
244        "   Test Accuracy: {}/{} ({:.2}%)",
245        correct, total, test_accuracy
246    );
247
248    println!("\n=== Training Complete! ===");
249    println!("   Device: {:?}", device);
250    println!("   Final test accuracy: {:.2}%", test_accuracy);
251}
examples/train_panoptes.rs (line 137)
56fn main() {
57    println!("╔══════════════════════════════════════════════════════════════╗");
58    println!("║     PANOPTES — Facility-Wide Anomaly Detection Training     ║");
59    println!("║     Heritage Pointe of Warren (59 equipment)                ║");
60    println!("╚══════════════════════════════════════════════════════════════╝");
61    println!();
62
63    // =========================================================================
64    // Generate training data
65    // =========================================================================
66    println!("[data] Generating physics-informed training data...");
67    let t0 = Instant::now();
68
69    let sim = WarrenSimulator::new(SEED);
70    let normal_train = sim.generate_normal(NORMAL_SAMPLES);
71    let fault_data = sim.generate_with_faults(FAULT_SAMPLES, 1.0);
72
73    // Validation set (different seed)
74    let val_sim = WarrenSimulator::new(SEED + 999);
75    let normal_val = val_sim.generate_normal(200);
76    let fault_val = val_sim.generate_with_faults(100, 1.0);
77
78    println!("  Normal train: {} samples", normal_train.len());
79    println!("  Fault train:  {} samples", fault_data.len());
80    println!("  Normal val:   {} samples", normal_val.len());
81    println!("  Fault val:    {} samples", fault_val.len());
82    println!("  Generated in {:.1}s", t0.elapsed().as_secs_f32());
83    println!();
84
85    // =========================================================================
86    // Create model
87    // =========================================================================
88    let model = Panoptes::new(NUM_EQUIPMENT);
89    println!("[model] Panoptes created");
90    println!("  Equipment slots: {NUM_EQUIPMENT}");
91    println!("  Parameters: {}", model.num_parameters());
92    println!("  Embed dim: {EMBED_DIM}");
93    println!();
94
95    let mse = MSELoss::new();
96
97    // Zero target for normal operation
98    let zero_target = Variable::new(
99        Tensor::from_vec(vec![0.0; NUM_EQUIPMENT], &[1, NUM_EQUIPMENT]).unwrap(),
100        false,
101    );
102
103    // =========================================================================
104    // Phase 1: Learn normal operation
105    // =========================================================================
106    println!("═══════════════════════════════════════════════════════════════");
107    println!(" PHASE 1: Learning Normal Operation ({PHASE1_EPOCHS} epochs)");
108    println!("═══════════════════════════════════════════════════════════════");
109    println!(
110        "  {:>5}  {:>12}  {:>12}  {:>8}",
111        "Epoch", "Train Loss", "Val Loss", "Time"
112    );
113    println!("  {:-<5}  {:-<12}  {:-<12}  {:-<8}", "", "", "", "");
114
115    let params = model.parameters();
116    let mut optimizer = Adam::new(params, LR);
117
118    for epoch in 1..=PHASE1_EPOCHS {
119        let epoch_start = Instant::now();
120        let mut epoch_loss = 0.0f32;
121        let mut batch_count = 0;
122
123        // Train on normal data: target = all zeros
124        for batch_start in (0..normal_train.len()).step_by(BATCH_SIZE) {
125            let batch_end = (batch_start + BATCH_SIZE).min(normal_train.len());
126
127            for i in batch_start..batch_end {
128                optimizer.zero_grad();
129
130                let (equip_scores, _) = model.forward_snapshot(&normal_train[i]);
131                let loss = mse.compute(&equip_scores, &zero_target);
132                let loss_val = loss.data().to_vec()[0];
133                epoch_loss += loss_val;
134                batch_count += 1;
135
136                if loss.requires_grad() {
137                    loss.backward();
138                    optimizer.step();
139                }
140            }
141        }
142
143        // Validation
144        let val_loss = evaluate_normal(&model, &normal_val, &mse, &zero_target);
145
146        let avg_loss = epoch_loss / batch_count as f32;
147        let elapsed = epoch_start.elapsed().as_secs_f32();
148
149        println!(
150            "  {:>5}  {:>12.6}  {:>12.6}  {:>6.1}s",
151            epoch, avg_loss, val_loss, elapsed
152        );
153    }
154
155    println!();
156
157    // =========================================================================
158    // Phase 2: Learn fault signatures
159    // =========================================================================
160    println!("═══════════════════════════════════════════════════════════════");
161    println!(" PHASE 2: Learning Fault Signatures ({PHASE2_EPOCHS} epochs)");
162    println!("═══════════════════════════════════════════════════════════════");
163    println!(
164        "  {:>5}  {:>12}  {:>12}  {:>12}  {:>8}",
165        "Epoch", "Normal Loss", "Fault Loss", "Val Loss", "Time"
166    );
167    println!(
168        "  {:-<5}  {:-<12}  {:-<12}  {:-<12}  {:-<8}",
169        "", "", "", "", ""
170    );
171
172    // Reset optimizer with lower LR for phase 2
173    let params = model.parameters();
174    let mut optimizer = Adam::new(params, LR * 0.5);
175
176    for epoch in 1..=PHASE2_EPOCHS {
177        let epoch_start = Instant::now();
178        let mut normal_loss_sum = 0.0f32;
179        let mut fault_loss_sum = 0.0f32;
180        let mut normal_count = 0;
181        let mut fault_count = 0;
182
183        // Interleave normal + fault samples
184        let normal_per_epoch = NORMAL_SAMPLES / 2; // Use half of normal data
185        let fault_per_epoch = fault_data.len();
186
187        // Normal samples: target = zeros
188        for i in 0..normal_per_epoch.min(normal_train.len()) {
189            optimizer.zero_grad();
190            let (equip_scores, _) = model.forward_snapshot(&normal_train[i]);
191            let loss = mse.compute(&equip_scores, &zero_target);
192            normal_loss_sum += loss.data().to_vec()[0];
193            normal_count += 1;
194
195            if loss.requires_grad() {
196                loss.backward();
197                optimizer.step();
198            }
199        }
200
201        // Fault samples: target = 1.0 for affected equipment
202        for i in 0..fault_per_epoch {
203            let (ref snap, ref _fault, ref affected) = fault_data[i];
204
205            let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
206            let fault_target = Variable::new(
207                Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
208                false,
209            );
210
211            optimizer.zero_grad();
212            let (equip_scores, _) = model.forward_snapshot(snap);
213            let loss = mse.compute(&equip_scores, &fault_target);
214            fault_loss_sum += loss.data().to_vec()[0];
215            fault_count += 1;
216
217            if loss.requires_grad() {
218                loss.backward();
219                optimizer.step();
220            }
221        }
222
223        // Validation
224        let val_loss = evaluate_mixed(&model, &normal_val, &fault_val, &mse, &zero_target);
225
226        let avg_normal = normal_loss_sum / normal_count.max(1) as f32;
227        let avg_fault = fault_loss_sum / fault_count.max(1) as f32;
228        let elapsed = epoch_start.elapsed().as_secs_f32();
229
230        println!(
231            "  {:>5}  {:>12.6}  {:>12.6}  {:>12.6}  {:>6.1}s",
232            epoch, avg_normal, avg_fault, val_loss, elapsed
233        );
234    }
235
236    println!();
237
238    // =========================================================================
239    // Phase 3: Temporal training
240    // =========================================================================
241    println!("═══════════════════════════════════════════════════════════════");
242    println!(" PHASE 3: Temporal Training ({PHASE3_EPOCHS} epochs, window={TEMPORAL_WINDOW})");
243    println!("═══════════════════════════════════════════════════════════════");
244
245    // Generate temporal sequences
246    println!("[data] Generating temporal sequences...");
247    let t0 = Instant::now();
248
249    // Normal temporal sequences: varied starting OAT, slow drift
250    let mut normal_seqs: Vec<Vec<FacilitySnapshot>> = Vec::new();
251    for i in 0..TEMPORAL_NORMAL_SEQS {
252        let start_oat = -5.0 + (i as f32 / TEMPORAL_NORMAL_SEQS as f32) * 100.0;
253        let drift = if start_oat < 50.0 { 0.2 } else { -0.1 }; // warming up or cooling down
254        let seq_sim = WarrenSimulator::new(SEED + 5000 + i as u64);
255        let seq = seq_sim.generate_temporal_sequence(TEMPORAL_WINDOW, start_oat, drift);
256        normal_seqs.push(seq);
257    }
258
259    // Fault temporal sequences: fault injected mid-sequence
260    let mut fault_seqs: Vec<(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)> = Vec::new();
261    for i in 0..TEMPORAL_FAULT_SEQS {
262        let start_oat = -5.0 + (i as f32 / TEMPORAL_FAULT_SEQS as f32) * 100.0;
263        let drift = 0.1;
264        let seq_sim = WarrenSimulator::new(SEED + 8000 + i as u64);
265        let seq_data =
266            seq_sim.generate_temporal_with_fault(TEMPORAL_WINDOW, start_oat, drift, i as u64);
267        fault_seqs.push(seq_data);
268    }
269
270    // Validation temporal sequences
271    let mut val_normal_seqs: Vec<Vec<FacilitySnapshot>> = Vec::new();
272    for i in 0..20 {
273        let start_oat = 10.0 + (i as f32 / 20.0) * 80.0;
274        let seq_sim = WarrenSimulator::new(SEED + 9000 + i as u64);
275        let seq = seq_sim.generate_temporal_sequence(TEMPORAL_WINDOW, start_oat, 0.15);
276        val_normal_seqs.push(seq);
277    }
278
279    let mut val_fault_seqs: Vec<(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)> = Vec::new();
280    for i in 0..20 {
281        let start_oat = 10.0 + (i as f32 / 20.0) * 80.0;
282        let seq_sim = WarrenSimulator::new(SEED + 9500 + i as u64);
283        let seq_data =
284            seq_sim.generate_temporal_with_fault(TEMPORAL_WINDOW, start_oat, 0.1, i as u64);
285        val_fault_seqs.push(seq_data);
286    }
287
288    println!("  Normal temporal seqs: {}", normal_seqs.len());
289    println!("  Fault temporal seqs:  {}", fault_seqs.len());
290    println!("  Val normal seqs:      {}", val_normal_seqs.len());
291    println!("  Val fault seqs:       {}", val_fault_seqs.len());
292    println!("  Window size: {TEMPORAL_WINDOW} snapshots (1 hour)");
293    println!("  Generated in {:.1}s", t0.elapsed().as_secs_f32());
294    println!();
295
296    println!(
297        "  {:>5}  {:>12}  {:>12}  {:>12}  {:>8}",
298        "Epoch", "Normal Loss", "Fault Loss", "Val Loss", "Time"
299    );
300    println!(
301        "  {:-<5}  {:-<12}  {:-<12}  {:-<12}  {:-<8}",
302        "", "", "", "", ""
303    );
304
305    // Lower LR for temporal fine-tuning
306    let params = model.parameters();
307    let mut optimizer = Adam::new(params, LR * 0.3);
308
309    for epoch in 1..=PHASE3_EPOCHS {
310        let epoch_start = Instant::now();
311        let mut normal_loss_sum = 0.0f32;
312        let mut fault_loss_sum = 0.0f32;
313        let mut normal_count = 0;
314        let mut fault_count = 0;
315
316        // Normal temporal sequences: target = all zeros
317        for seq in &normal_seqs {
318            optimizer.zero_grad();
319            let (equip_scores, _) = model.forward_temporal(seq);
320            let loss = mse.compute(&equip_scores, &zero_target);
321            normal_loss_sum += loss.data().to_vec()[0];
322            normal_count += 1;
323
324            if loss.requires_grad() {
325                loss.backward();
326                optimizer.step();
327            }
328        }
329
330        // Fault temporal sequences: target = 1.0 for affected equipment
331        for (seq, _onset, _fault, affected) in &fault_seqs {
332            let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
333            let fault_target = Variable::new(
334                Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
335                false,
336            );
337
338            optimizer.zero_grad();
339            let (equip_scores, _) = model.forward_temporal(seq);
340            let loss = mse.compute(&equip_scores, &fault_target);
341            fault_loss_sum += loss.data().to_vec()[0];
342            fault_count += 1;
343
344            if loss.requires_grad() {
345                loss.backward();
346                optimizer.step();
347            }
348        }
349
350        // Validation
351        let val_loss = evaluate_temporal_mixed(
352            &model,
353            &val_normal_seqs,
354            &val_fault_seqs,
355            &mse,
356            &zero_target,
357        );
358
359        let avg_normal = normal_loss_sum / normal_count.max(1) as f32;
360        let avg_fault = fault_loss_sum / fault_count.max(1) as f32;
361        let elapsed = epoch_start.elapsed().as_secs_f32();
362
363        println!(
364            "  {:>5}  {:>12.6}  {:>12.6}  {:>12.6}  {:>6.1}s",
365            epoch, avg_normal, avg_fault, val_loss, elapsed
366        );
367    }
368
369    println!();
370
371    // =========================================================================
372    // Final evaluation
373    // =========================================================================
374    println!("═══════════════════════════════════════════════════════════════");
375    println!(" FINAL EVALUATION");
376    println!("═══════════════════════════════════════════════════════════════");
377
378    // Test on normal data — scores should be near zero
379    let config = FacilityConfig::warren();
380    println!("\n  Normal operation (should be low scores):");
381    for i in [0, 50, 100, 150] {
382        if i >= normal_val.len() {
383            break;
384        }
385        let (equip_scores, fac_score) = model.forward_snapshot(&normal_val[i]);
386        let scores = equip_scores.data().to_vec();
387        let fac = fac_score.data().to_vec()[0];
388        let max_score = scores.iter().cloned().fold(0.0f32, f32::max);
389        let avg_score: f32 = scores.iter().sum::<f32>() / scores.len() as f32;
390        println!(
391            "    Sample {i:>3}: facility={fac:.4}, avg_equip={avg_score:.4}, max_equip={max_score:.4}"
392        );
393    }
394
395    // Test on fault data — affected equipment should have higher scores
396    println!("\n  Fault samples (affected equipment should score higher):");
397    for i in 0..5.min(fault_val.len()) {
398        let (ref snap, ref fault, ref affected) = fault_val[i];
399        let (equip_scores, fac_score) = model.forward_snapshot(snap);
400        let scores = equip_scores.data().to_vec();
401        let fac = fac_score.data().to_vec()[0];
402
403        let output = PanoptesOutput::from_scores(&scores, fac, &config, 0.3);
404
405        // Get scores for affected vs unaffected
406        let affected_avg: f32 = if !affected.is_empty() {
407            affected
408                .iter()
409                .filter(|&&s| s < scores.len())
410                .map(|&s| scores[s])
411                .sum::<f32>()
412                / affected.len() as f32
413        } else {
414            0.0
415        };
416
417        println!("    Fault {:?}:", fault);
418        println!(
419            "      facility={fac:.4}, affected_avg={affected_avg:.4}, alerts={}",
420            output.alerts.len()
421        );
422    }
423
424    // Temporal evaluation
425    println!("\n  Temporal normal (should be low scores):");
426    for i in 0..3.min(val_normal_seqs.len()) {
427        let (equip_scores, fac_score) = model.forward_temporal(&val_normal_seqs[i]);
428        let scores = equip_scores.data().to_vec();
429        let fac = fac_score.data().to_vec()[0];
430        let max_score = scores.iter().cloned().fold(0.0f32, f32::max);
431        let avg_score: f32 = scores.iter().sum::<f32>() / scores.len() as f32;
432        println!(
433            "    Seq {i:>3}: facility={fac:.4}, avg_equip={avg_score:.4}, max_equip={max_score:.4}"
434        );
435    }
436
437    println!("\n  Temporal fault (fault injected mid-sequence):");
438    for i in 0..5.min(val_fault_seqs.len()) {
439        let (ref seq, onset, ref fault, ref affected) = val_fault_seqs[i];
440        let (equip_scores, fac_score) = model.forward_temporal(seq);
441        let scores = equip_scores.data().to_vec();
442        let fac = fac_score.data().to_vec()[0];
443
444        let affected_avg: f32 = if !affected.is_empty() {
445            affected
446                .iter()
447                .filter(|&&s| s < scores.len())
448                .map(|&s| scores[s])
449                .sum::<f32>()
450                / affected.len() as f32
451        } else {
452            0.0
453        };
454        let unaffected_avg: f32 = {
455            let unaffected: Vec<f32> = scores
456                .iter()
457                .enumerate()
458                .filter(|(idx, _)| !affected.contains(idx))
459                .map(|(_, &s)| s)
460                .collect();
461            if unaffected.is_empty() {
462                0.0
463            } else {
464                unaffected.iter().sum::<f32>() / unaffected.len() as f32
465            }
466        };
467
468        let output = PanoptesOutput::from_scores(&scores, fac, &config, 0.3);
469        println!(
470            "    Fault {:?} (onset step {onset}/{TEMPORAL_WINDOW}):",
471            fault
472        );
473        println!(
474            "      facility={fac:.4}, affected={affected_avg:.4}, unaffected={unaffected_avg:.4}, alerts={}",
475            output.alerts.len()
476        );
477    }
478
479    println!();
480    println!("Training complete.");
481}
Source

pub fn backward_with_grad(&self, grad_output: &Tensor<f32>)

Runs the backward pass with a provided gradient tensor.

Unlike backward(), this does not require the variable to be scalar. The gradient tensor must match the shape of this variable.

Source

pub fn add_var(&self, other: &Variable) -> Variable

Element-wise addition.

Source

pub fn sub_var(&self, other: &Variable) -> Variable

Element-wise subtraction.

Examples found in repository?
examples/simple_training.rs (line 74)
19fn main() {
20    println!("=== Axonml ML Framework - Simple Training Example ===\n");
21
22    // Print version and features
23    println!("Version: {}", axonml::version());
24    println!("Features: {}\n", axonml::features());
25
26    // 1. Create a simple dataset (XOR problem)
27    println!("1. Creating XOR dataset...");
28    let inputs = vec![
29        vec![0.0, 0.0],
30        vec![0.0, 1.0],
31        vec![1.0, 0.0],
32        vec![1.0, 1.0],
33    ];
34    let targets = vec![0.0, 1.0, 1.0, 0.0]; // XOR outputs
35
36    println!("   Inputs: {inputs:?}");
37    println!("   Targets: {targets:?}\n");
38
39    // 2. Create a simple MLP model
40    println!("2. Creating MLP model (2 -> 4 -> 1)...");
41    let linear1 = Linear::new(2, 4);
42    let linear2 = Linear::new(4, 1);
43
44    println!("   Layer 1: Linear(2, 4)");
45    println!("   Layer 2: Linear(4, 1)\n");
46
47    // 3. Create optimizer
48    println!("3. Creating Adam optimizer (lr=0.1)...");
49    let params = [linear1.parameters(), linear2.parameters()].concat();
50    let mut optimizer = Adam::new(params, 0.1);
51    println!("   Optimizer created!\n");
52
53    // 4. Training loop
54    println!("4. Training for 1000 epochs...");
55    let epochs = 1000;
56
57    for epoch in 0..epochs {
58        let mut total_loss = 0.0;
59
60        for (input, &target) in inputs.iter().zip(targets.iter()) {
61            // Create input tensor
62            let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), true);
63
64            // Forward pass
65            let h = linear1.forward(&x);
66            let h = h.sigmoid();
67            let output = linear2.forward(&h);
68            let output = output.sigmoid();
69
70            // Create target tensor
71            let y = Variable::new(Tensor::from_vec(vec![target], &[1, 1]).unwrap(), false);
72
73            // Compute MSE loss manually: (output - target)^2
74            let diff = output.sub_var(&y);
75            let loss = diff.mul_var(&diff);
76
77            total_loss += loss.data().to_vec()[0];
78
79            // Backward pass
80            loss.backward();
81
82            // Update weights
83            optimizer.step();
84            optimizer.zero_grad();
85        }
86
87        if epoch % 200 == 0 || epoch == epochs - 1 {
88            println!("   Epoch {}: Loss = {:.6}", epoch, total_loss / 4.0);
89        }
90    }
91
92    // 5. Test the trained model
93    println!("\n5. Testing trained model...");
94    for (input, &expected) in inputs.iter().zip(targets.iter()) {
95        let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), false);
96
97        let h = linear1.forward(&x);
98        let h = h.sigmoid();
99        let output = linear2.forward(&h);
100        let output = output.sigmoid();
101
102        let pred = output.data().to_vec()[0];
103        let rounded = if pred > 0.5 { 1.0 } else { 0.0 };
104
105        println!(
106            "   Input: {input:?} -> Predicted: {pred:.4} (rounded: {rounded}) | Expected: {expected}"
107        );
108    }
109
110    println!("\n=== Training Complete! ===");
111}
Source

pub fn mul_var(&self, other: &Variable) -> Variable

Element-wise multiplication.

Examples found in repository?
examples/simple_training.rs (line 75)
19fn main() {
20    println!("=== Axonml ML Framework - Simple Training Example ===\n");
21
22    // Print version and features
23    println!("Version: {}", axonml::version());
24    println!("Features: {}\n", axonml::features());
25
26    // 1. Create a simple dataset (XOR problem)
27    println!("1. Creating XOR dataset...");
28    let inputs = vec![
29        vec![0.0, 0.0],
30        vec![0.0, 1.0],
31        vec![1.0, 0.0],
32        vec![1.0, 1.0],
33    ];
34    let targets = vec![0.0, 1.0, 1.0, 0.0]; // XOR outputs
35
36    println!("   Inputs: {inputs:?}");
37    println!("   Targets: {targets:?}\n");
38
39    // 2. Create a simple MLP model
40    println!("2. Creating MLP model (2 -> 4 -> 1)...");
41    let linear1 = Linear::new(2, 4);
42    let linear2 = Linear::new(4, 1);
43
44    println!("   Layer 1: Linear(2, 4)");
45    println!("   Layer 2: Linear(4, 1)\n");
46
47    // 3. Create optimizer
48    println!("3. Creating Adam optimizer (lr=0.1)...");
49    let params = [linear1.parameters(), linear2.parameters()].concat();
50    let mut optimizer = Adam::new(params, 0.1);
51    println!("   Optimizer created!\n");
52
53    // 4. Training loop
54    println!("4. Training for 1000 epochs...");
55    let epochs = 1000;
56
57    for epoch in 0..epochs {
58        let mut total_loss = 0.0;
59
60        for (input, &target) in inputs.iter().zip(targets.iter()) {
61            // Create input tensor
62            let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), true);
63
64            // Forward pass
65            let h = linear1.forward(&x);
66            let h = h.sigmoid();
67            let output = linear2.forward(&h);
68            let output = output.sigmoid();
69
70            // Create target tensor
71            let y = Variable::new(Tensor::from_vec(vec![target], &[1, 1]).unwrap(), false);
72
73            // Compute MSE loss manually: (output - target)^2
74            let diff = output.sub_var(&y);
75            let loss = diff.mul_var(&diff);
76
77            total_loss += loss.data().to_vec()[0];
78
79            // Backward pass
80            loss.backward();
81
82            // Update weights
83            optimizer.step();
84            optimizer.zero_grad();
85        }
86
87        if epoch % 200 == 0 || epoch == epochs - 1 {
88            println!("   Epoch {}: Loss = {:.6}", epoch, total_loss / 4.0);
89        }
90    }
91
92    // 5. Test the trained model
93    println!("\n5. Testing trained model...");
94    for (input, &expected) in inputs.iter().zip(targets.iter()) {
95        let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), false);
96
97        let h = linear1.forward(&x);
98        let h = h.sigmoid();
99        let output = linear2.forward(&h);
100        let output = output.sigmoid();
101
102        let pred = output.data().to_vec()[0];
103        let rounded = if pred > 0.5 { 1.0 } else { 0.0 };
104
105        println!(
106            "   Input: {input:?} -> Predicted: {pred:.4} (rounded: {rounded}) | Expected: {expected}"
107        );
108    }
109
110    println!("\n=== Training Complete! ===");
111}
Source

pub fn div_var(&self, other: &Variable) -> Variable

Element-wise division.

Source

pub fn neg_var(&self) -> Variable

Negation.

Source

pub fn matmul(&self, other: &Variable) -> Variable

Matrix multiplication.

Source

pub fn pow(&self, exponent: f32) -> Variable

Power operation.

Source

pub fn relu(&self) -> Variable

ReLU activation.

Source

pub fn leaky_relu(&self, negative_slope: f32) -> Variable

Leaky ReLU activation.

Source

pub fn elu(&self, alpha: f32) -> Variable

ELU activation.

Source

pub fn sigmoid(&self) -> Variable

Sigmoid activation.

Examples found in repository?
examples/simple_training.rs (line 66)
19fn main() {
20    println!("=== Axonml ML Framework - Simple Training Example ===\n");
21
22    // Print version and features
23    println!("Version: {}", axonml::version());
24    println!("Features: {}\n", axonml::features());
25
26    // 1. Create a simple dataset (XOR problem)
27    println!("1. Creating XOR dataset...");
28    let inputs = vec![
29        vec![0.0, 0.0],
30        vec![0.0, 1.0],
31        vec![1.0, 0.0],
32        vec![1.0, 1.0],
33    ];
34    let targets = vec![0.0, 1.0, 1.0, 0.0]; // XOR outputs
35
36    println!("   Inputs: {inputs:?}");
37    println!("   Targets: {targets:?}\n");
38
39    // 2. Create a simple MLP model
40    println!("2. Creating MLP model (2 -> 4 -> 1)...");
41    let linear1 = Linear::new(2, 4);
42    let linear2 = Linear::new(4, 1);
43
44    println!("   Layer 1: Linear(2, 4)");
45    println!("   Layer 2: Linear(4, 1)\n");
46
47    // 3. Create optimizer
48    println!("3. Creating Adam optimizer (lr=0.1)...");
49    let params = [linear1.parameters(), linear2.parameters()].concat();
50    let mut optimizer = Adam::new(params, 0.1);
51    println!("   Optimizer created!\n");
52
53    // 4. Training loop
54    println!("4. Training for 1000 epochs...");
55    let epochs = 1000;
56
57    for epoch in 0..epochs {
58        let mut total_loss = 0.0;
59
60        for (input, &target) in inputs.iter().zip(targets.iter()) {
61            // Create input tensor
62            let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), true);
63
64            // Forward pass
65            let h = linear1.forward(&x);
66            let h = h.sigmoid();
67            let output = linear2.forward(&h);
68            let output = output.sigmoid();
69
70            // Create target tensor
71            let y = Variable::new(Tensor::from_vec(vec![target], &[1, 1]).unwrap(), false);
72
73            // Compute MSE loss manually: (output - target)^2
74            let diff = output.sub_var(&y);
75            let loss = diff.mul_var(&diff);
76
77            total_loss += loss.data().to_vec()[0];
78
79            // Backward pass
80            loss.backward();
81
82            // Update weights
83            optimizer.step();
84            optimizer.zero_grad();
85        }
86
87        if epoch % 200 == 0 || epoch == epochs - 1 {
88            println!("   Epoch {}: Loss = {:.6}", epoch, total_loss / 4.0);
89        }
90    }
91
92    // 5. Test the trained model
93    println!("\n5. Testing trained model...");
94    for (input, &expected) in inputs.iter().zip(targets.iter()) {
95        let x = Variable::new(Tensor::from_vec(input.clone(), &[1, 2]).unwrap(), false);
96
97        let h = linear1.forward(&x);
98        let h = h.sigmoid();
99        let output = linear2.forward(&h);
100        let output = output.sigmoid();
101
102        let pred = output.data().to_vec()[0];
103        let rounded = if pred > 0.5 { 1.0 } else { 0.0 };
104
105        println!(
106            "   Input: {input:?} -> Predicted: {pred:.4} (rounded: {rounded}) | Expected: {expected}"
107        );
108    }
109
110    println!("\n=== Training Complete! ===");
111}
Source

pub fn tanh(&self) -> Variable

Tanh activation.

Source

pub fn exp(&self) -> Variable

Element-wise exponential.

Source

pub fn log(&self) -> Variable

Element-wise natural logarithm.

Source

pub fn clamp(&self, min_val: f32, max_val: f32) -> Variable

Element-wise clamp to [min_val, max_val].

Source

pub fn sum(&self) -> Variable

Sum all elements.

Source

pub fn sum_dim(&self, dim: usize) -> Variable

Sum along a dimension, removing that dimension.

Source

pub fn mean(&self) -> Variable

Mean of all elements.

Source

pub fn mse_loss(&self, target: &Variable) -> Variable

Mean Squared Error loss.

Source

pub fn binary_cross_entropy(&self, target: &Variable) -> Variable

Binary Cross Entropy loss (expects sigmoid output).

Source

pub fn reshape(&self, shape: &[usize]) -> Variable

Reshapes the variable to a new shape.

Examples found in repository?
examples/hvac_training.rs (line 623)
616    pub fn forward_multi(&self, x: &Variable) -> (Variable, Variable, Variable) {
617        let x_data = x.data();
618        let shape = x_data.shape();
619        let batch_size = shape[0];
620        let seq_len = shape[1];
621        drop(x_data);
622
623        let x_flat = x.reshape(&[batch_size * seq_len, self.config.num_features]);
624        let proj = self.input_proj.forward(&x_flat);
625        let proj = self.input_norm.forward(&proj);
626        let proj = self.input_relu.forward(&proj);
627        let proj = proj.reshape(&[batch_size, seq_len, self.config.hidden_size]);
628
629        // Use forward_mean for proper gradient flow (equivalent to forward + mean_pool)
630        let pooled = self.gru.forward_mean(&proj);
631
632        let imminent = self.head_imminent.forward(&pooled);
633        let warning = self.head_warning.forward(&pooled);
634        let early = self.head_early.forward(&pooled);
635
636        (imminent, warning, early)
637    }
More examples
Hide additional examples
examples/hvac_model.rs (line 203)
194    pub fn forward_multi(&self, x: &Variable) -> HvacOutput {
195        let x_data = x.data();
196        let shape = x_data.shape();
197        let batch_size = shape[0];
198        let seq_len = shape[1];
199        drop(x_data); // Release borrow
200
201        // Input projection: [batch, seq, features] -> [batch, seq, hidden]
202        // Reshape for linear: [batch * seq, features]
203        let x_flat = x.reshape(&[batch_size * seq_len, self.config.num_features]);
204        let proj = self.input_proj.forward(&x_flat);
205        let proj = self.input_norm.forward(&proj);
206        let proj = self.input_relu.forward(&proj);
207        let proj = proj.reshape(&[batch_size, seq_len, self.config.hidden_size]);
208
209        // GRU encoding: [batch, seq, hidden] -> [batch, seq, hidden]
210        let encoded = self.gru.forward(&proj);
211
212        // Mean pooling: [batch, seq, hidden] -> [batch, hidden]
213        let pooled = self.mean_pool(&encoded);
214
215        // Prediction heads
216        let imminent_logits = self.head_imminent.forward(&pooled);
217        let warning_logits = self.head_warning.forward(&pooled);
218        let early_logits = self.head_early.forward(&pooled);
219
220        HvacOutput {
221            imminent_logits,
222            warning_logits,
223            early_logits,
224        }
225    }
Source

pub fn flatten(&self, start_dim: usize) -> Variable

Flattens all dimensions from start_dim to the end into a single dimension.

flatten(1) on a [batch, C, H, W] tensor produces [batch, C*H*W]. flatten(0) flattens everything into a 1D vector.

Source

pub fn transpose(&self, dim0: usize, dim1: usize) -> Variable

Transposes two dimensions.

Source

pub fn slice(&self, ranges: &[Range<usize>]) -> Variable

Slices the variable along specified ranges.

Source

pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Variable

Narrows the variable along a dimension.

Returns a view of the tensor containing elements from start to start + length along the specified dimension. This operation preserves gradients for backpropagation.

Source

pub fn expand(&self, shape: &[usize]) -> Variable

Expands the variable to a new shape (broadcast).

Tracks the computational graph for backward pass.

Source

pub fn select(&self, dim: usize, index: usize) -> Variable

Selects a single index along a dimension, reducing rank by 1.

For a tensor of shape (A, B, C), select(1, i) returns shape (A, C). Tracks the computational graph for backward pass.

Source

pub fn unsqueeze(&self, dim: usize) -> Variable

Adds a dimension of size 1 at the given position.

Tracks the computational graph for backward pass.

Source

pub fn cat(variables: &[&Variable], dim: usize) -> Variable

Concatenates variables along a dimension.

All variables must have the same shape except along the cat dimension. Tracks the computational graph for backpropagation.

Source

pub fn mul_scalar(&self, scalar: f32) -> Variable

Multiplies by a scalar.

Source

pub fn add_scalar(&self, scalar: f32) -> Variable

Adds a scalar.

Source

pub fn sub_scalar(&self, scalar: f32) -> Variable

Subtracts a scalar.

Source

pub fn div_scalar(&self, scalar: f32) -> Variable

Divides by a scalar.

Source

pub fn gelu(&self) -> Variable

GELU activation function (Gaussian Error Linear Unit).

Source

pub fn silu(&self) -> Variable

SiLU/Swish activation function (x * sigmoid(x)).

Source

pub fn sqrt(&self) -> Variable

Square root.

Source

pub fn softmax(&self, dim: i32) -> Variable

Softmax along specified dimension.

Source

pub fn log_softmax(&self, dim: i32) -> Variable

Log softmax along specified dimension.

Source

pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Variable

Mean along a dimension, optionally keeping the dimension.

Source

pub fn var_dim(&self, dim: i32, keepdim: bool) -> Variable

Variance along a dimension, optionally keeping the dimension.

Source

pub fn from_tensor_with_grad(data: Tensor<f32>, requires_grad: bool) -> Variable

Creates a Variable from a tensor and requires_grad flag (for weight access). This is typically used internally by Parameter types.

Source

pub fn clone_var(&self) -> Variable

Clones the variable (alias for Clone trait).

Source

pub fn add(&self, other: &Variable) -> Variable

Adds another variable (alias for add_var for method chaining).

Source

pub fn sub(&self, other: &Variable) -> Variable

Subtracts another variable (alias for sub_var for method chaining).

Source

pub fn mul(&self, other: &Variable) -> Variable

Multiplies by another variable (alias for mul_var for method chaining).

Source

pub fn div(&self, other: &Variable) -> Variable

Divides by another variable (alias for div_var for method chaining).

Trait Implementations§

Source§

impl Add for &Variable

Source§

type Output = Variable

The resulting type after applying the + operator.
Source§

fn add(self, other: &Variable) -> Variable

Performs the + operation. Read more
Source§

impl Clone for Variable

Source§

fn clone(&self) -> Variable

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for Variable

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more
Source§

impl Div for &Variable

Source§

type Output = Variable

The resulting type after applying the / operator.
Source§

fn div(self, other: &Variable) -> Variable

Performs the / operation. Read more
Source§

impl Mul for &Variable

Source§

type Output = Variable

The resulting type after applying the * operator.
Source§

fn mul(self, other: &Variable) -> Variable

Performs the * operation. Read more
Source§

impl Neg for &Variable

Source§

type Output = Variable

The resulting type after applying the - operator.
Source§

fn neg(self) -> Variable

Performs the unary - operation. Read more
Source§

impl Sub for &Variable

Source§

type Output = Variable

The resulting type after applying the - operator.
Source§

fn sub(self, other: &Variable) -> Variable

Performs the - operation. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> PolicyExt for T
where T: ?Sized,

Source§

fn and<P, B, E>(self, other: P) -> And<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow only if self and other return Action::Follow. Read more
Source§

fn or<P, B, E>(self, other: P) -> Or<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow if either self or other returns Action::Follow. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more