datafusion_functions_window/
row_number.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `row_number` window function implementation
19
20use arrow::datatypes::FieldRef;
21use datafusion_common::arrow::array::ArrayRef;
22use datafusion_common::arrow::array::UInt64Array;
23use datafusion_common::arrow::compute::SortOptions;
24use datafusion_common::arrow::datatypes::DataType;
25use datafusion_common::arrow::datatypes::Field;
26use datafusion_common::{Result, ScalarValue};
27use datafusion_expr::{
28    Documentation, LimitEffect, PartitionEvaluator, Signature, Volatility, WindowUDFImpl,
29};
30use datafusion_functions_window_common::field;
31use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
32use datafusion_macros::user_doc;
33use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
34use field::WindowUDFFieldArgs;
35use std::any::Any;
36use std::fmt::Debug;
37use std::ops::Range;
38use std::sync::Arc;
39
40define_udwf_and_expr!(
41    RowNumber,
42    row_number,
43    "Returns a unique row number for each row in window partition beginning at 1."
44);
45
46/// row_number expression
47#[user_doc(
48    doc_section(label = "Ranking Functions"),
49    description = "Number of the current row within its partition, counting from 1.",
50    syntax_example = "row_number()",
51    sql_example = r#"
52```sql
53-- Example usage of the row_number window function:
54SELECT department,
55  salary,
56  row_number() OVER (PARTITION BY department ORDER BY salary DESC) AS row_num
57FROM employees;
58
59+-------------+--------+---------+
60| department  | salary | row_num |
61+-------------+--------+---------+
62| Sales       | 70000  | 1       |
63| Sales       | 50000  | 2       |
64| Sales       | 50000  | 3       |
65| Sales       | 30000  | 4       |
66| Engineering | 90000  | 1       |
67| Engineering | 80000  | 2       |
68+-------------+--------+---------+
69```
70"#
71)]
72#[derive(Debug, PartialEq, Eq, Hash)]
73pub struct RowNumber {
74    signature: Signature,
75}
76
77impl RowNumber {
78    /// Create a new `row_number` function
79    pub fn new() -> Self {
80        Self {
81            signature: Signature::nullary(Volatility::Immutable),
82        }
83    }
84}
85
86impl Default for RowNumber {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl WindowUDFImpl for RowNumber {
93    fn as_any(&self) -> &dyn Any {
94        self
95    }
96
97    fn name(&self) -> &str {
98        "row_number"
99    }
100
101    fn signature(&self) -> &Signature {
102        &self.signature
103    }
104
105    fn partition_evaluator(
106        &self,
107        _partition_evaluator_args: PartitionEvaluatorArgs,
108    ) -> Result<Box<dyn PartitionEvaluator>> {
109        Ok(Box::<NumRowsEvaluator>::default())
110    }
111
112    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
113        Ok(Field::new(field_args.name(), DataType::UInt64, false).into())
114    }
115
116    fn sort_options(&self) -> Option<SortOptions> {
117        Some(SortOptions {
118            descending: false,
119            nulls_first: false,
120        })
121    }
122
123    fn documentation(&self) -> Option<&Documentation> {
124        self.doc()
125    }
126
127    fn limit_effect(&self, _args: &[Arc<dyn PhysicalExpr>]) -> LimitEffect {
128        LimitEffect::None
129    }
130}
131
132/// State for the `row_number` built-in window function.
133#[derive(Debug, Default)]
134struct NumRowsEvaluator {
135    n_rows: usize,
136}
137
138impl PartitionEvaluator for NumRowsEvaluator {
139    fn is_causal(&self) -> bool {
140        // The row_number function doesn't need "future" values to emit results:
141        true
142    }
143
144    fn evaluate_all(
145        &mut self,
146        _values: &[ArrayRef],
147        num_rows: usize,
148    ) -> Result<ArrayRef> {
149        Ok(Arc::new(UInt64Array::from_iter_values(
150            1..(num_rows as u64) + 1,
151        )))
152    }
153
154    fn evaluate(
155        &mut self,
156        _values: &[ArrayRef],
157        _range: &Range<usize>,
158    ) -> Result<ScalarValue> {
159        self.n_rows += 1;
160        Ok(ScalarValue::UInt64(Some(self.n_rows as u64)))
161    }
162
163    fn supports_bounded_execution(&self) -> bool {
164        true
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use std::sync::Arc;
171
172    use datafusion_common::arrow::array::{Array, BooleanArray};
173    use datafusion_common::cast::as_uint64_array;
174
175    use super::*;
176
177    #[test]
178    fn row_number_all_null() -> Result<()> {
179        let values: ArrayRef = Arc::new(BooleanArray::from(vec![
180            None, None, None, None, None, None, None, None,
181        ]));
182        let num_rows = values.len();
183
184        let actual = RowNumber::default()
185            .partition_evaluator(PartitionEvaluatorArgs::default())?
186            .evaluate_all(&[values], num_rows)?;
187        let actual = as_uint64_array(&actual)?;
188
189        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], *actual.values());
190        Ok(())
191    }
192
193    #[test]
194    fn row_number_all_values() -> Result<()> {
195        let values: ArrayRef = Arc::new(BooleanArray::from(vec![
196            true, false, true, false, false, true, false, true,
197        ]));
198        let num_rows = values.len();
199
200        let actual = RowNumber::default()
201            .partition_evaluator(PartitionEvaluatorArgs::default())?
202            .evaluate_all(&[values], num_rows)?;
203        let actual = as_uint64_array(&actual)?;
204
205        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], *actual.values());
206        Ok(())
207    }
208}