Skip to main content

hvac_training/
hvac_training.rs

1//! HVAC Multi-Horizon Predictor - Training with Synthetic Data
2//!
3//! # File
4//! `crates/axonml/examples/hvac_training.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml::autograd::Variable;
18use axonml::nn::{CrossEntropyLoss, Dropout, GRU, LayerNorm, Linear, Module, Parameter, ReLU};
19use axonml::optim::{Adam, Optimizer};
20use axonml::tensor::Tensor;
21use std::time::Instant;
22
23// =============================================================================
24// Configuration
25// =============================================================================
26
27/// Training Configuration
28#[derive(Debug, Clone)]
29pub struct TrainingConfig {
30    pub batch_size: usize,
31    pub epochs: usize,
32    pub learning_rate: f32,
33    pub weight_decay: f32,
34    pub val_split: f32,
35    pub print_every: usize,
36}
37
38impl Default for TrainingConfig {
39    fn default() -> Self {
40        Self {
41            batch_size: 32,
42            epochs: 50,
43            learning_rate: 0.001,
44            weight_decay: 0.0001,
45            val_split: 0.2,
46            print_every: 10,
47        }
48    }
49}
50
51/// HVAC Model Configuration
52#[derive(Debug, Clone)]
53pub struct HvacConfig {
54    pub num_features: usize,
55    pub seq_len: usize,
56    pub hidden_size: usize,
57    pub num_layers: usize,
58    pub num_classes: usize,
59    pub dropout: f32,
60}
61
62impl Default for HvacConfig {
63    fn default() -> Self {
64        Self {
65            num_features: 28,
66            seq_len: 120,
67            hidden_size: 128,
68            num_layers: 2,
69            num_classes: 20,
70            dropout: 0.1,
71        }
72    }
73}
74
75// =============================================================================
76// Synthetic Data Generator
77// =============================================================================
78
79/// Operating conditions for simulation
80#[derive(Debug, Clone)]
81pub struct OperatingConditions {
82    pub outdoor_temp: f32,
83    pub is_winter: bool,
84    pub hw_lead_pump: usize,
85    pub cw_lead_pump: usize,
86    pub pipe2_lead_pump: usize,
87}
88
89impl Default for OperatingConditions {
90    fn default() -> Self {
91        Self {
92            outdoor_temp: 70.0,
93            is_winter: false,
94            hw_lead_pump: 0,
95            cw_lead_pump: 0,
96            pipe2_lead_pump: 0,
97        }
98    }
99}
100
101/// HVAC Synthetic Data Generator
102/// Generates realistic sensor data including normal operation and failure scenarios
103pub struct HvacDataGenerator {
104    _seed: u64,
105    rng_state: u64,
106}
107
108impl HvacDataGenerator {
109    pub fn new(seed: u64) -> Self {
110        Self {
111            _seed: seed,
112            rng_state: seed,
113        }
114    }
115
116    /// Simple LCG random number generator
117    fn rand(&mut self) -> f32 {
118        self.rng_state = self
119            .rng_state
120            .wrapping_mul(6364136223846793005)
121            .wrapping_add(1);
122        ((self.rng_state >> 33) as f32) / (u32::MAX as f32)
123    }
124
125    /// Generate random normal (approximate using Box-Muller)
126    fn randn(&mut self) -> f32 {
127        let u1 = self.rand().max(1e-10);
128        let u2 = self.rand();
129        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
130    }
131
132    /// Generate Gaussian noise
133    fn noise(&mut self, scale: f32) -> f32 {
134        self.randn() * scale
135    }
136
137    /// Generate normal operation data
138    /// Returns: (data [n_samples, 28], labels [n_samples])
139    pub fn generate_normal_operation(
140        &mut self,
141        n_samples: usize,
142        conditions: &OperatingConditions,
143    ) -> (Vec<f32>, Vec<i64>) {
144        let mut data = vec![0.0f32; n_samples * 28];
145        let labels = vec![0i64; n_samples]; // All normal
146
147        let oat = conditions.outdoor_temp;
148        let is_heating = oat < 65.0;
149        let is_cooling = oat > 72.0;
150        let base_current = 18.0;
151
152        for t in 0..n_samples {
153            let base = t * 28;
154
155            // Diurnal outdoor temp variation
156            let hour = (t as f32 / 3600.0) % 24.0;
157            let oat_var = 8.0 * (2.0 * std::f32::consts::PI * hour / 24.0).sin();
158            let outdoor_temp = oat + oat_var + self.noise(1.0);
159            data[base + 10] = outdoor_temp;
160
161            // ===== PUMP CURRENTS (0-5) =====
162            // 4-Pipe HW Pumps
163            if conditions.hw_lead_pump == 0 {
164                data[base + 0] = if is_heating {
165                    base_current + self.noise(2.0)
166                } else {
167                    0.5
168                };
169                data[base + 1] = 0.3 + self.noise(0.1);
170            } else {
171                data[base + 1] = if is_heating {
172                    base_current + self.noise(2.0)
173                } else {
174                    0.5
175                };
176                data[base + 0] = 0.3 + self.noise(0.1);
177            }
178
179            // 4-Pipe CW Pumps
180            if conditions.cw_lead_pump == 0 {
181                data[base + 2] = if is_cooling {
182                    base_current + self.noise(2.0)
183                } else {
184                    0.5
185                };
186                data[base + 3] = 0.3 + self.noise(0.1);
187            } else {
188                data[base + 3] = if is_cooling {
189                    base_current + self.noise(2.0)
190                } else {
191                    0.5
192                };
193                data[base + 2] = 0.3 + self.noise(0.1);
194            }
195
196            // 2-Pipe Pumps (winter only)
197            if conditions.is_winter {
198                if conditions.pipe2_lead_pump == 0 {
199                    data[base + 4] = base_current * 0.8 + self.noise(1.5);
200                    data[base + 5] = 0.3 + self.noise(0.1);
201                } else {
202                    data[base + 5] = base_current * 0.8 + self.noise(1.5);
203                    data[base + 4] = 0.3 + self.noise(0.1);
204                }
205            } else {
206                data[base + 4] = 0.3 + self.noise(0.1);
207                data[base + 5] = 0.3 + self.noise(0.1);
208            }
209
210            // ===== TEMPERATURES (6-13) =====
211            // 4-Pipe HW Supply
212            data[base + 6] = if is_heating {
213                160.0 + self.noise(3.0)
214            } else {
215                90.0 + self.noise(2.0)
216            };
217
218            // 4-Pipe CW Supply
219            data[base + 7] = if is_cooling {
220                45.0 + self.noise(1.5)
221            } else {
222                55.0 + self.noise(2.0)
223            };
224
225            // 2-Pipe HW Supply (OA reset)
226            let oa_reset = 155.0 - (outdoor_temp - 40.0) * (155.0 - 115.0) / 32.0;
227            data[base + 8] = oa_reset.clamp(115.0, 155.0) + self.noise(2.0);
228
229            // 2-Pipe CW Return
230            data[base + 9] = data[base + 8] - 20.0 + self.noise(3.0);
231
232            // Mech room temp
233            data[base + 11] = 72.0 + self.noise(2.0);
234
235            // Space sensors
236            data[base + 12] = 72.0 + self.noise(0.5);
237            data[base + 13] = 72.0 + self.noise(0.5);
238
239            // ===== PRESSURES (14-15) =====
240            data[base + 14] = 16.0 + self.noise(0.5); // HW
241            data[base + 15] = 12.0 + self.noise(0.4); // CW
242
243            // ===== VFD SPEEDS (16-21) =====
244            for i in 0..6 {
245                let current = data[base + i];
246                data[base + 16 + i] = if current > 1.0 {
247                    40.0 + (current / 25.0) * 40.0 + self.noise(2.0)
248                } else {
249                    0.0
250                };
251            }
252
253            // ===== VALVE POSITIONS (22-23) =====
254            let heating_load = ((65.0 - outdoor_temp) / 25.0).clamp(0.0, 1.0);
255            data[base + 22] = (heating_load * 70.0 + self.noise(3.0)).clamp(0.0, 100.0);
256            data[base + 23] = if data[base + 22] > 90.0 {
257                ((data[base + 22] - 90.0) * 5.0 + self.noise(2.0)).clamp(0.0, 100.0)
258            } else {
259                0.0
260            };
261
262            // ===== SYSTEM STATES (24-27) =====
263            data[base + 24] = if conditions.is_winter { 1.0 } else { 0.0 };
264            data[base + 25] = conditions.hw_lead_pump as f32;
265            data[base + 26] = conditions.cw_lead_pump as f32;
266            data[base + 27] = conditions.pipe2_lead_pump as f32;
267        }
268
269        (data, labels)
270    }
271
272    /// Inject pump failure into data
273    pub fn inject_pump_failure(
274        &mut self,
275        data: &mut [f32],
276        labels: &mut [i64],
277        n_samples: usize,
278        pump_index: usize,
279        failure_start: usize,
280        degradation_rate: f32,
281    ) {
282        let failure_type = (pump_index + 1) as i64;
283
284        for i in failure_start..n_samples {
285            let base = i * 28;
286            let degradation = (i - failure_start) as f32 * degradation_rate;
287
288            // Reduce pump current
289            data[base + pump_index] = (data[base + pump_index] - degradation).max(0.0);
290
291            // VFD speed also decreases
292            data[base + 16 + pump_index] =
293                (data[base + 16 + pump_index] - degradation * 2.0).max(0.0);
294
295            // Mark as failure when current drops below threshold
296            if data[base + pump_index] < 3.0 {
297                labels[i] = failure_type;
298            }
299        }
300    }
301
302    /// Inject pressure anomaly
303    pub fn inject_pressure_anomaly(
304        &mut self,
305        data: &mut [f32],
306        labels: &mut [i64],
307        n_samples: usize,
308        is_hw: bool,
309        is_low: bool,
310        anomaly_start: usize,
311        magnitude: f32,
312    ) {
313        let feat_idx = if is_hw { 14 } else { 15 };
314        let failure_type = match (is_hw, is_low) {
315            (true, true) => 7,    // pressure_low_hw
316            (true, false) => 8,   // pressure_high_hw
317            (false, true) => 9,   // pressure_low_cw
318            (false, false) => 10, // pressure_high_cw
319        };
320        let setpoint = if is_hw { 16.0 } else { 12.0 };
321
322        for i in anomaly_start..n_samples {
323            let base = i * 28;
324            let drift = (i - anomaly_start) as f32 * 0.005 * magnitude;
325
326            if is_low {
327                data[base + feat_idx] -= drift;
328            } else {
329                data[base + feat_idx] += drift;
330            }
331
332            if (data[base + feat_idx] - setpoint).abs() > 3.0 {
333                labels[i] = failure_type;
334            }
335        }
336    }
337
338    /// Inject temperature anomaly
339    pub fn inject_temperature_anomaly(
340        &mut self,
341        data: &mut [f32],
342        labels: &mut [i64],
343        n_samples: usize,
344        temp_type: &str,
345        anomaly_start: usize,
346    ) {
347        let (feat_indices, failure_type): (Vec<usize>, i64) = match temp_type {
348            "hw_supply" => (vec![6], 11),
349            "cw_supply" => (vec![7], 12),
350            "space" => (vec![12, 13], 13),
351            _ => return,
352        };
353
354        for i in anomaly_start..n_samples {
355            let base = i * 28;
356            let drift = (i - anomaly_start) as f32 * 0.02;
357
358            for &feat_idx in &feat_indices {
359                data[base + feat_idx] += drift + self.noise(0.5);
360            }
361
362            if drift > 5.0 {
363                labels[i] = failure_type;
364            }
365        }
366    }
367
368    /// Generate complete training dataset
369    pub fn generate_training_dataset(
370        &mut self,
371        n_normal_samples: usize,
372        n_failure_scenarios: usize,
373        failure_duration: usize,
374    ) -> (Vec<f32>, Vec<i64>) {
375        let mut all_data = Vec::new();
376        let mut all_labels = Vec::new();
377
378        println!("Generating normal operation data...");
379        // Generate normal data for various conditions
380        for season in [false, true] {
381            // summer, winter
382            for oat_idx in 0..10 {
383                let oat = 20.0 + oat_idx as f32 * 7.5;
384                let conditions = OperatingConditions {
385                    outdoor_temp: oat,
386                    is_winter: season,
387                    hw_lead_pump: (self.rand() * 2.0) as usize,
388                    cw_lead_pump: (self.rand() * 2.0) as usize,
389                    pipe2_lead_pump: (self.rand() * 2.0) as usize,
390                };
391                let samples = n_normal_samples / 20;
392                let (data, labels) = self.generate_normal_operation(samples, &conditions);
393                all_data.extend(data);
394                all_labels.extend(labels);
395            }
396        }
397
398        // Generate pump failure scenarios
399        println!("Generating pump failure scenarios...");
400        for pump_idx in 0..6 {
401            for _ in 0..(n_failure_scenarios / 6) {
402                let conditions = OperatingConditions {
403                    outdoor_temp: 30.0 + self.rand() * 55.0,
404                    is_winter: self.rand() > 0.5,
405                    ..Default::default()
406                };
407                let (mut data, mut labels) =
408                    self.generate_normal_operation(failure_duration, &conditions);
409                let failure_start = failure_duration / 3;
410                let rate = 0.005 + self.rand() * 0.015;
411                self.inject_pump_failure(
412                    &mut data,
413                    &mut labels,
414                    failure_duration,
415                    pump_idx,
416                    failure_start,
417                    rate,
418                );
419                all_data.extend(data);
420                all_labels.extend(labels);
421            }
422        }
423
424        // Generate pressure anomalies
425        println!("Generating pressure anomaly scenarios...");
426        for (is_hw, is_low) in [(true, true), (true, false), (false, true), (false, false)] {
427            for _ in 0..(n_failure_scenarios / 4) {
428                let conditions = OperatingConditions {
429                    outdoor_temp: 30.0 + self.rand() * 55.0,
430                    ..Default::default()
431                };
432                let (mut data, mut labels) =
433                    self.generate_normal_operation(failure_duration, &conditions);
434                let failure_start = failure_duration / 3;
435                let magnitude = 3.0 + self.rand() * 5.0;
436                self.inject_pressure_anomaly(
437                    &mut data,
438                    &mut labels,
439                    failure_duration,
440                    is_hw,
441                    is_low,
442                    failure_start,
443                    magnitude,
444                );
445                all_data.extend(data);
446                all_labels.extend(labels);
447            }
448        }
449
450        // Generate temperature anomalies
451        println!("Generating temperature anomaly scenarios...");
452        for temp_type in ["hw_supply", "cw_supply", "space"] {
453            for _ in 0..(n_failure_scenarios / 3) {
454                let conditions = OperatingConditions {
455                    outdoor_temp: 30.0 + self.rand() * 55.0,
456                    ..Default::default()
457                };
458                let (mut data, mut labels) =
459                    self.generate_normal_operation(failure_duration, &conditions);
460                let failure_start = failure_duration / 3;
461                self.inject_temperature_anomaly(
462                    &mut data,
463                    &mut labels,
464                    failure_duration,
465                    temp_type,
466                    failure_start,
467                );
468                all_data.extend(data);
469                all_labels.extend(labels);
470            }
471        }
472
473        let total_samples = all_labels.len();
474        println!("Generated {} total samples", total_samples);
475
476        // Print label distribution
477        let mut label_counts = vec![0usize; 20];
478        for &label in &all_labels {
479            label_counts[label as usize] += 1;
480        }
481        println!(
482            "Label distribution: Normal={}, Failures={}",
483            label_counts[0],
484            label_counts[1..].iter().sum::<usize>()
485        );
486
487        (all_data, all_labels)
488    }
489
490    /// Convert raw data to sequences with multi-horizon labels
491    pub fn generate_multi_horizon_sequences(
492        &self,
493        data: &[f32],
494        labels: &[i64],
495        sequence_length: usize,
496        horizons: &[usize], // [300, 900, 1800] for 5/15/30 min
497        stride: usize,
498    ) -> (Vec<f32>, Vec<i64>, Vec<i64>, Vec<i64>) {
499        let n_samples = labels.len();
500        let max_horizon = *horizons.iter().max().unwrap();
501        let n_sequences = (n_samples - sequence_length - max_horizon) / stride;
502
503        let mut x_data = vec![0.0f32; n_sequences * sequence_length * 28];
504        let mut y_imminent = vec![0i64; n_sequences];
505        let mut y_warning = vec![0i64; n_sequences];
506        let mut y_early = vec![0i64; n_sequences];
507
508        for i in 0..n_sequences {
509            let start_idx = i * stride;
510            let end_idx = start_idx + sequence_length;
511
512            // Copy sequence
513            for t in 0..sequence_length {
514                for f in 0..28 {
515                    x_data[i * sequence_length * 28 + t * 28 + f] = data[(start_idx + t) * 28 + f];
516                }
517            }
518
519            // Get labels for each horizon (max in prediction window)
520            for (h_idx, &horizon) in horizons.iter().enumerate() {
521                let label_start = end_idx;
522                let label_end = (end_idx + horizon).min(n_samples);
523                let mut max_label = 0i64;
524                for j in label_start..label_end {
525                    max_label = max_label.max(labels[j]);
526                }
527                match h_idx {
528                    0 => y_imminent[i] = max_label,
529                    1 => y_warning[i] = max_label,
530                    2 => y_early[i] = max_label,
531                    _ => {}
532                }
533            }
534        }
535
536        (x_data, y_imminent, y_warning, y_early)
537    }
538}
539
540// =============================================================================
541// Model Components (same as hvac_model.rs)
542// =============================================================================
543
544pub struct PredictionHead {
545    fc1: Linear,
546    fc2: Linear,
547    fc3: Linear,
548    relu: ReLU,
549    dropout: Dropout,
550}
551
552impl PredictionHead {
553    pub fn new(hidden_size: usize, num_classes: usize, dropout: f32) -> Self {
554        Self {
555            fc1: Linear::new(hidden_size, hidden_size),
556            fc2: Linear::new(hidden_size, 64),
557            fc3: Linear::new(64, num_classes),
558            relu: ReLU,
559            dropout: Dropout::new(dropout),
560        }
561    }
562}
563
564impl Module for PredictionHead {
565    fn forward(&self, x: &Variable) -> Variable {
566        let x = self.fc1.forward(x);
567        let x = self.relu.forward(&x);
568        let x = self.dropout.forward(&x);
569        let x = self.fc2.forward(&x);
570        let x = self.relu.forward(&x);
571        let x = self.dropout.forward(&x);
572        self.fc3.forward(&x)
573    }
574
575    fn parameters(&self) -> Vec<Parameter> {
576        let mut params = self.fc1.parameters();
577        params.extend(self.fc2.parameters());
578        params.extend(self.fc3.parameters());
579        params
580    }
581}
582
583pub struct HvacPredictor {
584    config: HvacConfig,
585    input_proj: Linear,
586    input_norm: LayerNorm,
587    input_relu: ReLU,
588    gru: GRU,
589    head_imminent: PredictionHead,
590    head_warning: PredictionHead,
591    head_early: PredictionHead,
592}
593
594impl HvacPredictor {
595    pub fn new(config: HvacConfig) -> Self {
596        Self {
597            input_proj: Linear::new(config.num_features, config.hidden_size),
598            input_norm: LayerNorm::new(vec![config.hidden_size]),
599            input_relu: ReLU,
600            gru: GRU::new(config.hidden_size, config.hidden_size, config.num_layers),
601            head_imminent: PredictionHead::new(
602                config.hidden_size,
603                config.num_classes,
604                config.dropout,
605            ),
606            head_warning: PredictionHead::new(
607                config.hidden_size,
608                config.num_classes,
609                config.dropout,
610            ),
611            head_early: PredictionHead::new(config.hidden_size, config.num_classes, config.dropout),
612            config,
613        }
614    }
615
616    pub fn forward_multi(&self, x: &Variable) -> (Variable, Variable, Variable) {
617        let x_data = x.data();
618        let shape = x_data.shape();
619        let batch_size = shape[0];
620        let seq_len = shape[1];
621        drop(x_data);
622
623        let x_flat = x.reshape(&[batch_size * seq_len, self.config.num_features]);
624        let proj = self.input_proj.forward(&x_flat);
625        let proj = self.input_norm.forward(&proj);
626        let proj = self.input_relu.forward(&proj);
627        let proj = proj.reshape(&[batch_size, seq_len, self.config.hidden_size]);
628
629        // Use forward_mean for proper gradient flow (equivalent to forward + mean_pool)
630        let pooled = self.gru.forward_mean(&proj);
631
632        let imminent = self.head_imminent.forward(&pooled);
633        let warning = self.head_warning.forward(&pooled);
634        let early = self.head_early.forward(&pooled);
635
636        (imminent, warning, early)
637    }
638
639    pub fn num_parameters(&self) -> usize {
640        self.parameters()
641            .iter()
642            .map(|p| p.variable().data().numel())
643            .sum()
644    }
645}
646
647impl Module for HvacPredictor {
648    fn forward(&self, x: &Variable) -> Variable {
649        let (imminent, _, _) = self.forward_multi(x);
650        imminent
651    }
652
653    fn parameters(&self) -> Vec<Parameter> {
654        let mut params = self.input_proj.parameters();
655        params.extend(self.input_norm.parameters());
656        params.extend(self.gru.parameters());
657        params.extend(self.head_imminent.parameters());
658        params.extend(self.head_warning.parameters());
659        params.extend(self.head_early.parameters());
660        params
661    }
662}
663
664// =============================================================================
665// Training Loop
666// =============================================================================
667
668/// Normalize data to [0, 1] range based on sensor ranges
669fn normalize_data(data: &mut [f32], n_samples: usize) {
670    let sensor_ranges: [(f32, f32); 28] = [
671        (0.0, 50.0),    // 0: hw_pump_5_current
672        (0.0, 50.0),    // 1: hw_pump_6_current
673        (0.0, 50.0),    // 2: cw_pump_3_current
674        (0.0, 50.0),    // 3: cw_pump_4_current
675        (0.0, 50.0),    // 4: 2pipe_pump_a_current
676        (0.0, 50.0),    // 5: 2pipe_pump_b_current
677        (80.0, 200.0),  // 6: hw_supply_4pipe_temp
678        (40.0, 80.0),   // 7: cw_supply_4pipe_temp
679        (115.0, 155.0), // 8: hw_supply_2pipe_temp
680        (70.0, 120.0),  // 9: cw_return_2pipe_temp
681        (-20.0, 120.0), // 10: outdoor_air_temp
682        (50.0, 90.0),   // 11: mech_room_temp
683        (65.0, 85.0),   // 12: space_sensor_1_temp
684        (65.0, 85.0),   // 13: space_sensor_2_temp
685        (0.0, 200.0),   // 14: hw_pressure_4pipe
686        (0.0, 200.0),   // 15: cw_pressure_4pipe
687        (0.0, 100.0),   // 16-21: VFD speeds
688        (0.0, 100.0),
689        (0.0, 100.0),
690        (0.0, 100.0),
691        (0.0, 100.0),
692        (0.0, 100.0),
693        (0.0, 100.0), // 22: steam_valve_1_3_pos
694        (0.0, 100.0), // 23: steam_valve_2_3_pos
695        (0.0, 1.0),   // 24: summer_winter_mode
696        (0.0, 1.0),   // 25: hw_lead_pump_id
697        (0.0, 1.0),   // 26: cw_lead_pump_id
698        (0.0, 1.0),   // 27: 2pipe_lead_pump_id
699    ];
700
701    for i in 0..n_samples {
702        for f in 0..28 {
703            let (min_val, max_val) = sensor_ranges[f];
704            let idx = i * 28 + f;
705            data[idx] = ((data[idx] - min_val) / (max_val - min_val)).clamp(0.0, 1.0);
706        }
707    }
708}
709
710/// Calculate accuracy
711fn calculate_accuracy(logits: &Variable, labels: &[i64]) -> f32 {
712    let data = logits.data();
713    let shape = data.shape();
714    let batch_size = shape[0];
715    let num_classes = shape[1];
716    let values = data.to_vec();
717
718    let mut correct = 0;
719    for b in 0..batch_size {
720        let start = b * num_classes;
721        let mut max_idx = 0;
722        let mut max_val = values[start];
723        for c in 1..num_classes {
724            if values[start + c] > max_val {
725                max_val = values[start + c];
726                max_idx = c;
727            }
728        }
729        if max_idx == labels[b] as usize {
730            correct += 1;
731        }
732    }
733    correct as f32 / batch_size as f32
734}
735
736/// Training function
737fn train_epoch(
738    model: &HvacPredictor,
739    optimizer: &mut Adam,
740    loss_fn: &CrossEntropyLoss,
741    x_data: &[f32],
742    y_imminent: &[i64],
743    y_warning: &[i64],
744    y_early: &[i64],
745    batch_size: usize,
746    seq_len: usize,
747    num_features: usize,
748) -> (f32, f32, f32, f32) {
749    let n_sequences = y_imminent.len();
750    let n_batches = n_sequences / batch_size;
751
752    let mut total_loss = 0.0f32;
753    let mut total_acc_imm = 0.0f32;
754    let mut total_acc_warn = 0.0f32;
755    let mut total_acc_early = 0.0f32;
756
757    for batch_idx in 0..n_batches {
758        let start = batch_idx * batch_size;
759
760        // Prepare batch data
761        let mut batch_x = vec![0.0f32; batch_size * seq_len * num_features];
762        let mut batch_y_imm = vec![0i64; batch_size];
763        let mut batch_y_warn = vec![0i64; batch_size];
764        let mut batch_y_early = vec![0i64; batch_size];
765
766        for b in 0..batch_size {
767            let seq_start = (start + b) * seq_len * num_features;
768            for i in 0..(seq_len * num_features) {
769                batch_x[b * seq_len * num_features + i] = x_data[seq_start + i];
770            }
771            batch_y_imm[b] = y_imminent[start + b];
772            batch_y_warn[b] = y_warning[start + b];
773            batch_y_early[b] = y_early[start + b];
774        }
775
776        // Create tensors
777        let x_tensor = Tensor::from_vec(batch_x, &[batch_size, seq_len, num_features])
778            .expect("Failed to create input tensor");
779        let x_var = Variable::new(x_tensor, true);
780
781        // Forward pass
782        let (logits_imm, logits_warn, logits_early) = model.forward_multi(&x_var);
783
784        // Calculate losses (simplified - just use imminent for now)
785        let y_imm_tensor = Tensor::from_vec(
786            batch_y_imm.iter().map(|&y| y as f32).collect(),
787            &[batch_size],
788        )
789        .expect("Failed to create label tensor");
790        let y_imm_var = Variable::new(y_imm_tensor, false);
791
792        let loss = loss_fn.compute(&logits_imm, &y_imm_var);
793
794        // Backward pass
795        optimizer.zero_grad();
796        loss.backward();
797        optimizer.step();
798
799        // Metrics
800        total_loss += loss.data().to_vec()[0];
801        total_acc_imm += calculate_accuracy(&logits_imm, &batch_y_imm);
802        total_acc_warn += calculate_accuracy(&logits_warn, &batch_y_warn);
803        total_acc_early += calculate_accuracy(&logits_early, &batch_y_early);
804    }
805
806    let n = n_batches as f32;
807    (
808        total_loss / n,
809        total_acc_imm / n,
810        total_acc_warn / n,
811        total_acc_early / n,
812    )
813}
814
815// =============================================================================
816// Main
817// =============================================================================
818
819fn main() {
820    println!("╔════════════════════════════════════════════════════════════╗");
821    println!("║     HVAC Multi-Horizon Predictor - Training Pipeline       ║");
822    println!("╚════════════════════════════════════════════════════════════╝");
823    println!();
824
825    // Check for quick mode via environment variable
826    let quick_mode = std::env::var("HVAC_QUICK").is_ok();
827
828    let train_config = if quick_mode {
829        TrainingConfig {
830            batch_size: 16,
831            epochs: 100,
832            learning_rate: 0.01,
833            weight_decay: 0.0001,
834            val_split: 0.2,
835            print_every: 10,
836        }
837    } else {
838        TrainingConfig::default()
839    };
840
841    let model_config = if quick_mode {
842        HvacConfig {
843            num_features: 28,
844            seq_len: 30,     // Reduced from 120
845            hidden_size: 64, // Reduced from 128
846            num_layers: 1,   // Reduced from 2
847            num_classes: 20,
848            dropout: 0.1,
849        }
850    } else {
851        HvacConfig::default()
852    };
853
854    println!("Training Configuration:");
855    println!("  Batch size: {}", train_config.batch_size);
856    println!("  Epochs: {}", train_config.epochs);
857    println!("  Learning rate: {}", train_config.learning_rate);
858    println!();
859
860    // Generate synthetic training data
861    println!("=== Data Generation ===");
862    if quick_mode {
863        println!("(Quick mode enabled - reduced data and model size)");
864    }
865    let start_time = Instant::now();
866    let mut generator = HvacDataGenerator::new(42);
867
868    let (normal_samples, failure_scenarios, failure_duration) = if quick_mode {
869        (5000, 5, 300) // Much smaller for quick testing
870    } else {
871        (50000, 30, 1800) // Full dataset
872    };
873
874    let (mut raw_data, raw_labels) =
875        generator.generate_training_dataset(normal_samples, failure_scenarios, failure_duration);
876
877    // Normalize data
878    let n_samples = raw_labels.len();
879    normalize_data(&mut raw_data, n_samples);
880
881    // Generate sequences
882    println!("Creating multi-horizon sequences...");
883    let stride = if quick_mode { 30 } else { 10 };
884    let (x_data, y_imminent, y_warning, y_early) = generator.generate_multi_horizon_sequences(
885        &raw_data,
886        &raw_labels,
887        model_config.seq_len,
888        &[300, 900, 1800], // 5, 15, 30 minutes
889        stride,
890    );
891
892    let n_sequences = y_imminent.len();
893    println!("Generated {} sequences", n_sequences);
894    println!("Data generation took: {:?}", start_time.elapsed());
895    println!();
896
897    // Split train/val
898    let val_size = (n_sequences as f32 * train_config.val_split) as usize;
899    let train_size = n_sequences - val_size;
900    println!(
901        "Train: {} sequences, Val: {} sequences",
902        train_size, val_size
903    );
904
905    // Create model
906    println!();
907    println!("=== Model ===");
908    let model = HvacPredictor::new(model_config.clone());
909    println!("Parameters: {}", model.num_parameters());
910
911    // Create optimizer and loss
912    let mut optimizer = Adam::new(model.parameters(), train_config.learning_rate);
913    let loss_fn = CrossEntropyLoss::new();
914
915    // Training loop
916    println!();
917    println!("=== Training ===");
918    let training_start = Instant::now();
919
920    for epoch in 0..train_config.epochs {
921        let epoch_start = Instant::now();
922
923        let (loss, acc_imm, acc_warn, acc_early) = train_epoch(
924            &model,
925            &mut optimizer,
926            &loss_fn,
927            &x_data[..(train_size * model_config.seq_len * model_config.num_features)],
928            &y_imminent[..train_size],
929            &y_warning[..train_size],
930            &y_early[..train_size],
931            train_config.batch_size,
932            model_config.seq_len,
933            model_config.num_features,
934        );
935
936        if epoch % train_config.print_every == 0 || epoch == train_config.epochs - 1 {
937            println!(
938                "Epoch {:3}/{}: Loss={:.4}, Acc(5m)={:.2}%, Acc(15m)={:.2}%, Acc(30m)={:.2}% [{:?}]",
939                epoch + 1,
940                train_config.epochs,
941                loss,
942                acc_imm * 100.0,
943                acc_warn * 100.0,
944                acc_early * 100.0,
945                epoch_start.elapsed()
946            );
947        }
948    }
949
950    println!();
951    println!("Training completed in {:?}", training_start.elapsed());
952    println!();
953    println!("Model ready for deployment!");
954}