Skip to main content

foundation_models/generation/
mod.rs

1//! Knobs that control how the model produces text.
2
3use crate::ffi;
4
5/// Strategy used when sampling the next token.
6#[derive(Debug, Clone, Copy, PartialEq)]
7#[non_exhaustive]
8#[derive(Default)]
9pub enum SamplingMode {
10    /// Defer to `FoundationModels`' default sampling strategy.
11    #[default]
12    Default,
13    /// Always pick the highest-probability token. Deterministic.
14    Greedy,
15    /// Sample from the top-`k` most probable tokens.
16    TopK(u32),
17    /// Nucleus sampling: smallest set of tokens whose cumulative probability
18    /// exceeds `p` (must be in `0.0..=1.0`).
19    TopP(f64),
20}
21
22/// Generation knobs. All fields are optional; unset fields keep the
23/// model's defaults.
24///
25/// # Examples
26///
27/// ```rust
28/// use foundation_models::{GenerationOptions, SamplingMode};
29///
30/// let opts = GenerationOptions::new()
31///     .with_temperature(0.7)
32///     .with_maximum_response_tokens(500)
33///     .with_sampling(SamplingMode::TopP(0.9));
34/// ```
35#[derive(Debug, Clone, Copy, Default)]
36pub struct GenerationOptions {
37    temperature: Option<f64>,
38    max_tokens: Option<u32>,
39    sampling: SamplingMode,
40}
41
42impl GenerationOptions {
43    /// Create options with all fields set to their defaults.
44    #[must_use]
45    pub const fn new() -> Self {
46        Self {
47            temperature: None,
48            max_tokens: None,
49            sampling: SamplingMode::Default,
50        }
51    }
52
53    /// Sampling temperature; higher values produce more varied output.
54    /// `FoundationModels` accepts values in `0.0..=2.0`.
55    #[must_use]
56    pub const fn with_temperature(mut self, temperature: f64) -> Self {
57        self.temperature = Some(temperature);
58        self
59    }
60
61    /// Hard cap on the number of tokens the model may emit.
62    #[must_use]
63    pub const fn with_maximum_response_tokens(mut self, tokens: u32) -> Self {
64        self.max_tokens = Some(tokens);
65        self
66    }
67
68    /// Override the sampling strategy.
69    #[must_use]
70    pub const fn with_sampling(mut self, mode: SamplingMode) -> Self {
71        self.sampling = mode;
72        self
73    }
74
75    /// Lower into the C-compatible struct shared with Swift.
76    pub(crate) fn to_ffi(self) -> ffi::FFIGenerationOptions {
77        let (mode_code, top_k, top_p) = match self.sampling {
78            SamplingMode::Default => (0, 0, 0.0),
79            SamplingMode::Greedy => (1, 0, 0.0),
80            SamplingMode::TopK(k) => (2, i32::try_from(k).unwrap_or(i32::MAX), 0.0),
81            SamplingMode::TopP(p) => (3, 0, p),
82        };
83        ffi::FFIGenerationOptions {
84            temperature: self.temperature.unwrap_or(f64::NAN),
85            maximum_response_tokens: self
86                .max_tokens
87                .map_or(0, |t| i32::try_from(t).unwrap_or(i32::MAX)),
88            sampling_mode: mode_code,
89            top_k,
90            top_p,
91        }
92    }
93}