1use super::*;
2
3pub(super) type JoinIds = Vec<IdxSize>;
4pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5pub type InnerJoinIds = (JoinIds, JoinIds);
6
7#[cfg(feature = "chunked_ids")]
8pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9#[cfg(feature = "chunked_ids")]
10pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12#[cfg(not(feature = "chunked_ids"))]
13pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15#[cfg(not(feature = "chunked_ids"))]
16pub type ChunkJoinIds = Vec<IdxSize>;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use strum_macros::IntoStaticStr;
21
22#[derive(Clone, PartialEq, Debug, Hash)]
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
27pub enum JoinBuildSide {
28 PreferLeft,
31 ForceLeft,
33
34 PreferRight,
36 ForceRight,
37}
38
39#[derive(Clone, PartialEq, Debug, Hash, Default)]
40#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
41#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
42pub struct JoinArgs {
43 pub how: JoinType,
44 pub validation: JoinValidation,
45 pub suffix: Option<PlSmallStr>,
46 pub slice: Option<(i64, usize)>,
47 pub nulls_equal: bool,
48 pub coalesce: JoinCoalesce,
49 pub maintain_order: MaintainOrderJoin,
50 pub build_side: Option<JoinBuildSide>,
51}
52
53impl JoinArgs {
54 pub fn should_coalesce(&self) -> bool {
55 self.coalesce.coalesce(&self.how)
56 }
57}
58
59#[derive(Clone, PartialEq, Hash, Default, IntoStaticStr)]
60#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
61#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
62pub enum JoinType {
63 #[default]
64 Inner,
65 Left,
66 Right,
67 Full,
68 #[cfg(feature = "asof_join")]
70 AsOf(Box<AsOfOptions>),
71 #[cfg(feature = "semi_anti_join")]
72 Semi,
73 #[cfg(feature = "semi_anti_join")]
74 Anti,
75 #[cfg(feature = "iejoin")]
76 IEJoin,
78 Cross,
80}
81
82#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
83#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
84#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
85pub enum JoinCoalesce {
86 #[default]
87 JoinSpecific,
88 CoalesceColumns,
89 KeepColumns,
90}
91
92impl JoinCoalesce {
93 pub fn coalesce(&self, join_type: &JoinType) -> bool {
94 use JoinCoalesce::*;
95 use JoinType::*;
96 match join_type {
97 Left | Inner | Right => {
98 matches!(self, JoinSpecific | CoalesceColumns)
99 },
100 Full => {
101 matches!(self, CoalesceColumns)
102 },
103 #[cfg(feature = "asof_join")]
104 AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
105 #[cfg(feature = "iejoin")]
106 IEJoin => false,
107 Cross => false,
108 #[cfg(feature = "semi_anti_join")]
109 Semi | Anti => false,
110 }
111 }
112}
113
114#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
115#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
116#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
117#[strum(serialize_all = "snake_case")]
118pub enum MaintainOrderJoin {
119 #[default]
120 None,
121 Left,
122 Right,
123 LeftRight,
124 RightLeft,
125}
126
127impl MaintainOrderJoin {
128 pub(super) fn flip(&self) -> Self {
129 match self {
130 MaintainOrderJoin::None => MaintainOrderJoin::None,
131 MaintainOrderJoin::Left => MaintainOrderJoin::Right,
132 MaintainOrderJoin::Right => MaintainOrderJoin::Left,
133 MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
134 MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
135 }
136 }
137}
138
139impl JoinArgs {
140 pub fn new(how: JoinType) -> Self {
141 Self {
142 how,
143 validation: Default::default(),
144 suffix: None,
145 slice: None,
146 nulls_equal: false,
147 coalesce: Default::default(),
148 maintain_order: Default::default(),
149 build_side: None,
150 }
151 }
152
153 pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
154 self.coalesce = coalesce;
155 self
156 }
157
158 pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
159 self.suffix = suffix;
160 self
161 }
162
163 pub fn with_build_side(mut self, build_side: Option<JoinBuildSide>) -> Self {
164 self.build_side = build_side;
165 self
166 }
167
168 pub fn suffix(&self) -> &PlSmallStr {
169 const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
170 self.suffix.as_ref().unwrap_or(DEFAULT)
171 }
172}
173
174impl From<JoinType> for JoinArgs {
175 fn from(value: JoinType) -> Self {
176 JoinArgs::new(value)
177 }
178}
179
180pub trait CrossJoinFilter: Send + Sync {
181 fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
182}
183
184impl<T> CrossJoinFilter for T
185where
186 T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
187{
188 fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
189 self(df)
190 }
191}
192
193#[derive(Clone)]
194pub struct CrossJoinOptions {
195 pub predicate: Arc<dyn CrossJoinFilter>,
196}
197
198impl CrossJoinOptions {
199 fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
200 Arc::as_ptr(&self.predicate)
201 }
202}
203
204impl Eq for CrossJoinOptions {}
205
206impl PartialEq for CrossJoinOptions {
207 fn eq(&self, other: &Self) -> bool {
208 std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
209 }
210}
211
212impl Hash for CrossJoinOptions {
213 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
214 self.as_ptr_ref().hash(state);
215 }
216}
217
218impl Debug for CrossJoinOptions {
219 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
220 write!(f, "CrossJoinOptions",)
221 }
222}
223
224#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
225#[strum(serialize_all = "snake_case")]
226pub enum JoinTypeOptions {
227 #[cfg(feature = "iejoin")]
228 IEJoin(IEJoinOptions),
229 Cross(CrossJoinOptions),
230}
231
232impl Display for JoinType {
233 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
234 use JoinType::*;
235 let val = match self {
236 Left => "LEFT",
237 Right => "RIGHT",
238 Inner => "INNER",
239 Full => "FULL",
240 #[cfg(feature = "asof_join")]
241 AsOf(_) => "ASOF",
242 #[cfg(feature = "iejoin")]
243 IEJoin => "IEJOIN",
244 Cross => "CROSS",
245 #[cfg(feature = "semi_anti_join")]
246 Semi => "SEMI",
247 #[cfg(feature = "semi_anti_join")]
248 Anti => "ANTI",
249 };
250 write!(f, "{val}")
251 }
252}
253
254impl Debug for JoinType {
255 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256 write!(f, "{self}")
257 }
258}
259
260impl JoinType {
261 pub fn is_equi(&self) -> bool {
262 matches!(
263 self,
264 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
265 )
266 }
267
268 pub fn is_semi_anti(&self) -> bool {
269 #[cfg(feature = "semi_anti_join")]
270 {
271 matches!(self, JoinType::Semi | JoinType::Anti)
272 }
273 #[cfg(not(feature = "semi_anti_join"))]
274 {
275 false
276 }
277 }
278
279 pub fn is_semi(&self) -> bool {
280 #[cfg(feature = "semi_anti_join")]
281 {
282 matches!(self, JoinType::Semi)
283 }
284 #[cfg(not(feature = "semi_anti_join"))]
285 {
286 false
287 }
288 }
289
290 pub fn is_anti(&self) -> bool {
291 #[cfg(feature = "semi_anti_join")]
292 {
293 matches!(self, JoinType::Anti)
294 }
295 #[cfg(not(feature = "semi_anti_join"))]
296 {
297 false
298 }
299 }
300
301 pub fn is_asof(&self) -> bool {
302 #[cfg(feature = "asof_join")]
303 {
304 matches!(self, JoinType::AsOf(_))
305 }
306 #[cfg(not(feature = "asof_join"))]
307 {
308 false
309 }
310 }
311
312 pub fn is_cross(&self) -> bool {
313 matches!(self, JoinType::Cross)
314 }
315
316 pub fn is_ie(&self) -> bool {
317 #[cfg(feature = "iejoin")]
318 {
319 matches!(self, JoinType::IEJoin)
320 }
321 #[cfg(not(feature = "iejoin"))]
322 {
323 false
324 }
325 }
326}
327
328#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
329#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
330#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
331pub enum JoinValidation {
332 #[default]
334 ManyToMany,
335 ManyToOne,
337 OneToMany,
339 OneToOne,
341}
342
343impl JoinValidation {
344 pub fn needs_checks(&self) -> bool {
345 !matches!(self, JoinValidation::ManyToMany)
346 }
347
348 fn swap(self, swap: bool) -> Self {
349 use JoinValidation::*;
350 if swap {
351 match self {
352 ManyToMany => ManyToMany,
353 ManyToOne => OneToMany,
354 OneToMany => ManyToOne,
355 OneToOne => OneToOne,
356 }
357 } else {
358 self
359 }
360 }
361
362 pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
363 if !self.needs_checks() {
364 return Ok(());
365 }
366 polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
367 ComputeError: "{self} validation on a {join_type} join is not supported");
368 Ok(())
369 }
370
371 pub(super) fn validate_probe(
372 &self,
373 s_left: &Series,
374 s_right: &Series,
375 build_shortest_table: bool,
376 nulls_equal: bool,
377 ) -> PolarsResult<()> {
378 let should_swap = build_shortest_table && s_left.len() <= s_right.len();
386 let probe = if should_swap { s_right } else { s_left };
387
388 use JoinValidation::*;
389 let valid = match self.swap(should_swap) {
390 ManyToMany | ManyToOne => true,
393 OneToMany | OneToOne => {
394 if !nulls_equal && probe.null_count() > 0 {
395 probe.n_unique()? - 1 == probe.len() - probe.null_count()
396 } else {
397 probe.n_unique()? == probe.len()
398 }
399 },
400 };
401 polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
402 Ok(())
403 }
404
405 pub(super) fn validate_build(
406 &self,
407 build_size: usize,
408 expected_size: usize,
409 swapped: bool,
410 ) -> PolarsResult<()> {
411 use JoinValidation::*;
412
413 let valid = match self.swap(swapped) {
415 ManyToMany | OneToMany => true,
418 ManyToOne | OneToOne => build_size == expected_size,
419 };
420 polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
421 Ok(())
422 }
423}
424
425impl Display for JoinValidation {
426 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
427 let s = match self {
428 JoinValidation::ManyToMany => "m:m",
429 JoinValidation::ManyToOne => "m:1",
430 JoinValidation::OneToMany => "1:m",
431 JoinValidation::OneToOne => "1:1",
432 };
433 write!(f, "{s}")
434 }
435}
436
437impl Debug for JoinValidation {
438 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
439 write!(f, "JoinValidation: {self}")
440 }
441}