Skip to main content

trs_dataframe/dataframe/
join.rs

1use super::Key;
2#[cfg(feature = "python")]
3use pyo3::prelude::*;
4use serde::{Deserialize, Serialize};
5#[cfg(feature = "utoipa")]
6use utoipa::ToSchema;
7
8#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
9#[cfg_attr(feature = "python", pyo3::pyclass)]
10#[cfg_attr(feature = "utoipa", derive(ToSchema))]
11/// Specifies the key columns used for an id-based join.
12pub struct JoinById {
13    pub keys: Vec<Key>,
14}
15
16impl JoinById {
17    /// Creates a new [`JoinById`] for the given key columns.
18    pub fn new(keys: Vec<Key>) -> Self {
19        Self { keys }
20    }
21}
22
23#[cfg(feature = "python")]
24#[pymethods]
25impl JoinById {
26    #[new]
27    pub fn init(keys: Vec<Key>) -> Self {
28        Self { keys }
29    }
30}
31
32#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)]
33#[cfg_attr(feature = "utoipa", derive(ToSchema))]
34/// Enum representing different strategies for combining or joining data structures.
35pub enum JoinBy {
36    /// Adds only non-existing columns to the existing structure.
37    /// This is the default behavior.
38    #[default]
39    AddColumns,
40
41    /// Replaces existing data with the new data.
42    Replace,
43
44    /// Extends the existing data by appending new elements.
45    Extend,
46
47    /// Performs a broadcast operation, replicating smaller data structures
48    /// to match the size of larger ones.
49    Broadcast,
50
51    /// Computes the Cartesian product of the input structures,
52    /// resulting in all possible combinations of elements.
53    CartesianProduct,
54
55    /// Joins two structures using a specific identifier or key.
56    ///
57    /// The behavior is determined by the provided `JoinById` variant.
58    JoinById(JoinById),
59}
60
61#[cfg(feature = "python")]
62pub mod python {
63    use super::*;
64    use serde::{Deserialize, Serialize};
65
66    #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
67    #[pyclass(eq, eq_int)]
68    pub enum PythonJoinBy {
69        /// Adds only non-existing columns to the existing structure.
70        /// This is the default behavior.
71        AddColumns,
72
73        /// Replaces existing data with the new data.
74        Replace,
75
76        /// Extends the existing data by appending new elements.
77        Extend,
78
79        /// Performs a broadcast operation, replicating smaller data structures
80        /// to match the size of larger ones.
81        Broadcast,
82
83        /// Computes the Cartesian product of the input structures,
84        /// resulting in all possible combinations of elements.
85        CartesianProduct,
86
87        /// Joins two structures using a specific identifier or key.
88        ///
89        /// The behavior is determined by the provided `JoinById` variant.
90        JoinById,
91    }
92
93    /// Python representation of the `JoinBy` enum,
94    /// which includes the join type and an optional `JoinById`.
95    /// This struct is used to facilitate conversions between Rust and Python representations.
96    /// It allows for the serialization and deserialization of join operations in a Python-friendly format.
97    /// This struct is particularly useful when integrating with Python code,
98    ///
99    #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
100    #[pyclass]
101    pub struct PythonJoin {
102        pub join_type: PythonJoinBy,
103        pub join_by_id: Option<JoinById>,
104    }
105
106    impl TryFrom<PythonJoin> for JoinBy {
107        type Error = crate::error::Error;
108        fn try_from(py_join: PythonJoin) -> Result<Self, Self::Error> {
109            Ok(match py_join.join_type {
110                PythonJoinBy::AddColumns => JoinBy::AddColumns,
111                PythonJoinBy::Replace => JoinBy::Replace,
112                PythonJoinBy::Extend => JoinBy::Extend,
113                PythonJoinBy::Broadcast => JoinBy::Broadcast,
114                PythonJoinBy::CartesianProduct => JoinBy::CartesianProduct,
115                PythonJoinBy::JoinById => {
116                    let join_by_id = py_join
117                        .join_by_id
118                        .ok_or_else(|| crate::error::Error::MissingField("join_by_id".into()))?;
119                    JoinBy::JoinById(join_by_id)
120                }
121            })
122        }
123    }
124
125    impl TryFrom<JoinBy> for PythonJoin {
126        type Error = crate::error::Error;
127        fn try_from(py_join: JoinBy) -> Result<Self, Self::Error> {
128            Ok(match py_join {
129                JoinBy::AddColumns => PythonJoin {
130                    join_type: PythonJoinBy::AddColumns,
131                    join_by_id: None,
132                },
133                JoinBy::Replace => PythonJoin {
134                    join_type: PythonJoinBy::Replace,
135                    join_by_id: None,
136                },
137                JoinBy::Extend => PythonJoin {
138                    join_type: PythonJoinBy::Extend,
139                    join_by_id: None,
140                },
141                JoinBy::Broadcast => PythonJoin {
142                    join_type: PythonJoinBy::Broadcast,
143                    join_by_id: None,
144                },
145                JoinBy::CartesianProduct => PythonJoin {
146                    join_type: PythonJoinBy::CartesianProduct,
147                    join_by_id: None,
148                },
149                JoinBy::JoinById(join_by_id) => PythonJoin {
150                    join_type: PythonJoinBy::JoinById,
151                    join_by_id: Some(join_by_id),
152                },
153            })
154        }
155    }
156
157    impl FromPyObject<'_> for JoinBy {
158        fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
159            let py_join: PythonJoin = ob.extract()?;
160            Self::try_from(py_join).map_err(|e: crate::error::Error| {
161                pyo3::exceptions::PyValueError::new_err(format!("{e}"))
162            })
163        }
164    }
165
166    impl<'py> IntoPyObject<'py> for JoinBy {
167        type Error = PyErr;
168        type Target = PythonJoin;
169        type Output = Bound<'py, Self::Target>;
170        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
171            let py_join: PythonJoin = self.try_into().map_err(|e: crate::error::Error| {
172                pyo3::exceptions::PyValueError::new_err(format!("Error converting: {e}"))
173            })?;
174            py_join.into_pyobject(py)
175        }
176    }
177
178    #[cfg(test)]
179    mod test {
180        use super::*;
181        use rstest::*;
182
183        #[rstest]
184        #[case(JoinBy::AddColumns)]
185        #[case(JoinBy::Replace)]
186        #[case(JoinBy::Extend)]
187        #[case(JoinBy::Broadcast)]
188        #[case(JoinBy::CartesianProduct)]
189        #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
190        fn test_join_by(#[case] join_by: JoinBy) {
191            let py_join = PythonJoin::try_from(join_by.clone()).unwrap();
192            let join_by2 = JoinBy::try_from(py_join).unwrap();
193            assert_eq!(join_by, join_by2);
194        }
195
196        #[rstest]
197        #[case(JoinBy::AddColumns)]
198        #[case(JoinBy::Replace)]
199        #[case(JoinBy::Extend)]
200        #[case(JoinBy::Broadcast)]
201        #[case(JoinBy::CartesianProduct)]
202        #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
203        fn test_into_py(#[case] join_by: JoinBy) {
204            pyo3::Python::attach(|py| {
205                let py_join = join_by.clone().into_pyobject(py);
206                assert!(py_join.is_ok());
207                let py_join = py_join.unwrap();
208                let from_py = JoinBy::extract_bound(&py_join);
209                assert!(from_py.is_ok());
210                let join_by2 = from_py.unwrap();
211                assert_eq!(join_by, join_by2);
212            });
213        }
214    }
215}
216
217#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq)]
218#[cfg_attr(feature = "python", pyclass)]
219#[cfg_attr(feature = "utoipa", derive(ToSchema))]
220pub struct JoinRelation {
221    pub join_type: JoinBy,
222}
223
224#[cfg(feature = "python")]
225#[pymethods]
226impl JoinRelation {
227    #[new]
228    pub fn init(join_type: JoinBy) -> Self {
229        Self::new(join_type)
230    }
231    #[pyo3(name = "broadcast")]
232    #[staticmethod]
233    /// Performs a broadcast operation, replicating smaller data structures
234    /// to match the size of larger ones.
235    pub fn py_broadcast() -> Self {
236        Self {
237            join_type: JoinBy::Broadcast,
238        }
239    }
240
241    #[pyo3(name = "add_columns")]
242    #[staticmethod]
243    /// Adds only non-existing columns to the existing structure.
244    /// This is the default behavior.
245    pub fn py_add_columns() -> Self {
246        Self {
247            join_type: JoinBy::AddColumns,
248        }
249    }
250
251    #[pyo3(name = "replace")]
252    #[staticmethod]
253    /// Replaces existing data with the new data.
254    pub fn py_replace() -> Self {
255        Self {
256            join_type: JoinBy::Replace,
257        }
258    }
259
260    #[pyo3(name = "extend")]
261    #[staticmethod]
262    /// Extends the existing data by appending new elements.
263    pub fn py_extend() -> Self {
264        Self {
265            join_type: JoinBy::Extend,
266        }
267    }
268
269    #[pyo3(name = "cartesian_product")]
270    #[staticmethod]
271    /// Computes the Cartesian product of the input structures,
272    /// resulting in all possible combinations of elements.
273    pub fn py_cartesian_product() -> Self {
274        Self {
275            join_type: JoinBy::CartesianProduct,
276        }
277    }
278
279    #[pyo3(name = "join_by_id")]
280    #[staticmethod]
281    /// Joins two structures using a specific identifier or key.
282    ///
283    /// The behavior is determined by the provided key variant.
284    pub fn py_join_by_id(keys: Vec<Key>) -> Self {
285        Self {
286            join_type: JoinBy::JoinById(JoinById::new(keys)),
287        }
288    }
289}
290
291impl JoinRelation {
292    pub fn new(join_type: JoinBy) -> Self {
293        Self { join_type }
294    }
295
296    pub fn broadcast() -> Self {
297        Self {
298            join_type: JoinBy::Broadcast,
299        }
300    }
301
302    pub fn add_columns() -> Self {
303        Self {
304            join_type: JoinBy::AddColumns,
305        }
306    }
307
308    pub fn replace() -> Self {
309        Self {
310            join_type: JoinBy::Replace,
311        }
312    }
313
314    pub fn extend() -> Self {
315        Self {
316            join_type: JoinBy::Extend,
317        }
318    }
319
320    pub fn cartesian_product() -> Self {
321        Self {
322            join_type: JoinBy::CartesianProduct,
323        }
324    }
325
326    pub fn join_by_id(keys: Vec<Key>) -> Self {
327        Self {
328            join_type: JoinBy::JoinById(JoinById::new(keys)),
329        }
330    }
331}
332
333#[cfg(test)]
334mod test {
335    use super::*;
336    use rstest::*;
337
338    #[cfg(feature = "utoipa")]
339    #[rstest]
340    fn test_join_relation_to_schema() {
341        let _name = JoinRelation::name();
342        let mut schemas = vec![];
343
344        JoinRelation::schemas(&mut schemas);
345
346        assert!(!schemas.is_empty());
347    }
348
349    #[rstest]
350    #[case(JoinBy::AddColumns)]
351    #[case(JoinBy::Replace)]
352    #[case(JoinBy::Extend)]
353    #[case(JoinBy::Broadcast)]
354    #[case(JoinBy::CartesianProduct)]
355    fn test_join_relation_new(#[case] join_type: JoinBy) {
356        let join_relation = JoinRelation::new(join_type.clone());
357        assert_eq!(join_relation.join_type, join_type);
358        let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
359        let deserialized: JoinRelation =
360            serde_json::from_str(&serde).expect("BUG: cannot deserialize");
361        assert_eq!(deserialized, join_relation);
362    }
363
364    #[rstest]
365    #[case(JoinBy::AddColumns, JoinRelation::add_columns())]
366    #[case(JoinBy::Replace, JoinRelation::replace())]
367    #[case(JoinBy::Extend, JoinRelation::extend())]
368    #[case(JoinBy::Broadcast, JoinRelation::broadcast())]
369    #[case(JoinBy::CartesianProduct, JoinRelation::cartesian_product())]
370    #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])), JoinRelation::join_by_id(vec!["a".into()]))]
371    fn test_join_releation(#[case] join_type: JoinBy, #[case] jt: JoinRelation) {
372        let join_relation = JoinRelation::new(join_type.clone());
373        assert_eq!(join_relation.join_type, join_type);
374        assert_eq!(join_relation, jt);
375        let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
376        let deserialized: JoinRelation =
377            serde_json::from_str(&serde).expect("BUG: cannot deserialize");
378        assert_eq!(deserialized, join_relation);
379    }
380
381    #[cfg(feature = "python")]
382    #[rstest]
383    #[case(JoinBy::AddColumns)]
384    #[case(JoinBy::Replace)]
385    #[case(JoinBy::Extend)]
386    #[case(JoinBy::Broadcast)]
387    #[case(JoinBy::CartesianProduct)]
388    #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
389    fn test_join_relation_py(#[case] join_type: JoinBy) {
390        pyo3::Python::attach(|_py| {
391            let join_relation = JoinRelation::new(join_type.clone());
392            let py_join_relation = JoinRelation::init(join_type.clone());
393            assert_eq!(join_relation.join_type, join_type);
394            assert_eq!(join_relation, py_join_relation);
395        });
396    }
397
398    #[cfg(feature = "python")]
399    #[rstest]
400    #[case(JoinBy::AddColumns, JoinRelation::py_add_columns())]
401    #[case(JoinBy::Replace, JoinRelation::py_replace())]
402    #[case(JoinBy::Extend, JoinRelation::py_extend())]
403    #[case(JoinBy::Broadcast, JoinRelation::py_broadcast())]
404    #[case(JoinBy::CartesianProduct, JoinRelation::py_cartesian_product())]
405    #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])), JoinRelation::py_join_by_id(vec!["a".into()]))]
406    fn test_py_join_releation(#[case] join_type: JoinBy, #[case] jt: JoinRelation) {
407        let join_relation = JoinRelation::new(join_type.clone());
408        assert_eq!(join_relation.join_type, join_type);
409        assert_eq!(join_relation, jt);
410        let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
411        let deserialized: JoinRelation =
412            serde_json::from_str(&serde).expect("BUG: cannot deserialize");
413        assert_eq!(deserialized, join_relation);
414    }
415}