Skip to main content

trident/neural/training/
supervised.rs

1//! Stage 1: Supervised pre-training with cross-entropy loss.
2//!
3//! Teacher forcing with grammar mask penalties. Trains the composite
4//! model (GNN encoder + Transformer decoder) on (TirGraph, TASM) pairs.
5
6use burn::grad_clipping::GradientClippingConfig;
7use burn::optim::{AdamWConfig, GradientsParams, Optimizer};
8use burn::prelude::*;
9use burn::tensor::activation;
10
11use crate::neural::data::pairs::TrainingPair;
12use crate::neural::data::tir_graph::NODE_FEATURE_DIM;
13use crate::neural::model::composite::NeuralCompilerV2;
14use crate::neural::model::grammar::precompute_sequence_state;
15
16/// Supervised training configuration.
17pub struct SupervisedConfig {
18    /// Initial learning rate.
19    pub lr: f64,
20    /// Minimum learning rate (cosine decay target).
21    pub lr_min: f64,
22    /// Weight decay.
23    pub weight_decay: f64,
24    /// Gradient clipping norm.
25    pub grad_clip: f32,
26    /// Maximum epochs.
27    pub max_epochs: usize,
28    /// Early stopping patience (epochs without improvement).
29    pub patience: usize,
30}
31
32impl Default for SupervisedConfig {
33    fn default() -> Self {
34        Self {
35            lr: 3e-4,
36            lr_min: 1e-5,
37            weight_decay: 0.01,
38            grad_clip: 1.0,
39            max_epochs: 100,
40            patience: 3,
41        }
42    }
43}
44
45/// Cosine annealing learning rate: lr_min + 0.5*(lr - lr_min)*(1 + cos(pi*t/T))
46pub fn cosine_lr(config: &SupervisedConfig, epoch: usize, total_epochs: usize) -> f64 {
47    if total_epochs <= 1 {
48        return config.lr;
49    }
50    let t = epoch as f64 / total_epochs as f64;
51    config.lr_min + 0.5 * (config.lr - config.lr_min) * (1.0 + (std::f64::consts::PI * t).cos())
52}
53
54/// Result of one training epoch.
55pub struct EpochResult {
56    /// Average cross-entropy loss over all pairs.
57    pub avg_loss: f32,
58    /// Number of training pairs processed.
59    pub num_pairs: usize,
60}
61
62/// Train one epoch of supervised learning on the given pairs.
63///
64/// Uses teacher forcing: at each step, the ground-truth previous token
65/// is provided as input. Grammar masks are applied as logit penalties.
66///
67/// Returns the model with updated weights and the epoch result.
68pub fn train_epoch<B: burn::tensor::backend::AutodiffBackend>(
69    model: NeuralCompilerV2<B>,
70    pairs: &[TrainingPair],
71    optimizer: &mut impl Optimizer<NeuralCompilerV2<B>, B>,
72    lr: f64,
73    device: &B::Device,
74) -> (NeuralCompilerV2<B>, EpochResult) {
75    let mut total_loss = 0.0f32;
76    let mut model = model;
77
78    for pair in pairs {
79        // 1. Prepare graph inputs
80        let node_features = graph_to_features::<B>(&pair.graph, device);
81        let (edge_src, edge_dst, edge_types) = graph_to_edges::<B>(&pair.graph, device);
82
83        // 2. Encode graph
84        let (node_emb, _global) =
85            model
86                .encoder
87                .forward(node_features, edge_src, edge_dst, edge_types);
88        // node_emb: [N, d_model] → expand to [1, N, d_model] for batch=1
89        let d_model = node_emb.dims()[1];
90        let num_nodes = node_emb.dims()[0];
91        let memory = node_emb.unsqueeze_dim::<3>(0);
92
93        // 3. Prepare decoder inputs (teacher forcing)
94        // Truncate to max_seq=256 to fit position embedding table
95        const MAX_SEQ: usize = 256;
96        let tokens = if pair.target_tokens.len() > MAX_SEQ {
97            &pair.target_tokens[..MAX_SEQ]
98        } else {
99            &pair.target_tokens
100        };
101        let seq_len = tokens.len();
102        if seq_len < 2 {
103            continue; // Need at least input + one target
104        }
105
106        // Input tokens: [0, t0, t1, ..., t_{n-2}] (shifted right, prepend EOS=0)
107        let mut input_tokens = vec![0i32]; // Start with EOS
108        for &t in &tokens[..seq_len - 1] {
109            input_tokens.push(t as i32);
110        }
111        let token_ids =
112            Tensor::<B, 2, Int>::from_data(TensorData::new(input_tokens, [1, seq_len]), device);
113
114        // Positions: [0, 1, 2, ...]
115        let positions = Tensor::<B, 2, Int>::from_data(
116            TensorData::new((0..seq_len as i32).collect::<Vec<_>>(), [1, seq_len]),
117            device,
118        );
119
120        // Precompute grammar state for the (truncated) target sequence
121        let state = precompute_sequence_state(tokens, 0);
122
123        let stack_depths = Tensor::<B, 2, Int>::from_data(
124            TensorData::new(
125                state
126                    .depths
127                    .iter()
128                    .map(|&d| (d as i32).min(64))
129                    .collect::<Vec<_>>(),
130                [1, seq_len],
131            ),
132            device,
133        );
134
135        let type_data: Vec<f32> = state.type_states.into_iter().flatten().collect();
136        let type_states =
137            Tensor::<B, 3>::from_data(TensorData::new(type_data, [1, seq_len, 24]), device);
138
139        // 4. Forward pass
140        let memory_expanded = memory.expand([1, num_nodes, d_model]);
141        let logits = model.decoder.forward(
142            token_ids,
143            positions,
144            stack_depths,
145            type_states,
146            memory_expanded,
147        );
148        // logits: [1, seq_len, VOCAB_SIZE]
149
150        // 5. Cross-entropy loss (no grammar mask during teacher forcing)
151        //
152        // Grammar masks are for inference-time beam search. During supervised
153        // training with teacher forcing, the target tokens ARE correct — applying
154        // -1e9 penalties to them causes loss explosion (~1e8). The stack depth
155        // and type state features above already provide grammar awareness to the
156        // decoder as input conditioning.
157        let targets = Tensor::<B, 2, Int>::from_data(
158            TensorData::new(
159                tokens.iter().map(|&t| t as i32).collect::<Vec<_>>(),
160                [1, seq_len],
161            ),
162            device,
163        );
164
165        let loss = cross_entropy_loss(logits, targets);
166        let loss_val: f32 = loss.clone().into_data().to_vec::<f32>().unwrap()[0];
167        total_loss += loss_val;
168
169        // 7. Backward pass + optimizer step
170        let grads = loss.backward();
171        let grads = GradientsParams::from_grads(grads, &model);
172        model = optimizer.step(lr, model, grads);
173    }
174
175    let avg_loss = if pairs.is_empty() {
176        0.0
177    } else {
178        total_loss / pairs.len() as f32
179    };
180
181    (
182        model,
183        EpochResult {
184            avg_loss,
185            num_pairs: pairs.len(),
186        },
187    )
188}
189
190/// Cross-entropy loss between logits and targets.
191/// logits: [batch, seq, vocab], targets: [batch, seq]
192fn cross_entropy_loss<B: Backend>(
193    logits: Tensor<B, 3>,
194    targets: Tensor<B, 2, Int>,
195) -> Tensor<B, 1> {
196    let [batch, seq, vocab] = logits.dims();
197
198    // Reshape to [batch*seq, vocab] for softmax
199    let logits_flat = logits.reshape([batch * seq, vocab]);
200    let targets_flat = targets.reshape([batch * seq]);
201
202    // Log-softmax
203    let log_probs = activation::log_softmax(logits_flat, 1);
204
205    // Gather the log-prob of the target class
206    let targets_2d: Tensor<B, 2, Int> = targets_flat.unsqueeze_dim::<2>(1);
207    let selected = log_probs.gather(1, targets_2d); // [batch*seq, 1]
208
209    // Negative mean
210    selected.mean().neg().unsqueeze()
211}
212
213/// Convert TirGraph nodes to a feature tensor.
214pub fn graph_to_features<B: Backend>(
215    graph: &crate::neural::data::tir_graph::TirGraph,
216    device: &B::Device,
217) -> Tensor<B, 2> {
218    let num_nodes = graph.nodes.len();
219    let mut data = vec![0.0f32; num_nodes * NODE_FEATURE_DIM];
220    for (i, node) in graph.nodes.iter().enumerate() {
221        let fv = node.feature_vector();
222        data[i * NODE_FEATURE_DIM..(i + 1) * NODE_FEATURE_DIM].copy_from_slice(&fv);
223    }
224    Tensor::from_data(TensorData::new(data, [num_nodes, NODE_FEATURE_DIM]), device)
225}
226
227/// Convert TirGraph edges to index tensors.
228pub fn graph_to_edges<B: Backend>(
229    graph: &crate::neural::data::tir_graph::TirGraph,
230    device: &B::Device,
231) -> (Tensor<B, 1, Int>, Tensor<B, 1, Int>, Tensor<B, 1, Int>) {
232    let num_edges = graph.edges.len().max(1); // Need at least 1 edge for burn
233    let mut src = vec![0i32; num_edges];
234    let mut dst = vec![0i32; num_edges];
235    let mut types = vec![0i32; num_edges];
236
237    for (i, &(s, d, ref kind)) in graph.edges.iter().enumerate() {
238        src[i] = s as i32;
239        dst[i] = d as i32;
240        types[i] = match kind {
241            crate::neural::data::tir_graph::EdgeKind::DataDep => 0,
242            crate::neural::data::tir_graph::EdgeKind::ControlFlow => 1,
243            crate::neural::data::tir_graph::EdgeKind::MemOrder => 2,
244        };
245    }
246
247    (
248        Tensor::from_data(TensorData::new(src, [num_edges]), device),
249        Tensor::from_data(TensorData::new(dst, [num_edges]), device),
250        Tensor::from_data(TensorData::new(types, [num_edges]), device),
251    )
252}
253
254/// Create an AdamW optimizer with gradient clipping.
255pub fn create_optimizer<B: burn::tensor::backend::AutodiffBackend>(
256    config: &SupervisedConfig,
257) -> impl Optimizer<NeuralCompilerV2<B>, B> {
258    AdamWConfig::new()
259        .with_weight_decay(config.weight_decay as f32)
260        .with_grad_clipping(Some(GradientClippingConfig::Norm(config.grad_clip)))
261        .init()
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use crate::ir::tir::TIROp;
268    use crate::neural::data::pairs::extract_pairs;
269    use crate::neural::model::composite::NeuralCompilerConfig;
270    use crate::neural::model::vocab::Vocab;
271    use burn::backend::Autodiff;
272    use burn::backend::NdArray;
273
274    type B = Autodiff<NdArray>;
275
276    #[test]
277    fn train_epoch_runs() {
278        let device = Default::default();
279
280        let config = NeuralCompilerConfig {
281            d_model: 32,
282            d_edge: 8,
283            gnn_layers: 1,
284            decoder_layers: 1,
285            n_heads: 4,
286            d_ff: 64,
287            max_seq: 32,
288            dropout: 0.0,
289        };
290        let model = config.init::<B>(&device);
291
292        let vocab = Vocab::new();
293        let blocks = vec![(
294            vec![TIROp::Push(1), TIROp::Push(2), TIROp::Add],
295            vec!["push 1".into(), "push 2".into(), "add".into()],
296            "test:0..3".into(),
297            3u64,
298        )];
299        let pairs = extract_pairs(&blocks, &vocab);
300
301        let supervised_config = SupervisedConfig::default();
302        let mut optimizer = create_optimizer::<B>(&supervised_config);
303
304        let lr = supervised_config.lr;
305        let (model, result) = train_epoch(model, &pairs, &mut optimizer, lr, &device);
306        assert_eq!(result.num_pairs, 1);
307        assert!(result.avg_loss > 0.0, "loss should be positive");
308        assert!(result.avg_loss.is_finite(), "loss should be finite");
309
310        // Train a second epoch — loss should change
311        let (_model2, result2) = train_epoch(model, &pairs, &mut optimizer, lr, &device);
312        assert!(result2.avg_loss.is_finite());
313    }
314}