use crate::block::TransformerStack;
use crate::embedding::{Embedding, PositionalEncoding};
use crate::error::Result;
use crate::metrics;
use crate::sample::{self, SamplingConfig};
use crate::tensor::Tensor;
#[derive(Debug, Clone, PartialEq)]
pub enum StopReason {
MaxTokens,
StopToken(usize),
LowConfidence {
step: usize,
confidence: f64,
threshold: f64,
},
MaxSeqLen,
}
#[derive(Debug, Clone)]
pub struct GenerativeModel {
pub embedding: Embedding,
pub pos_encoding: PositionalEncoding,
pub transformer: TransformerStack,
pub output_proj: Tensor,
pub vocab_size: usize,
pub model_dim: usize,
}
impl GenerativeModel {
pub fn new(
vocab_size: usize,
model_dim: usize,
num_heads: usize,
num_layers: usize,
max_seq_len: usize,
rng: &mut impl rand::Rng,
) -> Result<Self> {
Self::with_ffn_dim(
vocab_size,
model_dim,
num_heads,
num_layers,
4 * model_dim,
max_seq_len,
rng,
)
}
pub fn with_ffn_dim(
vocab_size: usize,
model_dim: usize,
num_heads: usize,
num_layers: usize,
ffn_inner_dim: usize,
max_seq_len: usize,
rng: &mut impl rand::Rng,
) -> Result<Self> {
let embedding = Embedding::new(vocab_size, model_dim, rng);
let pos_encoding = PositionalEncoding::new(max_seq_len, model_dim);
let transformer =
TransformerStack::with_ffn_dim(num_layers, model_dim, num_heads, ffn_inner_dim, rng)?;
let output_proj = Tensor::xavier_uniform(&[model_dim, vocab_size], rng)?;
Ok(Self {
embedding,
pos_encoding,
transformer,
output_proj,
vocab_size,
model_dim,
})
}
pub fn forward(&self, token_ids: &[usize]) -> Result<ForwardOutput> {
self.forward_with_mask(token_ids, true)
}
pub fn forward_with_mask(&self, token_ids: &[usize], causal: bool) -> Result<ForwardOutput> {
let embedded = self.embedding.forward_batch(token_ids)?;
let positioned = self.pos_encoding.forward(&embedded)?;
let stack_out = self.transformer.forward(&positioned, causal)?;
let logits = stack_out.hidden.matmul(&self.output_proj)?;
Ok(ForwardOutput {
logits,
hidden: stack_out.hidden,
attention_weights: stack_out.all_attention_weights,
})
}
pub fn generate(
&self,
prompt: &[usize],
max_new_tokens: usize,
sampling: &SamplingConfig,
stop_token: Option<usize>,
rng: &mut impl rand::Rng,
) -> Result<GenerationResult> {
self.generate_gated(prompt, max_new_tokens, sampling, stop_token, None, rng)
}
pub fn generate_gated(
&self,
prompt: &[usize],
max_new_tokens: usize,
sampling: &SamplingConfig,
stop_token: Option<usize>,
min_confidence: Option<f64>,
rng: &mut impl rand::Rng,
) -> Result<GenerationResult> {
let max_seq = self.pos_encoding.max_seq_len;
let mut tokens = prompt.to_vec();
let mut all_logits = Vec::new();
let mut stop_reason = StopReason::MaxTokens;
for step in 0..max_new_tokens {
if tokens.len() >= max_seq {
stop_reason = StopReason::MaxSeqLen;
break;
}
let output = self.forward(&tokens)?;
let seq_len = tokens.len();
let last_logits = output.logits.row(seq_len - 1)?;
if let Some(threshold) = min_confidence {
let confidence = metrics::generation_confidence(&last_logits).unwrap_or(0.0);
if confidence < threshold {
stop_reason = StopReason::LowConfidence {
step,
confidence,
threshold,
};
break;
}
}
let next_token =
sample::sample_token_with_context(&last_logits, sampling, &tokens, rng)?;
all_logits.push(last_logits);
tokens.push(next_token);
if let Some(stop) = stop_token {
if next_token == stop {
stop_reason = StopReason::StopToken(stop);
break;
}
}
}
Ok(GenerationResult {
tokens,
prompt_len: prompt.len(),
generated_logits: all_logits,
stop_reason,
})
}
}
#[derive(Debug, Clone)]
pub struct ForwardOutput {
pub logits: Tensor,
pub hidden: Tensor,
pub attention_weights: Vec<Vec<Tensor>>,
}
#[derive(Debug, Clone)]
pub struct GenerationResult {
pub tokens: Vec<usize>,
pub prompt_len: usize,
pub generated_logits: Vec<Tensor>,
pub stop_reason: StopReason,
}
impl GenerationResult {
pub fn generated_tokens(&self) -> &[usize] {
&self.tokens[self.prompt_len..]
}
pub fn num_generated(&self) -> usize {
self.tokens.len() - self.prompt_len
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_forward_shape() {
let mut rng = rand::rng();
let model = GenerativeModel::new(
50, 16, 2, 2, 64, &mut rng,
)
.unwrap();
let output = model.forward(&[0, 1, 2]).unwrap();
assert_eq!(output.logits.shape(), &[3, 50]);
assert_eq!(output.hidden.shape(), &[3, 16]);
}
#[test]
fn test_generate() {
let mut rng = rand::rng();
let model = GenerativeModel::new(50, 16, 2, 2, 64, &mut rng).unwrap();
let result = model
.generate(&[0, 1, 2], 5, &SamplingConfig::greedy(), None, &mut rng)
.unwrap();
assert_eq!(result.prompt_len, 3);
assert_eq!(result.num_generated(), 5);
assert_eq!(result.tokens.len(), 8); }
#[test]
fn test_stop_token() {
let mut rng = rand::rng();
let model = GenerativeModel::new(50, 16, 2, 1, 64, &mut rng).unwrap();
let result = model
.generate(&[0], 10, &SamplingConfig::greedy(), Some(999), &mut rng)
.unwrap();
assert_eq!(result.num_generated(), 10);
assert_eq!(result.stop_reason, StopReason::MaxTokens);
}
#[test]
fn test_stop_reason_reported() {
let mut rng = rand::rng();
let model = GenerativeModel::new(50, 16, 2, 2, 64, &mut rng).unwrap();
let result = model
.generate(&[0, 1, 2], 5, &SamplingConfig::greedy(), None, &mut rng)
.unwrap();
assert_eq!(result.stop_reason, StopReason::MaxTokens);
}
#[test]
fn test_confidence_gated_generation() {
let mut rng = rand::rng();
let model = GenerativeModel::new(50, 16, 2, 2, 64, &mut rng).unwrap();
let result = model
.generate_gated(
&[0, 1, 2],
20,
&SamplingConfig::greedy(),
None,
Some(0.99),
&mut rng,
)
.unwrap();
match &result.stop_reason {
StopReason::LowConfidence { threshold, .. } => {
assert!((*threshold - 0.99).abs() < 1e-10);
}
StopReason::MaxTokens => {
}
other => panic!("unexpected stop reason: {other:?}"),
}
}
#[test]
fn test_forward_with_mask() {
let mut rng = rand::rng();
let model = GenerativeModel::new(50, 16, 2, 2, 64, &mut rng).unwrap();
let causal_out = model.forward_with_mask(&[0, 1, 2], true).unwrap();
let bidir_out = model.forward_with_mask(&[0, 1, 2], false).unwrap();
assert_eq!(causal_out.logits.shape(), &[3, 50]);
assert_eq!(bidir_out.logits.shape(), &[3, 50]);
assert_ne!(causal_out.logits.data(), bidir_out.logits.data());
}
}