1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
impl Model {
/// Create a new transformer model
///
/// # Arguments
///
/// * `config` - Model configuration
///
/// # Errors
///
/// Returns error if configuration is invalid
pub fn new(config: ModelConfig) -> Result<Self> {
let embedding = Embedding::new(config.vocab_size, config.hidden_dim)?;
let mut blocks = Vec::with_capacity(config.num_layers);
for _ in 0..config.num_layers {
blocks.push(TransformerBlock::new(
config.hidden_dim,
config.num_heads,
config.intermediate_dim,
config.eps,
)?);
}
let final_norm = LayerNorm::new(config.hidden_dim, config.eps)?;
let lm_head = Linear::new(config.hidden_dim, config.vocab_size)?;
Ok(Self {
embedding,
blocks,
final_norm,
lm_head,
config,
})
}
/// Forward pass returning the final-layer hidden state (residual stream output
/// AFTER `final_norm` but BEFORE the `lm_head` projection).
///
/// This is exactly the tensor that `lm_head` consumes to produce logits — i.e.
/// the model's contextual hidden representation of each token. It is the correct
/// source for model-backed embeddings (PMAT-803): pool these per-token vectors
/// (mean-pool is the standard default) and L2-normalize.
///
/// # Arguments
///
/// * `token_ids` - Input token IDs
///
/// # Returns
///
/// Hidden-state tensor with shape `[seq_len, hidden_dim]`
///
/// # Errors
///
/// Returns error if input is invalid
pub fn forward_hidden(&self, token_ids: &[usize]) -> Result<Tensor<f32>> {
// Embed tokens
let mut hidden = self.embedding.forward(token_ids)?;
// Pass through transformer blocks
for block in &self.blocks {
hidden = block.forward(&hidden)?;
}
// Final layer norm — this is the residual-stream output that lm_head consumes.
self.final_norm.forward(&hidden)
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `token_ids` - Input token IDs
///
/// # Returns
///
/// Logits tensor with shape `[seq_len, vocab_size]`
///
/// # Errors
///
/// Returns error if input is invalid
pub fn forward(&self, token_ids: &[usize]) -> Result<Tensor<f32>> {
// Compute the pre-lm_head hidden state, then project to vocabulary.
let hidden = self.forward_hidden(token_ids)?;
self.lm_head.forward(&hidden)
}
/// Get model configuration
#[must_use]
pub fn config(&self) -> &ModelConfig {
&self.config
}
/// Get mutable reference to embedding layer
pub fn embedding_mut(&mut self) -> &mut Embedding {
&mut self.embedding
}
/// Get mutable reference to transformer blocks
pub fn blocks_mut(&mut self) -> &mut [TransformerBlock] {
&mut self.blocks
}
/// Get mutable reference to final layer norm
pub fn final_norm_mut(&mut self) -> &mut LayerNorm {
&mut self.final_norm
}
/// Get mutable reference to LM head
pub fn lm_head_mut(&mut self) -> &mut Linear {
&mut self.lm_head
}
/// Get number of parameters in the model (approximate)
#[must_use]
pub fn num_parameters(&self) -> usize {
let embed_params = self.config.vocab_size * self.config.hidden_dim;
let block_params = self.config.num_layers
* (
// Attention (Q, K, V, O projections would be here in full impl)
// For now just count layer norms and FFN
2 * self.config.hidden_dim // Layer norm weights
+ self.config.hidden_dim * self.config.intermediate_dim // fc1
+ self.config.intermediate_dim * self.config.hidden_dim
// fc2
);
let head_params = self.config.hidden_dim * self.config.vocab_size;
embed_params + block_params + head_params
}
/// Generate tokens autoregressively
///
/// # Arguments
///
/// * `prompt` - Initial token IDs
/// * `config` - Generation configuration
///
/// # Returns
///
/// Vector of generated token IDs (including prompt)
///
/// # Errors
///
/// Returns error if generation fails
///
/// # Example
///
/// ```rust,ignore
/// let generated = model.generate(&[1, 2, 3], &GenerationConfig::greedy())?;
/// ```
pub fn generate(&self, prompt: &[usize], config: &GenerationConfig) -> Result<Vec<usize>> {
if prompt.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Prompt cannot be empty".to_string(),
});
}
let mut tokens = prompt.to_vec();
let mut rng_state = config.seed.unwrap_or(42);
for _ in 0..config.max_tokens {
// Forward pass
let logits = self.forward(&tokens)?;
// Get logits for last position
let seq_len = tokens.len();
let vocab_size = self.config.vocab_size;
let last_logits_start = (seq_len - 1) * vocab_size;
let last_logits = &logits.data()[last_logits_start..last_logits_start + vocab_size];
let last_logits_tensor = Tensor::from_vec(vec![vocab_size], last_logits.to_vec())?;
// Simple LCG for random number generation
rng_state = rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
// PMAT-757: f32-safe [0,1) mapping. The old `(state >> 33)/(1<<31)` rounded its
// max numerator UP to 2^31 in f32 -> rng_value == 1.0 -> biased last-token draw.
let rng_value = crate::generate::lcg_state_to_unit_f32(rng_state);
// Sample next token
let next_token = sample_token(&last_logits_tensor, config, rng_value)?;
// Check for EOS
if let Some(eos_id) = config.eos_token_id {
if next_token == eos_id {
break;
}
}
tokens.push(next_token);
}
Ok(tokens)
}
}