lance_context_core/
serde.rs1use arrow_array::RecordBatch;
2use arrow_ipc::writer::StreamWriter;
3use arrow_schema::ArrowError;
4use serde::{Deserialize, Serialize};
5
6pub const CONTENT_TYPE_TEXT: &str = "text/plain";
7pub const CONTENT_TYPE_ARROW_STREAM: &str = "application/vnd.apache.arrow.stream";
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub struct SerializedContent {
11 pub content_type: String,
12 pub text_payload: Option<String>,
13 pub binary_payload: Option<Vec<u8>>,
14}
15
16impl SerializedContent {
17 pub fn text(value: impl Into<String>, content_type: Option<&str>) -> Self {
18 Self {
19 content_type: content_type.unwrap_or(CONTENT_TYPE_TEXT).to_string(),
20 text_payload: Some(value.into()),
21 binary_payload: None,
22 }
23 }
24
25 pub fn image(bytes: impl Into<Vec<u8>>, mime: impl Into<String>) -> Self {
26 Self {
27 content_type: mime.into(),
28 text_payload: None,
29 binary_payload: Some(bytes.into()),
30 }
31 }
32
33 pub fn dataframe_batches(batches: &[RecordBatch]) -> Result<Self, ArrowError> {
34 let ipc_bytes = record_batches_to_ipc(batches)?;
35 Ok(Self::dataframe_ipc_bytes(ipc_bytes))
36 }
37
38 pub fn dataframe_ipc_bytes(bytes: impl Into<Vec<u8>>) -> Self {
39 Self {
40 content_type: CONTENT_TYPE_ARROW_STREAM.to_string(),
41 text_payload: None,
42 binary_payload: Some(bytes.into()),
43 }
44 }
45}
46
47pub fn serialize_image(bytes: impl Into<Vec<u8>>, mime: impl Into<String>) -> SerializedContent {
48 SerializedContent::image(bytes, mime)
49}
50
51pub fn serialize_dataframe(batches: &[RecordBatch]) -> Result<SerializedContent, ArrowError> {
52 SerializedContent::dataframe_batches(batches)
53}
54
55pub fn serialize_dataframe_ipc(bytes: impl Into<Vec<u8>>) -> SerializedContent {
56 SerializedContent::dataframe_ipc_bytes(bytes)
57}
58
59fn record_batches_to_ipc(batches: &[RecordBatch]) -> Result<Vec<u8>, ArrowError> {
60 if batches.is_empty() {
61 return Err(ArrowError::InvalidArgumentError(
62 "no record batches provided".to_string(),
63 ));
64 }
65
66 let schema = batches[0].schema();
67 let mut buffer = Vec::new();
68 {
69 let mut writer = StreamWriter::try_new(&mut buffer, &schema)?;
70 for batch in batches {
71 if batch.schema() != schema {
72 return Err(ArrowError::SchemaError(
73 "record batch schema mismatch".to_string(),
74 ));
75 }
76 writer.write(batch)?;
77 }
78 writer.finish()?;
79 }
80 Ok(buffer)
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use arrow_array::{Int32Array, RecordBatch, StringArray};
87 use arrow_ipc::reader::StreamReader;
88 use arrow_schema::{DataType, Field, Schema};
89 use std::io::Cursor;
90 use std::sync::Arc;
91
92 fn make_batch() -> RecordBatch {
93 let schema = Arc::new(Schema::new(vec![
94 Field::new("id", DataType::Int32, false),
95 Field::new("name", DataType::Utf8, false),
96 ]));
97 let id_array = Arc::new(Int32Array::from(vec![1, 2]));
98 let name_array = Arc::new(StringArray::from(vec!["alpha", "beta"]));
99 RecordBatch::try_new(schema, vec![id_array, name_array]).unwrap()
100 }
101
102 #[test]
103 fn image_serialization_sets_payloads() {
104 let content = serialize_image(vec![1, 2, 3], "image/png");
105 assert_eq!(content.content_type, "image/png");
106 assert_eq!(content.text_payload, None);
107 assert_eq!(content.binary_payload, Some(vec![1, 2, 3]));
108 }
109
110 #[test]
111 fn dataframe_serialization_writes_ipc_stream() {
112 let batch = make_batch();
113 let content = serialize_dataframe(std::slice::from_ref(&batch)).unwrap();
114 assert_eq!(content.content_type, CONTENT_TYPE_ARROW_STREAM);
115 let bytes = content.binary_payload.expect("expected IPC payload");
116
117 let reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap();
118 let batches: Vec<RecordBatch> = reader.map(|item| item.unwrap()).collect();
119 assert_eq!(batches.len(), 1);
120 assert_eq!(batches[0].schema(), batch.schema());
121 assert_eq!(batches[0].num_rows(), batch.num_rows());
122 }
123
124 #[test]
125 fn dataframe_serialization_rejects_empty_batches() {
126 let err = serialize_dataframe(&[]).unwrap_err();
127 assert!(matches!(err, ArrowError::InvalidArgumentError(_)));
128 }
129
130 #[test]
131 fn dataframe_serialization_rejects_mismatched_schema() {
132 let batch = make_batch();
133 let other_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
134 let other_batch =
135 RecordBatch::try_new(other_schema, vec![Arc::new(Int32Array::from(vec![1, 2]))])
136 .unwrap();
137
138 let err = serialize_dataframe(&[batch, other_batch]).unwrap_err();
139 assert!(matches!(err, ArrowError::SchemaError(_)));
140 }
141}