Skip to main content

reifydb_routine/function/datetime/
subtract.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{Column, columns::Columns, data::ColumnData};
5use reifydb_type::value::{container::temporal::TemporalContainer, date::Date, datetime::DateTime, r#type::Type};
6
7use crate::function::{Function, FunctionCapability, FunctionContext, FunctionInfo, error::FunctionError};
8
9pub struct DateTimeSubtract {
10	info: FunctionInfo,
11}
12
13impl Default for DateTimeSubtract {
14	fn default() -> Self {
15		Self::new()
16	}
17}
18
19impl DateTimeSubtract {
20	pub fn new() -> Self {
21		Self {
22			info: FunctionInfo::new("datetime::subtract"),
23		}
24	}
25}
26
27impl Function for DateTimeSubtract {
28	fn info(&self) -> &FunctionInfo {
29		&self.info
30	}
31
32	fn capabilities(&self) -> &[FunctionCapability] {
33		&[FunctionCapability::Scalar]
34	}
35
36	fn return_type(&self, _input_types: &[Type]) -> Type {
37		Type::DateTime
38	}
39
40	fn execute(&self, ctx: &FunctionContext, args: &Columns) -> Result<Columns, FunctionError> {
41		if args.len() != 2 {
42			return Err(FunctionError::ArityMismatch {
43				function: ctx.fragment.clone(),
44				expected: 2,
45				actual: args.len(),
46			});
47		}
48
49		let dt_col = &args[0];
50		let dur_col = &args[1];
51		let (dt_data, dt_bitvec) = dt_col.data().unwrap_option();
52		let (dur_data, dur_bitvec) = dur_col.data().unwrap_option();
53		let row_count = dt_data.len();
54
55		let result_data = match (dt_data, dur_data) {
56			(ColumnData::DateTime(dt_container), ColumnData::Duration(dur_container)) => {
57				let mut container = TemporalContainer::with_capacity(row_count);
58
59				for i in 0..row_count {
60					match (dt_container.get(i), dur_container.get(i)) {
61						(Some(dt), Some(dur)) => {
62							let date = dt.date();
63							let time = dt.time();
64							let mut year = date.year();
65							let mut month = date.month() as i32;
66							let mut day = date.day();
67
68							// Subtract months component
69							let total_months = month - dur.get_months();
70							year += (total_months - 1).div_euclid(12);
71							month = (total_months - 1).rem_euclid(12) + 1;
72
73							// Clamp day to valid range for the new month
74							let max_day = days_in_month(year, month as u32);
75							if day > max_day {
76								day = max_day;
77							}
78
79							// Convert to seconds since epoch and subtract day/nanos
80							// components
81							if let Some(base_date) = Date::new(year, month as u32, day) {
82								let base_days = base_date.to_days_since_epoch() as i64
83									- dur.get_days() as i64;
84								let time_nanos = time.to_nanos_since_midnight() as i64
85									- dur.get_nanos();
86
87								let total_nanos = base_days as i128
88									* 86_400_000_000_000i128 + time_nanos
89									as i128;
90
91								if total_nanos >= 0 && total_nanos <= u64::MAX as i128 {
92									container.push(DateTime::from_nanos(
93										total_nanos as u64,
94									));
95								} else {
96									return Err(FunctionError::ExecutionFailed {
97										function: ctx.fragment.clone(),
98										reason: "datetime cannot be before Unix epoch".to_string(),
99									});
100								}
101							} else {
102								return Err(FunctionError::ExecutionFailed {
103									function: ctx.fragment.clone(),
104									reason: "datetime cannot be before Unix epoch"
105										.to_string(),
106								});
107							}
108						}
109						_ => container.push_default(),
110					}
111				}
112
113				ColumnData::DateTime(container)
114			}
115			(ColumnData::DateTime(_), other) => {
116				return Err(FunctionError::InvalidArgumentType {
117					function: ctx.fragment.clone(),
118					argument_index: 1,
119					expected: vec![Type::Duration],
120					actual: other.get_type(),
121				});
122			}
123			(other, _) => {
124				return Err(FunctionError::InvalidArgumentType {
125					function: ctx.fragment.clone(),
126					argument_index: 0,
127					expected: vec![Type::DateTime],
128					actual: other.get_type(),
129				});
130			}
131		};
132
133		let final_data = match (dt_bitvec, dur_bitvec) {
134			(Some(bv), _) | (_, Some(bv)) => ColumnData::Option {
135				inner: Box::new(result_data),
136				bitvec: bv.clone(),
137			},
138			_ => result_data,
139		};
140
141		Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), final_data)]))
142	}
143}
144
145fn days_in_month(year: i32, month: u32) -> u32 {
146	match month {
147		1 | 3 | 5 | 7 | 8 | 10 | 12 => 31,
148		4 | 6 | 9 | 11 => 30,
149		2 => {
150			if (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0) {
151				29
152			} else {
153				28
154			}
155		}
156		_ => 0,
157	}
158}