1pub mod chat_completions;
17pub mod completions;
18pub mod embeddings;
19pub mod models;
20pub mod nvext;
21
22use anyhow::Result;
23use serde::{Deserialize, Serialize};
24use std::{
25 fmt::Display,
26 ops::{Add, Div, Mul, Sub},
27};
28
29use super::{
30 common::{self, SamplingOptionsProvider, StopConditionsProvider},
31 ContentProvider,
32};
33
34pub const MIN_TEMPERATURE: f32 = 0.0;
36
37pub const MAX_TEMPERATURE: f32 = 2.0;
39
40pub const TEMPERATURE_RANGE: (f32, f32) = (MIN_TEMPERATURE, MAX_TEMPERATURE);
42
43pub const MIN_TOP_P: f32 = 0.0;
45
46pub const MAX_TOP_P: f32 = 1.0;
48
49pub const TOP_P_RANGE: (f32, f32) = (MIN_TOP_P, MAX_TOP_P);
51
52pub const MIN_FREQUENCY_PENALTY: f32 = -2.0;
54
55pub const MAX_FREQUENCY_PENALTY: f32 = 2.0;
57
58pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (MIN_FREQUENCY_PENALTY, MAX_FREQUENCY_PENALTY);
60
61pub const MIN_PRESENCE_PENALTY: f32 = -2.0;
63
64pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
66
67pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY);
69
70#[derive(Serialize, Deserialize, Debug, Clone, Default)]
72pub struct CompletionUsage {
73 pub completion_tokens: i32,
75
76 pub prompt_tokens: i32,
78
79 pub total_tokens: i32,
81
82 #[serde(skip_serializing_if = "Option::is_none")]
84 pub completion_tokens_details: Option<CompletionTokensDetails>,
85
86 #[serde(skip_serializing_if = "Option::is_none")]
88 pub prompt_tokens_details: Option<PromptTokensDetails>,
89}
90
91#[derive(Debug, Serialize, Deserialize, Clone)]
93pub struct CompletionTokensDetails {
94 pub audio_tokens: Option<i32>,
96
97 pub reasoning_tokens: Option<i32>,
99}
100
101#[derive(Debug, Serialize, Deserialize, Clone)]
103pub struct PromptTokensDetails {
104 pub audio_tokens: Option<i32>,
106
107 pub cached_tokens: Option<i32>,
109}
110
111#[derive(Serialize, Deserialize, Debug)]
121pub enum StreamingDelta<R> {
122 Delta(R),
124 Comment(String),
125}
126
127#[derive(Serialize, Deserialize, Debug)]
128pub struct AnnotatedDelta<R> {
129 pub delta: R,
130 pub id: Option<String>,
131 pub event: Option<String>,
132 pub comment: Option<String>,
133}
134
135trait OpenAISamplingOptionsProvider {
136 fn get_temperature(&self) -> Option<f32>;
137
138 fn get_top_p(&self) -> Option<f32>;
139
140 fn get_frequency_penalty(&self) -> Option<f32>;
141
142 fn get_presence_penalty(&self) -> Option<f32>;
143
144 fn nvext(&self) -> Option<&nvext::NvExt>;
145}
146
147trait OpenAIStopConditionsProvider {
148 fn get_max_tokens(&self) -> Option<u32>;
149
150 fn get_min_tokens(&self) -> Option<u32>;
151
152 fn get_stop(&self) -> Option<Vec<String>>;
153
154 fn nvext(&self) -> Option<&nvext::NvExt>;
155}
156
157impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T {
158 fn extract_sampling_options(&self) -> Result<common::SamplingOptions> {
159 let mut temperature = validate_range(self.get_temperature(), &TEMPERATURE_RANGE)
165 .map_err(|e| anyhow::anyhow!("Error validating temperature: {}", e))?;
166 let mut top_p = validate_range(self.get_top_p(), &TOP_P_RANGE)
167 .map_err(|e| anyhow::anyhow!("Error validating top_p: {}", e))?;
168 let frequency_penalty =
169 validate_range(self.get_frequency_penalty(), &FREQUENCY_PENALTY_RANGE)
170 .map_err(|e| anyhow::anyhow!("Error validating frequency_penalty: {}", e))?;
171 let presence_penalty = validate_range(self.get_presence_penalty(), &PRESENCE_PENALTY_RANGE)
172 .map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?;
173
174 if let Some(nvext) = self.nvext() {
175 let greedy = nvext.greed_sampling.unwrap_or(false);
176 if greedy {
177 top_p = None;
178 temperature = None;
179 }
180 }
181
182 Ok(common::SamplingOptions {
183 n: None,
184 best_of: None,
185 frequency_penalty,
186 presence_penalty,
187 repetition_penalty: None,
188 temperature,
189 top_p,
190 top_k: None,
191 min_p: None,
192 seed: None,
193 use_beam_search: None,
194 length_penalty: None,
195 })
196 }
197}
198
199impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
200 fn extract_stop_conditions(&self) -> Result<common::StopConditions> {
201 let max_tokens = self.get_max_tokens();
202 let min_tokens = self.get_min_tokens();
203 let stop = self.get_stop();
204
205 if let Some(stop) = &stop {
206 if stop.len() > 4 {
207 anyhow::bail!("stop conditions must be less than 4")
208 }
209 }
210
211 let mut ignore_eos = None;
212
213 if let Some(nvext) = self.nvext() {
214 ignore_eos = nvext.ignore_eos;
215 }
216
217 Ok(common::StopConditions {
218 max_tokens,
219 min_tokens,
220 stop,
221 stop_token_ids_hidden: None,
222 ignore_eos,
223 })
224 }
225}
226
227#[derive(Serialize, Deserialize, Debug, Clone)]
230pub struct GenericCompletionResponse<C>
231{
234 pub id: String,
236
237 pub choices: Vec<C>,
239
240 pub created: u64,
242
243 pub model: String,
245
246 pub object: String,
249
250 pub usage: Option<CompletionUsage>,
251
252 pub system_fingerprint: Option<String>,
261 }
263
264fn validate_range<T>(value: Option<T>, range: &(T, T)) -> Result<Option<T>>
266where
267 T: PartialOrd + Display,
268{
269 if value.is_none() {
270 return Ok(None);
271 }
272 let value = value.unwrap();
273 if value < range.0 || value > range.1 {
274 anyhow::bail!("Value {} is out of range [{}, {}]", value, range.0, range.1);
275 }
276 Ok(Some(value))
277}
278
279pub fn scale_value<T>(value: &T, src: &(T, T), dst: &(T, T)) -> Result<T>
282where
283 T: Copy
284 + PartialOrd
285 + Add<Output = T>
286 + Sub<Output = T>
287 + Mul<Output = T>
288 + Div<Output = T>
289 + From<f32>,
290{
291 let dst_range = dst.1 - dst.0;
292 let src_range = src.1 - src.0;
293 if dst_range == T::from(0.0) {
294 anyhow::bail!("dst range is 0");
295 }
296 if src_range == T::from(0.0) {
297 anyhow::bail!("src range is 0");
298 }
299 let value_scaled = (*value - src.0) / src_range;
300 Ok(dst.0 + (value_scaled * dst_range))
301}
302
303pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debug>:
304 Send + Sync + 'static
305{
306 fn choice_from_postprocessor(
307 &mut self,
308 response: common::llm_backend::BackendOutput,
309 ) -> Result<ResponseType>;
310
311 fn get_isl(&self) -> Option<u32>;
313}
314
315#[cfg(test)]
316mod tests {
317
318 use super::*;
319
320 #[test]
321 fn test_validate_range() {
322 assert_eq!(validate_range(Some(0.5), &(0.0, 1.0)).unwrap(), Some(0.5));
323 assert_eq!(validate_range(Some(0.0), &(0.0, 1.0)).unwrap(), Some(0.0));
324 assert_eq!(validate_range(Some(1.0), &(1.0, 1.0)).unwrap(), Some(1.0));
325 assert_eq!(validate_range(Some(1_i32), &(1, 1)).unwrap(), Some(1));
326 assert_eq!(
327 validate_range(Some(1.1), &(0.0, 1.0))
328 .unwrap_err()
329 .to_string(),
330 "Value 1.1 is out of range [0, 1]"
331 );
332 assert_eq!(
333 validate_range(Some(-0.1), &(0.0, 1.0))
334 .unwrap_err()
335 .to_string(),
336 "Value -0.1 is out of range [0, 1]"
337 );
338 }
339
340 #[test]
341 fn test_scaled_value() {
342 assert_eq!(scale_value(&0.5, &(0.0, 1.0), &(0.0, 2.0)).unwrap(), 1.0);
343 assert_eq!(scale_value(&0.0, &(0.0, 1.0), &(0.0, 2.0)).unwrap(), 0.0);
344 assert_eq!(scale_value(&-1.0, &(-2.0, 2.0), &(1.0, 2.0)).unwrap(), 1.25);
345 assert!(scale_value(&1.0, &(1.0, 1.0), &(0.0, 2.0)).is_err());
346 }
347}