llm_chain/
serialization.rs

1//! # Envelope Serialization
2//!
3//! This module contains the Envelope struct and its related functionality. It allows you to serialize and deserialize LLM chains and other data structures into different formats, such as YAML. The Envelope struct wraps your data and includes additional metadata, which can be useful for managing and organizing your serialized data.
4//!
5//! This module is mostly intended for internal use, but you can also use it to serialize your own data structures.
6//!
7//! ## Usage
8//!
9//! First, you need to implement the StorableEntity trait for your custom data type, which requires the get_metadata() method. Then, you can use the StorableEntityExt trait to easily read and write your data to and from files.
10//!
11//! ## Example
12//!
13//!```rust
14//! use serde::{Deserialize, Serialize};
15//! use llm_chain::serialization::{Envelope, StorableEntity};
16//!
17//! #[derive(Debug, Clone, Serialize, Deserialize)]
18//! struct MyData {
19//!    value: i32,
20//! }
21//!
22//! impl StorableEntity for MyData {
23//!    fn get_metadata() -> Vec<(String, String)> {
24//!        vec![("author".to_string(), "John Doe".to_string())]
25//!    }
26//! }
27//!
28//!
29//! let data = MyData { value: 42 };
30//!
31//! // Convert the data into an envelope
32//! let envelope = data.clone().write_file_sync("mydata.yaml").unwrap();
33//! // Serialize the envelope to a YAML file
34//! let path = "mydata.yaml";
35//! // Deserialize the envelope from a YAML file
36//! let read_data = MyData::read_file_sync(path).unwrap();
37//! assert_eq!(data.value, read_data.value);
38//!
39//! ```
40//! ## Features
41//!
42//! This module provides synchronous and asynchronous methods for reading and writing envelopes to and from files. The asynchronous methods are available behind the async feature flag.
43//!
44//! ## Errors
45//!
46//! The module also provides the EnvelopeError enum, which represents errors that can occur during serialization, deserialization, and file I/O operations.
47use serde::de::{DeserializeOwned, Deserializer, MapAccess, Visitor};
48use serde::ser::{SerializeStruct, Serializer};
49use serde::{Deserialize, Serialize};
50use std::collections::HashMap;
51use thiserror::Error;
52#[derive(Debug, Clone)]
53pub struct Envelope<T> {
54    pub metadata: HashMap<String, String>,
55    pub data: T,
56}
57
58impl<T: Serialize> Serialize for Envelope<T> {
59    fn serialize<SER>(&self, serializer: SER) -> Result<SER::Ok, SER::Error>
60    where
61        SER: Serializer,
62    {
63        let mut envelope = serializer.serialize_struct("Envelope", 2)?;
64        envelope.serialize_field("metadata", &self.metadata)?;
65        envelope.serialize_field("data", &self.data)?;
66        envelope.end()
67    }
68}
69
70impl<'de, T: Serialize + Deserialize<'de>> Deserialize<'de> for Envelope<T> {
71    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
72    where
73        D: Deserializer<'de>,
74    {
75        struct EnvelopeVisitor<T>(std::marker::PhantomData<T>);
76
77        impl<'de, T> Visitor<'de> for EnvelopeVisitor<T>
78        where
79            T: Serialize + Deserialize<'de>,
80        {
81            type Value = Envelope<T>;
82
83            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
84                formatter.write_str("struct Envelope")
85            }
86
87            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
88            where
89                A: MapAccess<'de>,
90            {
91                let mut metadata = None;
92                let mut data: Option<T> = None;
93
94                while let Some(key) = map.next_key::<String>()? {
95                    match key.as_str() {
96                        "metadata" => {
97                            let hm = map.next_value()?;
98                            metadata = Some(hm);
99                        }
100                        "data" => {
101                            data = Some(map.next_value()?);
102                        }
103                        _ => (),
104                    }
105                }
106
107                let metadata = metadata.unwrap_or_default();
108                let data = data.ok_or_else(|| serde::de::Error::missing_field("data"))?;
109
110                Ok(Envelope { metadata, data })
111            }
112        }
113        deserializer.deserialize_map(EnvelopeVisitor(std::marker::PhantomData))
114    }
115}
116
117impl<T> Envelope<T> {
118    pub fn new(data: T) -> Self {
119        Envelope {
120            metadata: HashMap::new(),
121            data,
122        }
123    }
124}
125
126#[derive(Error, Debug)]
127pub enum EnvelopeError {
128    // YAML parsing
129    #[error("YAML parsing error: {0}")]
130    YamlParsingError(#[from] serde_json::Error),
131    #[error("IO error: {0}")]
132    IOError(#[from] std::io::Error),
133}
134
135impl<T> Envelope<T>
136where
137    T: Serialize + DeserializeOwned,
138{
139    pub fn read_file_sync(path: &str) -> Result<Self, EnvelopeError> {
140        let file = std::fs::File::open(path)?;
141        let reader = std::io::BufReader::new(file);
142        let envelope = serde_json::from_reader(reader)?;
143        Ok(envelope)
144    }
145    #[cfg(feature = "async")]
146    pub async fn read_file_async(path: &str) -> Result<Self, EnvelopeError> {
147        use tokio::io::AsyncReadExt;
148        let mut file = tokio::fs::File::open(path).await?;
149        let mut contents: Vec<u8> = vec![];
150        file.read_to_end(&mut contents).await?;
151        let envelope = serde_json::from_slice(&contents)?;
152        Ok(envelope)
153    }
154    pub fn write_file_sync(&self, path: &str) -> Result<(), EnvelopeError> {
155        let file = std::fs::File::create(path)?;
156        let writer = std::io::BufWriter::new(file);
157        serde_json::to_writer(writer, &self)?;
158        Ok(())
159    }
160    #[cfg(feature = "async")]
161    pub async fn write_file_async(&self, path: &str) -> Result<(), EnvelopeError> {
162        use tokio::io::AsyncWriteExt;
163        let data = serde_json::to_string(&self)?;
164        let mut file = tokio::fs::File::create(path).await?;
165        file.write_all(data.as_bytes()).await?;
166        Ok(())
167    }
168}
169
170/// An entity that can be stored in an envelope.
171pub trait StorableEntity: Serialize + DeserializeOwned {
172    fn get_metadata() -> Vec<(String, String)>;
173    fn to_envelope(self) -> Envelope<Self>
174    where
175        Self: Sized,
176    {
177        let mut envelope = Envelope {
178            metadata: HashMap::new(),
179            data: self,
180        };
181        for (key, value) in Self::get_metadata() {
182            envelope.metadata.insert(key, value);
183        }
184        envelope
185    }
186    fn from_envelope(envelope: Envelope<Self>) -> Self {
187        envelope.data
188    }
189    fn read_file_sync(path: &str) -> Result<Self, EnvelopeError> {
190        Envelope::<Self>::read_file_sync(path).map(|envelope| Self::from_envelope(envelope))
191    }
192    fn write_file_sync(self, path: &str) -> Result<(), EnvelopeError> {
193        Envelope::new(self).write_file_sync(path)
194    }
195}