polars_plan/dsl/expr/
serde_expr.rs1use polars_utils::pl_serialize::deserialize_map_bytes;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3
4use super::named_serde::ExprRegistry;
5use super::*;
6
7impl Serialize for SpecialEq<Arc<dyn ColumnsUdf>> {
8 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
9 where
10 S: Serializer,
11 {
12 use serde::ser::Error;
13 let mut buf = vec![];
14 self.as_ref()
15 .try_serialize(&mut buf)
16 .map_err(|e| S::Error::custom(format!("{e}")))?;
17 serializer.serialize_bytes(&buf)
18 }
19}
20
21const NAMED_SERDE_MAGIC_BYTE_MARK: &[u8] = "PLNAMEDFN".as_bytes();
22const NAMED_SERDE_MAGIC_BYTE_END: u8 = b'!';
23
24fn serialize_named<S: Serializer>(
25 serializer: S,
26 name: &str,
27 payload: Option<&[u8]>,
28) -> Result<S::Ok, S::Error> {
29 let mut buf = vec![];
30 buf.extend_from_slice(NAMED_SERDE_MAGIC_BYTE_MARK);
31 buf.extend_from_slice(name.as_bytes());
32 buf.push(NAMED_SERDE_MAGIC_BYTE_END);
33 if let Some(payload) = payload {
34 buf.extend_from_slice(payload);
35 }
36 serializer.serialize_bytes(&buf)
37}
38
39fn deserialize_named_registry(buf: &[u8]) -> PolarsResult<(Arc<dyn ExprRegistry>, &str, &[u8])> {
40 let bytes = &buf[NAMED_SERDE_MAGIC_BYTE_MARK.len()..];
41 let Some(pos) = bytes.iter().position(|b| *b == NAMED_SERDE_MAGIC_BYTE_END) else {
42 polars_bail!(ComputeError: "named-serde expected magic byte end")
43 };
44
45 let Ok(name) = std::str::from_utf8(&bytes[..pos]) else {
46 polars_bail!(ComputeError: "named-serde name should be valid utf8")
47 };
48 let payload = &bytes[pos + 1..];
49
50 let registry = named_serde::NAMED_SERDE_REGISTRY_EXPR.read().unwrap();
51 match &*registry {
52 Some(reg) => Ok((reg.clone(), name, payload)),
53 None => polars_bail!(ComputeError: "named serde registry not set"),
54 }
55}
56
57impl<T: Serialize + Clone> Serialize for LazySerde<T> {
58 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
59 where
60 S: Serializer,
61 {
62 match self {
63 Self::Named {
64 name,
65 payload,
66 value: _,
67 } => serialize_named(serializer, name, payload.as_deref()),
68 Self::Deserialized(t) => t.serialize(serializer),
69 Self::Bytes(b) => b.serialize(serializer),
70 }
71 }
72}
73
74impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde<T> {
75 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
76 where
77 D: Deserializer<'a>,
78 {
79 let buf = bytes::Bytes::deserialize(deserializer)?;
80 Ok(Self::Bytes(buf))
81 }
82}
83
84pub(super) fn deserialize_column_udf(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> {
85 #[cfg(feature = "python")]
86 if buf.starts_with(crate::dsl::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {
87 return crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(buf);
88 };
89
90 if buf.starts_with(NAMED_SERDE_MAGIC_BYTE_MARK) {
91 let (reg, name, payload) = deserialize_named_registry(buf)?;
92
93 if let Some(func) = reg.get_function(name, payload) {
94 Ok(func)
95 } else {
96 let msg = "name not found in named serde registry";
97 polars_bail!(ComputeError: msg)
98 }
99 } else {
100 polars_bail!(ComputeError: "deserialization not supported for this 'opaque' function"
101 )
102 }
103}
104impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn ColumnsUdf>> {
106 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
107 where
108 D: Deserializer<'a>,
109 {
110 use serde::de::Error;
111 deserialize_map_bytes(deserializer, |buf| {
112 deserialize_column_udf(&buf)
113 .map_err(|e| D::Error::custom(format!("{e}")))
114 .map(SpecialEq::new)
115 })?
116 }
117}
118
119impl Serialize for GetOutput {
120 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
121 where
122 S: Serializer,
123 {
124 use serde::ser::Error;
125 let mut buf = vec![];
126
127 match self {
128 LazySerde::Bytes(b) => serializer.serialize_bytes(b),
129 LazySerde::Named {
130 name,
131 payload,
132 value: _,
133 } => serialize_named(serializer, name, payload.as_deref()),
134 LazySerde::Deserialized(s) => {
135 s.as_ref()
136 .try_serialize(&mut buf)
137 .map_err(|e| S::Error::custom(format!("{e}")))?;
138 serializer.serialize_bytes(&buf)
139 },
140 }
141 }
142}
143
144#[cfg(feature = "serde")]
145impl<'a> Deserialize<'a> for GetOutput {
146 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
147 where
148 D: Deserializer<'a>,
149 {
150 use serde::de::Error;
151 #[cfg(feature = "python")]
152 {
153 deserialize_map_bytes(deserializer, |buf| {
154 if buf.starts_with(self::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {
155 let get_output = self::python_dsl::PythonGetOutput::try_deserialize(&buf)
156 .map_err(|e| D::Error::custom(format!("{e}")))?;
157 Ok(LazySerde::Deserialized(SpecialEq::new(get_output)))
158 } else if buf.starts_with(NAMED_SERDE_MAGIC_BYTE_MARK) {
159 let (reg, name, _payload) = deserialize_named_registry(&buf)
160 .map_err(|e| D::Error::custom(format!("{}", e)))?;
161 if let Some(func) = reg.get_output(name) {
162 Ok(LazySerde::Deserialized(SpecialEq::new(func)))
163 } else {
164 let msg = "name not found in named serde registry";
165 Err(D::Error::custom(msg))
166 }
167 } else {
168 Err(D::Error::custom(
169 "deserialization not supported for this output field",
170 ))
171 }
172 })?
173 }
174 #[cfg(not(feature = "python"))]
175 {
176 _ = deserializer;
177
178 Err(D::Error::custom(
179 "deserialization not supported for this output field",
180 ))
181 }
182 }
183}
184
185impl Serialize for SpecialEq<Arc<dyn RenameAliasFn>> {
186 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
187 where
188 S: Serializer,
189 {
190 use serde::ser::Error;
191 let mut buf = vec![];
192 self.as_ref()
193 .try_serialize(&mut buf)
194 .map_err(|e| S::Error::custom(format!("{e}")))?;
195 serializer.serialize_bytes(&buf)
196 }
197}
198
199impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn RenameAliasFn>> {
200 fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
201 where
202 D: Deserializer<'a>,
203 {
204 use serde::de::Error;
205 Err(D::Error::custom(
206 "deserialization not supported for this renaming function",
207 ))
208 }
209}
210
211impl Serialize for SpecialEq<Series> {
212 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
213 where
214 S: Serializer,
215 {
216 let s: &Series = self;
217 s.serialize(serializer)
218 }
219}
220
221impl<'a> Deserialize<'a> for SpecialEq<Series> {
222 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
223 where
224 D: Deserializer<'a>,
225 {
226 let t = Series::deserialize(deserializer)?;
227 Ok(SpecialEq::new(t))
228 }
229}
230
231impl<T: Serialize> Serialize for SpecialEq<Arc<T>> {
232 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
233 where
234 S: Serializer,
235 {
236 self.as_ref().serialize(serializer)
237 }
238}
239
240#[cfg(feature = "serde")]
241impl<'a, T: Deserialize<'a>> Deserialize<'a> for SpecialEq<Arc<T>> {
242 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
243 where
244 D: Deserializer<'a>,
245 {
246 let t = T::deserialize(deserializer)?;
247 Ok(SpecialEq::new(Arc::new(t)))
248 }
249}