Skip to main content

orchard/ipc/
serialization.rs

1//! Binary serialization for PIE IPC protocol.
2//!
3//! Wire format: [4 bytes: metadata length][JSON metadata][16-byte aligned binary blobs]
4
5use crate::defaults;
6use crate::error::{Error, Result};
7use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize};
8use serde_json::{json, Value};
9
10/// A single prompt payload for batched requests.
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12pub struct PromptPayload {
13    pub prompt: String,
14    #[serde(default)]
15    pub image_buffers: Vec<Vec<u8>>,
16    #[serde(default)]
17    pub capabilities: Vec<CapabilityEntry>,
18    #[serde(default)]
19    pub layout: Vec<LayoutEntry>,
20    #[serde(default)]
21    pub max_generated_tokens: i32,
22    #[serde(default = "defaults::temperature")]
23    pub temperature: f64,
24    #[serde(default = "defaults::top_p")]
25    pub top_p: f64,
26    #[serde(default = "defaults::top_k")]
27    pub top_k: i32,
28    #[serde(default)]
29    pub min_p: f64,
30    #[serde(default)]
31    pub rng_seed: u64,
32    #[serde(default)]
33    pub stop_sequences: Vec<String>,
34    #[serde(default = "defaults::num_candidates")]
35    pub num_candidates: i32,
36    #[serde(default)]
37    pub best_of: Option<i32>,
38    #[serde(default)]
39    pub final_candidates: Option<i32>,
40    #[serde(default)]
41    pub frequency_penalty: f64,
42    #[serde(default)]
43    pub presence_penalty: f64,
44    #[serde(default = "defaults::repetition_penalty")]
45    pub repetition_penalty: f64,
46    #[serde(default = "defaults::repetition_context_size")]
47    pub repetition_context_size: i32,
48    #[serde(default)]
49    pub top_logprobs: i32,
50    #[serde(default)]
51    pub logit_bias: std::collections::HashMap<i32, f64>,
52    #[serde(default)]
53    pub tool_schemas_json: String,
54    #[serde(default)]
55    pub tool_calling_tokens: ToolCallingTokens,
56    #[serde(default = "default_tool_choice")]
57    pub tool_choice: String,
58    #[serde(default)]
59    pub max_tool_calls: i32,
60    #[serde(default)]
61    pub response_format_json: String,
62    #[serde(default)]
63    pub task_name: Option<String>,
64    #[serde(default)]
65    pub reasoning_effort: Option<String>,
66}
67
68#[derive(Debug, Clone, Default, Serialize, Deserialize)]
69pub struct ToolCallFormat {
70    #[serde(default)]
71    pub name: String,
72    #[serde(default)]
73    pub call_start: String,
74    #[serde(default)]
75    pub call_end: String,
76}
77
78#[derive(Debug, Clone, Default, Serialize, Deserialize)]
79pub struct ToolCallingTokens {
80    #[serde(default)]
81    pub formats: Vec<ToolCallFormat>,
82    #[serde(default)]
83    pub section_start: String,
84    #[serde(default)]
85    pub section_end: String,
86}
87
88fn default_tool_choice() -> String {
89    "auto".to_string()
90}
91
92/// Capability entry for multimodal content.
93#[derive(Debug, Clone, Default, Serialize, Deserialize)]
94pub struct CapabilityEntry {
95    pub name: String,
96    pub position: usize,
97    pub payload: Vec<u8>,
98}
99
100/// Layout entry describing content ordering.
101#[derive(Debug, Clone, Default, Serialize, Deserialize)]
102pub struct LayoutEntry {
103    #[serde(rename = "type")]
104    pub segment_type: String,
105    pub length: usize,
106}
107
108/// Payload alignment boundary (16 bytes)
109const PAYLOAD_ALIGNMENT: usize = 16;
110
111/// Layout segment size: 1 byte type + 7 padding + 8 bytes length
112const LAYOUT_SEGMENT_SIZE: usize = 16;
113
114/// Segment types matching C++ SerializedSegmentType
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116#[repr(u8)]
117pub enum SegmentType {
118    Text = 0,
119    Image = 1,
120    Capability = 2,
121}
122
123/// Request type codes matching PIE
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125#[repr(i32)]
126pub enum RequestType {
127    Generation = 0,
128    Embedding = 1,
129    Query = 2,
130    Point = 3,
131    Detect = 4,
132    Agent = 5,
133    Omni = 6,
134}
135
136/// Align offset to payload alignment boundary
137fn align(offset: usize) -> usize {
138    let remainder = offset % PAYLOAD_ALIGNMENT;
139    if remainder == 0 {
140        offset
141    } else {
142        offset + (PAYLOAD_ALIGNMENT - remainder)
143    }
144}
145
146/// Serialize JSON to compact format (no spaces after separators, matching Python).
147fn serialize_json_compact(value: &Value) -> Result<Vec<u8>> {
148    let mut buffer = Vec::new();
149    let formatter = serde_json::ser::CompactFormatter;
150    let mut serializer = serde_json::Serializer::with_formatter(&mut buffer, formatter);
151    SerializeTrait::serialize(value, &mut serializer)?;
152    Ok(buffer)
153}
154
155/// Encode layout segments.
156fn encode_layout(segments: &[(SegmentType, usize)]) -> Vec<u8> {
157    let mut buffer = Vec::with_capacity(segments.len() * LAYOUT_SEGMENT_SIZE);
158
159    for (segment_type, length) in segments {
160        // 1 byte type
161        buffer.push(*segment_type as u8);
162        // 7 bytes padding
163        buffer.extend_from_slice(&[0u8; 7]);
164        // 8 bytes length (little-endian)
165        buffer.extend_from_slice(&(*length as u64).to_le_bytes());
166    }
167
168    buffer
169}
170
171/// Parse a response delta from JSON.
172pub fn parse_response_delta(data: &[u8]) -> Result<Value> {
173    serde_json::from_slice(data).map_err(Error::from)
174}
175
176/// Build a batched request payload with multiple prompts.
177///
178/// This is the correct implementation that sends all prompts in ONE IPC message,
179/// allowing the engine to schedule them together efficiently.
180#[allow(clippy::too_many_arguments)]
181pub fn build_batch_request_payload(
182    request_id: u64,
183    model_id: &str,
184    model_path: &str,
185    request_type: RequestType,
186    response_channel_id: u64,
187    prompts: &[PromptPayload],
188) -> Result<Vec<u8>> {
189    if prompts.is_empty() {
190        return Err(Error::Serialization(
191            "At least one prompt is required".to_string(),
192        ));
193    }
194
195    // Track blob fragments: (offset, data)
196    let mut blob_fragments: Vec<(usize, Vec<u8>)> = Vec::new();
197    let mut total_size = 0usize;
198
199    // Reserve blob space with alignment
200    let mut reserve_blob = |data: Vec<u8>| -> (usize, usize) {
201        if data.is_empty() {
202            return (0, 0);
203        }
204        total_size = align(total_size);
205        let offset = total_size;
206        let size = data.len();
207        blob_fragments.push((offset, data));
208        total_size += size;
209        (offset, size)
210    };
211
212    // Build metadata for each prompt
213    let mut prompt_metadata_list: Vec<Value> = Vec::with_capacity(prompts.len());
214
215    for (index, prompt) in prompts.iter().enumerate() {
216        let text_bytes = prompt.prompt.as_bytes().to_vec();
217
218        // Encode image buffers
219        let (image_span_bytes, image_count, image_data_bytes) =
220            encode_image_buffers(&prompt.image_buffers);
221
222        // Encode capabilities
223        let (capability_metadata, capability_data_bytes) =
224            encode_capabilities(&prompt.capabilities);
225
226        // Build layout
227        let layout_data = if prompt.layout.is_empty() {
228            // Default layout: text followed by images
229            let mut segments = vec![(SegmentType::Text, text_bytes.len())];
230            for img in &prompt.image_buffers {
231                segments.push((SegmentType::Image, img.len()));
232            }
233            encode_layout(&segments)
234        } else {
235            let segments: Vec<(SegmentType, usize)> = prompt
236                .layout
237                .iter()
238                .map(|e| {
239                    let seg_type = match e.segment_type.as_str() {
240                        "image" => SegmentType::Image,
241                        "capability" => SegmentType::Capability,
242                        _ => SegmentType::Text,
243                    };
244                    (seg_type, e.length)
245                })
246                .collect();
247            encode_layout(&segments)
248        };
249        let layout_count = if prompt.layout.is_empty() {
250            1 + prompt.image_buffers.len()
251        } else {
252            prompt.layout.len()
253        };
254
255        // Validate layout if explicitly provided
256        if !prompt.layout.is_empty() {
257            let total_image_size: usize = prompt.image_buffers.iter().map(|b| b.len()).sum();
258            validate_layout(text_bytes.len(), total_image_size, &prompt.layout, index)?;
259        }
260
261        // Reserve space for all blob data
262        let (text_offset, text_size) = reserve_blob(text_bytes);
263        let (image_sizes_offset, _) = reserve_blob(image_span_bytes);
264        let (image_data_offset, image_data_size) = reserve_blob(image_data_bytes);
265        let (capability_data_offset, capability_data_size) = reserve_blob(capability_data_bytes);
266        let (layout_offset, _) = reserve_blob(layout_data);
267
268        // Compute best_of and final_candidates with proper defaults
269        let best_of = prompt.best_of.unwrap_or(prompt.num_candidates.max(1));
270        let final_candidates = prompt.final_candidates.unwrap_or(best_of);
271
272        // Convert logit_bias HashMap to array of {token, bias} objects (matching Python)
273        let logit_bias: Vec<Value> = prompt
274            .logit_bias
275            .iter()
276            .map(|(&k, &v)| json!({"token": k, "bias": v}))
277            .collect();
278
279        let prompt_meta = json!({
280            "prompt_index": index,
281            "num_candidates": prompt.num_candidates.max(1),
282            "best_of": best_of,
283            "final_candidates": final_candidates,
284            "max_generated_tokens": prompt.max_generated_tokens,
285            "text_offset": text_offset,
286            "text_size": text_size,
287            "image_data_offset": image_data_offset,
288            "image_data_size": image_data_size,
289            "image_sizes_offset": image_sizes_offset,
290            "image_count": image_count,
291            "capability_data_offset": capability_data_offset,
292            "capability_data_size": capability_data_size,
293            "capabilities": capability_metadata,
294            "layout_offset": layout_offset,
295            "layout_count": layout_count,
296            "temperature": prompt.temperature,
297            "top_p": prompt.top_p,
298            "top_k": prompt.top_k,
299            "min_p": prompt.min_p,
300            "rng_seed": (prompt.rng_seed & 0xFFFFFFFF) as u32,
301            "top_logprobs": prompt.top_logprobs,
302            "frequency_penalty": prompt.frequency_penalty,
303            "presence_penalty": prompt.presence_penalty,
304            "repetition_context_size": prompt.repetition_context_size,
305            "repetition_penalty": prompt.repetition_penalty,
306            "stop_sequences": prompt.stop_sequences,
307            "tool_schemas_json": prompt.tool_schemas_json,
308            "tool_calling_tokens": {
309                "formats": &prompt.tool_calling_tokens.formats,
310                "section_start": prompt.tool_calling_tokens.section_start,
311                "section_end": prompt.tool_calling_tokens.section_end,
312            },
313            "tool_choice": prompt.tool_choice,
314            "max_tool_calls": prompt.max_tool_calls,
315            "response_format_json": prompt.response_format_json,
316            "logit_bias": logit_bias,
317            "task_name": prompt.task_name,
318            "reasoning_effort": prompt.reasoning_effort,
319        });
320
321        prompt_metadata_list.push(prompt_meta);
322    }
323
324    // Build full metadata
325    let metadata = json!({
326        "request_id": request_id,
327        "model_id": model_id,
328        "model_path": model_path,
329        "request_type": request_type as i32,
330        "request_channel_id": 0,
331        "response_channel_id": response_channel_id,
332        "prompts": prompt_metadata_list,
333    });
334
335    // Use compact JSON serialization (no spaces, matching Python)
336    let metadata_bytes = serialize_json_compact(&metadata)?;
337
338    if metadata_bytes.len() > u32::MAX as usize {
339        return Err(Error::Serialization(
340            "Metadata exceeds 4-byte length prefix capacity".to_string(),
341        ));
342    }
343
344    // Build payload buffer
345    let mut payload = vec![0u8; total_size];
346    for (offset, data) in blob_fragments {
347        payload[offset..offset + data.len()].copy_from_slice(&data);
348    }
349
350    // Build frame: [4 bytes length][metadata][payload]
351    let mut frame = Vec::with_capacity(4 + metadata_bytes.len() + payload.len());
352    let length = metadata_bytes.len() as u32;
353    frame.extend_from_slice(&length.to_le_bytes());
354    frame.extend_from_slice(&metadata_bytes);
355    frame.extend_from_slice(&payload);
356
357    Ok(frame)
358}
359
360/// Validate that layout bytes match actual content sizes.
361fn validate_layout(
362    text_size: usize,
363    image_data_size: usize,
364    layout: &[LayoutEntry],
365    prompt_index: usize,
366) -> Result<()> {
367    let mut layout_text_bytes = 0usize;
368    let mut layout_image_bytes = 0usize;
369
370    for entry in layout {
371        match entry.segment_type.as_str() {
372            "text" => layout_text_bytes += entry.length,
373            "image" => layout_image_bytes += entry.length,
374            _ => {} // capabilities are handled separately
375        }
376    }
377
378    if layout_text_bytes != text_size {
379        return Err(Error::Serialization(format!(
380            "Prompt {}: Layout text bytes ({}) != actual text size ({})",
381            prompt_index, layout_text_bytes, text_size
382        )));
383    }
384
385    if layout_image_bytes != image_data_size {
386        return Err(Error::Serialization(format!(
387            "Prompt {}: Layout image bytes ({}) != actual image size ({})",
388            prompt_index, layout_image_bytes, image_data_size
389        )));
390    }
391
392    Ok(())
393}
394
395/// Encode image buffers into span array and concatenated data.
396fn encode_image_buffers(buffers: &[Vec<u8>]) -> (Vec<u8>, usize, Vec<u8>) {
397    if buffers.is_empty() {
398        return (Vec::new(), 0, Vec::new());
399    }
400
401    // Span array: 8 bytes per image (length as u64 LE)
402    let mut span_buffer = Vec::with_capacity(buffers.len() * 8);
403    let total_data_size: usize = buffers.iter().map(|b| b.len()).sum();
404    let mut data_buffer = Vec::with_capacity(total_data_size);
405
406    for buffer in buffers {
407        span_buffer.extend_from_slice(&(buffer.len() as u64).to_le_bytes());
408        data_buffer.extend_from_slice(buffer);
409    }
410
411    (span_buffer, buffers.len(), data_buffer)
412}
413
414/// Encode capability entries.
415fn encode_capabilities(capabilities: &[CapabilityEntry]) -> (Vec<Value>, Vec<u8>) {
416    if capabilities.is_empty() {
417        return (Vec::new(), Vec::new());
418    }
419
420    let mut metadata_list = Vec::with_capacity(capabilities.len());
421    let mut data_buffer = Vec::new();
422
423    for cap in capabilities {
424        metadata_list.push(json!({
425            "name": cap.name,
426            "position": cap.position,
427            "payload_size": cap.payload.len(),
428        }));
429        data_buffer.extend_from_slice(&cap.payload);
430    }
431
432    (metadata_list, data_buffer)
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_align() {
441        assert_eq!(align(0), 0);
442        assert_eq!(align(1), 16);
443        assert_eq!(align(16), 16);
444        assert_eq!(align(17), 32);
445    }
446
447    #[test]
448    fn test_build_batch_request_payload() {
449        let prompt = PromptPayload {
450            prompt: "Hello, world!".to_string(),
451            max_generated_tokens: 100,
452            temperature: 0.7,
453            top_p: 0.9,
454            ..Default::default()
455        };
456
457        let payload = build_batch_request_payload(
458            1,
459            "test-model",
460            "/path/to/model",
461            RequestType::Generation,
462            12345,
463            &[prompt],
464        )
465        .unwrap();
466
467        // Should have 4-byte length prefix
468        assert!(payload.len() > 4);
469
470        // Read length prefix
471        let length = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]) as usize;
472
473        // Metadata should be valid JSON
474        let metadata: Value = serde_json::from_slice(&payload[4..4 + length]).unwrap();
475        assert_eq!(metadata["request_id"], 1);
476        assert_eq!(metadata["model_id"], "test-model");
477        assert_eq!(metadata["prompts"].as_array().unwrap().len(), 1);
478    }
479
480    #[test]
481    fn test_validate_layout_success() {
482        let layout = vec![
483            LayoutEntry {
484                segment_type: "text".to_string(),
485                length: 100,
486            },
487            LayoutEntry {
488                segment_type: "image".to_string(),
489                length: 5000,
490            },
491        ];
492
493        let result = validate_layout(100, 5000, &layout, 0);
494        assert!(result.is_ok());
495    }
496
497    #[test]
498    fn test_validate_layout_text_mismatch() {
499        let layout = vec![LayoutEntry {
500            segment_type: "text".to_string(),
501            length: 50, // Mismatch!
502        }];
503
504        let result = validate_layout(100, 0, &layout, 0);
505        assert!(result.is_err());
506        assert!(result.unwrap_err().to_string().contains("text bytes"));
507    }
508
509    #[test]
510    fn test_validate_layout_image_mismatch() {
511        let layout = vec![
512            LayoutEntry {
513                segment_type: "text".to_string(),
514                length: 100,
515            },
516            LayoutEntry {
517                segment_type: "image".to_string(),
518                length: 1000, // Mismatch!
519            },
520        ];
521
522        let result = validate_layout(100, 5000, &layout, 0);
523        assert!(result.is_err());
524        assert!(result.unwrap_err().to_string().contains("image"));
525    }
526
527    #[test]
528    fn test_compact_json_serialization() {
529        let value = serde_json::json!({
530            "key": "value",
531            "number": 42
532        });
533
534        let compact = serialize_json_compact(&value).unwrap();
535        let compact_str = String::from_utf8(compact).unwrap();
536
537        // Should not have spaces after colons or commas
538        assert!(!compact_str.contains(": "));
539        assert!(!compact_str.contains(", "));
540        assert!(compact_str.contains(":") && compact_str.contains(","));
541    }
542}