Skip to main content

lance_context_core/
serde.rs

1use arrow_array::RecordBatch;
2use arrow_ipc::writer::StreamWriter;
3use arrow_schema::ArrowError;
4use serde::{Deserialize, Serialize};
5
6pub const CONTENT_TYPE_TEXT: &str = "text/plain";
7pub const CONTENT_TYPE_ARROW_STREAM: &str = "application/vnd.apache.arrow.stream";
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub struct SerializedContent {
11    pub content_type: String,
12    pub text_payload: Option<String>,
13    pub binary_payload: Option<Vec<u8>>,
14}
15
16impl SerializedContent {
17    pub fn text(value: impl Into<String>, content_type: Option<&str>) -> Self {
18        Self {
19            content_type: content_type.unwrap_or(CONTENT_TYPE_TEXT).to_string(),
20            text_payload: Some(value.into()),
21            binary_payload: None,
22        }
23    }
24
25    pub fn image(bytes: impl Into<Vec<u8>>, mime: impl Into<String>) -> Self {
26        Self {
27            content_type: mime.into(),
28            text_payload: None,
29            binary_payload: Some(bytes.into()),
30        }
31    }
32
33    pub fn dataframe_batches(batches: &[RecordBatch]) -> Result<Self, ArrowError> {
34        let ipc_bytes = record_batches_to_ipc(batches)?;
35        Ok(Self::dataframe_ipc_bytes(ipc_bytes))
36    }
37
38    pub fn dataframe_ipc_bytes(bytes: impl Into<Vec<u8>>) -> Self {
39        Self {
40            content_type: CONTENT_TYPE_ARROW_STREAM.to_string(),
41            text_payload: None,
42            binary_payload: Some(bytes.into()),
43        }
44    }
45}
46
47pub fn serialize_image(bytes: impl Into<Vec<u8>>, mime: impl Into<String>) -> SerializedContent {
48    SerializedContent::image(bytes, mime)
49}
50
51pub fn serialize_dataframe(batches: &[RecordBatch]) -> Result<SerializedContent, ArrowError> {
52    SerializedContent::dataframe_batches(batches)
53}
54
55pub fn serialize_dataframe_ipc(bytes: impl Into<Vec<u8>>) -> SerializedContent {
56    SerializedContent::dataframe_ipc_bytes(bytes)
57}
58
59fn record_batches_to_ipc(batches: &[RecordBatch]) -> Result<Vec<u8>, ArrowError> {
60    if batches.is_empty() {
61        return Err(ArrowError::InvalidArgumentError(
62            "no record batches provided".to_string(),
63        ));
64    }
65
66    let schema = batches[0].schema();
67    let mut buffer = Vec::new();
68    {
69        let mut writer = StreamWriter::try_new(&mut buffer, &schema)?;
70        for batch in batches {
71            if batch.schema() != schema {
72                return Err(ArrowError::SchemaError(
73                    "record batch schema mismatch".to_string(),
74                ));
75            }
76            writer.write(batch)?;
77        }
78        writer.finish()?;
79    }
80    Ok(buffer)
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86    use arrow_array::{Int32Array, RecordBatch, StringArray};
87    use arrow_ipc::reader::StreamReader;
88    use arrow_schema::{DataType, Field, Schema};
89    use std::io::Cursor;
90    use std::sync::Arc;
91
92    fn make_batch() -> RecordBatch {
93        let schema = Arc::new(Schema::new(vec![
94            Field::new("id", DataType::Int32, false),
95            Field::new("name", DataType::Utf8, false),
96        ]));
97        let id_array = Arc::new(Int32Array::from(vec![1, 2]));
98        let name_array = Arc::new(StringArray::from(vec!["alpha", "beta"]));
99        RecordBatch::try_new(schema, vec![id_array, name_array]).unwrap()
100    }
101
102    #[test]
103    fn image_serialization_sets_payloads() {
104        let content = serialize_image(vec![1, 2, 3], "image/png");
105        assert_eq!(content.content_type, "image/png");
106        assert_eq!(content.text_payload, None);
107        assert_eq!(content.binary_payload, Some(vec![1, 2, 3]));
108    }
109
110    #[test]
111    fn dataframe_serialization_writes_ipc_stream() {
112        let batch = make_batch();
113        let content = serialize_dataframe(std::slice::from_ref(&batch)).unwrap();
114        assert_eq!(content.content_type, CONTENT_TYPE_ARROW_STREAM);
115        let bytes = content.binary_payload.expect("expected IPC payload");
116
117        let reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap();
118        let batches: Vec<RecordBatch> = reader.map(|item| item.unwrap()).collect();
119        assert_eq!(batches.len(), 1);
120        assert_eq!(batches[0].schema(), batch.schema());
121        assert_eq!(batches[0].num_rows(), batch.num_rows());
122    }
123
124    #[test]
125    fn dataframe_serialization_rejects_empty_batches() {
126        let err = serialize_dataframe(&[]).unwrap_err();
127        assert!(matches!(err, ArrowError::InvalidArgumentError(_)));
128    }
129
130    #[test]
131    fn dataframe_serialization_rejects_mismatched_schema() {
132        let batch = make_batch();
133        let other_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
134        let other_batch =
135            RecordBatch::try_new(other_schema, vec![Arc::new(Int32Array::from(vec![1, 2]))])
136                .unwrap();
137
138        let err = serialize_dataframe(&[batch, other_batch]).unwrap_err();
139        assert!(matches!(err, ArrowError::SchemaError(_)));
140    }
141}