1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::{
8 common::{default_true, deserialize_null_as_false, GenerationRequest, InputIds},
9 sampling_params::SamplingParams,
10};
11use crate::validated::Normalizable;
12
13#[derive(Clone, Debug, Serialize, Deserialize, Validate, schemars::JsonSchema)]
18#[validate(schema(function = "validate_generate_request"))]
19pub struct GenerateRequest {
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub text: Option<String>,
23
24 #[serde(default = "super::common::default_unknown_model")]
25 pub model: String,
26
27 #[serde(skip_serializing_if = "Option::is_none")]
29 pub input_ids: Option<InputIds>,
30
31 #[serde(skip_serializing_if = "Option::is_none")]
35 pub input_embeds: Option<Value>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
42 pub image_data: Option<Value>,
43
44 #[serde(skip_serializing_if = "Option::is_none")]
49 pub video_data: Option<Value>,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
56 pub audio_data: Option<Value>,
57
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub sampling_params: Option<SamplingParams>,
61
62 #[serde(skip_serializing_if = "Option::is_none")]
64 pub return_logprob: Option<bool>,
65
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub logprob_start_len: Option<i32>,
69
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub top_logprobs_num: Option<i32>,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub token_ids_logprob: Option<Vec<u32>>,
77
78 #[serde(default)]
80 pub return_text_in_logprobs: bool,
81
82 #[serde(default, deserialize_with = "deserialize_null_as_false")]
84 pub stream: bool,
85
86 #[serde(default = "default_true")]
88 pub log_metrics: bool,
89
90 #[serde(default)]
92 pub return_hidden_states: bool,
93
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub modalities: Option<Vec<String>>,
97
98 #[serde(skip_serializing_if = "Option::is_none")]
100 pub session_params: Option<HashMap<String, Value>>,
101
102 #[serde(skip_serializing_if = "Option::is_none")]
104 pub lora_path: Option<String>,
105
106 #[serde(skip_serializing_if = "Option::is_none")]
108 pub lora_id: Option<String>,
109
110 #[serde(skip_serializing_if = "Option::is_none")]
114 pub custom_logit_processor: Option<String>,
115
116 #[serde(skip_serializing_if = "Option::is_none")]
118 pub bootstrap_host: Option<String>,
119
120 #[serde(skip_serializing_if = "Option::is_none")]
122 pub bootstrap_port: Option<i32>,
123
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub bootstrap_room: Option<i32>,
127
128 #[serde(skip_serializing_if = "Option::is_none")]
130 pub bootstrap_pair_key: Option<String>,
131
132 #[serde(skip_serializing_if = "Option::is_none")]
134 pub data_parallel_rank: Option<i32>,
135
136 #[serde(default)]
138 pub background: bool,
139
140 #[serde(skip_serializing_if = "Option::is_none")]
142 pub conversation_id: Option<String>,
143
144 #[serde(skip_serializing_if = "Option::is_none")]
146 pub priority: Option<i32>,
147
148 #[serde(skip_serializing_if = "Option::is_none")]
150 pub extra_key: Option<String>,
151
152 #[serde(default)]
154 pub no_logs: bool,
155
156 #[serde(skip_serializing_if = "Option::is_none")]
158 pub custom_labels: Option<HashMap<String, String>>,
159
160 #[serde(default)]
162 pub return_bytes: bool,
163
164 #[serde(default)]
166 pub return_entropy: bool,
167
168 #[serde(skip_serializing_if = "Option::is_none")]
170 pub rid: Option<String>,
171}
172
173impl Normalizable for GenerateRequest {
174 }
176
177fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> {
179 let has_text = req.text.is_some();
182 let has_input_ids = req.input_ids.is_some();
183
184 let count = [has_text, has_input_ids].iter().filter(|&&x| x).count();
185
186 if count == 0 {
187 return Err(validator::ValidationError::new(
188 "Either text or input_ids should be provided.",
189 ));
190 }
191
192 if count > 1 {
193 return Err(validator::ValidationError::new(
194 "Either text or input_ids should be provided.",
195 ));
196 }
197
198 Ok(())
199}
200
201impl GenerationRequest for GenerateRequest {
202 fn is_stream(&self) -> bool {
203 self.stream
204 }
205
206 fn get_model(&self) -> Option<&str> {
207 Some(self.model.as_str())
208 }
209
210 fn extract_text_for_routing(&self) -> String {
211 if let Some(ref text) = self.text {
213 return text.clone();
214 }
215
216 if let Some(ref input_ids) = self.input_ids {
217 return match input_ids {
218 InputIds::Single(ids) => ids
219 .iter()
220 .map(|&id| id.to_string())
221 .collect::<Vec<String>>()
222 .join(" "),
223 InputIds::Batch(batches) => batches
224 .iter()
225 .flat_map(|batch| batch.iter().map(|&id| id.to_string()))
226 .collect::<Vec<String>>()
227 .join(" "),
228 };
229 }
230
231 String::new()
233 }
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
259pub struct GenerateResponse {
260 pub text: String,
261 pub output_ids: Vec<u32>,
262 pub meta_info: GenerateMetaInfo,
263}
264
265#[serde_with::skip_serializing_none]
267#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
268pub struct GenerateMetaInfo {
269 pub id: String,
270 pub finish_reason: GenerateFinishReason,
271 pub prompt_tokens: u32,
272 pub weight_version: String,
273 pub input_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
274 pub output_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
275 pub completion_tokens: u32,
276 pub cached_tokens: u32,
277 pub e2e_latency: f64,
278 pub matched_stop: Option<Value>,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
283#[serde(untagged)]
284pub enum GenerateFinishReason {
285 Length {
286 #[serde(rename = "type")]
287 finish_type: GenerateFinishType,
288 length: u32,
289 },
290 Stop {
291 #[serde(rename = "type")]
292 finish_type: GenerateFinishType,
293 },
294 Other(Value),
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
298#[serde(rename_all = "lowercase")]
299pub enum GenerateFinishType {
300 Length,
301 Stop,
302}