Skip to main content

foundation_models/generation/
mod.rs

1//! Knobs that control how the model produces text.
2
3use serde_json::{Map, Value};
4
5use crate::ffi;
6
7/// Strategy used when sampling the next token.
8#[derive(Debug, Clone, Copy, PartialEq, Default)]
9#[non_exhaustive]
10pub enum SamplingMode {
11    /// Defer to `FoundationModels`' default sampling strategy.
12    #[default]
13    Default,
14    /// Always pick the highest-probability token. Deterministic.
15    Greedy,
16    /// Sample from the top-`k` most probable tokens.
17    TopK(u32),
18    /// Nucleus sampling: smallest set of tokens whose cumulative probability
19    /// exceeds `p` (must be in `0.0..=1.0`).
20    TopP(f64),
21}
22
23/// Generation knobs. All fields are optional; unset fields keep the model's
24/// defaults.
25#[derive(Debug, Clone, Copy, Default, PartialEq)]
26pub struct GenerationOptions {
27    temperature: Option<f64>,
28    max_tokens: Option<u32>,
29    sampling: SamplingMode,
30    sampling_seed: Option<u64>,
31}
32
33impl GenerationOptions {
34    /// Create options with all fields set to their defaults.
35    #[must_use]
36    pub const fn new() -> Self {
37        Self {
38            temperature: None,
39            max_tokens: None,
40            sampling: SamplingMode::Default,
41            sampling_seed: None,
42        }
43    }
44
45    /// Sampling temperature; higher values produce more varied output.
46    /// `FoundationModels` accepts values in `0.0..=2.0`.
47    #[must_use]
48    pub const fn with_temperature(mut self, temperature: f64) -> Self {
49        self.temperature = Some(temperature);
50        self
51    }
52
53    /// Hard cap on the number of tokens the model may emit.
54    #[must_use]
55    pub const fn with_maximum_response_tokens(mut self, tokens: u32) -> Self {
56        self.max_tokens = Some(tokens);
57        self
58    }
59
60    /// Override the sampling strategy.
61    #[must_use]
62    pub const fn with_sampling(mut self, mode: SamplingMode) -> Self {
63        self.sampling = mode;
64        self
65    }
66
67    /// Use a deterministic random seed for non-greedy sampling.
68    #[must_use]
69    pub const fn with_sampling_seed(mut self, seed: u64) -> Self {
70        self.sampling_seed = Some(seed);
71        self
72    }
73
74    /// Sampling temperature, if explicitly set.
75    #[must_use]
76    pub const fn temperature(self) -> Option<f64> {
77        self.temperature
78    }
79
80    /// The explicit token cap, if any.
81    #[must_use]
82    pub const fn maximum_response_tokens(self) -> Option<u32> {
83        self.max_tokens
84    }
85
86    /// The configured sampling strategy.
87    #[must_use]
88    pub const fn sampling(self) -> SamplingMode {
89        self.sampling
90    }
91
92    /// The deterministic random seed for top-k / top-p sampling.
93    #[must_use]
94    pub const fn sampling_seed(self) -> Option<u64> {
95        self.sampling_seed
96    }
97
98    /// Lower into the C-compatible struct shared with Swift.
99    pub(crate) fn to_ffi(self) -> ffi::FFIGenerationOptions {
100        let (mode_code, top_k, top_p) = match self.sampling {
101            SamplingMode::Default => (0, 0, 0.0),
102            SamplingMode::Greedy => (1, 0, 0.0),
103            SamplingMode::TopK(k) => (2, i32::try_from(k).unwrap_or(i32::MAX), 0.0),
104            SamplingMode::TopP(p) => (3, 0, p),
105        };
106        ffi::FFIGenerationOptions {
107            temperature: self.temperature.unwrap_or(f64::NAN),
108            maximum_response_tokens: self
109                .max_tokens
110                .map_or(0, |tokens| i32::try_from(tokens).unwrap_or(i32::MAX)),
111            sampling_mode: mode_code,
112            top_k,
113            top_p,
114            random_seed: self.sampling_seed.unwrap_or(0),
115            has_random_seed: self.sampling_seed.is_some(),
116        }
117    }
118
119    pub(crate) fn to_transcript_json_value(self) -> Value {
120        let mut map = Map::new();
121        if let Some(temperature) = self.temperature {
122            map.insert("temperature".into(), Value::from(temperature));
123        }
124        if let Some(max_tokens) = self.max_tokens {
125            map.insert("maximumResponseTokens".into(), Value::from(max_tokens));
126        }
127        if let Some(seed) = self.sampling_seed {
128            map.insert("randomSeed".into(), Value::from(seed));
129        }
130        match self.sampling {
131            SamplingMode::Default | SamplingMode::Greedy => {}
132            SamplingMode::TopK(k) => {
133                map.insert("topK".into(), Value::from(k));
134            }
135            SamplingMode::TopP(p) => {
136                map.insert("topP".into(), Value::from(p));
137            }
138        }
139        Value::Object(map)
140    }
141
142    #[must_use]
143    pub(crate) fn from_transcript_json_value(value: Option<&Value>) -> Self {
144        let Some(Value::Object(map)) = value else {
145            return Self::new();
146        };
147        let sampling = if let Some(top_k) = map.get("topK").and_then(Value::as_u64) {
148            SamplingMode::TopK(u32::try_from(top_k).unwrap_or(u32::MAX))
149        } else if let Some(top_p) = map.get("topP").and_then(Value::as_f64) {
150            SamplingMode::TopP(top_p)
151        } else {
152            SamplingMode::Default
153        };
154        Self {
155            temperature: map.get("temperature").and_then(Value::as_f64),
156            max_tokens: map
157                .get("maximumResponseTokens")
158                .and_then(Value::as_u64)
159                .and_then(|tokens| u32::try_from(tokens).ok()),
160            sampling,
161            sampling_seed: map.get("randomSeed").and_then(Value::as_u64),
162        }
163    }
164}