polars_ops/frame/join/
args.rs

1use super::*;
2
3pub(super) type JoinIds = Vec<IdxSize>;
4pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5pub type InnerJoinIds = (JoinIds, JoinIds);
6
7#[cfg(feature = "chunked_ids")]
8pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9#[cfg(feature = "chunked_ids")]
10pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12#[cfg(not(feature = "chunked_ids"))]
13pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15#[cfg(not(feature = "chunked_ids"))]
16pub type ChunkJoinIds = Vec<IdxSize>;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use strum_macros::IntoStaticStr;
21
22#[derive(Clone, PartialEq, Eq, Debug, Hash, Default)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24pub struct JoinArgs {
25    pub how: JoinType,
26    pub validation: JoinValidation,
27    pub suffix: Option<PlSmallStr>,
28    pub slice: Option<(i64, usize)>,
29    pub nulls_equal: bool,
30    pub coalesce: JoinCoalesce,
31    pub maintain_order: MaintainOrderJoin,
32}
33
34impl JoinArgs {
35    pub fn should_coalesce(&self) -> bool {
36        self.coalesce.coalesce(&self.how)
37    }
38}
39
40#[derive(Clone, PartialEq, Eq, Hash, Default, IntoStaticStr)]
41#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
42pub enum JoinType {
43    #[default]
44    Inner,
45    Left,
46    Right,
47    Full,
48    #[cfg(feature = "asof_join")]
49    AsOf(AsOfOptions),
50    #[cfg(feature = "semi_anti_join")]
51    Semi,
52    #[cfg(feature = "semi_anti_join")]
53    Anti,
54    #[cfg(feature = "iejoin")]
55    // Options are set by optimizer/planner in Options
56    IEJoin,
57    // Options are set by optimizer/planner in Options
58    Cross,
59}
60
61#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
62#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
63pub enum JoinCoalesce {
64    #[default]
65    JoinSpecific,
66    CoalesceColumns,
67    KeepColumns,
68}
69
70impl JoinCoalesce {
71    pub fn coalesce(&self, join_type: &JoinType) -> bool {
72        use JoinCoalesce::*;
73        use JoinType::*;
74        match join_type {
75            Left | Inner | Right => {
76                matches!(self, JoinSpecific | CoalesceColumns)
77            },
78            Full => {
79                matches!(self, CoalesceColumns)
80            },
81            #[cfg(feature = "asof_join")]
82            AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
83            #[cfg(feature = "iejoin")]
84            IEJoin => false,
85            Cross => false,
86            #[cfg(feature = "semi_anti_join")]
87            Semi | Anti => false,
88        }
89    }
90}
91
92#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
93#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
94#[strum(serialize_all = "snake_case")]
95pub enum MaintainOrderJoin {
96    #[default]
97    None,
98    Left,
99    Right,
100    LeftRight,
101    RightLeft,
102}
103
104impl MaintainOrderJoin {
105    pub(super) fn flip(&self) -> Self {
106        match self {
107            MaintainOrderJoin::None => MaintainOrderJoin::None,
108            MaintainOrderJoin::Left => MaintainOrderJoin::Right,
109            MaintainOrderJoin::Right => MaintainOrderJoin::Left,
110            MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
111            MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
112        }
113    }
114}
115
116impl JoinArgs {
117    pub fn new(how: JoinType) -> Self {
118        Self {
119            how,
120            validation: Default::default(),
121            suffix: None,
122            slice: None,
123            nulls_equal: false,
124            coalesce: Default::default(),
125            maintain_order: Default::default(),
126        }
127    }
128
129    pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
130        self.coalesce = coalesce;
131        self
132    }
133
134    pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
135        self.suffix = suffix;
136        self
137    }
138
139    pub fn suffix(&self) -> &PlSmallStr {
140        const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
141        self.suffix.as_ref().unwrap_or(DEFAULT)
142    }
143}
144
145impl From<JoinType> for JoinArgs {
146    fn from(value: JoinType) -> Self {
147        JoinArgs::new(value)
148    }
149}
150
151pub trait CrossJoinFilter: Send + Sync {
152    fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
153}
154
155impl<T> CrossJoinFilter for T
156where
157    T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
158{
159    fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
160        self(df)
161    }
162}
163
164#[derive(Clone)]
165pub struct CrossJoinOptions {
166    pub predicate: Arc<dyn CrossJoinFilter>,
167}
168
169impl CrossJoinOptions {
170    fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
171        Arc::as_ptr(&self.predicate)
172    }
173}
174
175impl Eq for CrossJoinOptions {}
176
177impl PartialEq for CrossJoinOptions {
178    fn eq(&self, other: &Self) -> bool {
179        std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
180    }
181}
182
183impl Hash for CrossJoinOptions {
184    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
185        self.as_ptr_ref().hash(state);
186    }
187}
188
189impl Debug for CrossJoinOptions {
190    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
191        write!(f, "CrossJoinOptions",)
192    }
193}
194
195#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
196#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
197#[strum(serialize_all = "snake_case")]
198pub enum JoinTypeOptions {
199    #[cfg(feature = "iejoin")]
200    IEJoin(IEJoinOptions),
201    #[cfg_attr(feature = "serde", serde(skip))]
202    Cross(CrossJoinOptions),
203}
204
205impl Display for JoinType {
206    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
207        use JoinType::*;
208        let val = match self {
209            Left => "LEFT",
210            Right => "RIGHT",
211            Inner => "INNER",
212            Full => "FULL",
213            #[cfg(feature = "asof_join")]
214            AsOf(_) => "ASOF",
215            #[cfg(feature = "iejoin")]
216            IEJoin => "IEJOIN",
217            Cross => "CROSS",
218            #[cfg(feature = "semi_anti_join")]
219            Semi => "SEMI",
220            #[cfg(feature = "semi_anti_join")]
221            Anti => "ANTI",
222        };
223        write!(f, "{val}")
224    }
225}
226
227impl Debug for JoinType {
228    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
229        write!(f, "{self}")
230    }
231}
232
233impl JoinType {
234    pub fn is_equi(&self) -> bool {
235        matches!(
236            self,
237            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
238        )
239    }
240
241    pub fn is_semi_anti(&self) -> bool {
242        #[cfg(feature = "semi_anti_join")]
243        {
244            matches!(self, JoinType::Semi | JoinType::Anti)
245        }
246        #[cfg(not(feature = "semi_anti_join"))]
247        {
248            false
249        }
250    }
251
252    pub fn is_semi(&self) -> bool {
253        #[cfg(feature = "semi_anti_join")]
254        {
255            matches!(self, JoinType::Semi)
256        }
257        #[cfg(not(feature = "semi_anti_join"))]
258        {
259            false
260        }
261    }
262
263    pub fn is_anti(&self) -> bool {
264        #[cfg(feature = "semi_anti_join")]
265        {
266            matches!(self, JoinType::Anti)
267        }
268        #[cfg(not(feature = "semi_anti_join"))]
269        {
270            false
271        }
272    }
273
274    pub fn is_asof(&self) -> bool {
275        #[cfg(feature = "asof_join")]
276        {
277            matches!(self, JoinType::AsOf(_))
278        }
279        #[cfg(not(feature = "asof_join"))]
280        {
281            false
282        }
283    }
284
285    pub fn is_cross(&self) -> bool {
286        matches!(self, JoinType::Cross)
287    }
288
289    pub fn is_ie(&self) -> bool {
290        #[cfg(feature = "iejoin")]
291        {
292            matches!(self, JoinType::IEJoin)
293        }
294        #[cfg(not(feature = "iejoin"))]
295        {
296            false
297        }
298    }
299}
300
301#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
302#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
303pub enum JoinValidation {
304    /// No unique checks
305    #[default]
306    ManyToMany,
307    /// Check if join keys are unique in right dataset.
308    ManyToOne,
309    /// Check if join keys are unique in left dataset.
310    OneToMany,
311    /// Check if join keys are unique in both left and right datasets
312    OneToOne,
313}
314
315impl JoinValidation {
316    pub fn needs_checks(&self) -> bool {
317        !matches!(self, JoinValidation::ManyToMany)
318    }
319
320    fn swap(self, swap: bool) -> Self {
321        use JoinValidation::*;
322        if swap {
323            match self {
324                ManyToMany => ManyToMany,
325                ManyToOne => OneToMany,
326                OneToMany => ManyToOne,
327                OneToOne => OneToOne,
328            }
329        } else {
330            self
331        }
332    }
333
334    pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
335        if !self.needs_checks() {
336            return Ok(());
337        }
338        polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
339                      ComputeError: "{self} validation on a {join_type} join is not supported");
340        Ok(())
341    }
342
343    pub(super) fn validate_probe(
344        &self,
345        s_left: &Series,
346        s_right: &Series,
347        build_shortest_table: bool,
348        nulls_equal: bool,
349    ) -> PolarsResult<()> {
350        // In default, probe is the left series.
351        //
352        // In inner join and outer join, the shortest relation will be used to create a hash table.
353        // In left join, always use the right side to create.
354        //
355        // If `build_shortest_table` and left is shorter, swap. Then rhs will be the probe.
356        // If left == right, swap too. (apply the same logic as `det_hash_prone_order`)
357        let should_swap = build_shortest_table && s_left.len() <= s_right.len();
358        let probe = if should_swap { s_right } else { s_left };
359
360        use JoinValidation::*;
361        let valid = match self.swap(should_swap) {
362            // Only check the `build` side.
363            // The other side use `validate_build` to check
364            ManyToMany | ManyToOne => true,
365            OneToMany | OneToOne => {
366                if !nulls_equal && probe.null_count() > 0 {
367                    probe.n_unique()? - 1 == probe.len() - probe.null_count()
368                } else {
369                    probe.n_unique()? == probe.len()
370                }
371            },
372        };
373        polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
374        Ok(())
375    }
376
377    pub(super) fn validate_build(
378        &self,
379        build_size: usize,
380        expected_size: usize,
381        swapped: bool,
382    ) -> PolarsResult<()> {
383        use JoinValidation::*;
384
385        // In default, build is in rhs.
386        let valid = match self.swap(swapped) {
387            // Only check the `build` side.
388            // The other side use `validate_prone` to check
389            ManyToMany | OneToMany => true,
390            ManyToOne | OneToOne => build_size == expected_size,
391        };
392        polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
393        Ok(())
394    }
395}
396
397impl Display for JoinValidation {
398    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
399        let s = match self {
400            JoinValidation::ManyToMany => "m:m",
401            JoinValidation::ManyToOne => "m:1",
402            JoinValidation::OneToMany => "1:m",
403            JoinValidation::OneToOne => "1:1",
404        };
405        write!(f, "{s}")
406    }
407}
408
409impl Debug for JoinValidation {
410    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
411        write!(f, "JoinValidation: {self}")
412    }
413}