use burn::tensor::{backend::Backend, Int, Tensor, TensorData};
use super::gemma::{GemmaModel, ROPE_CACHE_LEN};
use super::tokenizer::GemmaTokenizer;
const EOS_ID: i64 = 1;
fn at_context_limit(seq: usize) -> bool {
seq >= ROPE_CACHE_LEN
}
pub async fn generate<B: Backend>(
model: &GemmaModel<B>,
tok: &GemmaTokenizer,
prompt: &str,
max_new: usize,
device: &B::Device,
) -> String {
let mut tokens: Vec<i64> = tok.encode(prompt);
let mut generated: Vec<i64> = Vec::with_capacity(max_new);
for _ in 0..max_new {
if at_context_limit(tokens.len()) {
break;
}
let seq = tokens.len();
let input = Tensor::<B, 1, Int>::from_data(TensorData::from(tokens.as_slice()), device)
.reshape([1, seq]);
let logits = model.forward(input);
let argmax = logits.argmax(2);
let data = match argmax.into_data_async().await {
Ok(d) => d,
Err(_) => break,
};
let ids: Vec<i64> = data.iter::<i64>().collect();
let next = match ids.last().copied() {
Some(id) => id,
None => break, };
if next == EOS_ID {
break;
}
tokens.push(next);
generated.push(next);
}
tok.decode(&generated)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn context_limit_guards_the_rope_cache_boundary() {
assert!(!at_context_limit(0));
assert!(!at_context_limit(ROPE_CACHE_LEN - 1)); assert!(at_context_limit(ROPE_CACHE_LEN)); assert!(at_context_limit(ROPE_CACHE_LEN + 1));
}
}