Skip to main content

train_panoptes/
train_panoptes.rs

1//! Train Panoptes — Facility-Wide Anomaly Detection Model
2//!
3//! # File
4//! `crates/axonml/examples/train_panoptes.rs`
5//!
6//! # Description
7//! Trains the Panoptes model on physics-informed synthetic data generated
8//! from Heritage Pointe of Warren BAS control logic. Two-phase training:
9//! Phase 1 learns normal operation patterns, Phase 2 learns fault signatures.
10//!
11//! # Usage
12//! ```bash
13//! cargo run --release -p axonml --example train_panoptes
14//! ```
15//!
16//! # Author
17//! Andrew Jewell Sr - AutomataNexus
18//!
19//! # Updated
20//! March 9, 2026
21//!
22//! # Disclaimer
23//! Use at own risk. This software is provided "as is", without warranty of any
24//! kind, express or implied. The author and AutomataNexus shall not be held
25//! liable for any damages arising from the use of this software.
26
27use axonml::hvac::panoptes::*;
28use axonml::hvac::panoptes_datagen::*;
29use axonml_autograd::Variable;
30use axonml_nn::MSELoss;
31use axonml_optim::{Adam, Optimizer};
32use axonml_tensor::Tensor;
33use std::time::Instant;
34
35// =============================================================================
36// Configuration
37// =============================================================================
38
39const NUM_EQUIPMENT: usize = 59;
40const NORMAL_SAMPLES: usize = 2000;
41const FAULT_SAMPLES: usize = 1000;
42const PHASE1_EPOCHS: usize = 30;
43const PHASE2_EPOCHS: usize = 20;
44const PHASE3_EPOCHS: usize = 15;
45const TEMPORAL_WINDOW: usize = 12; // 12 snapshots = 1 hour at 5-min intervals
46const TEMPORAL_NORMAL_SEQS: usize = 100;
47const TEMPORAL_FAULT_SEQS: usize = 80;
48const BATCH_SIZE: usize = 16;
49const LR: f32 = 1e-3;
50const SEED: u64 = 42;
51
52// =============================================================================
53// Main
54// =============================================================================
55
56fn main() {
57    println!("╔══════════════════════════════════════════════════════════════╗");
58    println!("║     PANOPTES — Facility-Wide Anomaly Detection Training     ║");
59    println!("║     Heritage Pointe of Warren (59 equipment)                ║");
60    println!("╚══════════════════════════════════════════════════════════════╝");
61    println!();
62
63    // =========================================================================
64    // Generate training data
65    // =========================================================================
66    println!("[data] Generating physics-informed training data...");
67    let t0 = Instant::now();
68
69    let sim = WarrenSimulator::new(SEED);
70    let normal_train = sim.generate_normal(NORMAL_SAMPLES);
71    let fault_data = sim.generate_with_faults(FAULT_SAMPLES, 1.0);
72
73    // Validation set (different seed)
74    let val_sim = WarrenSimulator::new(SEED + 999);
75    let normal_val = val_sim.generate_normal(200);
76    let fault_val = val_sim.generate_with_faults(100, 1.0);
77
78    println!("  Normal train: {} samples", normal_train.len());
79    println!("  Fault train:  {} samples", fault_data.len());
80    println!("  Normal val:   {} samples", normal_val.len());
81    println!("  Fault val:    {} samples", fault_val.len());
82    println!("  Generated in {:.1}s", t0.elapsed().as_secs_f32());
83    println!();
84
85    // =========================================================================
86    // Create model
87    // =========================================================================
88    let model = Panoptes::new(NUM_EQUIPMENT);
89    println!("[model] Panoptes created");
90    println!("  Equipment slots: {NUM_EQUIPMENT}");
91    println!("  Parameters: {}", model.num_parameters());
92    println!("  Embed dim: {EMBED_DIM}");
93    println!();
94
95    let mse = MSELoss::new();
96
97    // Zero target for normal operation
98    let zero_target = Variable::new(
99        Tensor::from_vec(vec![0.0; NUM_EQUIPMENT], &[1, NUM_EQUIPMENT]).unwrap(),
100        false,
101    );
102
103    // =========================================================================
104    // Phase 1: Learn normal operation
105    // =========================================================================
106    println!("═══════════════════════════════════════════════════════════════");
107    println!(" PHASE 1: Learning Normal Operation ({PHASE1_EPOCHS} epochs)");
108    println!("═══════════════════════════════════════════════════════════════");
109    println!(
110        "  {:>5}  {:>12}  {:>12}  {:>8}",
111        "Epoch", "Train Loss", "Val Loss", "Time"
112    );
113    println!("  {:-<5}  {:-<12}  {:-<12}  {:-<8}", "", "", "", "");
114
115    let params = model.parameters();
116    let mut optimizer = Adam::new(params, LR);
117
118    for epoch in 1..=PHASE1_EPOCHS {
119        let epoch_start = Instant::now();
120        let mut epoch_loss = 0.0f32;
121        let mut batch_count = 0;
122
123        // Train on normal data: target = all zeros
124        for batch_start in (0..normal_train.len()).step_by(BATCH_SIZE) {
125            let batch_end = (batch_start + BATCH_SIZE).min(normal_train.len());
126
127            for i in batch_start..batch_end {
128                optimizer.zero_grad();
129
130                let (equip_scores, _) = model.forward_snapshot(&normal_train[i]);
131                let loss = mse.compute(&equip_scores, &zero_target);
132                let loss_val = loss.data().to_vec()[0];
133                epoch_loss += loss_val;
134                batch_count += 1;
135
136                if loss.requires_grad() {
137                    loss.backward();
138                    optimizer.step();
139                }
140            }
141        }
142
143        // Validation
144        let val_loss = evaluate_normal(&model, &normal_val, &mse, &zero_target);
145
146        let avg_loss = epoch_loss / batch_count as f32;
147        let elapsed = epoch_start.elapsed().as_secs_f32();
148
149        println!(
150            "  {:>5}  {:>12.6}  {:>12.6}  {:>6.1}s",
151            epoch, avg_loss, val_loss, elapsed
152        );
153    }
154
155    println!();
156
157    // =========================================================================
158    // Phase 2: Learn fault signatures
159    // =========================================================================
160    println!("═══════════════════════════════════════════════════════════════");
161    println!(" PHASE 2: Learning Fault Signatures ({PHASE2_EPOCHS} epochs)");
162    println!("═══════════════════════════════════════════════════════════════");
163    println!(
164        "  {:>5}  {:>12}  {:>12}  {:>12}  {:>8}",
165        "Epoch", "Normal Loss", "Fault Loss", "Val Loss", "Time"
166    );
167    println!(
168        "  {:-<5}  {:-<12}  {:-<12}  {:-<12}  {:-<8}",
169        "", "", "", "", ""
170    );
171
172    // Reset optimizer with lower LR for phase 2
173    let params = model.parameters();
174    let mut optimizer = Adam::new(params, LR * 0.5);
175
176    for epoch in 1..=PHASE2_EPOCHS {
177        let epoch_start = Instant::now();
178        let mut normal_loss_sum = 0.0f32;
179        let mut fault_loss_sum = 0.0f32;
180        let mut normal_count = 0;
181        let mut fault_count = 0;
182
183        // Interleave normal + fault samples
184        let normal_per_epoch = NORMAL_SAMPLES / 2; // Use half of normal data
185        let fault_per_epoch = fault_data.len();
186
187        // Normal samples: target = zeros
188        for i in 0..normal_per_epoch.min(normal_train.len()) {
189            optimizer.zero_grad();
190            let (equip_scores, _) = model.forward_snapshot(&normal_train[i]);
191            let loss = mse.compute(&equip_scores, &zero_target);
192            normal_loss_sum += loss.data().to_vec()[0];
193            normal_count += 1;
194
195            if loss.requires_grad() {
196                loss.backward();
197                optimizer.step();
198            }
199        }
200
201        // Fault samples: target = 1.0 for affected equipment
202        for i in 0..fault_per_epoch {
203            let (ref snap, ref _fault, ref affected) = fault_data[i];
204
205            let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
206            let fault_target = Variable::new(
207                Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
208                false,
209            );
210
211            optimizer.zero_grad();
212            let (equip_scores, _) = model.forward_snapshot(snap);
213            let loss = mse.compute(&equip_scores, &fault_target);
214            fault_loss_sum += loss.data().to_vec()[0];
215            fault_count += 1;
216
217            if loss.requires_grad() {
218                loss.backward();
219                optimizer.step();
220            }
221        }
222
223        // Validation
224        let val_loss = evaluate_mixed(&model, &normal_val, &fault_val, &mse, &zero_target);
225
226        let avg_normal = normal_loss_sum / normal_count.max(1) as f32;
227        let avg_fault = fault_loss_sum / fault_count.max(1) as f32;
228        let elapsed = epoch_start.elapsed().as_secs_f32();
229
230        println!(
231            "  {:>5}  {:>12.6}  {:>12.6}  {:>12.6}  {:>6.1}s",
232            epoch, avg_normal, avg_fault, val_loss, elapsed
233        );
234    }
235
236    println!();
237
238    // =========================================================================
239    // Phase 3: Temporal training
240    // =========================================================================
241    println!("═══════════════════════════════════════════════════════════════");
242    println!(" PHASE 3: Temporal Training ({PHASE3_EPOCHS} epochs, window={TEMPORAL_WINDOW})");
243    println!("═══════════════════════════════════════════════════════════════");
244
245    // Generate temporal sequences
246    println!("[data] Generating temporal sequences...");
247    let t0 = Instant::now();
248
249    // Normal temporal sequences: varied starting OAT, slow drift
250    let mut normal_seqs: Vec<Vec<FacilitySnapshot>> = Vec::new();
251    for i in 0..TEMPORAL_NORMAL_SEQS {
252        let start_oat = -5.0 + (i as f32 / TEMPORAL_NORMAL_SEQS as f32) * 100.0;
253        let drift = if start_oat < 50.0 { 0.2 } else { -0.1 }; // warming up or cooling down
254        let seq_sim = WarrenSimulator::new(SEED + 5000 + i as u64);
255        let seq = seq_sim.generate_temporal_sequence(TEMPORAL_WINDOW, start_oat, drift);
256        normal_seqs.push(seq);
257    }
258
259    // Fault temporal sequences: fault injected mid-sequence
260    let mut fault_seqs: Vec<(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)> = Vec::new();
261    for i in 0..TEMPORAL_FAULT_SEQS {
262        let start_oat = -5.0 + (i as f32 / TEMPORAL_FAULT_SEQS as f32) * 100.0;
263        let drift = 0.1;
264        let seq_sim = WarrenSimulator::new(SEED + 8000 + i as u64);
265        let seq_data =
266            seq_sim.generate_temporal_with_fault(TEMPORAL_WINDOW, start_oat, drift, i as u64);
267        fault_seqs.push(seq_data);
268    }
269
270    // Validation temporal sequences
271    let mut val_normal_seqs: Vec<Vec<FacilitySnapshot>> = Vec::new();
272    for i in 0..20 {
273        let start_oat = 10.0 + (i as f32 / 20.0) * 80.0;
274        let seq_sim = WarrenSimulator::new(SEED + 9000 + i as u64);
275        let seq = seq_sim.generate_temporal_sequence(TEMPORAL_WINDOW, start_oat, 0.15);
276        val_normal_seqs.push(seq);
277    }
278
279    let mut val_fault_seqs: Vec<(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)> = Vec::new();
280    for i in 0..20 {
281        let start_oat = 10.0 + (i as f32 / 20.0) * 80.0;
282        let seq_sim = WarrenSimulator::new(SEED + 9500 + i as u64);
283        let seq_data =
284            seq_sim.generate_temporal_with_fault(TEMPORAL_WINDOW, start_oat, 0.1, i as u64);
285        val_fault_seqs.push(seq_data);
286    }
287
288    println!("  Normal temporal seqs: {}", normal_seqs.len());
289    println!("  Fault temporal seqs:  {}", fault_seqs.len());
290    println!("  Val normal seqs:      {}", val_normal_seqs.len());
291    println!("  Val fault seqs:       {}", val_fault_seqs.len());
292    println!("  Window size: {TEMPORAL_WINDOW} snapshots (1 hour)");
293    println!("  Generated in {:.1}s", t0.elapsed().as_secs_f32());
294    println!();
295
296    println!(
297        "  {:>5}  {:>12}  {:>12}  {:>12}  {:>8}",
298        "Epoch", "Normal Loss", "Fault Loss", "Val Loss", "Time"
299    );
300    println!(
301        "  {:-<5}  {:-<12}  {:-<12}  {:-<12}  {:-<8}",
302        "", "", "", "", ""
303    );
304
305    // Lower LR for temporal fine-tuning
306    let params = model.parameters();
307    let mut optimizer = Adam::new(params, LR * 0.3);
308
309    for epoch in 1..=PHASE3_EPOCHS {
310        let epoch_start = Instant::now();
311        let mut normal_loss_sum = 0.0f32;
312        let mut fault_loss_sum = 0.0f32;
313        let mut normal_count = 0;
314        let mut fault_count = 0;
315
316        // Normal temporal sequences: target = all zeros
317        for seq in &normal_seqs {
318            optimizer.zero_grad();
319            let (equip_scores, _) = model.forward_temporal(seq);
320            let loss = mse.compute(&equip_scores, &zero_target);
321            normal_loss_sum += loss.data().to_vec()[0];
322            normal_count += 1;
323
324            if loss.requires_grad() {
325                loss.backward();
326                optimizer.step();
327            }
328        }
329
330        // Fault temporal sequences: target = 1.0 for affected equipment
331        for (seq, _onset, _fault, affected) in &fault_seqs {
332            let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
333            let fault_target = Variable::new(
334                Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
335                false,
336            );
337
338            optimizer.zero_grad();
339            let (equip_scores, _) = model.forward_temporal(seq);
340            let loss = mse.compute(&equip_scores, &fault_target);
341            fault_loss_sum += loss.data().to_vec()[0];
342            fault_count += 1;
343
344            if loss.requires_grad() {
345                loss.backward();
346                optimizer.step();
347            }
348        }
349
350        // Validation
351        let val_loss = evaluate_temporal_mixed(
352            &model,
353            &val_normal_seqs,
354            &val_fault_seqs,
355            &mse,
356            &zero_target,
357        );
358
359        let avg_normal = normal_loss_sum / normal_count.max(1) as f32;
360        let avg_fault = fault_loss_sum / fault_count.max(1) as f32;
361        let elapsed = epoch_start.elapsed().as_secs_f32();
362
363        println!(
364            "  {:>5}  {:>12.6}  {:>12.6}  {:>12.6}  {:>6.1}s",
365            epoch, avg_normal, avg_fault, val_loss, elapsed
366        );
367    }
368
369    println!();
370
371    // =========================================================================
372    // Final evaluation
373    // =========================================================================
374    println!("═══════════════════════════════════════════════════════════════");
375    println!(" FINAL EVALUATION");
376    println!("═══════════════════════════════════════════════════════════════");
377
378    // Test on normal data — scores should be near zero
379    let config = FacilityConfig::warren();
380    println!("\n  Normal operation (should be low scores):");
381    for i in [0, 50, 100, 150] {
382        if i >= normal_val.len() {
383            break;
384        }
385        let (equip_scores, fac_score) = model.forward_snapshot(&normal_val[i]);
386        let scores = equip_scores.data().to_vec();
387        let fac = fac_score.data().to_vec()[0];
388        let max_score = scores.iter().cloned().fold(0.0f32, f32::max);
389        let avg_score: f32 = scores.iter().sum::<f32>() / scores.len() as f32;
390        println!(
391            "    Sample {i:>3}: facility={fac:.4}, avg_equip={avg_score:.4}, max_equip={max_score:.4}"
392        );
393    }
394
395    // Test on fault data — affected equipment should have higher scores
396    println!("\n  Fault samples (affected equipment should score higher):");
397    for i in 0..5.min(fault_val.len()) {
398        let (ref snap, ref fault, ref affected) = fault_val[i];
399        let (equip_scores, fac_score) = model.forward_snapshot(snap);
400        let scores = equip_scores.data().to_vec();
401        let fac = fac_score.data().to_vec()[0];
402
403        let output = PanoptesOutput::from_scores(&scores, fac, &config, 0.3);
404
405        // Get scores for affected vs unaffected
406        let affected_avg: f32 = if !affected.is_empty() {
407            affected
408                .iter()
409                .filter(|&&s| s < scores.len())
410                .map(|&s| scores[s])
411                .sum::<f32>()
412                / affected.len() as f32
413        } else {
414            0.0
415        };
416
417        println!("    Fault {:?}:", fault);
418        println!(
419            "      facility={fac:.4}, affected_avg={affected_avg:.4}, alerts={}",
420            output.alerts.len()
421        );
422    }
423
424    // Temporal evaluation
425    println!("\n  Temporal normal (should be low scores):");
426    for i in 0..3.min(val_normal_seqs.len()) {
427        let (equip_scores, fac_score) = model.forward_temporal(&val_normal_seqs[i]);
428        let scores = equip_scores.data().to_vec();
429        let fac = fac_score.data().to_vec()[0];
430        let max_score = scores.iter().cloned().fold(0.0f32, f32::max);
431        let avg_score: f32 = scores.iter().sum::<f32>() / scores.len() as f32;
432        println!(
433            "    Seq {i:>3}: facility={fac:.4}, avg_equip={avg_score:.4}, max_equip={max_score:.4}"
434        );
435    }
436
437    println!("\n  Temporal fault (fault injected mid-sequence):");
438    for i in 0..5.min(val_fault_seqs.len()) {
439        let (ref seq, onset, ref fault, ref affected) = val_fault_seqs[i];
440        let (equip_scores, fac_score) = model.forward_temporal(seq);
441        let scores = equip_scores.data().to_vec();
442        let fac = fac_score.data().to_vec()[0];
443
444        let affected_avg: f32 = if !affected.is_empty() {
445            affected
446                .iter()
447                .filter(|&&s| s < scores.len())
448                .map(|&s| scores[s])
449                .sum::<f32>()
450                / affected.len() as f32
451        } else {
452            0.0
453        };
454        let unaffected_avg: f32 = {
455            let unaffected: Vec<f32> = scores
456                .iter()
457                .enumerate()
458                .filter(|(idx, _)| !affected.contains(idx))
459                .map(|(_, &s)| s)
460                .collect();
461            if unaffected.is_empty() {
462                0.0
463            } else {
464                unaffected.iter().sum::<f32>() / unaffected.len() as f32
465            }
466        };
467
468        let output = PanoptesOutput::from_scores(&scores, fac, &config, 0.3);
469        println!(
470            "    Fault {:?} (onset step {onset}/{TEMPORAL_WINDOW}):",
471            fault
472        );
473        println!(
474            "      facility={fac:.4}, affected={affected_avg:.4}, unaffected={unaffected_avg:.4}, alerts={}",
475            output.alerts.len()
476        );
477    }
478
479    println!();
480    println!("Training complete.");
481}
482
483// =============================================================================
484// Evaluation helpers
485// =============================================================================
486
487fn evaluate_normal(
488    model: &Panoptes,
489    val_data: &[FacilitySnapshot],
490    mse: &MSELoss,
491    zero_target: &Variable,
492) -> f32 {
493    let mut total_loss = 0.0f32;
494    for snap in val_data {
495        let (equip_scores, _) = model.forward_snapshot(snap);
496        let loss = mse.compute(&equip_scores, zero_target);
497        total_loss += loss.data().to_vec()[0];
498    }
499    total_loss / val_data.len() as f32
500}
501
502fn evaluate_mixed(
503    model: &Panoptes,
504    normal_val: &[FacilitySnapshot],
505    fault_val: &[(FacilitySnapshot, FaultType, Vec<usize>)],
506    mse: &MSELoss,
507    zero_target: &Variable,
508) -> f32 {
509    let mut total_loss = 0.0f32;
510    let mut count = 0;
511
512    for snap in normal_val {
513        let (equip_scores, _) = model.forward_snapshot(snap);
514        let loss = mse.compute(&equip_scores, zero_target);
515        total_loss += loss.data().to_vec()[0];
516        count += 1;
517    }
518
519    for (snap, _, affected) in fault_val {
520        let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
521        let fault_target = Variable::new(
522            Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
523            false,
524        );
525        let (equip_scores, _) = model.forward_snapshot(snap);
526        let loss = mse.compute(&equip_scores, &fault_target);
527        total_loss += loss.data().to_vec()[0];
528        count += 1;
529    }
530
531    total_loss / count as f32
532}
533
534fn evaluate_temporal_mixed(
535    model: &Panoptes,
536    normal_seqs: &[Vec<FacilitySnapshot>],
537    fault_seqs: &[(Vec<FacilitySnapshot>, usize, FaultType, Vec<usize>)],
538    mse: &MSELoss,
539    zero_target: &Variable,
540) -> f32 {
541    let mut total_loss = 0.0f32;
542    let mut count = 0;
543
544    for seq in normal_seqs {
545        let (equip_scores, _) = model.forward_temporal(seq);
546        let loss = mse.compute(&equip_scores, zero_target);
547        total_loss += loss.data().to_vec()[0];
548        count += 1;
549    }
550
551    for (seq, _, _, affected) in fault_seqs {
552        let target_vec = PanoptesTrainingData::fault_target(NUM_EQUIPMENT, affected);
553        let fault_target = Variable::new(
554            Tensor::from_vec(target_vec, &[1, NUM_EQUIPMENT]).unwrap(),
555            false,
556        );
557        let (equip_scores, _) = model.forward_temporal(seq);
558        let loss = mse.compute(&equip_scores, &fault_target);
559        total_loss += loss.data().to_vec()[0];
560        count += 1;
561    }
562
563    total_loss / count as f32
564}