#[cfg(feature = "dtype-date")]
use chrono::DateTime;
use polars_core::prelude::arity::{binary_elementwise_values, try_binary_elementwise};
use polars_core::prelude::*;
#[cfg(feature = "dtype-date")]
use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY;
use polars_utils::binary_search::{find_first_ge_index, find_first_gt_index};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "timezones")]
use crate::prelude::replace_time_zone;
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Roll {
Forward,
Backward,
Raise,
}
pub fn business_day_count(
start: &Series,
end: &Series,
week_mask: [bool; 7],
holidays: &[i32],
) -> PolarsResult<Series> {
if !week_mask.iter().any(|&x| x) {
polars_bail!(ComputeError:"`week_mask` must have at least one business day");
}
let holidays = normalise_holidays(holidays, &week_mask);
let start_dates = start.date()?;
let end_dates = end.date()?;
let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;
let out = match (start_dates.len(), end_dates.len()) {
(_, 1) => {
if let Some(end_date) = end_dates.get(0) {
start_dates.apply_values(|start_date| {
business_day_count_impl(
start_date,
end_date,
&week_mask,
n_business_days_in_week_mask,
&holidays,
)
})
} else {
Int32Chunked::full_null(start_dates.name().clone(), start_dates.len())
}
},
(1, _) => {
if let Some(start_date) = start_dates.get(0) {
end_dates.apply_values(|end_date| {
business_day_count_impl(
start_date,
end_date,
&week_mask,
n_business_days_in_week_mask,
&holidays,
)
})
} else {
Int32Chunked::full_null(start_dates.name().clone(), end_dates.len())
}
},
_ => binary_elementwise_values(start_dates, end_dates, |start_date, end_date| {
business_day_count_impl(
start_date,
end_date,
&week_mask,
n_business_days_in_week_mask,
&holidays,
)
}),
};
Ok(out.into_series())
}
fn business_day_count_impl(
mut start_date: i32,
mut end_date: i32,
week_mask: &[bool; 7],
n_business_days_in_week_mask: i32,
holidays: &[i32],
) -> i32 {
let swapped = start_date > end_date;
if swapped {
(start_date, end_date) = (end_date, start_date);
start_date += 1;
end_date += 1;
}
let holidays_begin = find_first_ge_index(holidays, start_date);
let holidays_end = find_first_ge_index(&holidays[holidays_begin..], end_date) + holidays_begin;
let mut start_day_of_week = get_day_of_week(start_date);
let diff = end_date - start_date;
let whole_weeks = diff / 7;
let mut count = -((holidays_end - holidays_begin) as i32);
count += whole_weeks * n_business_days_in_week_mask;
start_date += whole_weeks * 7;
while start_date < end_date {
if unsafe { *week_mask.get_unchecked(start_day_of_week) } {
count += 1;
}
start_date += 1;
start_day_of_week = increment_day_of_week(start_day_of_week);
}
if swapped {
-count
} else {
count
}
}
pub fn add_business_days(
start: &Series,
n: &Series,
week_mask: [bool; 7],
holidays: &[i32],
roll: Roll,
) -> PolarsResult<Series> {
if !week_mask.iter().any(|&x| x) {
polars_bail!(ComputeError:"`week_mask` must have at least one business day");
}
match start.dtype() {
DataType::Date => {},
#[cfg(feature = "dtype-datetime")]
DataType::Datetime(time_unit, None) => {
let result_date =
add_business_days(&start.cast(&DataType::Date)?, n, week_mask, holidays, roll)?;
let start_time = start
.cast(&DataType::Time)?
.cast(&DataType::Duration(*time_unit))?;
return std::ops::Add::add(
result_date.cast(&DataType::Datetime(*time_unit, None))?,
start_time,
);
},
#[cfg(feature = "timezones")]
DataType::Datetime(time_unit, Some(time_zone)) => {
let start_naive = replace_time_zone(
start.datetime().unwrap(),
None,
&StringChunked::from_iter(std::iter::once("raise")),
NonExistent::Raise,
)?;
let result_date = add_business_days(
&start_naive.cast(&DataType::Date)?,
n,
week_mask,
holidays,
roll,
)?;
let start_time = start_naive
.cast(&DataType::Time)?
.cast(&DataType::Duration(*time_unit))?;
let result_naive = std::ops::Add::add(
result_date.cast(&DataType::Datetime(*time_unit, None))?,
start_time,
)?;
let result_tz_aware = replace_time_zone(
result_naive.datetime().unwrap(),
Some(time_zone),
&StringChunked::from_iter(std::iter::once("raise")),
NonExistent::Raise,
)?;
return Ok(result_tz_aware.into_series());
},
_ => polars_bail!(InvalidOperation: "expected date or datetime, got {}", start.dtype()),
}
let holidays = normalise_holidays(holidays, &week_mask);
let start_dates = start.date()?;
let n = match &n.dtype() {
DataType::Int64 | DataType::UInt64 | DataType::UInt32 => n.cast(&DataType::Int32)?,
DataType::Int32 => n.clone(),
_ => {
polars_bail!(InvalidOperation: "expected Int64, Int32, UInt64, or UInt32, got {}", n.dtype())
},
};
let n = n.i32()?;
let n_business_days_in_week_mask = week_mask.iter().filter(|&x| *x).count() as i32;
let out: Int32Chunked = match (start_dates.len(), n.len()) {
(_, 1) => {
if let Some(n) = n.get(0) {
start_dates.try_apply_nonnull_values_generic(|start_date| {
let (start_date, day_of_week) =
roll_start_date(start_date, roll, &week_mask, &holidays)?;
Ok::<i32, PolarsError>(add_business_days_impl(
start_date,
day_of_week,
n,
&week_mask,
n_business_days_in_week_mask,
&holidays,
))
})?
} else {
Int32Chunked::full_null(start_dates.name().clone(), start_dates.len())
}
},
(1, _) => {
if let Some(start_date) = start_dates.get(0) {
let (start_date, day_of_week) =
roll_start_date(start_date, roll, &week_mask, &holidays)?;
n.apply_values(|n| {
add_business_days_impl(
start_date,
day_of_week,
n,
&week_mask,
n_business_days_in_week_mask,
&holidays,
)
})
} else {
Int32Chunked::full_null(start_dates.name().clone(), n.len())
}
},
_ => try_binary_elementwise(start_dates, n, |opt_start_date, opt_n| {
match (opt_start_date, opt_n) {
(Some(start_date), Some(n)) => {
let (start_date, day_of_week) =
roll_start_date(start_date, roll, &week_mask, &holidays)?;
Ok::<Option<i32>, PolarsError>(Some(add_business_days_impl(
start_date,
day_of_week,
n,
&week_mask,
n_business_days_in_week_mask,
&holidays,
)))
},
_ => Ok(None),
}
})?,
};
Ok(out.into_date().into_series())
}
fn add_business_days_impl(
mut date: i32,
mut day_of_week: usize,
mut n: i32,
week_mask: &[bool; 7],
n_business_days_in_week_mask: i32,
holidays: &[i32],
) -> i32 {
if n > 0 {
let holidays_begin = find_first_ge_index(holidays, date);
date += (n / n_business_days_in_week_mask) * 7;
n %= n_business_days_in_week_mask;
let holidays_temp = find_first_gt_index(&holidays[holidays_begin..], date) + holidays_begin;
n += (holidays_temp - holidays_begin) as i32;
let holidays_begin = holidays_temp;
while n > 0 {
date += 1;
day_of_week = increment_day_of_week(day_of_week);
if unsafe {
(*week_mask.get_unchecked(day_of_week))
&& (!holidays[holidays_begin..].contains(&date))
} {
n -= 1;
}
}
date
} else {
let holidays_end = find_first_gt_index(holidays, date);
date += (n / n_business_days_in_week_mask) * 7;
n %= n_business_days_in_week_mask;
let holidays_temp = find_first_ge_index(&holidays[..holidays_end], date);
n -= (holidays_end - holidays_temp) as i32;
let holidays_end = holidays_temp;
while n < 0 {
date -= 1;
day_of_week = decrement_day_of_week(day_of_week);
if unsafe {
(*week_mask.get_unchecked(day_of_week))
&& (!holidays[..holidays_end].contains(&date))
} {
n += 1;
}
}
date
}
}
fn roll_start_date(
mut date: i32,
roll: Roll,
week_mask: &[bool; 7],
holidays: &[i32],
) -> PolarsResult<(i32, usize)> {
let mut day_of_week = get_day_of_week(date);
match roll {
Roll::Raise => {
if holidays.contains(&date) | unsafe { !*week_mask.get_unchecked(day_of_week) } {
let date = DateTime::from_timestamp(date as i64 * SECONDS_IN_DAY, 0)
.unwrap()
.format("%Y-%m-%d");
polars_bail!(ComputeError:
"date {} is not a business date; use `roll` to roll forwards (or backwards) to the next (or previous) valid date.", date
)
};
},
Roll::Forward => {
while holidays.contains(&date) | unsafe { !*week_mask.get_unchecked(day_of_week) } {
date += 1;
day_of_week = increment_day_of_week(day_of_week);
}
},
Roll::Backward => {
while holidays.contains(&date) | unsafe { !*week_mask.get_unchecked(day_of_week) } {
date -= 1;
day_of_week = decrement_day_of_week(day_of_week);
}
},
}
Ok((date, day_of_week))
}
fn normalise_holidays(holidays: &[i32], week_mask: &[bool; 7]) -> Vec<i32> {
let mut holidays: Vec<i32> = holidays.to_vec();
holidays.sort_unstable();
let mut previous_holiday: Option<i32> = None;
holidays.retain(|&x| {
if (Some(x) == previous_holiday) || !unsafe { *week_mask.get_unchecked(get_day_of_week(x)) }
{
return false;
}
previous_holiday = Some(x);
true
});
holidays
}
fn get_day_of_week(x: i32) -> usize {
(((x - 4) % 7 + 7) % 7) as usize
}
fn increment_day_of_week(x: usize) -> usize {
if x == 6 {
0
} else {
x + 1
}
}
fn decrement_day_of_week(x: usize) -> usize {
if x == 0 {
6
} else {
x - 1
}
}