Skip to main content

reifydb_function/duration/
scale.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::data::ColumnData;
5use reifydb_type::value::{container::temporal::TemporalContainer, r#type::Type};
6
7use crate::{ScalarFunction, ScalarFunctionContext, error::ScalarFunctionError, propagate_options};
8
9pub struct DurationScale;
10
11impl DurationScale {
12	pub fn new() -> Self {
13		Self
14	}
15}
16
17fn extract_i64(data: &ColumnData, i: usize) -> Option<i64> {
18	match data {
19		ColumnData::Int1(c) => c.get(i).map(|&v| v as i64),
20		ColumnData::Int2(c) => c.get(i).map(|&v| v as i64),
21		ColumnData::Int4(c) => c.get(i).map(|&v| v as i64),
22		ColumnData::Int8(c) => c.get(i).copied(),
23		ColumnData::Int16(c) => c.get(i).map(|&v| v as i64),
24		ColumnData::Uint1(c) => c.get(i).map(|&v| v as i64),
25		ColumnData::Uint2(c) => c.get(i).map(|&v| v as i64),
26		ColumnData::Uint4(c) => c.get(i).map(|&v| v as i64),
27		ColumnData::Uint8(c) => c.get(i).map(|&v| v as i64),
28		ColumnData::Uint16(c) => c.get(i).map(|&v| v as i64),
29		_ => None,
30	}
31}
32
33fn is_integer_type(data: &ColumnData) -> bool {
34	matches!(
35		data,
36		ColumnData::Int1(_)
37			| ColumnData::Int2(_) | ColumnData::Int4(_)
38			| ColumnData::Int8(_) | ColumnData::Int16(_)
39			| ColumnData::Uint1(_)
40			| ColumnData::Uint2(_)
41			| ColumnData::Uint4(_)
42			| ColumnData::Uint8(_)
43			| ColumnData::Uint16(_)
44	)
45}
46
47impl ScalarFunction for DurationScale {
48	fn scalar(&self, ctx: ScalarFunctionContext) -> crate::error::ScalarFunctionResult<ColumnData> {
49		if let Some(result) = propagate_options(self, &ctx) {
50			return result;
51		}
52		let columns = ctx.columns;
53		let row_count = ctx.row_count;
54
55		if columns.len() != 2 {
56			return Err(ScalarFunctionError::ArityMismatch {
57				function: ctx.fragment.clone(),
58				expected: 2,
59				actual: columns.len(),
60			});
61		}
62
63		let dur_col = columns.get(0).unwrap();
64		let scalar_col = columns.get(1).unwrap();
65
66		match dur_col.data() {
67			ColumnData::Duration(dur_container) => {
68				if !is_integer_type(scalar_col.data()) {
69					return Err(ScalarFunctionError::InvalidArgumentType {
70						function: ctx.fragment.clone(),
71						argument_index: 1,
72						expected: vec![
73							Type::Int1,
74							Type::Int2,
75							Type::Int4,
76							Type::Int8,
77							Type::Int16,
78							Type::Uint1,
79							Type::Uint2,
80							Type::Uint4,
81							Type::Uint8,
82							Type::Uint16,
83						],
84						actual: scalar_col.data().get_type(),
85					});
86				}
87
88				let mut container = TemporalContainer::with_capacity(row_count);
89
90				for i in 0..row_count {
91					match (dur_container.get(i), extract_i64(scalar_col.data(), i)) {
92						(Some(dur), Some(scalar)) => {
93							container.push(*dur * scalar);
94						}
95						_ => container.push_default(),
96					}
97				}
98
99				Ok(ColumnData::Duration(container))
100			}
101			other => Err(ScalarFunctionError::InvalidArgumentType {
102				function: ctx.fragment.clone(),
103				argument_index: 0,
104				expected: vec![Type::Duration],
105				actual: other.get_type(),
106			}),
107		}
108	}
109
110	fn return_type(&self, _input_types: &[Type]) -> Type {
111		Type::Duration
112	}
113}