Skip to main content

GemmaGenerator

Struct GemmaGenerator 

Source
pub struct GemmaGenerator { /* private fields */ }
Expand description

Stateful Gemma generation handle.

Holds the (config, weight bytes, token history) and rebuilds a prefill graph on each [step] call. Cheap to construct after initial weight load; tokens stay in-memory between calls.

Implementations§

Source§

impl GemmaGenerator

Source

pub fn from_loader( cfg: GemmaConfig, loader: &mut dyn WeightLoader, device: Device, ) -> Result<GemmaGenerator, Error>

Construct from any WeightLoader — drains it into an internal cache so the loader is free after this call.

Source

pub fn from_loader_at( cfg: GemmaConfig, loader: &mut dyn WeightLoader, device: Device, weights_path: &Path, ) -> Result<GemmaGenerator, Error>

Like Self::from_loader but loads tier-1 profiles from gemma.rlx.toml in the weights directory when present.

Source

pub fn with_compile_profiles( self, prefill: CompileProfile, decode: CompileProfile, ) -> GemmaGenerator

Override tier-1 compile profiles explicitly.

Source

pub fn prefill_profile(&self) -> &CompileProfile

Source

pub fn decode_profile(&self) -> &CompileProfile

Source

pub fn with_prefill_cache(self, capacity: usize) -> GemmaGenerator

Enable the prefill compile cache with the given LRU capacity. Useful when the same prompt length is used across multiple generation runs — the second + Nth run skip the compile + param-attach roundtrip (~30-50ms per call on CPU).

Source

pub fn with_dynamic_prefill_cache(self, capacity: usize) -> GemmaGenerator

Compile prefill once with sym::SEQ, specialize per prompt length.

Source

pub fn with_decode_cache(self, max_past: usize) -> GemmaGenerator

Enable the bucketed decode compile cache spanning past-seq values in [1, max_past]. Buckets are power-of-two [1..2, 2..3, 3..5, 5..9, 9..17, …]. Each bucket compiles one graph at its upper bound; a steady-state generation loop across N tokens compiles O(log N) graphs instead of N.

Padding compute waste is bounded at 2×: actual past_seq is at least half the bucket’s upper bound (except possibly the smallest bucket).

Source

pub fn with_dynamic_decode_cache(self, capacity: usize) -> GemmaGenerator

Compile decode once with sym::PAST_SEQ, specialize per prefix length.

Source

pub fn with_inference_caches(self, max_seq: usize) -> GemmaGenerator

Production inference caches: dynamic prefill (+ multimodal embed prefill) and bucketed decode. Bucketed decode avoids the per-step specialize overhead of RLX_GEMMA_DYNAMIC_DECODE=1 (experimental; often much slower on Metal).

Source

pub fn sync_device(&mut self)

Wait for in-flight Metal command buffers on all cached graphs. Call between heavy inference phases to avoid MPS lifecycle warnings.

Source

pub fn from_path( cfg: GemmaConfig, path: &str, device: Device, ) -> Result<GemmaGenerator, Error>

Convenience: load weights from a safetensors or GGUF path (dispatch by extension; see rlx_core::weight_loader::load_from_path).

Source

pub fn from_path_with_mtp( cfg: GemmaConfig, path: &str, device: Device, include_mtp: bool, ) -> Result<GemmaGenerator, Error>

Same as [from_path] but with MTP-head visibility control. When include_mtp=true and the file is GGUF, MTP weights are drained into the generator’s cache alongside the base weights. The base inference path still ignores them — they sit in cache for a future MTP-aware decoder. Non-GGUF formats silently ignore the flag (safetensors files publish all tensors uniformly; downstream code distinguishes by name).

Source

pub fn prefill(&mut self, prompt_ids: &[u32])

Replace the token history with prompt_ids. Does not run the model — the next [step] call processes the full sequence. Clears any KV cache from a prior generation.

Source

pub fn prefill_from_embeds( &mut self, prompt_ids: &[u32], embeds: &[f32], attn_bias: Option<Vec<f32>>, ) -> Result<(), Error>

Like [prefill], but the next cached prefill uses fused inputs_embeds (prefill_hidden) instead of token lookup.

Source

pub fn weights_cache(&self) -> &HashMap<String, (Vec<f32>, Vec<usize>)>

Weight table for CPU-side embedding / fusion helpers.

Source

pub fn step(&mut self, opts: SampleOpts) -> Result<u32, Error>

Run one prefill over the current token history and sample the next token. The sampled token is appended to the history and returned. Call repeatedly to generate.

Source

pub fn generate( &mut self, n: usize, opts: SampleOpts, ) -> Result<Vec<u32>, Error>

Run n steps and return the newly generated token ids (excludes the prefill prompt).

Source

pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32, Error>

Cached step: O(L) per token instead of O(L²). First call seeds the KV cache from the prompt via prefill-with-cache; subsequent calls run the decode-mode graph on just the last token + cached past. Output is bit-identical to [step] modulo reduction order in the SDPA kernel.

Invariant after each call: cache.past_seq == tokens.len() - 1 (the just-sampled token is appended but not yet in the cache; it becomes the input for the next decode step).

Source

pub fn generate_from_embeds( &mut self, prompt_ids: &[u32], embeds: &[f32], n: usize, opts: SampleOpts, ) -> Result<Vec<u32>, Error>

Run n cached steps after [prefill_from_embeds].

Source

pub fn generate_from_embeds_with_bias( &mut self, prompt_ids: &[u32], embeds: &[f32], attn_bias: Option<Vec<f32>>, n: usize, opts: SampleOpts, ) -> Result<Vec<u32>, Error>

Source

pub fn generate_from_embeds_with( &mut self, prompt_ids: &[u32], embeds: &[f32], n: usize, opts: SampleOpts, on_token: impl FnMut(u32), ) -> Result<Vec<u32>, Error>

Streaming variant of [generate_from_embeds].

Source

pub fn generate_from_embeds_with_bias_and_callback( &mut self, prompt_ids: &[u32], embeds: &[f32], attn_bias: Option<Vec<f32>>, n: usize, opts: SampleOpts, on_token: impl FnMut(u32), ) -> Result<Vec<u32>, Error>

Source

pub fn generate_cached( &mut self, n: usize, opts: SampleOpts, ) -> Result<Vec<u32>, Error>

Run n cached steps and return the newly generated tokens.

Source

pub fn generate_cached_with( &mut self, n: usize, opts: SampleOpts, on_token: impl FnMut(u32), ) -> Result<Vec<u32>, Error>

Same as [generate_cached] but invokes on_token once per freshly sampled id, inside the decode loop. The whole n step loop shares the bucketed compile cache — callers wanting a streaming UI should prefer this to calling generate_cached(1, …) n times (which forces a fresh compile per token at the bucket boundaries).

Source

pub fn tokens(&self) -> &[u32]

Full token history (prompt + generated).

Source

pub fn config(&self) -> &GemmaConfig

Source

pub fn prefill_get_last_logits( &mut self, context: &[u32], ) -> Result<Vec<f32>, Error>

Low-level primitive: reset internal state, run prefill-with-cache over context, and return the last position’s logits row (P(next_token | context)). Does NOT sample or append. The internal tokens buffer is set to context and the KV cache is populated to past_seq = context.len().

First row of logits after prefill-with-cache (no sampling).

Source

pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>, Error>

Low-level primitive: run one decode step with the caller- supplied input token (no sampling), advance the KV cache, and return the resulting logits row P(next | history ++ input). Appends input to the tokens buffer so the invariant cache.past_seq == tokens.len() holds after this call (note: differs from step_cached invariant because this method does not append a sampled token).

Trait Implementations§

Source§

impl Drop for GemmaGenerator

Source§

fn drop(&mut self)

Executes the destructor for this type. Read more
Source§

fn pin_drop(self: Pin<&mut Self>)

🔬This is a nightly-only experimental API. (pin_ergonomics)
Execute the destructor for this type, but different to Drop::drop, it requires self to be pinned. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> Read<Exclusive, BecauseExclusive> for T
where T: ?Sized,

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V