llm_chain/
serialization.rs1use serde::de::{DeserializeOwned, Deserializer, MapAccess, Visitor};
48use serde::ser::{SerializeStruct, Serializer};
49use serde::{Deserialize, Serialize};
50use std::collections::HashMap;
51use thiserror::Error;
52#[derive(Debug, Clone)]
53pub struct Envelope<T> {
54 pub metadata: HashMap<String, String>,
55 pub data: T,
56}
57
58impl<T: Serialize> Serialize for Envelope<T> {
59 fn serialize<SER>(&self, serializer: SER) -> Result<SER::Ok, SER::Error>
60 where
61 SER: Serializer,
62 {
63 let mut envelope = serializer.serialize_struct("Envelope", 2)?;
64 envelope.serialize_field("metadata", &self.metadata)?;
65 envelope.serialize_field("data", &self.data)?;
66 envelope.end()
67 }
68}
69
70impl<'de, T: Serialize + Deserialize<'de>> Deserialize<'de> for Envelope<T> {
71 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
72 where
73 D: Deserializer<'de>,
74 {
75 struct EnvelopeVisitor<T>(std::marker::PhantomData<T>);
76
77 impl<'de, T> Visitor<'de> for EnvelopeVisitor<T>
78 where
79 T: Serialize + Deserialize<'de>,
80 {
81 type Value = Envelope<T>;
82
83 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
84 formatter.write_str("struct Envelope")
85 }
86
87 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
88 where
89 A: MapAccess<'de>,
90 {
91 let mut metadata = None;
92 let mut data: Option<T> = None;
93
94 while let Some(key) = map.next_key::<String>()? {
95 match key.as_str() {
96 "metadata" => {
97 let hm = map.next_value()?;
98 metadata = Some(hm);
99 }
100 "data" => {
101 data = Some(map.next_value()?);
102 }
103 _ => (),
104 }
105 }
106
107 let metadata = metadata.unwrap_or_default();
108 let data = data.ok_or_else(|| serde::de::Error::missing_field("data"))?;
109
110 Ok(Envelope { metadata, data })
111 }
112 }
113 deserializer.deserialize_map(EnvelopeVisitor(std::marker::PhantomData))
114 }
115}
116
117impl<T> Envelope<T> {
118 pub fn new(data: T) -> Self {
119 Envelope {
120 metadata: HashMap::new(),
121 data,
122 }
123 }
124}
125
126#[derive(Error, Debug)]
127pub enum EnvelopeError {
128 #[error("YAML parsing error: {0}")]
130 YamlParsingError(#[from] serde_json::Error),
131 #[error("IO error: {0}")]
132 IOError(#[from] std::io::Error),
133}
134
135impl<T> Envelope<T>
136where
137 T: Serialize + DeserializeOwned,
138{
139 pub fn read_file_sync(path: &str) -> Result<Self, EnvelopeError> {
140 let file = std::fs::File::open(path)?;
141 let reader = std::io::BufReader::new(file);
142 let envelope = serde_json::from_reader(reader)?;
143 Ok(envelope)
144 }
145 #[cfg(feature = "async")]
146 pub async fn read_file_async(path: &str) -> Result<Self, EnvelopeError> {
147 use tokio::io::AsyncReadExt;
148 let mut file = tokio::fs::File::open(path).await?;
149 let mut contents: Vec<u8> = vec![];
150 file.read_to_end(&mut contents).await?;
151 let envelope = serde_json::from_slice(&contents)?;
152 Ok(envelope)
153 }
154 pub fn write_file_sync(&self, path: &str) -> Result<(), EnvelopeError> {
155 let file = std::fs::File::create(path)?;
156 let writer = std::io::BufWriter::new(file);
157 serde_json::to_writer(writer, &self)?;
158 Ok(())
159 }
160 #[cfg(feature = "async")]
161 pub async fn write_file_async(&self, path: &str) -> Result<(), EnvelopeError> {
162 use tokio::io::AsyncWriteExt;
163 let data = serde_json::to_string(&self)?;
164 let mut file = tokio::fs::File::create(path).await?;
165 file.write_all(data.as_bytes()).await?;
166 Ok(())
167 }
168}
169
170pub trait StorableEntity: Serialize + DeserializeOwned {
172 fn get_metadata() -> Vec<(String, String)>;
173 fn to_envelope(self) -> Envelope<Self>
174 where
175 Self: Sized,
176 {
177 let mut envelope = Envelope {
178 metadata: HashMap::new(),
179 data: self,
180 };
181 for (key, value) in Self::get_metadata() {
182 envelope.metadata.insert(key, value);
183 }
184 envelope
185 }
186 fn from_envelope(envelope: Envelope<Self>) -> Self {
187 envelope.data
188 }
189 fn read_file_sync(path: &str) -> Result<Self, EnvelopeError> {
190 Envelope::<Self>::read_file_sync(path).map(|envelope| Self::from_envelope(envelope))
191 }
192 fn write_file_sync(self, path: &str) -> Result<(), EnvelopeError> {
193 Envelope::new(self).write_file_sync(path)
194 }
195}