openai_protocol/
generate.rs1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::{
8 common::{default_true, GenerationRequest, InputIds},
9 sampling_params::SamplingParams,
10};
11use crate::validated::Normalizable;
12
13#[derive(Clone, Debug, Serialize, Deserialize, Validate)]
18#[validate(schema(function = "validate_generate_request"))]
19pub struct GenerateRequest {
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub text: Option<String>,
23
24 pub model: Option<String>,
25
26 #[serde(skip_serializing_if = "Option::is_none")]
28 pub input_ids: Option<InputIds>,
29
30 #[serde(skip_serializing_if = "Option::is_none")]
34 pub input_embeds: Option<Value>,
35
36 #[serde(skip_serializing_if = "Option::is_none")]
41 pub image_data: Option<Value>,
42
43 #[serde(skip_serializing_if = "Option::is_none")]
48 pub video_data: Option<Value>,
49
50 #[serde(skip_serializing_if = "Option::is_none")]
55 pub audio_data: Option<Value>,
56
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub sampling_params: Option<SamplingParams>,
60
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub return_logprob: Option<bool>,
64
65 #[serde(skip_serializing_if = "Option::is_none")]
67 pub logprob_start_len: Option<i32>,
68
69 #[serde(skip_serializing_if = "Option::is_none")]
71 pub top_logprobs_num: Option<i32>,
72
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub token_ids_logprob: Option<Vec<u32>>,
76
77 #[serde(default)]
79 pub return_text_in_logprobs: bool,
80
81 #[serde(default)]
83 pub stream: bool,
84
85 #[serde(default = "default_true")]
87 pub log_metrics: bool,
88
89 #[serde(default)]
91 pub return_hidden_states: bool,
92
93 #[serde(skip_serializing_if = "Option::is_none")]
95 pub modalities: Option<Vec<String>>,
96
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub session_params: Option<HashMap<String, Value>>,
100
101 #[serde(skip_serializing_if = "Option::is_none")]
103 pub lora_path: Option<String>,
104
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub lora_id: Option<String>,
108
109 #[serde(skip_serializing_if = "Option::is_none")]
113 pub custom_logit_processor: Option<String>,
114
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub bootstrap_host: Option<String>,
118
119 #[serde(skip_serializing_if = "Option::is_none")]
121 pub bootstrap_port: Option<i32>,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub bootstrap_room: Option<i32>,
126
127 #[serde(skip_serializing_if = "Option::is_none")]
129 pub bootstrap_pair_key: Option<String>,
130
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub data_parallel_rank: Option<i32>,
134
135 #[serde(default)]
137 pub background: bool,
138
139 #[serde(skip_serializing_if = "Option::is_none")]
141 pub conversation_id: Option<String>,
142
143 #[serde(skip_serializing_if = "Option::is_none")]
145 pub priority: Option<i32>,
146
147 #[serde(skip_serializing_if = "Option::is_none")]
149 pub extra_key: Option<String>,
150
151 #[serde(default)]
153 pub no_logs: bool,
154
155 #[serde(skip_serializing_if = "Option::is_none")]
157 pub custom_labels: Option<HashMap<String, String>>,
158
159 #[serde(default)]
161 pub return_bytes: bool,
162
163 #[serde(default)]
165 pub return_entropy: bool,
166
167 #[serde(skip_serializing_if = "Option::is_none")]
169 pub rid: Option<String>,
170}
171
172impl Normalizable for GenerateRequest {
173 }
175
176fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> {
178 let has_text = req.text.is_some();
181 let has_input_ids = req.input_ids.is_some();
182
183 let count = [has_text, has_input_ids].iter().filter(|&&x| x).count();
184
185 if count == 0 {
186 return Err(validator::ValidationError::new(
187 "Either text or input_ids should be provided.",
188 ));
189 }
190
191 if count > 1 {
192 return Err(validator::ValidationError::new(
193 "Either text or input_ids should be provided.",
194 ));
195 }
196
197 Ok(())
198}
199
200impl GenerationRequest for GenerateRequest {
201 fn is_stream(&self) -> bool {
202 self.stream
203 }
204
205 fn get_model(&self) -> Option<&str> {
206 if let Some(s) = &self.model {
208 Some(s.as_str())
209 } else {
210 None
211 }
212 }
213
214 fn extract_text_for_routing(&self) -> String {
215 if let Some(ref text) = self.text {
217 return text.clone();
218 }
219
220 if let Some(ref input_ids) = self.input_ids {
221 return match input_ids {
222 InputIds::Single(ids) => ids
223 .iter()
224 .map(|&id| id.to_string())
225 .collect::<Vec<String>>()
226 .join(" "),
227 InputIds::Batch(batches) => batches
228 .iter()
229 .flat_map(|batch| batch.iter().map(|&id| id.to_string()))
230 .collect::<Vec<String>>()
231 .join(" "),
232 };
233 }
234
235 String::new()
237 }
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct GenerateResponse {
264 pub text: String,
265 pub output_ids: Vec<u32>,
266 pub meta_info: GenerateMetaInfo,
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct GenerateMetaInfo {
272 pub id: String,
273 pub finish_reason: GenerateFinishReason,
274 pub prompt_tokens: u32,
275 pub weight_version: String,
276 #[serde(skip_serializing_if = "Option::is_none")]
277 pub input_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
278 #[serde(skip_serializing_if = "Option::is_none")]
279 pub output_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
280 pub completion_tokens: u32,
281 pub cached_tokens: u32,
282 pub e2e_latency: f64,
283 #[serde(skip_serializing_if = "Option::is_none")]
284 pub matched_stop: Option<Value>,
285}
286
287#[derive(Debug, Clone, Serialize, Deserialize)]
289#[serde(tag = "type", rename_all = "lowercase")]
290pub enum GenerateFinishReason {
291 Length {
292 length: u32,
293 },
294 Stop,
295 #[serde(untagged)]
296 Other(Value),
297}