use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use validator::Validate;
use super::{
common::{default_true, GenerationRequest, InputIds},
sampling_params::SamplingParams,
};
use crate::validated::Normalizable;
#[derive(Clone, Debug, Serialize, Deserialize, Validate)]
#[validate(schema(function = "validate_generate_request"))]
pub struct GenerateRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_ids: Option<InputIds>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_embeds: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_data: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub video_data: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audio_data: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_params: Option<SamplingParams>,
#[serde(skip_serializing_if = "Option::is_none")]
pub return_logprob: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprob_start_len: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs_num: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_ids_logprob: Option<Vec<u32>>,
#[serde(default)]
pub return_text_in_logprobs: bool,
#[serde(default)]
pub stream: bool,
#[serde(default = "default_true")]
pub log_metrics: bool,
#[serde(default)]
pub return_hidden_states: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub modalities: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub custom_logit_processor: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bootstrap_host: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bootstrap_port: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bootstrap_room: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bootstrap_pair_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data_parallel_rank: Option<i32>,
#[serde(default)]
pub background: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub conversation_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub extra_key: Option<String>,
#[serde(default)]
pub no_logs: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub custom_labels: Option<HashMap<String, String>>,
#[serde(default)]
pub return_bytes: bool,
#[serde(default)]
pub return_entropy: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub rid: Option<String>,
}
impl Normalizable for GenerateRequest {
}
fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> {
let has_text = req.text.is_some();
let has_input_ids = req.input_ids.is_some();
let count = [has_text, has_input_ids].iter().filter(|&&x| x).count();
if count == 0 {
return Err(validator::ValidationError::new(
"Either text or input_ids should be provided.",
));
}
if count > 1 {
return Err(validator::ValidationError::new(
"Either text or input_ids should be provided.",
));
}
Ok(())
}
impl GenerationRequest for GenerateRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
if let Some(s) = &self.model {
Some(s.as_str())
} else {
None
}
}
fn extract_text_for_routing(&self) -> String {
if let Some(ref text) = self.text {
return text.clone();
}
if let Some(ref input_ids) = self.input_ids {
return match input_ids {
InputIds::Single(ids) => ids
.iter()
.map(|&id| id.to_string())
.collect::<Vec<String>>()
.join(" "),
InputIds::Batch(batches) => batches
.iter()
.flat_map(|batch| batch.iter().map(|&id| id.to_string()))
.collect::<Vec<String>>()
.join(" "),
};
}
String::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateResponse {
pub text: String,
pub output_ids: Vec<u32>,
pub meta_info: GenerateMetaInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateMetaInfo {
pub id: String,
pub finish_reason: GenerateFinishReason,
pub prompt_tokens: u32,
pub weight_version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
pub completion_tokens: u32,
pub cached_tokens: u32,
pub e2e_latency: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum GenerateFinishReason {
Length {
length: u32,
},
Stop,
#[serde(untagged)]
Other(Value),
}