use crate::fock_forward::fock_forward;
use crate::transformer::QCT;
pub fn generate(model: &QCT, prompt: &[usize], num_tokens: usize, temperature: f32, seed: u64) -> Vec<usize> {
let mut sequence = prompt.to_vec();
let mut rng_state = seed;
for _ in 0..num_tokens {
let context_len = 32.min(sequence.len());
let start = sequence.len() - context_len;
let context = &sequence[start..];
let (logits, _) = model.forward(context);
let last_logits = logits.last().unwrap();
let token = sample_token(last_logits, temperature, rng_state);
rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
sequence.push(token);
}
sequence
}
pub fn generate_fast(model: &QCT, prompt: &[usize], num_tokens: usize, temperature: f32, seed: u64) -> Vec<usize> {
let mut sequence = prompt.to_vec();
let mut rng_state = seed;
for _ in 0..num_tokens {
let context_len = 32.min(sequence.len());
let start = sequence.len() - context_len;
let context = &sequence[start..];
let (logits, _, _cache) = fock_forward(model, context);
let last_logits = logits.last().unwrap();
let token = sample_token(last_logits, temperature, rng_state);
rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
sequence.push(token);
}
sequence
}
fn sample_token(logits: &[f32], temperature: f32, seed: u64) -> usize {
let temp = temperature.max(0.01);
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let scaled: Vec<f32> = logits.iter().map(|&l| ((l - max_logit) / temp).exp()).collect();
let sum: f32 = scaled.iter().sum();
let probs: Vec<f32> = scaled.iter().map(|&s| s / sum).collect();
let r = splitmix_f32(seed);
let mut cumulative = 0.0f32;
for (i, &p) in probs.iter().enumerate() {
cumulative += p;
if r < cumulative {
return i;
}
}
probs.len() - 1
}
fn splitmix_f32(seed: u64) -> f32 {
let mut z = seed.wrapping_add(0x9e3779b97f4a7c15);
z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
z = z ^ (z >> 31);
(z as f32) / (u64::MAX as f32)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transformer::{QCTConfig, QCT};
#[test]
fn generation_produces_tokens() {
let model = QCT::new(QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
});
let prompt = vec![0, 1, 2, 3];
let generated = generate(&model, &prompt, 20, 1.0, 42);
assert_eq!(generated.len(), 24); for &t in &generated {
assert!(t < 10, "token should be in vocab range");
}
}
#[test]
fn generation_deterministic() {
let model = QCT::new(QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
});
let prompt = vec![0, 1, 2];
let a = generate(&model, &prompt, 10, 1.0, 42);
let b = generate(&model, &prompt, 10, 1.0, 42);
assert_eq!(a, b, "generation should be deterministic");
}
#[test]
fn temperature_affects_distribution() {
let model = QCT::new(QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
});
let prompt = vec![0, 1, 2];
let hot = generate(&model, &prompt, 20, 2.0, 42);
let cold = generate(&model, &prompt, 20, 0.1, 42);
for &t in &hot {
assert!(t < 10);
}
for &t in &cold {
assert!(t < 10);
}
let diffs: usize = hot.iter().zip(cold.iter()).filter(|(a, b)| a != b).count();
assert!(diffs > 0, "different temperatures should produce different sequences");
}
}