pub mod chat_completions;
pub mod completions;
pub mod models;
pub mod nvext;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::{
fmt::Display,
ops::{Add, Div, Mul, Sub},
};
use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider},
ContentProvider,
};
pub const MIN_TEMPERATURE: f32 = 0.0;
pub const MAX_TEMPERATURE: f32 = 2.0;
pub const TEMPERATURE_RANGE: (f32, f32) = (MIN_TEMPERATURE, MAX_TEMPERATURE);
pub const MIN_TOP_P: f32 = 0.0;
pub const MAX_TOP_P: f32 = 1.0;
pub const TOP_P_RANGE: (f32, f32) = (MIN_TOP_P, MAX_TOP_P);
pub const MIN_FREQUENCY_PENALTY: f32 = -2.0;
pub const MAX_FREQUENCY_PENALTY: f32 = 2.0;
pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (MIN_FREQUENCY_PENALTY, MAX_FREQUENCY_PENALTY);
pub const MIN_PRESENCE_PENALTY: f32 = -2.0;
pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY);
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct CompletionUsage {
pub completion_tokens: i32,
pub prompt_tokens: i32,
pub total_tokens: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<CompletionTokensDetails>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<PromptTokensDetails>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CompletionTokensDetails {
pub audio_tokens: Option<i32>,
pub reasoning_tokens: Option<i32>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PromptTokensDetails {
pub audio_tokens: Option<i32>,
pub cached_tokens: Option<i32>,
}
#[derive(Serialize, Deserialize, Debug)]
pub enum StreamingDelta<R> {
Delta(R),
Comment(String),
}
#[derive(Serialize, Deserialize, Debug)]
pub struct AnnotatedDelta<R> {
pub delta: R,
pub id: Option<String>,
pub event: Option<String>,
pub comment: Option<String>,
}
trait OpenAISamplingOptionsProvider {
fn get_temperature(&self) -> Option<f32>;
fn get_top_p(&self) -> Option<f32>;
fn get_frequency_penalty(&self) -> Option<f32>;
fn get_presence_penalty(&self) -> Option<f32>;
fn nvext(&self) -> Option<&nvext::NvExt>;
}
trait OpenAIStopConditionsProvider {
fn get_max_tokens(&self) -> Option<u32>;
fn get_min_tokens(&self) -> Option<u32>;
fn get_stop(&self) -> Option<Vec<String>>;
fn nvext(&self) -> Option<&nvext::NvExt>;
}
impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T {
fn extract_sampling_options(&self) -> Result<common::SamplingOptions> {
let mut temperature = validate_range(self.get_temperature(), &TEMPERATURE_RANGE)
.map_err(|e| anyhow::anyhow!("Error validating temperature: {}", e))?;
let mut top_p = validate_range(self.get_top_p(), &TOP_P_RANGE)
.map_err(|e| anyhow::anyhow!("Error validating top_p: {}", e))?;
let frequency_penalty =
validate_range(self.get_frequency_penalty(), &FREQUENCY_PENALTY_RANGE)
.map_err(|e| anyhow::anyhow!("Error validating frequency_penalty: {}", e))?;
let presence_penalty = validate_range(self.get_presence_penalty(), &PRESENCE_PENALTY_RANGE)
.map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?;
if let Some(nvext) = self.nvext() {
let greedy = nvext.greed_sampling.unwrap_or(false);
if greedy {
top_p = None;
temperature = None;
}
}
Ok(common::SamplingOptions {
n: None,
best_of: None,
frequency_penalty,
presence_penalty,
repetition_penalty: None,
temperature,
top_p,
top_k: None,
min_p: None,
seed: None,
use_beam_search: None,
length_penalty: None,
})
}
}
impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
fn extract_stop_conditions(&self) -> Result<common::StopConditions> {
let max_tokens = self.get_max_tokens();
let min_tokens = self.get_min_tokens();
let stop = self.get_stop();
if let Some(stop) = &stop {
if stop.len() > 4 {
anyhow::bail!("stop conditions must be less than 4")
}
}
let mut ignore_eos = None;
if let Some(nvext) = self.nvext() {
ignore_eos = nvext.ignore_eos;
}
Ok(common::StopConditions {
max_tokens,
min_tokens,
stop,
stop_token_ids_hidden: None,
ignore_eos,
})
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct GenericCompletionResponse<C>
{
pub id: String,
pub choices: Vec<C>,
pub created: u64,
pub model: String,
pub object: String,
pub usage: Option<CompletionUsage>,
pub system_fingerprint: Option<String>,
}
fn validate_range<T>(value: Option<T>, range: &(T, T)) -> Result<Option<T>>
where
T: PartialOrd + Display,
{
if value.is_none() {
return Ok(None);
}
let value = value.unwrap();
if value < range.0 || value > range.1 {
anyhow::bail!("Value {} is out of range [{}, {}]", value, range.0, range.1);
}
Ok(Some(value))
}
pub fn scale_value<T>(value: &T, src: &(T, T), dst: &(T, T)) -> Result<T>
where
T: Copy
+ PartialOrd
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ From<f32>,
{
let dst_range = dst.1 - dst.0;
let src_range = src.1 - src.0;
if dst_range == T::from(0.0) {
anyhow::bail!("dst range is 0");
}
if src_range == T::from(0.0) {
anyhow::bail!("src range is 0");
}
let value_scaled = (*value - src.0) / src_range;
Ok(dst.0 + (value_scaled * dst_range))
}
pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debug>:
Send + Sync + 'static
{
fn choice_from_postprocessor(
&mut self,
response: common::llm_backend::BackendOutput,
) -> Result<ResponseType>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_range() {
assert_eq!(validate_range(Some(0.5), &(0.0, 1.0)).unwrap(), Some(0.5));
assert_eq!(validate_range(Some(0.0), &(0.0, 1.0)).unwrap(), Some(0.0));
assert_eq!(validate_range(Some(1.0), &(1.0, 1.0)).unwrap(), Some(1.0));
assert_eq!(validate_range(Some(1_i32), &(1, 1)).unwrap(), Some(1));
assert_eq!(
validate_range(Some(1.1), &(0.0, 1.0))
.unwrap_err()
.to_string(),
"Value 1.1 is out of range [0, 1]"
);
assert_eq!(
validate_range(Some(-0.1), &(0.0, 1.0))
.unwrap_err()
.to_string(),
"Value -0.1 is out of range [0, 1]"
);
}
#[test]
fn test_scaled_value() {
assert_eq!(scale_value(&0.5, &(0.0, 1.0), &(0.0, 2.0)).unwrap(), 1.0);
assert_eq!(scale_value(&0.0, &(0.0, 1.0), &(0.0, 2.0)).unwrap(), 0.0);
assert_eq!(scale_value(&-1.0, &(-2.0, 2.0), &(1.0, 2.0)).unwrap(), 1.25);
assert!(scale_value(&1.0, &(1.0, 1.0), &(0.0, 2.0)).is_err());
}
}