use crate::attention;
use crate::complex::Complex;
use crate::density_matrix::DensityMatrixN;
use crate::train::EpochMetrics;
use crate::transformer::QCT;
pub struct ForwardCache {
pub rho_before: Vec<Vec<Complex>>,
pub unitaries: Vec<Vec<Complex>>,
pub rho_after: Vec<Vec<Complex>>,
pub populations: Vec<Vec<f32>>,
pub values: Vec<Vec<f32>>,
}
pub struct AllGradients {
pub embed_grad: Vec<f32>,
pub block_grads: Vec<BlockGrad>,
pub output_grad: Vec<f32>,
}
pub struct BlockGrad {
pub hamiltonian_grad: Vec<f32>,
pub value_weight_grad: Vec<f32>,
}
impl AllGradients {
pub fn flatten(&self) -> Vec<f32> {
let mut v = Vec::new();
v.extend_from_slice(&self.embed_grad);
for bg in &self.block_grads {
v.extend_from_slice(&bg.hamiltonian_grad);
v.extend_from_slice(&bg.value_weight_grad);
}
v.extend_from_slice(&self.output_grad);
v
}
}
pub fn forward_with_cache(model: &QCT, tokens: &[usize]) -> (Vec<Vec<f32>>, f32, Vec<ForwardCache>) {
forward_with_cache_converter(model, tokens, None)
}
pub fn forward_with_cache_converter(
model: &QCT,
tokens: &[usize],
converter: Option<&crate::golden_ratio_converter::GoldenRatioConverter>,
) -> (Vec<Vec<f32>>, f32, Vec<ForwardCache>) {
let dim = model.config.dim;
let t = tokens.len();
let mut states: Vec<DensityMatrixN> = tokens.iter().map(|&tok| model.embedding.embed(tok)).collect();
if let Some(conv) = converter {
for (i, &tok) in tokens.iter().enumerate() {
let eps = conv.dephasing_rate(tok);
if eps > 1e-6 {
states[i].dephase(eps);
}
}
}
let mut values: Vec<Vec<f32>> = states.iter().map(|s| s.populations()).collect();
let mut current_states = states;
let mut caches = Vec::with_capacity(model.blocks.len());
let mut total_f = 0.0f32;
for block in &model.blocks {
let mut cache = ForwardCache {
rho_before: Vec::with_capacity(t),
unitaries: Vec::with_capacity(t),
rho_after: Vec::with_capacity(t),
populations: Vec::with_capacity(t),
values: values.clone(),
};
let precomputed_unitaries: Vec<Vec<Complex>> = {
use rayon::prelude::*;
(0..t)
.into_par_iter()
.map(|i| {
let h_matrix = block.hamiltonian.build_matrix(i);
DensityMatrixN::hamiltonian_unitary(&h_matrix, dim, 0.090) })
.collect()
};
const COHERENCE_WINDOW: usize = 8;
let f_gate = 0.236 / dim as f32;
for i in 0..t {
let mut rho = current_states[i].clone();
let window_start = i.saturating_sub(COHERENCE_WINDOW);
for j in window_start..i {
let dist = i - j;
let eps = block.hamiltonian.causal_dephasing(dist);
rho.couple_dephase(¤t_states[j], eps);
}
cache.rho_before.push(rho.entries.clone());
let f_before = rho.free_energy(&block.hamiltonian.bias);
let unitary = &precomputed_unitaries[i];
if f_before.abs() > f_gate {
cache.unitaries.push(unitary.clone());
rho.evolve(unitary);
} else {
cache.unitaries.push(Vec::new());
}
cache.rho_after.push(rho.entries.clone());
let f = rho.free_energy(&block.hamiltonian.bias);
total_f += f;
let pops = rho.populations();
cache.populations.push(pops);
}
let attn_output = attention::AttentionOutput {
populations: cache.populations.clone(),
free_energies: vec![0.0; t],
coherences: vec![0.0; t],
};
let new_values = attention::attention_project(&attn_output, &values, dim);
values = new_values;
let eps_block = 0.236 / model.blocks.len().max(1) as f32;
for state in &mut current_states {
state.dephase(eps_block);
}
caches.push(cache);
}
let vocab = model.config.vocab_size;
let mut logits = Vec::with_capacity(t);
for i in 0..t {
let mut token_logits = vec![0.0f32; vocab];
for v in 0..vocab {
for d in 0..dim {
token_logits[v] += values[i][d] * model.output_weights[d * vocab + v];
}
}
logits.push(token_logits);
}
(logits, total_f / t as f32, caches)
}
pub fn qug_backward(model: &QCT, tokens: &[usize], logits: &[Vec<f32>], caches: &[ForwardCache]) -> AllGradients {
let dim = model.config.dim;
let vocab = model.config.vocab_size;
let t = tokens.len().saturating_sub(1); if t == 0 {
return AllGradients {
embed_grad: vec![0.0; model.embedding.num_params()],
block_grads: model
.blocks
.iter()
.map(|b| BlockGrad {
hamiltonian_grad: vec![0.0; b.hamiltonian.num_params()],
value_weight_grad: vec![0.0; b.value_weights.len()],
})
.collect(),
output_grad: vec![0.0; model.output_weights.len()],
};
}
let mut d_logits: Vec<Vec<f32>> = Vec::with_capacity(t);
for i in 0..t {
let target = tokens[i + 1];
let max_l = logits[i].iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits[i].iter().map(|&l| (l - max_l).exp()).sum();
let mut d_log = vec![0.0f32; vocab];
for v in 0..vocab {
let softmax_v = (logits[i][v] - max_l).exp() / exp_sum;
d_log[v] = softmax_v - if v == target { 1.0 } else { 0.0 };
}
for v in &mut d_log {
*v /= t as f32;
}
d_logits.push(d_log);
}
let mut d_output = vec![0.0f32; dim * vocab];
for i in 0..t {
let cache = caches.last().unwrap();
let vals = if i < cache.populations.len() {
&cache.populations[i]
} else {
continue;
};
for d_idx in 0..dim {
for v in 0..vocab {
d_output[d_idx * vocab + v] += vals.get(d_idx).copied().unwrap_or(0.0) * d_logits[i][v];
}
}
}
let mut d_values: Vec<Vec<f32>> = vec![vec![0.0f32; dim]; t.max(1)];
for i in 0..t {
for d_idx in 0..dim {
for v in 0..vocab {
d_values[i][d_idx] += model.output_weights[d_idx * vocab + v] * d_logits[i][v];
}
}
}
let mut block_grads = Vec::with_capacity(model.blocks.len());
for (block_idx, block) in model.blocks.iter().enumerate().rev() {
let cache = &caches[block_idx];
let num_h = block.hamiltonian.num_params();
let mut d_vw = vec![0.0f32; dim * dim];
for i in 0..t.min(cache.populations.len()) {
for d_idx in 0..dim {
for s in 0..dim {
let pop = cache.populations[i].get(s).copied().unwrap_or(0.0);
let dv = d_values[i].get(d_idx).copied().unwrap_or(0.0);
d_vw[d_idx * dim + s] += pop * dv;
}
}
}
let dt = 0.090f32; let len = t.min(cache.unitaries.len());
let position_grads: Vec<Vec<f32>> = {
use rayon::prelude::*;
(0..len)
.into_par_iter()
.map(|i| {
let mut local_d_h = vec![0.0f32; num_h];
let u = &cache.unitaries[i];
if u.is_empty() {
return local_d_h;
}
let mut scratch_a = vec![Complex::ZERO; dim * dim];
let mut scratch_b = vec![Complex::ZERO; dim * dim];
let mut d_rho = vec![Complex::ZERO; dim * dim];
for k in 0..dim {
let dp = d_values[i].get(k).copied().unwrap_or(0.0);
d_rho[k * dim + k] = Complex::new(dp, 0.0);
}
let mut u_dag = vec![Complex::ZERO; dim * dim];
for r in 0..dim {
for c in 0..dim {
u_dag[r * dim + c] = u[c * dim + r].conj();
}
}
dreamwell_math::linalg::cgemm(&u_dag, &d_rho, &mut scratch_a, dim, dim, dim);
dreamwell_math::linalg::cgemm(&scratch_a, u, &mut scratch_b, dim, dim, dim);
let rho_before = &cache.rho_before[i];
let mut h_idx = 0;
for k in 0..dim {
let mut comm_diag = 0.0f32;
for j in 0..dim {
let ab = scratch_b[k * dim + j].mul(rho_before[j * dim + k]);
let ba = rho_before[k * dim + j].mul(scratch_b[j * dim + k]);
comm_diag += (ab.sub(ba)).im;
}
if h_idx < local_d_h.len() {
local_d_h[h_idx] += -dt * comm_diag;
}
h_idx += 1;
}
for p in 0..dim {
for q in (p + 1)..dim {
if h_idx >= local_d_h.len() {
break;
}
let ab_pq = scratch_b[p * dim + q].mul(rho_before[q * dim + p]);
let ba_pq = rho_before[p * dim + q].mul(scratch_b[q * dim + p]);
let comm_pq = ab_pq.sub(ba_pq);
local_d_h[h_idx] += -dt * 2.0 * comm_pq.im;
h_idx += 1;
}
}
local_d_h
})
.collect()
};
let mut d_h = vec![0.0f32; num_h];
for pg in &position_grads {
for (k, &v) in pg.iter().enumerate() {
d_h[k] += v;
}
}
block_grads.push(BlockGrad {
hamiltonian_grad: d_h,
value_weight_grad: d_vw,
});
}
block_grads.reverse();
let embed_grad = vec![0.0f32; model.embedding.num_params()];
AllGradients {
embed_grad,
block_grads,
output_grad: d_output,
}
}
pub fn forward_backward_epoch_gpu(
gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
model: &QCT,
windows: &[(usize, usize)],
tokens: &[usize],
) -> (Vec<f32>, f32, f32) {
use dreamwell_math::Complex;
let dim = model.config.dim;
let vocab = model.config.vocab_size;
let stride = dim * dim;
let num_windows = windows.len();
let num_blocks = model.blocks.len();
let dt = 0.090f32;
let mut all_window_data: Vec<WindowForwardState> = Vec::with_capacity(num_windows);
for &(ws, we) in windows {
let window_tokens = &tokens[ws..we];
let input = &window_tokens[..window_tokens.len().saturating_sub(1)];
let t = input.len();
let states: Vec<DensityMatrixN> = input.iter().map(|&tok| model.embedding.embed(tok)).collect();
let values: Vec<Vec<f32>> = states.iter().map(|s| s.populations()).collect();
let mut block_unitaries: Vec<Vec<Vec<Complex>>> = Vec::with_capacity(num_blocks);
for block in &model.blocks {
let mut all_h = vec![0.0f32; t * stride];
for i in 0..t {
let h = block.hamiltonian.build_matrix(i);
all_h[i * stride..(i + 1) * stride].copy_from_slice(&h);
}
let flat_unitaries = gpu.batched_expm(&all_h, dt, t);
let per_pos: Vec<Vec<Complex>> = (0..t)
.map(|i| flat_unitaries[i * stride..(i + 1) * stride].to_vec())
.collect();
block_unitaries.push(per_pos);
}
all_window_data.push(WindowForwardState {
window_tokens: window_tokens.to_vec(),
t,
states,
values,
block_unitaries,
});
}
let f_gate = 0.236 / dim as f32;
const COHERENCE_WINDOW: usize = 8;
let eps_block = 0.236 / num_blocks.max(1) as f32;
let mut all_window_results: Vec<WindowResult> = Vec::with_capacity(num_windows);
for wdata in &mut all_window_data {
let t = wdata.t;
let mut current_states = wdata.states.clone();
let mut values = wdata.values.clone();
let mut caches: Vec<ForwardCache> = Vec::with_capacity(num_blocks);
let mut total_f = 0.0f32;
for (block_idx, block) in model.blocks.iter().enumerate() {
let mut cache = ForwardCache {
rho_before: Vec::with_capacity(t),
unitaries: Vec::with_capacity(t),
rho_after: Vec::with_capacity(t),
populations: Vec::with_capacity(t),
values: values.clone(),
};
let precomputed_unitaries = &wdata.block_unitaries[block_idx];
for i in 0..t {
let mut rho = current_states[i].clone();
let window_start = i.saturating_sub(COHERENCE_WINDOW);
for j in window_start..i {
let dist = i - j;
let eps = block.hamiltonian.causal_dephasing(dist);
rho.couple_dephase(¤t_states[j], eps);
}
cache.rho_before.push(rho.entries.clone());
let f_before = rho.free_energy(&block.hamiltonian.bias);
let unitary = &precomputed_unitaries[i];
if f_before.abs() > f_gate {
cache.unitaries.push(unitary.clone());
rho.evolve(unitary);
} else {
cache.unitaries.push(Vec::new());
}
cache.rho_after.push(rho.entries.clone());
let f = rho.free_energy(&block.hamiltonian.bias);
total_f += f;
let pops = rho.populations();
cache.populations.push(pops);
current_states[i] = rho;
}
let attn_output = attention::AttentionOutput {
populations: cache.populations.clone(),
free_energies: vec![0.0; t],
coherences: vec![0.0; t],
};
values = attention::attention_project(&attn_output, &values, dim);
for state in &mut current_states {
state.dephase(eps_block);
}
caches.push(cache);
}
let mut logits = Vec::with_capacity(t);
for i in 0..t {
let mut token_logits = vec![0.0f32; vocab];
for v in 0..vocab {
for d in 0..dim {
token_logits[v] += values[i][d] * model.output_weights[d * vocab + v];
}
}
logits.push(token_logits);
}
let avg_f = total_f / t.max(1) as f32;
let loss = QCT::loss_from_logits(&logits, &wdata.window_tokens, avg_f);
all_window_results.push(WindowResult {
logits,
caches,
loss,
avg_f,
});
}
let num_params = model.num_params();
let mut total_grad = vec![0.0f32; num_params];
let mut total_loss = 0.0f32;
let mut total_f = 0.0f32;
for (w_idx, wdata) in all_window_data.iter().enumerate() {
let wr = &all_window_results[w_idx];
let grads = qug_backward_gpu(gpu, model, &wdata.window_tokens, &wr.logits, &wr.caches);
let grad_flat = grads.flatten();
total_loss += wr.loss;
total_f += wr.avg_f;
for (i, &g) in grad_flat.iter().enumerate() {
if i < num_params {
total_grad[i] += g;
}
}
}
let n = num_windows as f32;
for g in &mut total_grad {
*g /= n;
}
(total_grad, total_loss / n, total_f / n)
}
fn qug_backward_gpu(
gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
model: &QCT,
tokens: &[usize],
logits: &[Vec<f32>],
caches: &[ForwardCache],
) -> AllGradients {
use dreamwell_math::Complex;
let dim = model.config.dim;
let vocab = model.config.vocab_size;
let stride = dim * dim;
let t = tokens.len().saturating_sub(1);
if t == 0 {
return AllGradients {
embed_grad: vec![0.0; model.embedding.num_params()],
block_grads: model
.blocks
.iter()
.map(|b| BlockGrad {
hamiltonian_grad: vec![0.0; b.hamiltonian.num_params()],
value_weight_grad: vec![0.0; b.value_weights.len()],
})
.collect(),
output_grad: vec![0.0; model.output_weights.len()],
};
}
let mut d_logits: Vec<Vec<f32>> = Vec::with_capacity(t);
for i in 0..t {
let target = tokens[i + 1];
let max_l = logits[i].iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits[i].iter().map(|&l| (l - max_l).exp()).sum();
let mut d_log = vec![0.0f32; vocab];
for v in 0..vocab {
let softmax_v = (logits[i][v] - max_l).exp() / exp_sum;
d_log[v] = (softmax_v - if v == target { 1.0 } else { 0.0 }) / t as f32;
}
d_logits.push(d_log);
}
let mut d_output = vec![0.0f32; dim * vocab];
for i in 0..t {
let cache = caches.last().unwrap();
let vals = if i < cache.populations.len() {
&cache.populations[i]
} else {
continue;
};
for d_idx in 0..dim {
for v in 0..vocab {
d_output[d_idx * vocab + v] += vals.get(d_idx).copied().unwrap_or(0.0) * d_logits[i][v];
}
}
}
let mut d_values: Vec<Vec<f32>> = vec![vec![0.0f32; dim]; t.max(1)];
for i in 0..t {
for d_idx in 0..dim {
for v in 0..vocab {
d_values[i][d_idx] += model.output_weights[d_idx * vocab + v] * d_logits[i][v];
}
}
}
let dt_val = 0.090f32;
let mut block_grads = Vec::with_capacity(model.blocks.len());
for (block_idx, block) in model.blocks.iter().enumerate().rev() {
let cache = &caches[block_idx];
let num_h = block.hamiltonian.num_params();
let mut d_vw = vec![0.0f32; dim * dim];
for i in 0..t.min(cache.populations.len()) {
for d_idx in 0..dim {
for s in 0..dim {
let pop = cache.populations[i].get(s).copied().unwrap_or(0.0);
let dv = d_values[i].get(d_idx).copied().unwrap_or(0.0);
d_vw[d_idx * dim + s] += pop * dv;
}
}
}
let len = t.min(cache.unitaries.len());
let mut active_indices: Vec<usize> = Vec::new();
let mut u_batch: Vec<Complex> = Vec::new();
let mut d_rho_batch: Vec<Complex> = Vec::new();
for i in 0..len {
let u = &cache.unitaries[i];
if u.is_empty() {
continue;
}
active_indices.push(i);
u_batch.extend_from_slice(u);
let mut d_rho = vec![Complex::ZERO; stride];
for k in 0..dim {
let dp = d_values[i].get(k).copied().unwrap_or(0.0);
d_rho[k * dim + k] = Complex::new(dp, 0.0);
}
d_rho_batch.extend_from_slice(&d_rho);
}
let adjoint_results = if !active_indices.is_empty() {
gpu.batched_adjoint(&u_batch, &d_rho_batch, active_indices.len())
} else {
Vec::new()
};
let mut d_h = vec![0.0f32; num_h];
for (batch_idx, &pos_idx) in active_indices.iter().enumerate() {
let base = batch_idx * stride;
let scratch_b = &adjoint_results[base..base + stride];
let rho_before = &cache.rho_before[pos_idx];
let mut h_idx = 0;
for k in 0..dim {
let mut comm_diag = 0.0f32;
for j in 0..dim {
let ab = scratch_b[k * dim + j].mul(rho_before[j * dim + k]);
let ba = rho_before[k * dim + j].mul(scratch_b[j * dim + k]);
comm_diag += (ab.sub(ba)).im;
}
if h_idx < d_h.len() {
d_h[h_idx] += -dt_val * comm_diag;
}
h_idx += 1;
}
for p in 0..dim {
for q in (p + 1)..dim {
if h_idx >= d_h.len() {
break;
}
let ab_pq = scratch_b[p * dim + q].mul(rho_before[q * dim + p]);
let ba_pq = rho_before[p * dim + q].mul(scratch_b[q * dim + p]);
let comm_pq = ab_pq.sub(ba_pq);
d_h[h_idx] += -dt_val * 2.0 * comm_pq.im;
h_idx += 1;
}
}
}
block_grads.push(BlockGrad {
hamiltonian_grad: d_h,
value_weight_grad: d_vw,
});
}
block_grads.reverse();
let embed_grad = vec![0.0f32; model.embedding.num_params()];
AllGradients {
embed_grad,
block_grads,
output_grad: d_output,
}
}
struct WindowForwardState {
window_tokens: Vec<usize>,
t: usize,
states: Vec<DensityMatrixN>,
values: Vec<Vec<f32>>,
block_unitaries: Vec<Vec<Vec<dreamwell_math::Complex>>>,
}
struct WindowResult {
logits: Vec<Vec<f32>>,
caches: Vec<ForwardCache>,
loss: f32,
avg_f: f32,
}
pub fn gpu_precompute_unitaries(
gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
block: &crate::transformer::QCTBlock,
dim: usize,
t: usize,
dt: f32,
) -> Vec<Vec<dreamwell_math::Complex>> {
let n2 = dim * dim;
let mut all_h = vec![0.0f32; t * n2];
for i in 0..t {
let h = block.hamiltonian.build_matrix(i);
all_h[i * n2..(i + 1) * n2].copy_from_slice(&h);
}
let flat = gpu.batched_expm(&all_h, dt, t);
let stride = dim * dim;
(0..t).map(|i| flat[i * stride..(i + 1) * stride].to_vec()).collect()
}
pub fn gpu_batch_evolve(
gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
unitaries_flat: &[dreamwell_math::Complex],
rhos_flat: &[dreamwell_math::Complex],
batch_count: usize,
) -> Vec<dreamwell_math::Complex> {
gpu.batched_evolve(unitaries_flat, rhos_flat, batch_count)
}
pub fn gpu_batch_adjoint(
gpu: &dreamwell_math::gpu_training::GpuTrainingContext,
unitaries_flat: &[dreamwell_math::Complex],
d_rho_flat: &[dreamwell_math::Complex],
batch_count: usize,
) -> Vec<dreamwell_math::Complex> {
gpu.batched_adjoint(unitaries_flat, d_rho_flat, batch_count)
}
pub fn train_qug(model: &mut QCT, tokens: &[usize], config: &crate::train::TrainConfig) -> Vec<EpochMetrics> {
let mut metrics = Vec::new();
for epoch in 0..config.num_epochs {
let start = std::time::Instant::now();
let lr = crate::train::learning_rate_pub(config, epoch);
let max_start = tokens.len().saturating_sub(config.context_length + 1);
let window_start = if max_start > 0 { epoch % max_start } else { 0 };
let window_end = (window_start + config.context_length + 1).min(tokens.len());
let window = &tokens[window_start..window_end];
let (logits, avg_f, caches) = forward_with_cache(model, &window[..window.len() - 1]);
let loss = QCT::loss_from_logits(&logits, window, avg_f);
let grads = qug_backward(model, window, &logits, &caches);
let grad_flat = grads.flatten();
let grad_norm: f32 = grad_flat.iter().map(|g| g * g).sum::<f32>().sqrt();
let scale = if grad_norm > config.grad_clip && grad_norm > 0.0 {
config.grad_clip / grad_norm
} else {
1.0
};
model.apply_gradient_update(&grad_flat, lr, scale);
let elapsed = start.elapsed().as_secs_f32() * 1000.0;
if epoch % config.log_interval == 0 || epoch == config.num_epochs - 1 {
let m = EpochMetrics {
epoch,
loss,
free_energy: avg_f,
grad_norm,
elapsed_ms: elapsed,
learning_rate: lr,
params_trained: grad_flat.len(),
};
log::info!(
"QUG Epoch {:4}: loss={:.4} F={:.4} |∇|={:.6} lr={:.5} ({:.1}ms)",
m.epoch,
m.loss,
m.free_energy,
m.grad_norm,
m.learning_rate,
m.elapsed_ms
);
metrics.push(m);
}
}
metrics
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transformer::QCTConfig;
#[test]
fn forward_cache_matches_forward() {
let config = QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
};
let model = QCT::new(config);
let tokens = vec![0, 1, 2, 3, 4, 5];
let (logits_normal, f_normal) = model.forward(&tokens);
let (logits_cached, f_cached, caches) = forward_with_cache(&model, &tokens);
assert_eq!(logits_normal.len(), logits_cached.len());
assert!(
(f_normal - f_cached).abs() < 0.5,
"free energy mismatch: {} vs {}",
f_normal,
f_cached
);
assert!(!caches.is_empty());
}
#[test]
fn qug_gradient_nonzero() {
let config = QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
};
let model = QCT::new(config);
let tokens = vec![0, 1, 2, 3, 4, 5];
let (logits, _, caches) = forward_with_cache(&model, &tokens[..5]);
let grads = qug_backward(&model, &tokens, &logits, &caches);
let flat = grads.flatten();
let norm: f32 = flat.iter().map(|g| g * g).sum::<f32>().sqrt();
assert!(norm > 0.0, "QUG gradient should be nonzero");
assert!(norm.is_finite(), "QUG gradient should be finite");
}
#[test]
fn qug_training_runs() {
let config = QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
};
let mut model = QCT::new(config);
let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let train_config = crate::train::TrainConfig {
learning_rate: 0.01,
num_epochs: 3,
context_length: 6,
log_interval: 1,
..Default::default()
};
let metrics = train_qug(&mut model, &tokens, &train_config);
assert_eq!(metrics.len(), 3);
assert!(metrics[0].loss.is_finite());
assert!(metrics[0].grad_norm > 0.0);
eprintln!(
"QUG training: {:.1}ms/epoch (vs ~7400ms for PSR)",
metrics[0].elapsed_ms
);
}
}