fm_rs/
options.rs

1//! Generation options for controlling model output.
2
3use serde::{Deserialize, Serialize};
4
5/// Sampling strategy for token generation.
6#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum Sampling {
9    /// Greedy sampling: always pick the most likely token.
10    Greedy,
11    /// Random sampling with temperature.
12    #[default]
13    Random,
14}
15
16/// Options that control how the model generates its response.
17///
18/// Use the builder pattern to configure options:
19///
20/// ```rust
21/// use fm_rs::GenerationOptions;
22///
23/// let options = GenerationOptions::builder()
24///     .temperature(0.7)
25///     .max_response_tokens(500)
26///     .build();
27/// ```
28#[derive(Debug, Clone, Default, Serialize, Deserialize)]
29pub struct GenerationOptions {
30    /// Temperature for sampling (0.0-2.0).
31    /// Higher values produce more random outputs.
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub temperature: Option<f64>,
34
35    /// Sampling strategy.
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub sampling: Option<Sampling>,
38
39    /// Maximum number of tokens in the response.
40    #[serde(skip_serializing_if = "Option::is_none")]
41    #[serde(rename = "maximumResponseTokens")]
42    pub max_response_tokens: Option<u32>,
43
44    /// Random seed for reproducible generation.
45    ///
46    /// **Note**: This field is currently not supported by Apple's `GenerationOptions` API
47    /// and is ignored. It is included for potential future use.
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub seed: Option<u64>,
50}
51
52impl GenerationOptions {
53    /// Creates a new builder for configuring generation options.
54    pub fn builder() -> GenerationOptionsBuilder {
55        GenerationOptionsBuilder::default()
56    }
57
58    /// Serializes the options to JSON for FFI.
59    pub fn to_json(&self) -> String {
60        serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string())
61    }
62}
63
64/// Builder for configuring [`GenerationOptions`].
65#[derive(Debug, Default)]
66pub struct GenerationOptionsBuilder {
67    temperature: Option<f64>,
68    sampling: Option<Sampling>,
69    max_response_tokens: Option<u32>,
70    seed: Option<u64>,
71}
72
73impl GenerationOptionsBuilder {
74    /// Sets the temperature for generation.
75    ///
76    /// Temperature influences the confidence of the model's response.
77    /// Higher values (e.g., 1.5) produce more random outputs.
78    /// Lower values (e.g., 0.2) produce more deterministic outputs.
79    ///
80    /// Valid range: 0.0 to 2.0. Values outside this range are ignored
81    /// and the default temperature is used instead.
82    pub fn temperature(mut self, temp: f64) -> Self {
83        if (0.0..=2.0).contains(&temp) {
84            self.temperature = Some(temp);
85        }
86        self
87    }
88
89    /// Sets the temperature, returning an error if out of range.
90    ///
91    /// This is the fallible version of [`temperature`](Self::temperature).
92    /// Use this when you want to catch invalid temperature values at build time.
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if `temp` is not in the range 0.0 to 2.0.
97    pub fn try_temperature(mut self, temp: f64) -> Result<Self, crate::Error> {
98        if (0.0..=2.0).contains(&temp) {
99            self.temperature = Some(temp);
100            Ok(self)
101        } else {
102            Err(crate::Error::InvalidInput(format!(
103                "Temperature must be between 0.0 and 2.0, got {temp}"
104            )))
105        }
106    }
107
108    /// Sets the sampling strategy.
109    pub fn sampling(mut self, sampling: Sampling) -> Self {
110        self.sampling = Some(sampling);
111        self
112    }
113
114    /// Sets the maximum number of tokens in the response.
115    ///
116    /// Only use this when you need to protect against unexpectedly verbose responses.
117    /// Enforcing a strict token limit can lead to malformed or grammatically incorrect output.
118    pub fn max_response_tokens(mut self, tokens: u32) -> Self {
119        if tokens > 0 {
120            self.max_response_tokens = Some(tokens);
121        }
122        self
123    }
124
125    /// Sets the random seed for reproducible generation.
126    ///
127    /// **Note**: This is currently not supported by Apple's `GenerationOptions` API
128    /// and will be ignored. Included for potential future use.
129    pub fn seed(mut self, seed: u64) -> Self {
130        self.seed = Some(seed);
131        self
132    }
133
134    /// Builds the [`GenerationOptions`].
135    pub fn build(self) -> GenerationOptions {
136        GenerationOptions {
137            temperature: self.temperature,
138            sampling: self.sampling,
139            max_response_tokens: self.max_response_tokens,
140            seed: self.seed,
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_default_options() {
151        let options = GenerationOptions::default();
152        assert!(options.temperature.is_none());
153        assert!(options.sampling.is_none());
154        assert!(options.max_response_tokens.is_none());
155    }
156
157    #[test]
158    fn test_builder() {
159        let options = GenerationOptions::builder()
160            .temperature(0.7)
161            .sampling(Sampling::Random)
162            .max_response_tokens(500)
163            .seed(42)
164            .build();
165
166        assert_eq!(options.temperature, Some(0.7));
167        assert_eq!(options.sampling, Some(Sampling::Random));
168        assert_eq!(options.max_response_tokens, Some(500));
169        assert_eq!(options.seed, Some(42));
170    }
171
172    #[test]
173    fn test_temperature_bounds() {
174        // Valid temperature
175        let options = GenerationOptions::builder().temperature(1.5).build();
176        assert_eq!(options.temperature, Some(1.5));
177
178        // Out of bounds (negative)
179        let options = GenerationOptions::builder().temperature(-0.5).build();
180        assert!(options.temperature.is_none());
181
182        // Out of bounds (too high)
183        let options = GenerationOptions::builder().temperature(3.0).build();
184        assert!(options.temperature.is_none());
185    }
186
187    #[test]
188    fn test_json_serialization() {
189        let options = GenerationOptions::builder()
190            .temperature(0.7)
191            .max_response_tokens(100)
192            .build();
193
194        let json = options.to_json();
195        assert!(json.contains("temperature"));
196        assert!(json.contains("0.7"));
197    }
198}