Skip to main content

lance_context_core/
serde.rs

1use arrow_array::RecordBatch;
2use arrow_ipc::writer::StreamWriter;
3use arrow_schema::ArrowError;
4use serde::{Deserialize, Serialize};
5
6pub const CONTENT_TYPE_TEXT: &str = "text/plain";
7pub const CONTENT_TYPE_ARROW_STREAM: &str = "application/vnd.apache.arrow.stream";
8pub const CONTENT_TYPE_TOMBSTONE: &str = "application/vnd.lance-context.tombstone";
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub struct SerializedContent {
12    pub content_type: String,
13    pub text_payload: Option<String>,
14    pub binary_payload: Option<Vec<u8>>,
15}
16
17impl SerializedContent {
18    pub fn text(value: impl Into<String>, content_type: Option<&str>) -> Self {
19        Self {
20            content_type: content_type.unwrap_or(CONTENT_TYPE_TEXT).to_string(),
21            text_payload: Some(value.into()),
22            binary_payload: None,
23        }
24    }
25
26    pub fn image(bytes: impl Into<Vec<u8>>, mime: impl Into<String>) -> Self {
27        Self {
28            content_type: mime.into(),
29            text_payload: None,
30            binary_payload: Some(bytes.into()),
31        }
32    }
33
34    pub fn dataframe_batches(batches: &[RecordBatch]) -> Result<Self, ArrowError> {
35        let ipc_bytes = record_batches_to_ipc(batches)?;
36        Ok(Self::dataframe_ipc_bytes(ipc_bytes))
37    }
38
39    pub fn dataframe_ipc_bytes(bytes: impl Into<Vec<u8>>) -> Self {
40        Self {
41            content_type: CONTENT_TYPE_ARROW_STREAM.to_string(),
42            text_payload: None,
43            binary_payload: Some(bytes.into()),
44        }
45    }
46}
47
48pub fn serialize_image(bytes: impl Into<Vec<u8>>, mime: impl Into<String>) -> SerializedContent {
49    SerializedContent::image(bytes, mime)
50}
51
52pub fn serialize_dataframe(batches: &[RecordBatch]) -> Result<SerializedContent, ArrowError> {
53    SerializedContent::dataframe_batches(batches)
54}
55
56pub fn serialize_dataframe_ipc(bytes: impl Into<Vec<u8>>) -> SerializedContent {
57    SerializedContent::dataframe_ipc_bytes(bytes)
58}
59
60fn record_batches_to_ipc(batches: &[RecordBatch]) -> Result<Vec<u8>, ArrowError> {
61    if batches.is_empty() {
62        return Err(ArrowError::InvalidArgumentError(
63            "no record batches provided".to_string(),
64        ));
65    }
66
67    let schema = batches[0].schema();
68    let mut buffer = Vec::new();
69    {
70        let mut writer = StreamWriter::try_new(&mut buffer, &schema)?;
71        for batch in batches {
72            if batch.schema() != schema {
73                return Err(ArrowError::SchemaError(
74                    "record batch schema mismatch".to_string(),
75                ));
76            }
77            writer.write(batch)?;
78        }
79        writer.finish()?;
80    }
81    Ok(buffer)
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use arrow_array::{Int32Array, RecordBatch, StringArray};
88    use arrow_ipc::reader::StreamReader;
89    use arrow_schema::{DataType, Field, Schema};
90    use std::io::Cursor;
91    use std::sync::Arc;
92
93    fn make_batch() -> RecordBatch {
94        let schema = Arc::new(Schema::new(vec![
95            Field::new("id", DataType::Int32, false),
96            Field::new("name", DataType::Utf8, false),
97        ]));
98        let id_array = Arc::new(Int32Array::from(vec![1, 2]));
99        let name_array = Arc::new(StringArray::from(vec!["alpha", "beta"]));
100        RecordBatch::try_new(schema, vec![id_array, name_array]).unwrap()
101    }
102
103    #[test]
104    fn image_serialization_sets_payloads() {
105        let content = serialize_image(vec![1, 2, 3], "image/png");
106        assert_eq!(content.content_type, "image/png");
107        assert_eq!(content.text_payload, None);
108        assert_eq!(content.binary_payload, Some(vec![1, 2, 3]));
109    }
110
111    #[test]
112    fn dataframe_serialization_writes_ipc_stream() {
113        let batch = make_batch();
114        let content = serialize_dataframe(std::slice::from_ref(&batch)).unwrap();
115        assert_eq!(content.content_type, CONTENT_TYPE_ARROW_STREAM);
116        let bytes = content.binary_payload.expect("expected IPC payload");
117
118        let reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap();
119        let batches: Vec<RecordBatch> = reader.map(|item| item.unwrap()).collect();
120        assert_eq!(batches.len(), 1);
121        assert_eq!(batches[0].schema(), batch.schema());
122        assert_eq!(batches[0].num_rows(), batch.num_rows());
123    }
124
125    #[test]
126    fn dataframe_serialization_rejects_empty_batches() {
127        let err = serialize_dataframe(&[]).unwrap_err();
128        assert!(matches!(err, ArrowError::InvalidArgumentError(_)));
129    }
130
131    #[test]
132    fn dataframe_serialization_rejects_mismatched_schema() {
133        let batch = make_batch();
134        let other_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
135        let other_batch =
136            RecordBatch::try_new(other_schema, vec![Arc::new(Int32Array::from(vec![1, 2]))])
137                .unwrap();
138
139        let err = serialize_dataframe(&[batch, other_batch]).unwrap_err();
140        assert!(matches!(err, ArrowError::SchemaError(_)));
141    }
142}