polars_plan/dsl/expr/
serde_expr.rs

1use polars_utils::pl_serialize::deserialize_map_bytes;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3
4use super::named_serde::ExprRegistry;
5use super::*;
6
7impl Serialize for SpecialEq<Arc<dyn ColumnsUdf>> {
8    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
9    where
10        S: Serializer,
11    {
12        use serde::ser::Error;
13        let mut buf = vec![];
14        self.as_ref()
15            .try_serialize(&mut buf)
16            .map_err(|e| S::Error::custom(format!("{e}")))?;
17        serializer.serialize_bytes(&buf)
18    }
19}
20
21const NAMED_SERDE_MAGIC_BYTE_MARK: &[u8] = "PLNAMEDFN".as_bytes();
22const NAMED_SERDE_MAGIC_BYTE_END: u8 = b'!';
23
24fn serialize_named<S: Serializer>(
25    serializer: S,
26    name: &str,
27    payload: Option<&[u8]>,
28) -> Result<S::Ok, S::Error> {
29    let mut buf = vec![];
30    buf.extend_from_slice(NAMED_SERDE_MAGIC_BYTE_MARK);
31    buf.extend_from_slice(name.as_bytes());
32    buf.push(NAMED_SERDE_MAGIC_BYTE_END);
33    if let Some(payload) = payload {
34        buf.extend_from_slice(payload);
35    }
36    serializer.serialize_bytes(&buf)
37}
38
39fn deserialize_named_registry(buf: &[u8]) -> PolarsResult<(Arc<dyn ExprRegistry>, &str, &[u8])> {
40    let bytes = &buf[NAMED_SERDE_MAGIC_BYTE_MARK.len()..];
41    let Some(pos) = bytes.iter().position(|b| *b == NAMED_SERDE_MAGIC_BYTE_END) else {
42        polars_bail!(ComputeError: "named-serde expected magic byte end")
43    };
44
45    let Ok(name) = std::str::from_utf8(&bytes[..pos]) else {
46        polars_bail!(ComputeError: "named-serde name should be valid utf8")
47    };
48    let payload = &bytes[pos + 1..];
49
50    let registry = named_serde::NAMED_SERDE_REGISTRY_EXPR.read().unwrap();
51    match &*registry {
52        Some(reg) => Ok((reg.clone(), name, payload)),
53        None => polars_bail!(ComputeError: "named serde registry not set"),
54    }
55}
56
57impl<T: Serialize + Clone> Serialize for LazySerde<T> {
58    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
59    where
60        S: Serializer,
61    {
62        match self {
63            Self::Named {
64                name,
65                payload,
66                value: _,
67            } => serialize_named(serializer, name, payload.as_deref()),
68            Self::Deserialized(t) => t.serialize(serializer),
69            Self::Bytes(b) => b.serialize(serializer),
70        }
71    }
72}
73
74impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde<T> {
75    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
76    where
77        D: Deserializer<'a>,
78    {
79        let buf = bytes::Bytes::deserialize(deserializer)?;
80        Ok(Self::Bytes(buf))
81    }
82}
83
84pub(super) fn deserialize_column_udf(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> {
85    #[cfg(feature = "python")]
86    if buf.starts_with(crate::dsl::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {
87        return crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(buf);
88    };
89
90    if buf.starts_with(NAMED_SERDE_MAGIC_BYTE_MARK) {
91        let (reg, name, payload) = deserialize_named_registry(buf)?;
92
93        if let Some(func) = reg.get_function(name, payload) {
94            Ok(func)
95        } else {
96            let msg = "name not found in named serde registry";
97            polars_bail!(ComputeError: msg)
98        }
99    } else {
100        polars_bail!(ComputeError: "deserialization not supported for this 'opaque' function"
101        )
102    }
103}
104// impl<T: Deserialize> Deserialize for crate::dsl::expr::LazySerde<T> {
105impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn ColumnsUdf>> {
106    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
107    where
108        D: Deserializer<'a>,
109    {
110        use serde::de::Error;
111        deserialize_map_bytes(deserializer, |buf| {
112            deserialize_column_udf(&buf)
113                .map_err(|e| D::Error::custom(format!("{e}")))
114                .map(SpecialEq::new)
115        })?
116    }
117}
118
119impl Serialize for GetOutput {
120    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
121    where
122        S: Serializer,
123    {
124        use serde::ser::Error;
125        let mut buf = vec![];
126
127        match self {
128            LazySerde::Bytes(b) => serializer.serialize_bytes(b),
129            LazySerde::Named {
130                name,
131                payload,
132                value: _,
133            } => serialize_named(serializer, name, payload.as_deref()),
134            LazySerde::Deserialized(s) => {
135                s.as_ref()
136                    .try_serialize(&mut buf)
137                    .map_err(|e| S::Error::custom(format!("{e}")))?;
138                serializer.serialize_bytes(&buf)
139            },
140        }
141    }
142}
143
144#[cfg(feature = "serde")]
145impl<'a> Deserialize<'a> for GetOutput {
146    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
147    where
148        D: Deserializer<'a>,
149    {
150        use serde::de::Error;
151        #[cfg(feature = "python")]
152        {
153            deserialize_map_bytes(deserializer, |buf| {
154                if buf.starts_with(self::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {
155                    let get_output = self::python_dsl::PythonGetOutput::try_deserialize(&buf)
156                        .map_err(|e| D::Error::custom(format!("{e}")))?;
157                    Ok(LazySerde::Deserialized(SpecialEq::new(get_output)))
158                } else if buf.starts_with(NAMED_SERDE_MAGIC_BYTE_MARK) {
159                    let (reg, name, _payload) = deserialize_named_registry(&buf)
160                        .map_err(|e| D::Error::custom(format!("{}", e)))?;
161                    if let Some(func) = reg.get_output(name) {
162                        Ok(LazySerde::Deserialized(SpecialEq::new(func)))
163                    } else {
164                        let msg = "name not found in named serde registry";
165                        Err(D::Error::custom(msg))
166                    }
167                } else {
168                    Err(D::Error::custom(
169                        "deserialization not supported for this output field",
170                    ))
171                }
172            })?
173        }
174        #[cfg(not(feature = "python"))]
175        {
176            _ = deserializer;
177
178            Err(D::Error::custom(
179                "deserialization not supported for this output field",
180            ))
181        }
182    }
183}
184
185impl Serialize for SpecialEq<Arc<dyn RenameAliasFn>> {
186    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
187    where
188        S: Serializer,
189    {
190        use serde::ser::Error;
191        let mut buf = vec![];
192        self.as_ref()
193            .try_serialize(&mut buf)
194            .map_err(|e| S::Error::custom(format!("{e}")))?;
195        serializer.serialize_bytes(&buf)
196    }
197}
198
199impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn RenameAliasFn>> {
200    fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
201    where
202        D: Deserializer<'a>,
203    {
204        use serde::de::Error;
205        Err(D::Error::custom(
206            "deserialization not supported for this renaming function",
207        ))
208    }
209}
210
211impl Serialize for SpecialEq<Series> {
212    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
213    where
214        S: Serializer,
215    {
216        let s: &Series = self;
217        s.serialize(serializer)
218    }
219}
220
221impl<'a> Deserialize<'a> for SpecialEq<Series> {
222    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
223    where
224        D: Deserializer<'a>,
225    {
226        let t = Series::deserialize(deserializer)?;
227        Ok(SpecialEq::new(t))
228    }
229}
230
231impl<T: Serialize> Serialize for SpecialEq<Arc<T>> {
232    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
233    where
234        S: Serializer,
235    {
236        self.as_ref().serialize(serializer)
237    }
238}
239
240#[cfg(feature = "serde")]
241impl<'a, T: Deserialize<'a>> Deserialize<'a> for SpecialEq<Arc<T>> {
242    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
243    where
244        D: Deserializer<'a>,
245    {
246        let t = T::deserialize(deserializer)?;
247        Ok(SpecialEq::new(Arc::new(t)))
248    }
249}