pub struct GemmaGenerator { /* private fields */ }Expand description
Stateful Gemma generation handle.
Holds the (config, weight bytes, token history) and rebuilds a
prefill graph on each [step] call. Cheap to construct after
initial weight load; tokens stay in-memory between calls.
Implementations§
Source§impl GemmaGenerator
impl GemmaGenerator
Sourcepub fn from_loader(
cfg: GemmaConfig,
loader: &mut dyn WeightLoader,
device: Device,
) -> Result<GemmaGenerator, Error>
pub fn from_loader( cfg: GemmaConfig, loader: &mut dyn WeightLoader, device: Device, ) -> Result<GemmaGenerator, Error>
Construct from any WeightLoader — drains it into an
internal cache so the loader is free after this call.
Sourcepub fn from_loader_at(
cfg: GemmaConfig,
loader: &mut dyn WeightLoader,
device: Device,
weights_path: &Path,
) -> Result<GemmaGenerator, Error>
pub fn from_loader_at( cfg: GemmaConfig, loader: &mut dyn WeightLoader, device: Device, weights_path: &Path, ) -> Result<GemmaGenerator, Error>
Like Self::from_loader but loads tier-1 profiles from
gemma.rlx.toml in the weights directory when present.
Sourcepub fn with_compile_profiles(
self,
prefill: CompileProfile,
decode: CompileProfile,
) -> GemmaGenerator
pub fn with_compile_profiles( self, prefill: CompileProfile, decode: CompileProfile, ) -> GemmaGenerator
Override tier-1 compile profiles explicitly.
pub fn prefill_profile(&self) -> &CompileProfile
pub fn decode_profile(&self) -> &CompileProfile
Sourcepub fn with_prefill_cache(self, capacity: usize) -> GemmaGenerator
pub fn with_prefill_cache(self, capacity: usize) -> GemmaGenerator
Enable the prefill compile cache with the given LRU capacity. Useful when the same prompt length is used across multiple generation runs — the second + Nth run skip the compile + param-attach roundtrip (~30-50ms per call on CPU).
Sourcepub fn with_dynamic_prefill_cache(self, capacity: usize) -> GemmaGenerator
pub fn with_dynamic_prefill_cache(self, capacity: usize) -> GemmaGenerator
Compile prefill once with sym::SEQ, specialize per prompt length.
Sourcepub fn with_decode_cache(self, max_past: usize) -> GemmaGenerator
pub fn with_decode_cache(self, max_past: usize) -> GemmaGenerator
Enable the bucketed decode compile cache spanning past-seq
values in [1, max_past]. Buckets are power-of-two
[1..2, 2..3, 3..5, 5..9, 9..17, …]. Each bucket compiles
one graph at its upper bound; a steady-state generation loop
across N tokens compiles O(log N) graphs instead of N.
Padding compute waste is bounded at 2×: actual past_seq is
at least half the bucket’s upper bound (except possibly the
smallest bucket).
Sourcepub fn with_dynamic_decode_cache(self, capacity: usize) -> GemmaGenerator
pub fn with_dynamic_decode_cache(self, capacity: usize) -> GemmaGenerator
Compile decode once with sym::PAST_SEQ, specialize per prefix length.
Sourcepub fn with_inference_caches(self, max_seq: usize) -> GemmaGenerator
pub fn with_inference_caches(self, max_seq: usize) -> GemmaGenerator
Production inference caches: dynamic prefill (+ multimodal embed prefill)
and bucketed decode. Bucketed decode avoids the per-step specialize
overhead of RLX_GEMMA_DYNAMIC_DECODE=1 (experimental; often much slower on Metal).
Sourcepub fn sync_device(&mut self)
pub fn sync_device(&mut self)
Wait for in-flight Metal command buffers on all cached graphs. Call between heavy inference phases to avoid MPS lifecycle warnings.
Sourcepub fn from_path(
cfg: GemmaConfig,
path: &str,
device: Device,
) -> Result<GemmaGenerator, Error>
pub fn from_path( cfg: GemmaConfig, path: &str, device: Device, ) -> Result<GemmaGenerator, Error>
Convenience: load weights from a safetensors or GGUF path
(dispatch by extension; see rlx_core::weight_loader::load_from_path).
Sourcepub fn from_path_with_mtp(
cfg: GemmaConfig,
path: &str,
device: Device,
include_mtp: bool,
) -> Result<GemmaGenerator, Error>
pub fn from_path_with_mtp( cfg: GemmaConfig, path: &str, device: Device, include_mtp: bool, ) -> Result<GemmaGenerator, Error>
Same as [from_path] but with MTP-head visibility control.
When include_mtp=true and the file is GGUF, MTP weights are
drained into the generator’s cache alongside the base
weights. The base inference path still ignores them — they
sit in cache for a future MTP-aware decoder. Non-GGUF formats
silently ignore the flag (safetensors files publish all
tensors uniformly; downstream code distinguishes by name).
Sourcepub fn prefill(&mut self, prompt_ids: &[u32])
pub fn prefill(&mut self, prompt_ids: &[u32])
Replace the token history with prompt_ids. Does not run the
model — the next [step] call processes the full sequence.
Clears any KV cache from a prior generation.
Sourcepub fn prefill_from_embeds(
&mut self,
prompt_ids: &[u32],
embeds: &[f32],
attn_bias: Option<Vec<f32>>,
) -> Result<(), Error>
pub fn prefill_from_embeds( &mut self, prompt_ids: &[u32], embeds: &[f32], attn_bias: Option<Vec<f32>>, ) -> Result<(), Error>
Like [prefill], but the next cached prefill uses fused
inputs_embeds (prefill_hidden) instead of token lookup.
Sourcepub fn weights_cache(&self) -> &HashMap<String, (Vec<f32>, Vec<usize>)>
pub fn weights_cache(&self) -> &HashMap<String, (Vec<f32>, Vec<usize>)>
Weight table for CPU-side embedding / fusion helpers.
Sourcepub fn step(&mut self, opts: SampleOpts) -> Result<u32, Error>
pub fn step(&mut self, opts: SampleOpts) -> Result<u32, Error>
Run one prefill over the current token history and sample the next token. The sampled token is appended to the history and returned. Call repeatedly to generate.
Sourcepub fn generate(
&mut self,
n: usize,
opts: SampleOpts,
) -> Result<Vec<u32>, Error>
pub fn generate( &mut self, n: usize, opts: SampleOpts, ) -> Result<Vec<u32>, Error>
Run n steps and return the newly generated token ids
(excludes the prefill prompt).
Sourcepub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32, Error>
pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32, Error>
Cached step: O(L) per token instead of O(L²). First call seeds
the KV cache from the prompt via prefill-with-cache; subsequent
calls run the decode-mode graph on just the last token + cached
past. Output is bit-identical to [step] modulo reduction
order in the SDPA kernel.
Invariant after each call: cache.past_seq == tokens.len() - 1
(the just-sampled token is appended but not yet in the cache;
it becomes the input for the next decode step).
Sourcepub fn generate_from_embeds(
&mut self,
prompt_ids: &[u32],
embeds: &[f32],
n: usize,
opts: SampleOpts,
) -> Result<Vec<u32>, Error>
pub fn generate_from_embeds( &mut self, prompt_ids: &[u32], embeds: &[f32], n: usize, opts: SampleOpts, ) -> Result<Vec<u32>, Error>
Run n cached steps after [prefill_from_embeds].
pub fn generate_from_embeds_with_bias( &mut self, prompt_ids: &[u32], embeds: &[f32], attn_bias: Option<Vec<f32>>, n: usize, opts: SampleOpts, ) -> Result<Vec<u32>, Error>
Sourcepub fn generate_from_embeds_with(
&mut self,
prompt_ids: &[u32],
embeds: &[f32],
n: usize,
opts: SampleOpts,
on_token: impl FnMut(u32),
) -> Result<Vec<u32>, Error>
pub fn generate_from_embeds_with( &mut self, prompt_ids: &[u32], embeds: &[f32], n: usize, opts: SampleOpts, on_token: impl FnMut(u32), ) -> Result<Vec<u32>, Error>
Streaming variant of [generate_from_embeds].
pub fn generate_from_embeds_with_bias_and_callback( &mut self, prompt_ids: &[u32], embeds: &[f32], attn_bias: Option<Vec<f32>>, n: usize, opts: SampleOpts, on_token: impl FnMut(u32), ) -> Result<Vec<u32>, Error>
Sourcepub fn generate_cached(
&mut self,
n: usize,
opts: SampleOpts,
) -> Result<Vec<u32>, Error>
pub fn generate_cached( &mut self, n: usize, opts: SampleOpts, ) -> Result<Vec<u32>, Error>
Run n cached steps and return the newly generated tokens.
Sourcepub fn generate_cached_with(
&mut self,
n: usize,
opts: SampleOpts,
on_token: impl FnMut(u32),
) -> Result<Vec<u32>, Error>
pub fn generate_cached_with( &mut self, n: usize, opts: SampleOpts, on_token: impl FnMut(u32), ) -> Result<Vec<u32>, Error>
Same as [generate_cached] but invokes on_token once per
freshly sampled id, inside the decode loop. The whole n step
loop shares the bucketed compile cache — callers wanting a
streaming UI should prefer this to calling
generate_cached(1, …) n times (which forces a fresh
compile per token at the bucket boundaries).
pub fn config(&self) -> &GemmaConfig
Sourcepub fn prefill_get_last_logits(
&mut self,
context: &[u32],
) -> Result<Vec<f32>, Error>
pub fn prefill_get_last_logits( &mut self, context: &[u32], ) -> Result<Vec<f32>, Error>
Low-level primitive: reset internal state, run prefill-with-cache
over context, and return the last position’s logits row
(P(next_token | context)). Does NOT sample or append. The
internal tokens buffer is set to context and the KV cache
is populated to past_seq = context.len().
First row of logits after prefill-with-cache (no sampling).
Sourcepub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>, Error>
pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>, Error>
Low-level primitive: run one decode step with the caller-
supplied input token (no sampling), advance the KV cache, and
return the resulting logits row P(next | history ++ input).
Appends input to the tokens buffer so the invariant
cache.past_seq == tokens.len() holds after this call (note:
differs from step_cached invariant because this method does
not append a sampled token).
Trait Implementations§
Source§impl Drop for GemmaGenerator
impl Drop for GemmaGenerator
Auto Trait Implementations§
impl !RefUnwindSafe for GemmaGenerator
impl !Sync for GemmaGenerator
impl !UnwindSafe for GemmaGenerator
impl Freeze for GemmaGenerator
impl Send for GemmaGenerator
impl Unpin for GemmaGenerator
impl UnsafeUnpin for GemmaGenerator
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more