use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use validator::Validate;
use super::{
common::*,
sampling_params::{validate_top_k_value, validate_top_p_value},
};
use crate::validated::Normalizable;
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize, Validate, schemars::JsonSchema)]
#[validate(schema(function = "validate_completion_cross_parameters"))]
pub struct CompletionRequest {
pub model: String,
#[validate(custom(function = "validate_completion_prompt"))]
pub prompt: StringOrArray,
#[validate(range(min = 0, max = 20))]
pub best_of: Option<u32>,
#[serde(default)]
pub echo: bool,
#[validate(range(min = -2.0, max = 2.0))]
pub frequency_penalty: Option<f32>,
pub logit_bias: Option<HashMap<String, f32>>,
#[validate(range(min = 0, max = 5))]
pub logprobs: Option<u32>,
#[validate(range(min = 0))]
pub max_tokens: Option<u32>,
#[validate(range(min = 1, max = 128))]
pub n: Option<u32>,
#[validate(range(min = -2.0, max = 2.0))]
pub presence_penalty: Option<f32>,
pub seed: Option<i64>,
#[validate(custom(function = "validate_stop"))]
pub stop: Option<StringOrArray>,
#[serde(default, deserialize_with = "deserialize_null_as_false")]
pub stream: bool,
pub stream_options: Option<StreamOptions>,
pub suffix: Option<String>,
#[validate(range(min = 0.0, max = 2.0))]
pub temperature: Option<f32>,
#[validate(custom(function = "validate_top_p_value"))]
pub top_p: Option<f32>,
pub user: Option<String>,
#[validate(custom(function = "validate_top_k_value"))]
pub top_k: Option<i32>,
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>,
#[validate(range(min = 1))]
pub min_tokens: Option<u32>,
#[validate(range(min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f32>,
pub regex: Option<String>,
pub ebnf: Option<String>,
pub json_schema: Option<String>,
pub stop_token_ids: Option<Vec<u32>>,
#[serde(default)]
pub no_stop_trim: bool,
#[serde(default)]
pub ignore_eos: bool,
#[serde(default = "default_true")]
pub skip_special_tokens: bool,
pub lora_path: Option<String>,
pub session_params: Option<HashMap<String, Value>>,
#[serde(default)]
pub return_hidden_states: bool,
pub sampling_seed: Option<u64>,
#[serde(flatten)]
pub other: Map<String, Value>,
}
impl Normalizable for CompletionRequest {}
fn validate_completion_prompt(prompt: &StringOrArray) -> Result<(), validator::ValidationError> {
match prompt {
StringOrArray::String(_) => {}
StringOrArray::Array(arr) => {
if arr.is_empty() {
let mut error = validator::ValidationError::new("prompt_empty");
error.message = Some("prompt array cannot be empty".into());
return Err(error);
}
}
}
Ok(())
}
fn validate_completion_cross_parameters(
req: &CompletionRequest,
) -> Result<(), validator::ValidationError> {
if req.stream_options.is_some() && !req.stream {
let mut error = validator::ValidationError::new("stream_options_requires_stream");
error.message =
Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into());
return Err(error);
}
if let (Some(min), Some(max)) = (req.min_tokens, req.max_tokens) {
if min > max {
let mut error = validator::ValidationError::new("min_tokens_exceeds_max");
error.message = Some("min_tokens cannot exceed max_tokens".into());
return Err(error);
}
}
let constraint_count =
req.regex.is_some() as u8 + req.ebnf.is_some() as u8 + req.json_schema.is_some() as u8;
if constraint_count > 1 {
let mut error = validator::ValidationError::new("multiple_constraints");
error.message = Some(
"only one structured output constraint (regex, ebnf, or json_schema) can be active at a time"
.into(),
);
return Err(error);
}
if let (Some(best_of), Some(n)) = (req.best_of, req.n) {
if best_of <= n {
let mut error = validator::ValidationError::new("best_of_less_than_n");
error.message = Some("best_of must be greater than n".into());
return Err(error);
}
}
if req.stream && req.best_of.is_some() {
let mut error = validator::ValidationError::new("best_of_not_supported_with_stream");
error.message = Some("best_of is not supported when stream is enabled".into());
return Err(error);
}
Ok(())
}
impl GenerationRequest for CompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
match &self.prompt {
StringOrArray::String(s) => s.clone(),
StringOrArray::Array(v) => v.join(" "),
}
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct CompletionResponse {
pub id: String,
pub object: String, pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Option<Usage>,
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct CompletionChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<Value>, }
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct CompletionStreamResponse {
pub id: String,
pub object: String, pub created: u64,
pub choices: Vec<CompletionStreamChoice>,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct CompletionStreamChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>,
}