Skip to main content

batuta/serve/banco/
inference.rs

1//! Inference engine — bridges banco chat handler to realizar's forward pass.
2//!
3//! Gated behind `#[cfg(feature = "realizar")]`. Provides:
4//! - `generate_sync()` — greedy/sampled token generation for non-streaming
5//! - `generate_stream()` — yields tokens one at a time for SSE
6
7#[cfg(feature = "realizar")]
8use std::sync::Arc;
9
10#[cfg(feature = "realizar")]
11use realizar::gguf::{OwnedQuantizedKVCache, OwnedQuantizedModel};
12
13/// Result of a completed generation.
14#[cfg(feature = "realizar")]
15pub struct GenerationResult {
16    pub text: String,
17    pub token_count: u32,
18    pub finish_reason: String,
19}
20
21/// Sampling parameters for token generation.
22#[cfg(feature = "realizar")]
23#[derive(Debug, Clone)]
24pub struct SamplingParams {
25    pub temperature: f32,
26    pub top_k: u32,
27    pub max_tokens: u32,
28}
29
30#[cfg(feature = "realizar")]
31impl Default for SamplingParams {
32    fn default() -> Self {
33        Self { temperature: 0.7, top_k: 40, max_tokens: 256 }
34    }
35}
36
37/// Generate a complete response synchronously (non-streaming).
38///
39/// Runs the autoregressive loop: embed → forward → sample → decode.
40/// Returns the full generated text plus token count and finish reason.
41#[cfg(feature = "realizar")]
42pub fn generate_sync(
43    model: &Arc<OwnedQuantizedModel>,
44    vocab: &[String],
45    prompt_tokens: &[u32],
46    params: &SamplingParams,
47) -> Result<GenerationResult, String> {
48    if prompt_tokens.is_empty() {
49        return Err("prompt_tokens must not be empty".to_string());
50    }
51
52    let config = model.config();
53    let num_kv_heads = config.num_kv_heads;
54    let head_dim = config.hidden_dim / config.num_heads;
55    let kv_dim = num_kv_heads * head_dim;
56    let max_seq = prompt_tokens.len() + params.max_tokens as usize;
57
58    let mut cache = OwnedQuantizedKVCache::new(config.num_layers, kv_dim, max_seq);
59
60    // Prefill: process all prompt tokens through the model
61    let mut logits = Vec::new();
62    for (pos, &token) in prompt_tokens.iter().enumerate() {
63        logits = model
64            .forward_single_with_cache(token, &mut cache, pos)
65            .map_err(|e| format!("forward error at pos {pos}: {e}"))?;
66    }
67
68    // Decode: generate new tokens autoregressively
69    let mut generated_tokens: Vec<u32> = Vec::new();
70    let mut pos = prompt_tokens.len();
71    let eos_token = find_eos_token(vocab);
72
73    for _ in 0..params.max_tokens {
74        let next_token = sample_token(&logits, params);
75
76        // Check EOS
77        if Some(next_token) == eos_token {
78            return Ok(GenerationResult {
79                text: decode_tokens(vocab, &generated_tokens),
80                token_count: generated_tokens.len() as u32,
81                finish_reason: "stop".to_string(),
82            });
83        }
84
85        generated_tokens.push(next_token);
86
87        // Forward pass for the new token
88        logits = model
89            .forward_single_with_cache(next_token, &mut cache, pos)
90            .map_err(|e| format!("forward error at pos {pos}: {e}"))?;
91        pos += 1;
92    }
93
94    Ok(GenerationResult {
95        text: decode_tokens(vocab, &generated_tokens),
96        token_count: generated_tokens.len() as u32,
97        finish_reason: "length".to_string(),
98    })
99}
100
101/// Generate tokens one at a time for streaming.
102///
103/// Returns an iterator-like vec of (token_text, is_last, finish_reason).
104/// For true async streaming we'd use a channel, but this is simpler for Phase 2b.
105#[cfg(feature = "realizar")]
106pub fn generate_stream_tokens(
107    model: &Arc<OwnedQuantizedModel>,
108    vocab: &[String],
109    prompt_tokens: &[u32],
110    params: &SamplingParams,
111) -> Result<Vec<StreamToken>, String> {
112    if prompt_tokens.is_empty() {
113        return Err("prompt_tokens must not be empty".to_string());
114    }
115
116    let config = model.config();
117    let num_kv_heads = config.num_kv_heads;
118    let head_dim = config.hidden_dim / config.num_heads;
119    let kv_dim = num_kv_heads * head_dim;
120    let max_seq = prompt_tokens.len() + params.max_tokens as usize;
121
122    let mut cache = OwnedQuantizedKVCache::new(config.num_layers, kv_dim, max_seq);
123
124    // Prefill
125    let mut logits = Vec::new();
126    for (pos, &token) in prompt_tokens.iter().enumerate() {
127        logits = model
128            .forward_single_with_cache(token, &mut cache, pos)
129            .map_err(|e| format!("forward error at pos {pos}: {e}"))?;
130    }
131
132    // Decode
133    let mut tokens = Vec::new();
134    let mut pos = prompt_tokens.len();
135    let eos_token = find_eos_token(vocab);
136
137    for _ in 0..params.max_tokens {
138        let next_token = sample_token(&logits, params);
139
140        if Some(next_token) == eos_token {
141            tokens
142                .push(StreamToken { text: String::new(), finish_reason: Some("stop".to_string()) });
143            return Ok(tokens);
144        }
145
146        let raw = vocab
147            .get(next_token as usize)
148            .cloned()
149            .unwrap_or_else(|| format!("<unk:{next_token}>"));
150        let text = decode_bpe_text(&raw);
151
152        tokens.push(StreamToken { text, finish_reason: None });
153
154        logits = model
155            .forward_single_with_cache(next_token, &mut cache, pos)
156            .map_err(|e| format!("forward error at pos {pos}: {e}"))?;
157        pos += 1;
158    }
159
160    // Hit max_tokens
161    tokens.push(StreamToken { text: String::new(), finish_reason: Some("length".to_string()) });
162
163    Ok(tokens)
164}
165
166/// A single token in a streaming response.
167#[cfg(feature = "realizar")]
168pub struct StreamToken {
169    pub text: String,
170    pub finish_reason: Option<String>,
171}
172
173/// Sample a token from logits using temperature + top-k.
174#[cfg(feature = "realizar")]
175fn sample_token(logits: &[f32], params: &SamplingParams) -> u32 {
176    if params.temperature <= 0.0 || params.top_k <= 1 {
177        // Greedy: argmax
178        return argmax(logits);
179    }
180
181    // Temperature scaling
182    let scaled: Vec<f32> = logits.iter().map(|&l| l / params.temperature).collect();
183
184    // Top-k filtering
185    let k = (params.top_k as usize).min(scaled.len());
186    let mut indexed: Vec<(usize, f32)> = scaled.iter().enumerate().map(|(i, &v)| (i, v)).collect();
187    indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
188    let top_k = &indexed[..k];
189
190    // Softmax over top-k
191    let max_val = top_k[0].1;
192    let exps: Vec<f32> = top_k.iter().map(|(_, v)| (v - max_val).exp()).collect();
193    let sum: f32 = exps.iter().sum();
194    let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
195
196    // Simple deterministic sampling using hash of logits as pseudo-random
197    // (True random would use thread_rng, but this is reproducible for testing)
198    let hash = logits_hash(logits);
199    let r = (hash as f32) / (u64::MAX as f32);
200    let mut cumulative = 0.0;
201    for (i, &p) in probs.iter().enumerate() {
202        cumulative += p;
203        if r < cumulative {
204            return top_k[i].0 as u32;
205        }
206    }
207
208    top_k[0].0 as u32
209}
210
211/// Argmax over a logit vector.
212#[cfg(feature = "realizar")]
213fn argmax(logits: &[f32]) -> u32 {
214    logits
215        .iter()
216        .enumerate()
217        .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
218        .map(|(i, _)| i as u32)
219        .unwrap_or(0)
220}
221
222/// Decode token IDs back to text using the vocabulary.
223/// Handles BPE byte encoding (Ġ → space, Ċ → newline, etc.).
224#[cfg(feature = "realizar")]
225fn decode_tokens(vocab: &[String], tokens: &[u32]) -> String {
226    let raw: String =
227        tokens.iter().map(|&id| vocab.get(id as usize).map(String::as_str).unwrap_or("")).collect();
228    decode_bpe_text(&raw)
229}
230
231/// Decode BPE byte-encoded text back to UTF-8.
232///
233/// GPT/Qwen BPE uses Unicode characters U+0100..U+01FF to represent raw bytes.
234/// Ġ (U+0120) = space (0x20), Ċ (U+010A) = newline (0x0A), etc.
235#[cfg(feature = "realizar")]
236fn decode_bpe_text(text: &str) -> String {
237    let mut bytes = Vec::with_capacity(text.len());
238    for ch in text.chars() {
239        let cp = ch as u32;
240        if (0x100..=0x1FF).contains(&cp) {
241            // BPE byte token: U+01XX → byte 0xXX
242            bytes.push((cp - 0x100) as u8);
243        } else if cp == 0x0100 {
244            // U+0100 typically maps to byte 0x00 but often means special — skip
245        } else if ch == 'Ā' {
246            // Some BPE encodings use Ā for null byte
247        } else {
248            // Regular UTF-8 character
249            let mut buf = [0u8; 4];
250            let encoded = ch.encode_utf8(&mut buf);
251            bytes.extend_from_slice(encoded.as_bytes());
252        }
253    }
254    String::from_utf8_lossy(&bytes).to_string()
255}
256
257/// Find the EOS token ID in the vocabulary.
258#[cfg(feature = "realizar")]
259fn find_eos_token(vocab: &[String]) -> Option<u32> {
260    // Common EOS tokens across model families
261    let eos_candidates = ["</s>", "<|endoftext|>", "<|end|>", "<eos>", "<|im_end|>", "<|eot_id|>"];
262    for candidate in &eos_candidates {
263        if let Some(pos) = vocab.iter().position(|t| t == candidate) {
264            return Some(pos as u32);
265        }
266    }
267    None
268}
269
270/// Simple hash of logits for reproducible pseudo-random sampling.
271#[cfg(feature = "realizar")]
272fn logits_hash(logits: &[f32]) -> u64 {
273    let mut h: u64 = 0xcbf2_9ce4_8422_2325;
274    for &l in logits.iter().take(64) {
275        h ^= l.to_bits() as u64;
276        h = h.wrapping_mul(0x0100_0000_01b3);
277    }
278    h
279}
280
281/// Encode a text prompt into token IDs using the vocabulary.
282///
283/// Uses greedy longest-match tokenization. For production, the GGUF vocab
284/// includes merge rules that `realizar::tokenizer::BPETokenizer` handles,
285/// but for Phase 2b this simple approach works for basic generation.
286#[cfg(feature = "realizar")]
287pub fn encode_prompt(vocab: &[String], text: &str) -> Vec<u32> {
288    if text.is_empty() {
289        return Vec::new();
290    }
291
292    // Build token→id lookup (could be cached on ModelSlot, but keep simple for now)
293    let token_to_id: std::collections::HashMap<&str, u32> =
294        vocab.iter().enumerate().map(|(i, t)| (t.as_str(), i as u32)).collect();
295
296    // Greedy longest-match character by character
297    let chars: Vec<char> = text.chars().collect();
298    let mut tokens = Vec::new();
299    let mut pos = 0;
300
301    while pos < chars.len() {
302        let mut best_len = 0;
303        let mut best_id = None;
304
305        // Try decreasing lengths from current position
306        let max_len = (chars.len() - pos).min(32); // Cap at 32 chars per token
307        for len in (1..=max_len).rev() {
308            let substr: String = chars[pos..pos + len].iter().collect();
309            if let Some(&id) = token_to_id.get(substr.as_str()) {
310                best_len = len;
311                best_id = Some(id);
312                break;
313            }
314        }
315
316        if let Some(id) = best_id {
317            tokens.push(id);
318            pos += best_len;
319        } else {
320            // Unknown character — use UNK token (usually 0)
321            tokens.push(0);
322            pos += 1;
323        }
324    }
325
326    tokens
327}
328
329/// Compute a mean-pooled embedding using pre-tokenized IDs.
330///
331/// Looks up each token embedding via `model.embed()`,
332/// and returns the mean across all token positions. The resulting vector
333/// has `hidden_dim` dimensions.
334///
335/// Caller should use `ModelSlot::encode_text()` for proper BPE tokenization.
336#[cfg(feature = "realizar")]
337pub fn embed_tokens(model: &Arc<OwnedQuantizedModel>, token_ids: &[u32]) -> Option<Vec<f32>> {
338    if token_ids.is_empty() {
339        return None;
340    }
341
342    // Get raw embeddings [num_tokens * hidden_dim]
343    let raw = model.embed(token_ids);
344    let hidden_dim = model.config().hidden_dim;
345    let num_tokens = token_ids.len();
346
347    if raw.len() != num_tokens * hidden_dim {
348        return None;
349    }
350
351    // Mean pool across tokens
352    let mut pooled = vec![0.0f32; hidden_dim];
353    for t in 0..num_tokens {
354        let offset = t * hidden_dim;
355        for d in 0..hidden_dim {
356            pooled[d] += raw[offset + d];
357        }
358    }
359    let scale = 1.0 / num_tokens as f32;
360    for val in &mut pooled {
361        *val *= scale;
362    }
363
364    // L2 normalize
365    let norm: f32 = pooled.iter().map(|v| v * v).sum::<f32>().sqrt();
366    if norm > f32::EPSILON {
367        for val in &mut pooled {
368            *val /= norm;
369        }
370    }
371
372    Some(pooled)
373}
374
375// ============================================================================
376// Tests (available without inference feature)
377// ============================================================================
378
379#[cfg(test)]
380#[cfg(feature = "realizar")]
381mod tests {
382    use super::*;
383
384    fn test_vocab() -> Vec<String> {
385        vec![
386            "<unk>".to_string(),
387            "</s>".to_string(),
388            "Hello".to_string(),
389            " world".to_string(),
390            "!".to_string(),
391            "The".to_string(),
392            " answer".to_string(),
393            " is".to_string(),
394            " 42".to_string(),
395        ]
396    }
397
398    #[test]
399    fn test_inf_001_argmax() {
400        let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
401        assert_eq!(argmax(&logits), 3);
402    }
403
404    #[test]
405    fn test_inf_002_argmax_empty() {
406        let logits: Vec<f32> = Vec::new();
407        assert_eq!(argmax(&logits), 0);
408    }
409
410    #[test]
411    fn test_inf_003_decode_tokens() {
412        let vocab = test_vocab();
413        let tokens = vec![2, 3, 4]; // "Hello", " world", "!"
414        assert_eq!(decode_tokens(&vocab, &tokens), "Hello world!");
415    }
416
417    #[test]
418    fn test_inf_004_decode_unknown_token() {
419        let vocab = test_vocab();
420        let tokens = vec![2, 999]; // "Hello", out-of-range
421        assert_eq!(decode_tokens(&vocab, &tokens), "Hello");
422    }
423
424    #[test]
425    fn test_inf_005_find_eos_token() {
426        let vocab = test_vocab();
427        assert_eq!(find_eos_token(&vocab), Some(1)); // "</s>" is at index 1
428    }
429
430    #[test]
431    fn test_inf_006_find_eos_missing() {
432        let vocab = vec!["a".to_string(), "b".to_string()];
433        assert_eq!(find_eos_token(&vocab), None);
434    }
435
436    #[test]
437    fn test_inf_007_sample_greedy() {
438        let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
439        let params = SamplingParams { temperature: 0.0, top_k: 1, max_tokens: 10 };
440        assert_eq!(sample_token(&logits, &params), 3);
441    }
442
443    #[test]
444    fn test_inf_008_encode_prompt() {
445        let vocab = test_vocab();
446        let tokens = encode_prompt(&vocab, "Hello world!");
447        // Should find "Hello", " world", "!" but space handling depends on vocab
448        // At minimum, should produce non-empty output
449        assert!(!tokens.is_empty());
450    }
451
452    #[test]
453    fn test_inf_009_encode_empty() {
454        let vocab = test_vocab();
455        assert!(encode_prompt(&vocab, "").is_empty());
456    }
457
458    #[test]
459    fn test_inf_010_logits_hash_deterministic() {
460        let logits = vec![0.1, 0.2, 0.3];
461        let h1 = logits_hash(&logits);
462        let h2 = logits_hash(&logits);
463        assert_eq!(h1, h2);
464    }
465
466    #[test]
467    fn test_inf_011_sampling_params_default() {
468        let params = SamplingParams::default();
469        assert!((params.temperature - 0.7).abs() < f32::EPSILON);
470        assert_eq!(params.top_k, 40);
471        assert_eq!(params.max_tokens, 256);
472    }
473}