datafusion_functions_window/
rank.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Implementation of `rank`, `dense_rank`, and `percent_rank` window functions,
19//! which can be evaluated at runtime during query execution.
20
21use crate::define_udwf_and_expr;
22use arrow::datatypes::FieldRef;
23use datafusion_common::arrow::array::ArrayRef;
24use datafusion_common::arrow::array::{Float64Array, UInt64Array};
25use datafusion_common::arrow::compute::SortOptions;
26use datafusion_common::arrow::datatypes::DataType;
27use datafusion_common::arrow::datatypes::Field;
28use datafusion_common::utils::get_row_at_idx;
29use datafusion_common::{exec_err, Result, ScalarValue};
30use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING;
31use datafusion_expr::{
32    Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl,
33};
34use datafusion_functions_window_common::field;
35use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
36use field::WindowUDFFieldArgs;
37use std::any::Any;
38use std::fmt::Debug;
39use std::hash::{DefaultHasher, Hash, Hasher};
40use std::iter;
41use std::ops::Range;
42use std::sync::{Arc, LazyLock};
43
44define_udwf_and_expr!(
45    Rank,
46    rank,
47    "Returns rank of the current row with gaps. Same as `row_number` of its first peer",
48    Rank::basic
49);
50
51define_udwf_and_expr!(
52    DenseRank,
53    dense_rank,
54    "Returns rank of the current row without gaps. This function counts peer groups",
55    Rank::dense_rank
56);
57
58define_udwf_and_expr!(
59    PercentRank,
60    percent_rank,
61    "Returns the relative rank of the current row: (rank - 1) / (total rows - 1)",
62    Rank::percent_rank
63);
64
65/// Rank calculates the rank in the window function with order by
66#[derive(Debug)]
67pub struct Rank {
68    name: String,
69    signature: Signature,
70    rank_type: RankType,
71}
72
73impl Rank {
74    /// Create a new `rank` function with the specified name and rank type
75    pub fn new(name: String, rank_type: RankType) -> Self {
76        Self {
77            name,
78            signature: Signature::nullary(Volatility::Immutable),
79            rank_type,
80        }
81    }
82
83    /// Create a `rank` window function
84    pub fn basic() -> Self {
85        Rank::new("rank".to_string(), RankType::Basic)
86    }
87
88    /// Create a `dense_rank` window function
89    pub fn dense_rank() -> Self {
90        Rank::new("dense_rank".to_string(), RankType::Dense)
91    }
92
93    /// Create a `percent_rank` window function
94    pub fn percent_rank() -> Self {
95        Rank::new("percent_rank".to_string(), RankType::Percent)
96    }
97}
98
99#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
100pub enum RankType {
101    Basic,
102    Dense,
103    Percent,
104}
105
106static RANK_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
107    Documentation::builder(
108        DOC_SECTION_RANKING,
109            "Returns the rank of the current row within its partition, allowing \
110            gaps between ranks. This function provides a ranking similar to `row_number`, but \
111            skips ranks for identical values.",
112
113        "rank()")
114        .with_sql_example(r#"```sql
115    --Example usage of the rank window function:
116    SELECT department,
117           salary,
118           rank() OVER (PARTITION BY department ORDER BY salary DESC) AS rank
119    FROM employees;
120```
121
122```sql
123+-------------+--------+------+
124| department  | salary | rank |
125+-------------+--------+------+
126| Sales       | 70000  | 1    |
127| Sales       | 50000  | 2    |
128| Sales       | 50000  | 2    |
129| Sales       | 30000  | 4    |
130| Engineering | 90000  | 1    |
131| Engineering | 80000  | 2    |
132+-------------+--------+------+
133```"#)
134        .build()
135});
136
137fn get_rank_doc() -> &'static Documentation {
138    &RANK_DOCUMENTATION
139}
140
141static DENSE_RANK_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
142    Documentation::builder(DOC_SECTION_RANKING, "Returns the rank of the current row without gaps. This function ranks \
143            rows in a dense manner, meaning consecutive ranks are assigned even for identical \
144            values.", "dense_rank()")
145        .with_sql_example(r#"```sql
146    --Example usage of the dense_rank window function:
147    SELECT department,
148           salary,
149           dense_rank() OVER (PARTITION BY department ORDER BY salary DESC) AS dense_rank
150    FROM employees;
151```
152
153```sql
154+-------------+--------+------------+
155| department  | salary | dense_rank |
156+-------------+--------+------------+
157| Sales       | 70000  | 1          |
158| Sales       | 50000  | 2          |
159| Sales       | 50000  | 2          |
160| Sales       | 30000  | 3          |
161| Engineering | 90000  | 1          |
162| Engineering | 80000  | 2          |
163+-------------+--------+------------+
164```"#)
165        .build()
166});
167
168fn get_dense_rank_doc() -> &'static Documentation {
169    &DENSE_RANK_DOCUMENTATION
170}
171
172static PERCENT_RANK_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
173    Documentation::builder(DOC_SECTION_RANKING, "Returns the percentage rank of the current row within its partition. \
174            The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`.", "percent_rank()")
175        .with_sql_example(r#"```sql
176    --Example usage of the percent_rank window function:
177    SELECT employee_id,
178           salary,
179           percent_rank() OVER (ORDER BY salary) AS percent_rank
180    FROM employees;
181```
182
183```sql
184+-------------+--------+---------------+
185| employee_id | salary | percent_rank  |
186+-------------+--------+---------------+
187| 1           | 30000  | 0.00          |
188| 2           | 50000  | 0.50          |
189| 3           | 70000  | 1.00          |
190+-------------+--------+---------------+
191```"#)
192        .build()
193});
194
195fn get_percent_rank_doc() -> &'static Documentation {
196    &PERCENT_RANK_DOCUMENTATION
197}
198
199impl WindowUDFImpl for Rank {
200    fn as_any(&self) -> &dyn Any {
201        self
202    }
203
204    fn name(&self) -> &str {
205        &self.name
206    }
207
208    fn signature(&self) -> &Signature {
209        &self.signature
210    }
211
212    fn partition_evaluator(
213        &self,
214        _partition_evaluator_args: PartitionEvaluatorArgs,
215    ) -> Result<Box<dyn PartitionEvaluator>> {
216        Ok(Box::new(RankEvaluator {
217            state: RankState::default(),
218            rank_type: self.rank_type,
219        }))
220    }
221
222    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
223        let return_type = match self.rank_type {
224            RankType::Basic | RankType::Dense => DataType::UInt64,
225            RankType::Percent => DataType::Float64,
226        };
227
228        let nullable = false;
229        Ok(Field::new(field_args.name(), return_type, nullable).into())
230    }
231
232    fn sort_options(&self) -> Option<SortOptions> {
233        Some(SortOptions {
234            descending: false,
235            nulls_first: false,
236        })
237    }
238
239    fn documentation(&self) -> Option<&Documentation> {
240        match self.rank_type {
241            RankType::Basic => Some(get_rank_doc()),
242            RankType::Dense => Some(get_dense_rank_doc()),
243            RankType::Percent => Some(get_percent_rank_doc()),
244        }
245    }
246
247    fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
248        let Some(other) = other.as_any().downcast_ref::<Self>() else {
249            return false;
250        };
251        let Self {
252            name,
253            signature,
254            rank_type,
255        } = self;
256        name == &other.name
257            && signature == &other.signature
258            && rank_type == &other.rank_type
259    }
260
261    fn hash_value(&self) -> u64 {
262        let Self {
263            name,
264            signature,
265            rank_type,
266        } = self;
267        let mut hasher = DefaultHasher::new();
268        std::any::type_name::<Self>().hash(&mut hasher);
269        name.hash(&mut hasher);
270        signature.hash(&mut hasher);
271        rank_type.hash(&mut hasher);
272        hasher.finish()
273    }
274}
275
276/// State for the RANK(rank) built-in window function.
277#[derive(Debug, Clone, Default)]
278pub struct RankState {
279    /// The last values for rank as these values change, we increase n_rank
280    pub last_rank_data: Option<Vec<ScalarValue>>,
281    /// The index where last_rank_boundary is started
282    pub last_rank_boundary: usize,
283    /// Keep the number of entries in current rank
284    pub current_group_count: usize,
285    /// Rank number kept from the start
286    pub n_rank: usize,
287}
288
289/// State for the `rank` built-in window function.
290#[derive(Debug)]
291struct RankEvaluator {
292    state: RankState,
293    rank_type: RankType,
294}
295
296impl PartitionEvaluator for RankEvaluator {
297    fn is_causal(&self) -> bool {
298        matches!(self.rank_type, RankType::Basic | RankType::Dense)
299    }
300
301    fn evaluate(
302        &mut self,
303        values: &[ArrayRef],
304        range: &Range<usize>,
305    ) -> Result<ScalarValue> {
306        let row_idx = range.start;
307        // There is no argument, values are order by column values (where rank is calculated)
308        let range_columns = values;
309        let last_rank_data = get_row_at_idx(range_columns, row_idx)?;
310        let new_rank_encountered =
311            if let Some(state_last_rank_data) = &self.state.last_rank_data {
312                // if rank data changes, new rank is encountered
313                state_last_rank_data != &last_rank_data
314            } else {
315                // First rank seen
316                true
317            };
318        if new_rank_encountered {
319            self.state.last_rank_data = Some(last_rank_data);
320            self.state.last_rank_boundary += self.state.current_group_count;
321            self.state.current_group_count = 1;
322            self.state.n_rank += 1;
323        } else {
324            // data is still in the same rank
325            self.state.current_group_count += 1;
326        }
327
328        match self.rank_type {
329            RankType::Basic => Ok(ScalarValue::UInt64(Some(
330                self.state.last_rank_boundary as u64 + 1,
331            ))),
332            RankType::Dense => Ok(ScalarValue::UInt64(Some(self.state.n_rank as u64))),
333            RankType::Percent => {
334                exec_err!("Can not execute PERCENT_RANK in a streaming fashion")
335            }
336        }
337    }
338
339    fn evaluate_all_with_rank(
340        &self,
341        num_rows: usize,
342        ranks_in_partition: &[Range<usize>],
343    ) -> Result<ArrayRef> {
344        let result: ArrayRef = match self.rank_type {
345            RankType::Basic => Arc::new(UInt64Array::from_iter_values(
346                ranks_in_partition
347                    .iter()
348                    .scan(1_u64, |acc, range| {
349                        let len = range.end - range.start;
350                        let result = iter::repeat_n(*acc, len);
351                        *acc += len as u64;
352                        Some(result)
353                    })
354                    .flatten(),
355            )),
356
357            RankType::Dense => Arc::new(UInt64Array::from_iter_values(
358                ranks_in_partition
359                    .iter()
360                    .zip(1u64..)
361                    .flat_map(|(range, rank)| {
362                        let len = range.end - range.start;
363                        iter::repeat_n(rank, len)
364                    }),
365            )),
366
367            RankType::Percent => {
368                let denominator = num_rows as f64;
369
370                Arc::new(Float64Array::from_iter_values(
371                    ranks_in_partition
372                        .iter()
373                        .scan(0_u64, |acc, range| {
374                            let len = range.end - range.start;
375                            let value = (*acc as f64) / (denominator - 1.0).max(1.0);
376                            let result = iter::repeat_n(value, len);
377                            *acc += len as u64;
378                            Some(result)
379                        })
380                        .flatten(),
381                ))
382            }
383        };
384
385        Ok(result)
386    }
387
388    fn supports_bounded_execution(&self) -> bool {
389        matches!(self.rank_type, RankType::Basic | RankType::Dense)
390    }
391
392    fn include_rank(&self) -> bool {
393        true
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use datafusion_common::cast::{as_float64_array, as_uint64_array};
401
402    fn test_with_rank(expr: &Rank, expected: Vec<u64>) -> Result<()> {
403        test_i32_result(expr, vec![0..2, 2..3, 3..6, 6..7, 7..8], expected)
404    }
405
406    #[allow(clippy::single_range_in_vec_init)]
407    fn test_without_rank(expr: &Rank, expected: Vec<u64>) -> Result<()> {
408        test_i32_result(expr, vec![0..8], expected)
409    }
410
411    fn test_i32_result(
412        expr: &Rank,
413        ranks: Vec<Range<usize>>,
414        expected: Vec<u64>,
415    ) -> Result<()> {
416        let args = PartitionEvaluatorArgs::default();
417        let result = expr
418            .partition_evaluator(args)?
419            .evaluate_all_with_rank(8, &ranks)?;
420        let result = as_uint64_array(&result)?;
421        let result = result.values();
422        assert_eq!(expected, *result);
423        Ok(())
424    }
425
426    fn test_f64_result(
427        expr: &Rank,
428        num_rows: usize,
429        ranks: Vec<Range<usize>>,
430        expected: Vec<f64>,
431    ) -> Result<()> {
432        let args = PartitionEvaluatorArgs::default();
433        let result = expr
434            .partition_evaluator(args)?
435            .evaluate_all_with_rank(num_rows, &ranks)?;
436        let result = as_float64_array(&result)?;
437        let result = result.values();
438        assert_eq!(expected, *result);
439        Ok(())
440    }
441
442    #[test]
443    fn test_rank() -> Result<()> {
444        let r = Rank::basic();
445        test_without_rank(&r, vec![1; 8])?;
446        test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?;
447        Ok(())
448    }
449
450    #[test]
451    fn test_dense_rank() -> Result<()> {
452        let r = Rank::dense_rank();
453        test_without_rank(&r, vec![1; 8])?;
454        test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?;
455        Ok(())
456    }
457
458    #[test]
459    #[allow(clippy::single_range_in_vec_init)]
460    fn test_percent_rank() -> Result<()> {
461        let r = Rank::percent_rank();
462
463        // empty case
464        let expected = vec![0.0; 0];
465        test_f64_result(&r, 0, vec![0..0; 0], expected)?;
466
467        // singleton case
468        let expected = vec![0.0];
469        test_f64_result(&r, 1, vec![0..1], expected)?;
470
471        // uniform case
472        let expected = vec![0.0; 7];
473        test_f64_result(&r, 7, vec![0..7], expected)?;
474
475        // non-trivial case
476        let expected = vec![0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5];
477        test_f64_result(&r, 7, vec![0..3, 3..7], expected)?;
478
479        Ok(())
480    }
481}