1use super::*;
2
3pub(super) type JoinIds = Vec<IdxSize>;
4pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5pub type InnerJoinIds = (JoinIds, JoinIds);
6
7#[cfg(feature = "chunked_ids")]
8pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9#[cfg(feature = "chunked_ids")]
10pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12#[cfg(not(feature = "chunked_ids"))]
13pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15#[cfg(not(feature = "chunked_ids"))]
16pub type ChunkJoinIds = Vec<IdxSize>;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use strum_macros::IntoStaticStr;
21
22#[derive(Clone, PartialEq, Eq, Debug, Hash, Default)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24pub struct JoinArgs {
25 pub how: JoinType,
26 pub validation: JoinValidation,
27 pub suffix: Option<PlSmallStr>,
28 pub slice: Option<(i64, usize)>,
29 pub nulls_equal: bool,
30 pub coalesce: JoinCoalesce,
31 pub maintain_order: MaintainOrderJoin,
32}
33
34impl JoinArgs {
35 pub fn should_coalesce(&self) -> bool {
36 self.coalesce.coalesce(&self.how)
37 }
38}
39
40#[derive(Clone, PartialEq, Eq, Hash, Default, IntoStaticStr)]
41#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
42pub enum JoinType {
43 #[default]
44 Inner,
45 Left,
46 Right,
47 Full,
48 #[cfg(feature = "asof_join")]
49 AsOf(AsOfOptions),
50 #[cfg(feature = "semi_anti_join")]
51 Semi,
52 #[cfg(feature = "semi_anti_join")]
53 Anti,
54 #[cfg(feature = "iejoin")]
55 IEJoin,
57 Cross,
59}
60
61#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
62#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
63pub enum JoinCoalesce {
64 #[default]
65 JoinSpecific,
66 CoalesceColumns,
67 KeepColumns,
68}
69
70impl JoinCoalesce {
71 pub fn coalesce(&self, join_type: &JoinType) -> bool {
72 use JoinCoalesce::*;
73 use JoinType::*;
74 match join_type {
75 Left | Inner | Right => {
76 matches!(self, JoinSpecific | CoalesceColumns)
77 },
78 Full => {
79 matches!(self, CoalesceColumns)
80 },
81 #[cfg(feature = "asof_join")]
82 AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
83 #[cfg(feature = "iejoin")]
84 IEJoin => false,
85 Cross => false,
86 #[cfg(feature = "semi_anti_join")]
87 Semi | Anti => false,
88 }
89 }
90}
91
92#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
93#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
94#[strum(serialize_all = "snake_case")]
95pub enum MaintainOrderJoin {
96 #[default]
97 None,
98 Left,
99 Right,
100 LeftRight,
101 RightLeft,
102}
103
104impl MaintainOrderJoin {
105 pub(super) fn flip(&self) -> Self {
106 match self {
107 MaintainOrderJoin::None => MaintainOrderJoin::None,
108 MaintainOrderJoin::Left => MaintainOrderJoin::Right,
109 MaintainOrderJoin::Right => MaintainOrderJoin::Left,
110 MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
111 MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
112 }
113 }
114}
115
116impl JoinArgs {
117 pub fn new(how: JoinType) -> Self {
118 Self {
119 how,
120 validation: Default::default(),
121 suffix: None,
122 slice: None,
123 nulls_equal: false,
124 coalesce: Default::default(),
125 maintain_order: Default::default(),
126 }
127 }
128
129 pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
130 self.coalesce = coalesce;
131 self
132 }
133
134 pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
135 self.suffix = suffix;
136 self
137 }
138
139 pub fn suffix(&self) -> &PlSmallStr {
140 const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
141 self.suffix.as_ref().unwrap_or(DEFAULT)
142 }
143}
144
145impl From<JoinType> for JoinArgs {
146 fn from(value: JoinType) -> Self {
147 JoinArgs::new(value)
148 }
149}
150
151pub trait CrossJoinFilter: Send + Sync {
152 fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
153}
154
155impl<T> CrossJoinFilter for T
156where
157 T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
158{
159 fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
160 self(df)
161 }
162}
163
164#[derive(Clone)]
165pub struct CrossJoinOptions {
166 pub predicate: Arc<dyn CrossJoinFilter>,
167}
168
169impl CrossJoinOptions {
170 fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
171 Arc::as_ptr(&self.predicate)
172 }
173}
174
175impl Eq for CrossJoinOptions {}
176
177impl PartialEq for CrossJoinOptions {
178 fn eq(&self, other: &Self) -> bool {
179 std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
180 }
181}
182
183impl Hash for CrossJoinOptions {
184 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
185 self.as_ptr_ref().hash(state);
186 }
187}
188
189impl Debug for CrossJoinOptions {
190 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
191 write!(f, "CrossJoinOptions",)
192 }
193}
194
195#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
196#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
197#[strum(serialize_all = "snake_case")]
198pub enum JoinTypeOptions {
199 #[cfg(feature = "iejoin")]
200 IEJoin(IEJoinOptions),
201 #[cfg_attr(feature = "serde", serde(skip))]
202 Cross(CrossJoinOptions),
203}
204
205impl Display for JoinType {
206 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
207 use JoinType::*;
208 let val = match self {
209 Left => "LEFT",
210 Right => "RIGHT",
211 Inner => "INNER",
212 Full => "FULL",
213 #[cfg(feature = "asof_join")]
214 AsOf(_) => "ASOF",
215 #[cfg(feature = "iejoin")]
216 IEJoin => "IEJOIN",
217 Cross => "CROSS",
218 #[cfg(feature = "semi_anti_join")]
219 Semi => "SEMI",
220 #[cfg(feature = "semi_anti_join")]
221 Anti => "ANTI",
222 };
223 write!(f, "{val}")
224 }
225}
226
227impl Debug for JoinType {
228 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
229 write!(f, "{self}")
230 }
231}
232
233impl JoinType {
234 pub fn is_equi(&self) -> bool {
235 matches!(
236 self,
237 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
238 )
239 }
240
241 pub fn is_semi_anti(&self) -> bool {
242 #[cfg(feature = "semi_anti_join")]
243 {
244 matches!(self, JoinType::Semi | JoinType::Anti)
245 }
246 #[cfg(not(feature = "semi_anti_join"))]
247 {
248 false
249 }
250 }
251
252 pub fn is_semi(&self) -> bool {
253 #[cfg(feature = "semi_anti_join")]
254 {
255 matches!(self, JoinType::Semi)
256 }
257 #[cfg(not(feature = "semi_anti_join"))]
258 {
259 false
260 }
261 }
262
263 pub fn is_anti(&self) -> bool {
264 #[cfg(feature = "semi_anti_join")]
265 {
266 matches!(self, JoinType::Anti)
267 }
268 #[cfg(not(feature = "semi_anti_join"))]
269 {
270 false
271 }
272 }
273
274 pub fn is_asof(&self) -> bool {
275 #[cfg(feature = "asof_join")]
276 {
277 matches!(self, JoinType::AsOf(_))
278 }
279 #[cfg(not(feature = "asof_join"))]
280 {
281 false
282 }
283 }
284
285 pub fn is_cross(&self) -> bool {
286 matches!(self, JoinType::Cross)
287 }
288
289 pub fn is_ie(&self) -> bool {
290 #[cfg(feature = "iejoin")]
291 {
292 matches!(self, JoinType::IEJoin)
293 }
294 #[cfg(not(feature = "iejoin"))]
295 {
296 false
297 }
298 }
299}
300
301#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
302#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
303pub enum JoinValidation {
304 #[default]
306 ManyToMany,
307 ManyToOne,
309 OneToMany,
311 OneToOne,
313}
314
315impl JoinValidation {
316 pub fn needs_checks(&self) -> bool {
317 !matches!(self, JoinValidation::ManyToMany)
318 }
319
320 fn swap(self, swap: bool) -> Self {
321 use JoinValidation::*;
322 if swap {
323 match self {
324 ManyToMany => ManyToMany,
325 ManyToOne => OneToMany,
326 OneToMany => ManyToOne,
327 OneToOne => OneToOne,
328 }
329 } else {
330 self
331 }
332 }
333
334 pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
335 if !self.needs_checks() {
336 return Ok(());
337 }
338 polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
339 ComputeError: "{self} validation on a {join_type} join is not supported");
340 Ok(())
341 }
342
343 pub(super) fn validate_probe(
344 &self,
345 s_left: &Series,
346 s_right: &Series,
347 build_shortest_table: bool,
348 nulls_equal: bool,
349 ) -> PolarsResult<()> {
350 let should_swap = build_shortest_table && s_left.len() <= s_right.len();
358 let probe = if should_swap { s_right } else { s_left };
359
360 use JoinValidation::*;
361 let valid = match self.swap(should_swap) {
362 ManyToMany | ManyToOne => true,
365 OneToMany | OneToOne => {
366 if !nulls_equal && probe.null_count() > 0 {
367 probe.n_unique()? - 1 == probe.len() - probe.null_count()
368 } else {
369 probe.n_unique()? == probe.len()
370 }
371 },
372 };
373 polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
374 Ok(())
375 }
376
377 pub(super) fn validate_build(
378 &self,
379 build_size: usize,
380 expected_size: usize,
381 swapped: bool,
382 ) -> PolarsResult<()> {
383 use JoinValidation::*;
384
385 let valid = match self.swap(swapped) {
387 ManyToMany | OneToMany => true,
390 ManyToOne | OneToOne => build_size == expected_size,
391 };
392 polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
393 Ok(())
394 }
395}
396
397impl Display for JoinValidation {
398 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
399 let s = match self {
400 JoinValidation::ManyToMany => "m:m",
401 JoinValidation::ManyToOne => "m:1",
402 JoinValidation::OneToMany => "1:m",
403 JoinValidation::OneToOne => "1:1",
404 };
405 write!(f, "{s}")
406 }
407}
408
409impl Debug for JoinValidation {
410 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
411 write!(f, "JoinValidation: {self}")
412 }
413}