Skip to main content

reifydb_function/datetime/
subtract.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::data::ColumnData;
5use reifydb_type::value::{container::temporal::TemporalContainer, date::Date, datetime::DateTime, r#type::Type};
6
7use crate::{ScalarFunction, ScalarFunctionContext, error::ScalarFunctionError, propagate_options};
8
9pub struct DateTimeSubtract;
10
11impl DateTimeSubtract {
12	pub fn new() -> Self {
13		Self
14	}
15}
16
17impl ScalarFunction for DateTimeSubtract {
18	fn scalar(&self, ctx: ScalarFunctionContext) -> crate::error::ScalarFunctionResult<ColumnData> {
19		if let Some(result) = propagate_options(self, &ctx) {
20			return result;
21		}
22		let columns = ctx.columns;
23		let row_count = ctx.row_count;
24
25		if columns.len() != 2 {
26			return Err(ScalarFunctionError::ArityMismatch {
27				function: ctx.fragment.clone(),
28				expected: 2,
29				actual: columns.len(),
30			});
31		}
32
33		let dt_col = columns.get(0).unwrap();
34		let dur_col = columns.get(1).unwrap();
35
36		match (dt_col.data(), dur_col.data()) {
37			(ColumnData::DateTime(dt_container), ColumnData::Duration(dur_container)) => {
38				let mut container = TemporalContainer::with_capacity(row_count);
39
40				for i in 0..row_count {
41					match (dt_container.get(i), dur_container.get(i)) {
42						(Some(dt), Some(dur)) => {
43							let date = dt.date();
44							let time = dt.time();
45							let mut year = date.year();
46							let mut month = date.month() as i32;
47							let mut day = date.day();
48
49							// Subtract months component
50							let total_months = month - dur.get_months();
51							year += (total_months - 1).div_euclid(12);
52							month = (total_months - 1).rem_euclid(12) + 1;
53
54							// Clamp day to valid range for the new month
55							let max_day = days_in_month(year, month as u32);
56							if day > max_day {
57								day = max_day;
58							}
59
60							// Convert to seconds since epoch and subtract day/nanos
61							// components
62							if let Some(base_date) = Date::new(year, month as u32, day) {
63								let base_days = base_date.to_days_since_epoch() as i64
64									- dur.get_days() as i64;
65								let time_nanos = time.to_nanos_since_midnight() as i64
66									- dur.get_nanos();
67
68								let total_seconds =
69									base_days * 86400 + time_nanos / 1_000_000_000;
70								let nano_rem = time_nanos % 1_000_000_000;
71								let (total_seconds, nano_part) = if nano_rem < 0 {
72									(
73										total_seconds - 1,
74										(1_000_000_000 + nano_rem) as u32,
75									)
76								} else {
77									(total_seconds, nano_rem as u32)
78								};
79
80								match DateTime::from_parts(total_seconds, nano_part) {
81									Ok(result) => container.push(result),
82									Err(_) => container.push_default(),
83								}
84							} else {
85								container.push_default();
86							}
87						}
88						_ => container.push_default(),
89					}
90				}
91
92				Ok(ColumnData::DateTime(container))
93			}
94			(ColumnData::DateTime(_), other) => Err(ScalarFunctionError::InvalidArgumentType {
95				function: ctx.fragment.clone(),
96				argument_index: 1,
97				expected: vec![Type::Duration],
98				actual: other.get_type(),
99			}),
100			(other, _) => Err(ScalarFunctionError::InvalidArgumentType {
101				function: ctx.fragment.clone(),
102				argument_index: 0,
103				expected: vec![Type::DateTime],
104				actual: other.get_type(),
105			}),
106		}
107	}
108
109	fn return_type(&self, _input_types: &[Type]) -> Type {
110		Type::DateTime
111	}
112}
113
114fn days_in_month(year: i32, month: u32) -> u32 {
115	match month {
116		1 | 3 | 5 | 7 | 8 | 10 | 12 => 31,
117		4 | 6 | 9 | 11 => 30,
118		2 => {
119			if (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0) {
120				29
121			} else {
122				28
123			}
124		}
125		_ => 0,
126	}
127}