foundation_models/generation/
mod.rs1use serde_json::{Map, Value};
4
5use crate::ffi;
6
7#[derive(Debug, Clone, Copy, PartialEq, Default)]
9#[non_exhaustive]
10pub enum SamplingMode {
11 #[default]
13 Default,
14 Greedy,
16 TopK(u32),
18 TopP(f64),
21}
22
23#[derive(Debug, Clone, Copy, Default, PartialEq)]
26pub struct GenerationOptions {
27 temperature: Option<f64>,
28 max_tokens: Option<u32>,
29 sampling: SamplingMode,
30 sampling_seed: Option<u64>,
31}
32
33impl GenerationOptions {
34 #[must_use]
36 pub const fn new() -> Self {
37 Self {
38 temperature: None,
39 max_tokens: None,
40 sampling: SamplingMode::Default,
41 sampling_seed: None,
42 }
43 }
44
45 #[must_use]
48 pub const fn with_temperature(mut self, temperature: f64) -> Self {
49 self.temperature = Some(temperature);
50 self
51 }
52
53 #[must_use]
55 pub const fn with_maximum_response_tokens(mut self, tokens: u32) -> Self {
56 self.max_tokens = Some(tokens);
57 self
58 }
59
60 #[must_use]
62 pub const fn with_sampling(mut self, mode: SamplingMode) -> Self {
63 self.sampling = mode;
64 self
65 }
66
67 #[must_use]
69 pub const fn with_sampling_seed(mut self, seed: u64) -> Self {
70 self.sampling_seed = Some(seed);
71 self
72 }
73
74 #[must_use]
76 pub const fn temperature(self) -> Option<f64> {
77 self.temperature
78 }
79
80 #[must_use]
82 pub const fn maximum_response_tokens(self) -> Option<u32> {
83 self.max_tokens
84 }
85
86 #[must_use]
88 pub const fn sampling(self) -> SamplingMode {
89 self.sampling
90 }
91
92 #[must_use]
94 pub const fn sampling_seed(self) -> Option<u64> {
95 self.sampling_seed
96 }
97
98 pub(crate) fn to_ffi(self) -> ffi::FFIGenerationOptions {
100 let (mode_code, top_k, top_p) = match self.sampling {
101 SamplingMode::Default => (0, 0, 0.0),
102 SamplingMode::Greedy => (1, 0, 0.0),
103 SamplingMode::TopK(k) => (2, i32::try_from(k).unwrap_or(i32::MAX), 0.0),
104 SamplingMode::TopP(p) => (3, 0, p),
105 };
106 ffi::FFIGenerationOptions {
107 temperature: self.temperature.unwrap_or(f64::NAN),
108 maximum_response_tokens: self
109 .max_tokens
110 .map_or(0, |tokens| i32::try_from(tokens).unwrap_or(i32::MAX)),
111 sampling_mode: mode_code,
112 top_k,
113 top_p,
114 random_seed: self.sampling_seed.unwrap_or(0),
115 has_random_seed: self.sampling_seed.is_some(),
116 }
117 }
118
119 pub(crate) fn to_transcript_json_value(self) -> Value {
120 let mut map = Map::new();
121 if let Some(temperature) = self.temperature {
122 map.insert("temperature".into(), Value::from(temperature));
123 }
124 if let Some(max_tokens) = self.max_tokens {
125 map.insert("maximumResponseTokens".into(), Value::from(max_tokens));
126 }
127 if let Some(seed) = self.sampling_seed {
128 map.insert("randomSeed".into(), Value::from(seed));
129 }
130 match self.sampling {
131 SamplingMode::Default | SamplingMode::Greedy => {}
132 SamplingMode::TopK(k) => {
133 map.insert("topK".into(), Value::from(k));
134 }
135 SamplingMode::TopP(p) => {
136 map.insert("topP".into(), Value::from(p));
137 }
138 }
139 Value::Object(map)
140 }
141
142 #[must_use]
143 pub(crate) fn from_transcript_json_value(value: Option<&Value>) -> Self {
144 let Some(Value::Object(map)) = value else {
145 return Self::new();
146 };
147 let sampling = if let Some(top_k) = map.get("topK").and_then(Value::as_u64) {
148 SamplingMode::TopK(u32::try_from(top_k).unwrap_or(u32::MAX))
149 } else if let Some(top_p) = map.get("topP").and_then(Value::as_f64) {
150 SamplingMode::TopP(top_p)
151 } else {
152 SamplingMode::Default
153 };
154 Self {
155 temperature: map.get("temperature").and_then(Value::as_f64),
156 max_tokens: map
157 .get("maximumResponseTokens")
158 .and_then(Value::as_u64)
159 .and_then(|tokens| u32::try_from(tokens).ok()),
160 sampling,
161 sampling_seed: map.get("randomSeed").and_then(Value::as_u64),
162 }
163 }
164}