1use crate::defaults;
6use crate::error::{Error, Result};
7use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize};
8use serde_json::{json, Value};
9
10#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12pub struct PromptPayload {
13 pub prompt: String,
14 #[serde(default)]
15 pub image_buffers: Vec<Vec<u8>>,
16 #[serde(default)]
17 pub capabilities: Vec<CapabilityEntry>,
18 #[serde(default)]
19 pub layout: Vec<LayoutEntry>,
20 #[serde(default)]
21 pub max_generated_tokens: i32,
22 #[serde(default = "defaults::temperature")]
23 pub temperature: f64,
24 #[serde(default = "defaults::top_p")]
25 pub top_p: f64,
26 #[serde(default = "defaults::top_k")]
27 pub top_k: i32,
28 #[serde(default)]
29 pub min_p: f64,
30 #[serde(default)]
31 pub rng_seed: u64,
32 #[serde(default)]
33 pub stop_sequences: Vec<String>,
34 #[serde(default = "defaults::num_candidates")]
35 pub num_candidates: i32,
36 #[serde(default)]
37 pub best_of: Option<i32>,
38 #[serde(default)]
39 pub final_candidates: Option<i32>,
40 #[serde(default)]
41 pub frequency_penalty: f64,
42 #[serde(default)]
43 pub presence_penalty: f64,
44 #[serde(default = "defaults::repetition_penalty")]
45 pub repetition_penalty: f64,
46 #[serde(default = "defaults::repetition_context_size")]
47 pub repetition_context_size: i32,
48 #[serde(default)]
49 pub top_logprobs: i32,
50 #[serde(default)]
51 pub logit_bias: std::collections::HashMap<i32, f64>,
52 #[serde(default)]
53 pub tool_schemas_json: String,
54 #[serde(default)]
55 pub tool_calling_tokens: ToolCallingTokens,
56 #[serde(default = "default_tool_choice")]
57 pub tool_choice: String,
58 #[serde(default)]
59 pub max_tool_calls: i32,
60 #[serde(default)]
61 pub response_format_json: String,
62 #[serde(default)]
63 pub task_name: Option<String>,
64 #[serde(default)]
65 pub reasoning_effort: Option<String>,
66}
67
68#[derive(Debug, Clone, Default, Serialize, Deserialize)]
69pub struct ToolCallFormat {
70 #[serde(default)]
71 pub name: String,
72 #[serde(default)]
73 pub call_start: String,
74 #[serde(default)]
75 pub call_end: String,
76}
77
78#[derive(Debug, Clone, Default, Serialize, Deserialize)]
79pub struct ToolCallingTokens {
80 #[serde(default)]
81 pub formats: Vec<ToolCallFormat>,
82 #[serde(default)]
83 pub section_start: String,
84 #[serde(default)]
85 pub section_end: String,
86}
87
88fn default_tool_choice() -> String {
89 "auto".to_string()
90}
91
92#[derive(Debug, Clone, Default, Serialize, Deserialize)]
94pub struct CapabilityEntry {
95 pub name: String,
96 pub position: usize,
97 pub payload: Vec<u8>,
98}
99
100#[derive(Debug, Clone, Default, Serialize, Deserialize)]
102pub struct LayoutEntry {
103 #[serde(rename = "type")]
104 pub segment_type: String,
105 pub length: usize,
106}
107
108const PAYLOAD_ALIGNMENT: usize = 16;
110
111const LAYOUT_SEGMENT_SIZE: usize = 16;
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116#[repr(u8)]
117pub enum SegmentType {
118 Text = 0,
119 Image = 1,
120 Capability = 2,
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125#[repr(i32)]
126pub enum RequestType {
127 Generation = 0,
128 Embedding = 1,
129 Query = 2,
130 Point = 3,
131 Detect = 4,
132 Agent = 5,
133 Omni = 6,
134}
135
136fn align(offset: usize) -> usize {
138 let remainder = offset % PAYLOAD_ALIGNMENT;
139 if remainder == 0 {
140 offset
141 } else {
142 offset + (PAYLOAD_ALIGNMENT - remainder)
143 }
144}
145
146fn serialize_json_compact(value: &Value) -> Result<Vec<u8>> {
148 let mut buffer = Vec::new();
149 let formatter = serde_json::ser::CompactFormatter;
150 let mut serializer = serde_json::Serializer::with_formatter(&mut buffer, formatter);
151 SerializeTrait::serialize(value, &mut serializer)?;
152 Ok(buffer)
153}
154
155fn encode_layout(segments: &[(SegmentType, usize)]) -> Vec<u8> {
157 let mut buffer = Vec::with_capacity(segments.len() * LAYOUT_SEGMENT_SIZE);
158
159 for (segment_type, length) in segments {
160 buffer.push(*segment_type as u8);
162 buffer.extend_from_slice(&[0u8; 7]);
164 buffer.extend_from_slice(&(*length as u64).to_le_bytes());
166 }
167
168 buffer
169}
170
171pub fn parse_response_delta(data: &[u8]) -> Result<Value> {
173 serde_json::from_slice(data).map_err(Error::from)
174}
175
176#[allow(clippy::too_many_arguments)]
181pub fn build_batch_request_payload(
182 request_id: u64,
183 model_id: &str,
184 model_path: &str,
185 request_type: RequestType,
186 response_channel_id: u64,
187 prompts: &[PromptPayload],
188) -> Result<Vec<u8>> {
189 if prompts.is_empty() {
190 return Err(Error::Serialization(
191 "At least one prompt is required".to_string(),
192 ));
193 }
194
195 let mut blob_fragments: Vec<(usize, Vec<u8>)> = Vec::new();
197 let mut total_size = 0usize;
198
199 let mut reserve_blob = |data: Vec<u8>| -> (usize, usize) {
201 if data.is_empty() {
202 return (0, 0);
203 }
204 total_size = align(total_size);
205 let offset = total_size;
206 let size = data.len();
207 blob_fragments.push((offset, data));
208 total_size += size;
209 (offset, size)
210 };
211
212 let mut prompt_metadata_list: Vec<Value> = Vec::with_capacity(prompts.len());
214
215 for (index, prompt) in prompts.iter().enumerate() {
216 let text_bytes = prompt.prompt.as_bytes().to_vec();
217
218 let (image_span_bytes, image_count, image_data_bytes) =
220 encode_image_buffers(&prompt.image_buffers);
221
222 let (capability_metadata, capability_data_bytes) =
224 encode_capabilities(&prompt.capabilities);
225
226 let layout_data = if prompt.layout.is_empty() {
228 let mut segments = vec![(SegmentType::Text, text_bytes.len())];
230 for img in &prompt.image_buffers {
231 segments.push((SegmentType::Image, img.len()));
232 }
233 encode_layout(&segments)
234 } else {
235 let segments: Vec<(SegmentType, usize)> = prompt
236 .layout
237 .iter()
238 .map(|e| {
239 let seg_type = match e.segment_type.as_str() {
240 "image" => SegmentType::Image,
241 "capability" => SegmentType::Capability,
242 _ => SegmentType::Text,
243 };
244 (seg_type, e.length)
245 })
246 .collect();
247 encode_layout(&segments)
248 };
249 let layout_count = if prompt.layout.is_empty() {
250 1 + prompt.image_buffers.len()
251 } else {
252 prompt.layout.len()
253 };
254
255 if !prompt.layout.is_empty() {
257 let total_image_size: usize = prompt.image_buffers.iter().map(|b| b.len()).sum();
258 validate_layout(text_bytes.len(), total_image_size, &prompt.layout, index)?;
259 }
260
261 let (text_offset, text_size) = reserve_blob(text_bytes);
263 let (image_sizes_offset, _) = reserve_blob(image_span_bytes);
264 let (image_data_offset, image_data_size) = reserve_blob(image_data_bytes);
265 let (capability_data_offset, capability_data_size) = reserve_blob(capability_data_bytes);
266 let (layout_offset, _) = reserve_blob(layout_data);
267
268 let best_of = prompt.best_of.unwrap_or(prompt.num_candidates.max(1));
270 let final_candidates = prompt.final_candidates.unwrap_or(best_of);
271
272 let logit_bias: Vec<Value> = prompt
274 .logit_bias
275 .iter()
276 .map(|(&k, &v)| json!({"token": k, "bias": v}))
277 .collect();
278
279 let prompt_meta = json!({
280 "prompt_index": index,
281 "num_candidates": prompt.num_candidates.max(1),
282 "best_of": best_of,
283 "final_candidates": final_candidates,
284 "max_generated_tokens": prompt.max_generated_tokens,
285 "text_offset": text_offset,
286 "text_size": text_size,
287 "image_data_offset": image_data_offset,
288 "image_data_size": image_data_size,
289 "image_sizes_offset": image_sizes_offset,
290 "image_count": image_count,
291 "capability_data_offset": capability_data_offset,
292 "capability_data_size": capability_data_size,
293 "capabilities": capability_metadata,
294 "layout_offset": layout_offset,
295 "layout_count": layout_count,
296 "temperature": prompt.temperature,
297 "top_p": prompt.top_p,
298 "top_k": prompt.top_k,
299 "min_p": prompt.min_p,
300 "rng_seed": (prompt.rng_seed & 0xFFFFFFFF) as u32,
301 "top_logprobs": prompt.top_logprobs,
302 "frequency_penalty": prompt.frequency_penalty,
303 "presence_penalty": prompt.presence_penalty,
304 "repetition_context_size": prompt.repetition_context_size,
305 "repetition_penalty": prompt.repetition_penalty,
306 "stop_sequences": prompt.stop_sequences,
307 "tool_schemas_json": prompt.tool_schemas_json,
308 "tool_calling_tokens": {
309 "formats": &prompt.tool_calling_tokens.formats,
310 "section_start": prompt.tool_calling_tokens.section_start,
311 "section_end": prompt.tool_calling_tokens.section_end,
312 },
313 "tool_choice": prompt.tool_choice,
314 "max_tool_calls": prompt.max_tool_calls,
315 "response_format_json": prompt.response_format_json,
316 "logit_bias": logit_bias,
317 "task_name": prompt.task_name,
318 "reasoning_effort": prompt.reasoning_effort,
319 });
320
321 prompt_metadata_list.push(prompt_meta);
322 }
323
324 let metadata = json!({
326 "request_id": request_id,
327 "model_id": model_id,
328 "model_path": model_path,
329 "request_type": request_type as i32,
330 "request_channel_id": 0,
331 "response_channel_id": response_channel_id,
332 "prompts": prompt_metadata_list,
333 });
334
335 let metadata_bytes = serialize_json_compact(&metadata)?;
337
338 if metadata_bytes.len() > u32::MAX as usize {
339 return Err(Error::Serialization(
340 "Metadata exceeds 4-byte length prefix capacity".to_string(),
341 ));
342 }
343
344 let mut payload = vec![0u8; total_size];
346 for (offset, data) in blob_fragments {
347 payload[offset..offset + data.len()].copy_from_slice(&data);
348 }
349
350 let mut frame = Vec::with_capacity(4 + metadata_bytes.len() + payload.len());
352 let length = metadata_bytes.len() as u32;
353 frame.extend_from_slice(&length.to_le_bytes());
354 frame.extend_from_slice(&metadata_bytes);
355 frame.extend_from_slice(&payload);
356
357 Ok(frame)
358}
359
360fn validate_layout(
362 text_size: usize,
363 image_data_size: usize,
364 layout: &[LayoutEntry],
365 prompt_index: usize,
366) -> Result<()> {
367 let mut layout_text_bytes = 0usize;
368 let mut layout_image_bytes = 0usize;
369
370 for entry in layout {
371 match entry.segment_type.as_str() {
372 "text" => layout_text_bytes += entry.length,
373 "image" => layout_image_bytes += entry.length,
374 _ => {} }
376 }
377
378 if layout_text_bytes != text_size {
379 return Err(Error::Serialization(format!(
380 "Prompt {}: Layout text bytes ({}) != actual text size ({})",
381 prompt_index, layout_text_bytes, text_size
382 )));
383 }
384
385 if layout_image_bytes != image_data_size {
386 return Err(Error::Serialization(format!(
387 "Prompt {}: Layout image bytes ({}) != actual image size ({})",
388 prompt_index, layout_image_bytes, image_data_size
389 )));
390 }
391
392 Ok(())
393}
394
395fn encode_image_buffers(buffers: &[Vec<u8>]) -> (Vec<u8>, usize, Vec<u8>) {
397 if buffers.is_empty() {
398 return (Vec::new(), 0, Vec::new());
399 }
400
401 let mut span_buffer = Vec::with_capacity(buffers.len() * 8);
403 let total_data_size: usize = buffers.iter().map(|b| b.len()).sum();
404 let mut data_buffer = Vec::with_capacity(total_data_size);
405
406 for buffer in buffers {
407 span_buffer.extend_from_slice(&(buffer.len() as u64).to_le_bytes());
408 data_buffer.extend_from_slice(buffer);
409 }
410
411 (span_buffer, buffers.len(), data_buffer)
412}
413
414fn encode_capabilities(capabilities: &[CapabilityEntry]) -> (Vec<Value>, Vec<u8>) {
416 if capabilities.is_empty() {
417 return (Vec::new(), Vec::new());
418 }
419
420 let mut metadata_list = Vec::with_capacity(capabilities.len());
421 let mut data_buffer = Vec::new();
422
423 for cap in capabilities {
424 metadata_list.push(json!({
425 "name": cap.name,
426 "position": cap.position,
427 "payload_size": cap.payload.len(),
428 }));
429 data_buffer.extend_from_slice(&cap.payload);
430 }
431
432 (metadata_list, data_buffer)
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_align() {
441 assert_eq!(align(0), 0);
442 assert_eq!(align(1), 16);
443 assert_eq!(align(16), 16);
444 assert_eq!(align(17), 32);
445 }
446
447 #[test]
448 fn test_build_batch_request_payload() {
449 let prompt = PromptPayload {
450 prompt: "Hello, world!".to_string(),
451 max_generated_tokens: 100,
452 temperature: 0.7,
453 top_p: 0.9,
454 ..Default::default()
455 };
456
457 let payload = build_batch_request_payload(
458 1,
459 "test-model",
460 "/path/to/model",
461 RequestType::Generation,
462 12345,
463 &[prompt],
464 )
465 .unwrap();
466
467 assert!(payload.len() > 4);
469
470 let length = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]) as usize;
472
473 let metadata: Value = serde_json::from_slice(&payload[4..4 + length]).unwrap();
475 assert_eq!(metadata["request_id"], 1);
476 assert_eq!(metadata["model_id"], "test-model");
477 assert_eq!(metadata["prompts"].as_array().unwrap().len(), 1);
478 }
479
480 #[test]
481 fn test_validate_layout_success() {
482 let layout = vec![
483 LayoutEntry {
484 segment_type: "text".to_string(),
485 length: 100,
486 },
487 LayoutEntry {
488 segment_type: "image".to_string(),
489 length: 5000,
490 },
491 ];
492
493 let result = validate_layout(100, 5000, &layout, 0);
494 assert!(result.is_ok());
495 }
496
497 #[test]
498 fn test_validate_layout_text_mismatch() {
499 let layout = vec![LayoutEntry {
500 segment_type: "text".to_string(),
501 length: 50, }];
503
504 let result = validate_layout(100, 0, &layout, 0);
505 assert!(result.is_err());
506 assert!(result.unwrap_err().to_string().contains("text bytes"));
507 }
508
509 #[test]
510 fn test_validate_layout_image_mismatch() {
511 let layout = vec![
512 LayoutEntry {
513 segment_type: "text".to_string(),
514 length: 100,
515 },
516 LayoutEntry {
517 segment_type: "image".to_string(),
518 length: 1000, },
520 ];
521
522 let result = validate_layout(100, 5000, &layout, 0);
523 assert!(result.is_err());
524 assert!(result.unwrap_err().to_string().contains("image"));
525 }
526
527 #[test]
528 fn test_compact_json_serialization() {
529 let value = serde_json::json!({
530 "key": "value",
531 "number": 42
532 });
533
534 let compact = serialize_json_compact(&value).unwrap();
535 let compact_str = String::from_utf8(compact).unwrap();
536
537 assert!(!compact_str.contains(": "));
539 assert!(!compact_str.contains(", "));
540 assert!(compact_str.contains(":") && compact_str.contains(","));
541 }
542}