1#[cfg(feature = "realizar")]
8use std::sync::Arc;
9
10#[cfg(feature = "realizar")]
11use realizar::gguf::{OwnedQuantizedKVCache, OwnedQuantizedModel};
12
13#[cfg(feature = "realizar")]
15pub struct GenerationResult {
16 pub text: String,
17 pub token_count: u32,
18 pub finish_reason: String,
19}
20
21#[cfg(feature = "realizar")]
23#[derive(Debug, Clone)]
24pub struct SamplingParams {
25 pub temperature: f32,
26 pub top_k: u32,
27 pub max_tokens: u32,
28}
29
30#[cfg(feature = "realizar")]
31impl Default for SamplingParams {
32 fn default() -> Self {
33 Self { temperature: 0.7, top_k: 40, max_tokens: 256 }
34 }
35}
36
37#[cfg(feature = "realizar")]
42pub fn generate_sync(
43 model: &Arc<OwnedQuantizedModel>,
44 vocab: &[String],
45 prompt_tokens: &[u32],
46 params: &SamplingParams,
47) -> Result<GenerationResult, String> {
48 if prompt_tokens.is_empty() {
49 return Err("prompt_tokens must not be empty".to_string());
50 }
51
52 let config = model.config();
53 let num_kv_heads = config.num_kv_heads;
54 let head_dim = config.hidden_dim / config.num_heads;
55 let kv_dim = num_kv_heads * head_dim;
56 let max_seq = prompt_tokens.len() + params.max_tokens as usize;
57
58 let mut cache = OwnedQuantizedKVCache::new(config.num_layers, kv_dim, max_seq);
59
60 let mut logits = Vec::new();
62 for (pos, &token) in prompt_tokens.iter().enumerate() {
63 logits = model
64 .forward_single_with_cache(token, &mut cache, pos)
65 .map_err(|e| format!("forward error at pos {pos}: {e}"))?;
66 }
67
68 let mut generated_tokens: Vec<u32> = Vec::new();
70 let mut pos = prompt_tokens.len();
71 let eos_token = find_eos_token(vocab);
72
73 for _ in 0..params.max_tokens {
74 let next_token = sample_token(&logits, params);
75
76 if Some(next_token) == eos_token {
78 return Ok(GenerationResult {
79 text: decode_tokens(vocab, &generated_tokens),
80 token_count: generated_tokens.len() as u32,
81 finish_reason: "stop".to_string(),
82 });
83 }
84
85 generated_tokens.push(next_token);
86
87 logits = model
89 .forward_single_with_cache(next_token, &mut cache, pos)
90 .map_err(|e| format!("forward error at pos {pos}: {e}"))?;
91 pos += 1;
92 }
93
94 Ok(GenerationResult {
95 text: decode_tokens(vocab, &generated_tokens),
96 token_count: generated_tokens.len() as u32,
97 finish_reason: "length".to_string(),
98 })
99}
100
101#[cfg(feature = "realizar")]
106pub fn generate_stream_tokens(
107 model: &Arc<OwnedQuantizedModel>,
108 vocab: &[String],
109 prompt_tokens: &[u32],
110 params: &SamplingParams,
111) -> Result<Vec<StreamToken>, String> {
112 if prompt_tokens.is_empty() {
113 return Err("prompt_tokens must not be empty".to_string());
114 }
115
116 let config = model.config();
117 let num_kv_heads = config.num_kv_heads;
118 let head_dim = config.hidden_dim / config.num_heads;
119 let kv_dim = num_kv_heads * head_dim;
120 let max_seq = prompt_tokens.len() + params.max_tokens as usize;
121
122 let mut cache = OwnedQuantizedKVCache::new(config.num_layers, kv_dim, max_seq);
123
124 let mut logits = Vec::new();
126 for (pos, &token) in prompt_tokens.iter().enumerate() {
127 logits = model
128 .forward_single_with_cache(token, &mut cache, pos)
129 .map_err(|e| format!("forward error at pos {pos}: {e}"))?;
130 }
131
132 let mut tokens = Vec::new();
134 let mut pos = prompt_tokens.len();
135 let eos_token = find_eos_token(vocab);
136
137 for _ in 0..params.max_tokens {
138 let next_token = sample_token(&logits, params);
139
140 if Some(next_token) == eos_token {
141 tokens
142 .push(StreamToken { text: String::new(), finish_reason: Some("stop".to_string()) });
143 return Ok(tokens);
144 }
145
146 let raw = vocab
147 .get(next_token as usize)
148 .cloned()
149 .unwrap_or_else(|| format!("<unk:{next_token}>"));
150 let text = decode_bpe_text(&raw);
151
152 tokens.push(StreamToken { text, finish_reason: None });
153
154 logits = model
155 .forward_single_with_cache(next_token, &mut cache, pos)
156 .map_err(|e| format!("forward error at pos {pos}: {e}"))?;
157 pos += 1;
158 }
159
160 tokens.push(StreamToken { text: String::new(), finish_reason: Some("length".to_string()) });
162
163 Ok(tokens)
164}
165
166#[cfg(feature = "realizar")]
168pub struct StreamToken {
169 pub text: String,
170 pub finish_reason: Option<String>,
171}
172
173#[cfg(feature = "realizar")]
175fn sample_token(logits: &[f32], params: &SamplingParams) -> u32 {
176 if params.temperature <= 0.0 || params.top_k <= 1 {
177 return argmax(logits);
179 }
180
181 let scaled: Vec<f32> = logits.iter().map(|&l| l / params.temperature).collect();
183
184 let k = (params.top_k as usize).min(scaled.len());
186 let mut indexed: Vec<(usize, f32)> = scaled.iter().enumerate().map(|(i, &v)| (i, v)).collect();
187 indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
188 let top_k = &indexed[..k];
189
190 let max_val = top_k[0].1;
192 let exps: Vec<f32> = top_k.iter().map(|(_, v)| (v - max_val).exp()).collect();
193 let sum: f32 = exps.iter().sum();
194 let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
195
196 let hash = logits_hash(logits);
199 let r = (hash as f32) / (u64::MAX as f32);
200 let mut cumulative = 0.0;
201 for (i, &p) in probs.iter().enumerate() {
202 cumulative += p;
203 if r < cumulative {
204 return top_k[i].0 as u32;
205 }
206 }
207
208 top_k[0].0 as u32
209}
210
211#[cfg(feature = "realizar")]
213fn argmax(logits: &[f32]) -> u32 {
214 logits
215 .iter()
216 .enumerate()
217 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
218 .map(|(i, _)| i as u32)
219 .unwrap_or(0)
220}
221
222#[cfg(feature = "realizar")]
225fn decode_tokens(vocab: &[String], tokens: &[u32]) -> String {
226 let raw: String =
227 tokens.iter().map(|&id| vocab.get(id as usize).map(String::as_str).unwrap_or("")).collect();
228 decode_bpe_text(&raw)
229}
230
231#[cfg(feature = "realizar")]
236fn decode_bpe_text(text: &str) -> String {
237 let mut bytes = Vec::with_capacity(text.len());
238 for ch in text.chars() {
239 let cp = ch as u32;
240 if (0x100..=0x1FF).contains(&cp) {
241 bytes.push((cp - 0x100) as u8);
243 } else if cp == 0x0100 {
244 } else if ch == 'Ā' {
246 } else {
248 let mut buf = [0u8; 4];
250 let encoded = ch.encode_utf8(&mut buf);
251 bytes.extend_from_slice(encoded.as_bytes());
252 }
253 }
254 String::from_utf8_lossy(&bytes).to_string()
255}
256
257#[cfg(feature = "realizar")]
259fn find_eos_token(vocab: &[String]) -> Option<u32> {
260 let eos_candidates = ["</s>", "<|endoftext|>", "<|end|>", "<eos>", "<|im_end|>", "<|eot_id|>"];
262 for candidate in &eos_candidates {
263 if let Some(pos) = vocab.iter().position(|t| t == candidate) {
264 return Some(pos as u32);
265 }
266 }
267 None
268}
269
270#[cfg(feature = "realizar")]
272fn logits_hash(logits: &[f32]) -> u64 {
273 let mut h: u64 = 0xcbf2_9ce4_8422_2325;
274 for &l in logits.iter().take(64) {
275 h ^= l.to_bits() as u64;
276 h = h.wrapping_mul(0x0100_0000_01b3);
277 }
278 h
279}
280
281#[cfg(feature = "realizar")]
287pub fn encode_prompt(vocab: &[String], text: &str) -> Vec<u32> {
288 if text.is_empty() {
289 return Vec::new();
290 }
291
292 let token_to_id: std::collections::HashMap<&str, u32> =
294 vocab.iter().enumerate().map(|(i, t)| (t.as_str(), i as u32)).collect();
295
296 let chars: Vec<char> = text.chars().collect();
298 let mut tokens = Vec::new();
299 let mut pos = 0;
300
301 while pos < chars.len() {
302 let mut best_len = 0;
303 let mut best_id = None;
304
305 let max_len = (chars.len() - pos).min(32); for len in (1..=max_len).rev() {
308 let substr: String = chars[pos..pos + len].iter().collect();
309 if let Some(&id) = token_to_id.get(substr.as_str()) {
310 best_len = len;
311 best_id = Some(id);
312 break;
313 }
314 }
315
316 if let Some(id) = best_id {
317 tokens.push(id);
318 pos += best_len;
319 } else {
320 tokens.push(0);
322 pos += 1;
323 }
324 }
325
326 tokens
327}
328
329#[cfg(feature = "realizar")]
337pub fn embed_tokens(model: &Arc<OwnedQuantizedModel>, token_ids: &[u32]) -> Option<Vec<f32>> {
338 if token_ids.is_empty() {
339 return None;
340 }
341
342 let raw = model.embed(token_ids);
344 let hidden_dim = model.config().hidden_dim;
345 let num_tokens = token_ids.len();
346
347 if raw.len() != num_tokens * hidden_dim {
348 return None;
349 }
350
351 let mut pooled = vec![0.0f32; hidden_dim];
353 for t in 0..num_tokens {
354 let offset = t * hidden_dim;
355 for d in 0..hidden_dim {
356 pooled[d] += raw[offset + d];
357 }
358 }
359 let scale = 1.0 / num_tokens as f32;
360 for val in &mut pooled {
361 *val *= scale;
362 }
363
364 let norm: f32 = pooled.iter().map(|v| v * v).sum::<f32>().sqrt();
366 if norm > f32::EPSILON {
367 for val in &mut pooled {
368 *val /= norm;
369 }
370 }
371
372 Some(pooled)
373}
374
375#[cfg(test)]
380#[cfg(feature = "realizar")]
381mod tests {
382 use super::*;
383
384 fn test_vocab() -> Vec<String> {
385 vec![
386 "<unk>".to_string(),
387 "</s>".to_string(),
388 "Hello".to_string(),
389 " world".to_string(),
390 "!".to_string(),
391 "The".to_string(),
392 " answer".to_string(),
393 " is".to_string(),
394 " 42".to_string(),
395 ]
396 }
397
398 #[test]
399 fn test_inf_001_argmax() {
400 let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
401 assert_eq!(argmax(&logits), 3);
402 }
403
404 #[test]
405 fn test_inf_002_argmax_empty() {
406 let logits: Vec<f32> = Vec::new();
407 assert_eq!(argmax(&logits), 0);
408 }
409
410 #[test]
411 fn test_inf_003_decode_tokens() {
412 let vocab = test_vocab();
413 let tokens = vec![2, 3, 4]; assert_eq!(decode_tokens(&vocab, &tokens), "Hello world!");
415 }
416
417 #[test]
418 fn test_inf_004_decode_unknown_token() {
419 let vocab = test_vocab();
420 let tokens = vec![2, 999]; assert_eq!(decode_tokens(&vocab, &tokens), "Hello");
422 }
423
424 #[test]
425 fn test_inf_005_find_eos_token() {
426 let vocab = test_vocab();
427 assert_eq!(find_eos_token(&vocab), Some(1)); }
429
430 #[test]
431 fn test_inf_006_find_eos_missing() {
432 let vocab = vec!["a".to_string(), "b".to_string()];
433 assert_eq!(find_eos_token(&vocab), None);
434 }
435
436 #[test]
437 fn test_inf_007_sample_greedy() {
438 let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
439 let params = SamplingParams { temperature: 0.0, top_k: 1, max_tokens: 10 };
440 assert_eq!(sample_token(&logits, ¶ms), 3);
441 }
442
443 #[test]
444 fn test_inf_008_encode_prompt() {
445 let vocab = test_vocab();
446 let tokens = encode_prompt(&vocab, "Hello world!");
447 assert!(!tokens.is_empty());
450 }
451
452 #[test]
453 fn test_inf_009_encode_empty() {
454 let vocab = test_vocab();
455 assert!(encode_prompt(&vocab, "").is_empty());
456 }
457
458 #[test]
459 fn test_inf_010_logits_hash_deterministic() {
460 let logits = vec![0.1, 0.2, 0.3];
461 let h1 = logits_hash(&logits);
462 let h2 = logits_hash(&logits);
463 assert_eq!(h1, h2);
464 }
465
466 #[test]
467 fn test_inf_011_sampling_params_default() {
468 let params = SamplingParams::default();
469 assert!((params.temperature - 0.7).abs() < f32::EPSILON);
470 assert_eq!(params.top_k, 40);
471 assert_eq!(params.max_tokens, 256);
472 }
473}