1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
//! Greedy (argmax) decoding for the local Gemma 3 270M backend.
//!
//! `v1` deliberately keeps this as simple as the architecture allows: no KV
//! cache, no sampling, no beam search. Each step re-runs the FULL forward pass
//! over the running token sequence, reads the logits for the final position,
//! takes the `argmax` token, appends it, and repeats — until the model emits
//! the Gemma EOS token (`1`) or `max_new` tokens have been generated. The
//! prompt is tokenized with a leading BOS (`2`) by [`GemmaTokenizer::encode`].
//!
//! Recomputing the whole sequence every step is O(n²) and obviously slower than
//! a cached decode, but it is correct, allocation-light, and — critically —
//! identical on native and `wasm32`, which is all `v1` needs. A KV cache is a
//! later optimisation that does not change this public surface.
//!
//! Compiles on native and `wasm32-unknown-unknown`. The whole module is gated
//! on `feature = "local"` (see `super`).
use ;
use GemmaModel;
use GemmaTokenizer;
/// Gemma end-of-sequence token id. Generation stops as soon as the model emits
/// this (it is NOT included in the decoded output).
const EOS_ID: i64 = 1;
/// Greedy argmax decode.
///
/// Tokenizes `prompt` (the tokenizer prepends BOS), then autoregressively
/// appends the highest-probability next token — recomputing the full forward
/// pass each step (no KV cache in v1) — stopping at the first EOS (`1`) or after
/// `max_new` new tokens. Returns the decoded continuation (the prompt tokens are
/// not re-emitted; a trailing EOS is dropped).
///
/// `device` is the Burn device the running token tensor is built on; it must
/// match the device `model` lives on.
pub async