Skip to main content

polars_ops/frame/join/
args.rs

1use super::*;
2
3pub(super) type JoinIds = Vec<IdxSize>;
4pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5pub type InnerJoinIds = (JoinIds, JoinIds);
6
7#[cfg(feature = "chunked_ids")]
8pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9#[cfg(feature = "chunked_ids")]
10pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12#[cfg(not(feature = "chunked_ids"))]
13pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15#[cfg(not(feature = "chunked_ids"))]
16pub type ChunkJoinIds = Vec<IdxSize>;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use strum_macros::IntoStaticStr;
21
22/// Parameters for which side to use as the build side in a join. Currently only
23/// respected by the streaming engine.
24#[derive(Clone, PartialEq, Debug, Hash)]
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
27pub enum JoinBuildSide {
28    /// Unless there's a very good reason to believe that the right side is
29    /// smaller, use the left side.
30    PreferLeft,
31    /// Regardless of other heuristics, use the left side as build side.
32    ForceLeft,
33
34    // Similar to above.
35    PreferRight,
36    ForceRight,
37}
38
39#[derive(Clone, PartialEq, Debug, Hash, Default)]
40#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
41#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
42pub struct JoinArgs {
43    pub how: JoinType,
44    pub validation: JoinValidation,
45    pub suffix: Option<PlSmallStr>,
46    pub slice: Option<(i64, usize)>,
47    pub nulls_equal: bool,
48    pub coalesce: JoinCoalesce,
49    pub maintain_order: MaintainOrderJoin,
50    pub build_side: Option<JoinBuildSide>,
51}
52
53impl JoinArgs {
54    pub fn should_coalesce(&self) -> bool {
55        self.coalesce.coalesce(&self.how)
56    }
57}
58
59#[derive(Clone, PartialEq, Hash, Default, IntoStaticStr)]
60#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
61#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
62pub enum JoinType {
63    #[default]
64    Inner,
65    Left,
66    Right,
67    Full,
68    // Box is okay because this is inside a `Arc<JoinOptionsIR>`
69    #[cfg(feature = "asof_join")]
70    AsOf(Box<AsOfOptions>),
71    #[cfg(feature = "semi_anti_join")]
72    Semi,
73    #[cfg(feature = "semi_anti_join")]
74    Anti,
75    #[cfg(feature = "iejoin")]
76    // Options are set by optimizer/planner in Options
77    IEJoin,
78    // Options are set by optimizer/planner in Options
79    Cross,
80}
81
82#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
83#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
84#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
85pub enum JoinCoalesce {
86    #[default]
87    JoinSpecific,
88    CoalesceColumns,
89    KeepColumns,
90}
91
92impl JoinCoalesce {
93    pub fn coalesce(&self, join_type: &JoinType) -> bool {
94        use JoinCoalesce::*;
95        use JoinType::*;
96        match join_type {
97            Left | Inner | Right => {
98                matches!(self, JoinSpecific | CoalesceColumns)
99            },
100            Full => {
101                matches!(self, CoalesceColumns)
102            },
103            #[cfg(feature = "asof_join")]
104            AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
105            #[cfg(feature = "iejoin")]
106            IEJoin => false,
107            Cross => false,
108            #[cfg(feature = "semi_anti_join")]
109            Semi | Anti => false,
110        }
111    }
112}
113
114#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
115#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
116#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
117#[strum(serialize_all = "snake_case")]
118pub enum MaintainOrderJoin {
119    #[default]
120    None,
121    Left,
122    Right,
123    LeftRight,
124    RightLeft,
125}
126
127impl MaintainOrderJoin {
128    pub(super) fn flip(&self) -> Self {
129        match self {
130            MaintainOrderJoin::None => MaintainOrderJoin::None,
131            MaintainOrderJoin::Left => MaintainOrderJoin::Right,
132            MaintainOrderJoin::Right => MaintainOrderJoin::Left,
133            MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
134            MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
135        }
136    }
137}
138
139impl JoinArgs {
140    pub fn new(how: JoinType) -> Self {
141        Self {
142            how,
143            validation: Default::default(),
144            suffix: None,
145            slice: None,
146            nulls_equal: false,
147            coalesce: Default::default(),
148            maintain_order: Default::default(),
149            build_side: None,
150        }
151    }
152
153    pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
154        self.coalesce = coalesce;
155        self
156    }
157
158    pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
159        self.suffix = suffix;
160        self
161    }
162
163    pub fn with_build_side(mut self, build_side: Option<JoinBuildSide>) -> Self {
164        self.build_side = build_side;
165        self
166    }
167
168    pub fn suffix(&self) -> &PlSmallStr {
169        const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
170        self.suffix.as_ref().unwrap_or(DEFAULT)
171    }
172}
173
174impl From<JoinType> for JoinArgs {
175    fn from(value: JoinType) -> Self {
176        JoinArgs::new(value)
177    }
178}
179
180pub trait CrossJoinFilter: Send + Sync {
181    fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
182}
183
184impl<T> CrossJoinFilter for T
185where
186    T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
187{
188    fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
189        self(df)
190    }
191}
192
193#[derive(Clone)]
194pub struct CrossJoinOptions {
195    pub predicate: Arc<dyn CrossJoinFilter>,
196}
197
198impl CrossJoinOptions {
199    fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
200        Arc::as_ptr(&self.predicate)
201    }
202}
203
204impl Eq for CrossJoinOptions {}
205
206impl PartialEq for CrossJoinOptions {
207    fn eq(&self, other: &Self) -> bool {
208        std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
209    }
210}
211
212impl Hash for CrossJoinOptions {
213    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
214        self.as_ptr_ref().hash(state);
215    }
216}
217
218impl Debug for CrossJoinOptions {
219    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
220        write!(f, "CrossJoinOptions",)
221    }
222}
223
224#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
225#[strum(serialize_all = "snake_case")]
226pub enum JoinTypeOptions {
227    #[cfg(feature = "iejoin")]
228    IEJoin(IEJoinOptions),
229    Cross(CrossJoinOptions),
230}
231
232impl Display for JoinType {
233    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
234        use JoinType::*;
235        let val = match self {
236            Left => "LEFT",
237            Right => "RIGHT",
238            Inner => "INNER",
239            Full => "FULL",
240            #[cfg(feature = "asof_join")]
241            AsOf(_) => "ASOF",
242            #[cfg(feature = "iejoin")]
243            IEJoin => "IEJOIN",
244            Cross => "CROSS",
245            #[cfg(feature = "semi_anti_join")]
246            Semi => "SEMI",
247            #[cfg(feature = "semi_anti_join")]
248            Anti => "ANTI",
249        };
250        write!(f, "{val}")
251    }
252}
253
254impl Debug for JoinType {
255    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256        write!(f, "{self}")
257    }
258}
259
260impl JoinType {
261    pub fn is_equi(&self) -> bool {
262        matches!(
263            self,
264            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
265        )
266    }
267
268    pub fn is_semi_anti(&self) -> bool {
269        #[cfg(feature = "semi_anti_join")]
270        {
271            matches!(self, JoinType::Semi | JoinType::Anti)
272        }
273        #[cfg(not(feature = "semi_anti_join"))]
274        {
275            false
276        }
277    }
278
279    pub fn is_semi(&self) -> bool {
280        #[cfg(feature = "semi_anti_join")]
281        {
282            matches!(self, JoinType::Semi)
283        }
284        #[cfg(not(feature = "semi_anti_join"))]
285        {
286            false
287        }
288    }
289
290    pub fn is_anti(&self) -> bool {
291        #[cfg(feature = "semi_anti_join")]
292        {
293            matches!(self, JoinType::Anti)
294        }
295        #[cfg(not(feature = "semi_anti_join"))]
296        {
297            false
298        }
299    }
300
301    pub fn is_asof(&self) -> bool {
302        #[cfg(feature = "asof_join")]
303        {
304            matches!(self, JoinType::AsOf(_))
305        }
306        #[cfg(not(feature = "asof_join"))]
307        {
308            false
309        }
310    }
311
312    pub fn is_cross(&self) -> bool {
313        matches!(self, JoinType::Cross)
314    }
315
316    pub fn is_ie(&self) -> bool {
317        #[cfg(feature = "iejoin")]
318        {
319            matches!(self, JoinType::IEJoin)
320        }
321        #[cfg(not(feature = "iejoin"))]
322        {
323            false
324        }
325    }
326}
327
328#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
329#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
330#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
331pub enum JoinValidation {
332    /// No unique checks
333    #[default]
334    ManyToMany,
335    /// Check if join keys are unique in right dataset.
336    ManyToOne,
337    /// Check if join keys are unique in left dataset.
338    OneToMany,
339    /// Check if join keys are unique in both left and right datasets
340    OneToOne,
341}
342
343impl JoinValidation {
344    pub fn needs_checks(&self) -> bool {
345        !matches!(self, JoinValidation::ManyToMany)
346    }
347
348    fn swap(self, swap: bool) -> Self {
349        use JoinValidation::*;
350        if swap {
351            match self {
352                ManyToMany => ManyToMany,
353                ManyToOne => OneToMany,
354                OneToMany => ManyToOne,
355                OneToOne => OneToOne,
356            }
357        } else {
358            self
359        }
360    }
361
362    pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
363        if !self.needs_checks() {
364            return Ok(());
365        }
366        polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
367                      ComputeError: "{self} validation on a {join_type} join is not supported");
368        Ok(())
369    }
370
371    pub(super) fn validate_probe(
372        &self,
373        s_left: &Series,
374        s_right: &Series,
375        build_shortest_table: bool,
376        nulls_equal: bool,
377    ) -> PolarsResult<()> {
378        // In default, probe is the left series.
379        //
380        // In inner join and outer join, the shortest relation will be used to create a hash table.
381        // In left join, always use the right side to create.
382        //
383        // If `build_shortest_table` and left is shorter, swap. Then rhs will be the probe.
384        // If left == right, swap too. (apply the same logic as `det_hash_prone_order`)
385        let should_swap = build_shortest_table && s_left.len() <= s_right.len();
386        let probe = if should_swap { s_right } else { s_left };
387
388        use JoinValidation::*;
389        let valid = match self.swap(should_swap) {
390            // Only check the `build` side.
391            // The other side use `validate_build` to check
392            ManyToMany | ManyToOne => true,
393            OneToMany | OneToOne => {
394                if !nulls_equal && probe.null_count() > 0 {
395                    probe.n_unique()? - 1 == probe.len() - probe.null_count()
396                } else {
397                    probe.n_unique()? == probe.len()
398                }
399            },
400        };
401        polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
402        Ok(())
403    }
404
405    pub(super) fn validate_build(
406        &self,
407        build_size: usize,
408        expected_size: usize,
409        swapped: bool,
410    ) -> PolarsResult<()> {
411        use JoinValidation::*;
412
413        // In default, build is in rhs.
414        let valid = match self.swap(swapped) {
415            // Only check the `build` side.
416            // The other side use `validate_prone` to check
417            ManyToMany | OneToMany => true,
418            ManyToOne | OneToOne => build_size == expected_size,
419        };
420        polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
421        Ok(())
422    }
423}
424
425impl Display for JoinValidation {
426    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
427        let s = match self {
428            JoinValidation::ManyToMany => "m:m",
429            JoinValidation::ManyToOne => "m:1",
430            JoinValidation::OneToMany => "1:m",
431            JoinValidation::OneToOne => "1:1",
432        };
433        write!(f, "{s}")
434    }
435}
436
437impl Debug for JoinValidation {
438    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
439        write!(f, "JoinValidation: {self}")
440    }
441}