1use std::{collections::HashSet, ops::Add};
2
3use num_traits::Zero;
4use time::{Date, Duration, OffsetDateTime, PrimitiveDateTime, Time, Weekday, macros::time};
5use time_tz::{OffsetDateTimeExt, PrimitiveDateTimeExt, Tz};
6
7use crate::{
8 CandlestickComponents, CandlestickType, Period, QuoteType, TradeType, UpdateFields,
9 find_session::{FindSession, FindSessionResult},
10};
11
12#[derive(Debug, Copy, Clone, Eq, PartialEq)]
13pub struct TradeSession {
14 pub start: Time,
15 pub end: Time,
16 pub inclusive: bool,
17 pub timeout: Duration,
18}
19
20impl TradeSession {
21 #[inline]
22 pub const fn new(start: Time, end: Time) -> Self {
23 Self {
24 start,
25 end,
26 inclusive: false,
27 timeout: Duration::ZERO,
28 }
29 }
30
31 #[inline]
32 pub const fn with_timeout(self, timeout: Duration) -> Self {
33 Self { timeout, ..self }
34 }
35
36 #[inline]
37 pub const fn with_inclusive(self) -> Self {
38 Self {
39 inclusive: true,
40 ..self
41 }
42 }
43}
44
45#[derive(Debug, Copy, Clone, Eq, PartialEq)]
46pub struct TradeSessionKind(usize);
47
48pub trait TradeSessionType: Copy {
49 fn kind(&self) -> TradeSessionKind;
50
51 #[inline]
52 fn is_intraday(&self) -> bool {
53 self.kind() == TRADE_SESSION_INTRADAY
54 }
55
56 #[inline]
57 fn as_str(&self) -> &'static str {
58 let kind = self.kind();
59 if kind == TRADE_SESSION_INTRADAY {
60 "intraday"
61 } else if kind == TRADE_SESSION_PRE {
62 "pre"
63 } else if kind == TRADE_SESSION_POST {
64 "post"
65 } else if kind == TRADE_SESSION_OVERNIGHT {
66 "overnight"
67 } else {
68 unreachable!()
69 }
70 }
71}
72
73impl TradeSessionType for TradeSessionKind {
74 #[inline]
75 fn kind(&self) -> TradeSessionKind {
76 *self
77 }
78}
79
80pub const TRADE_SESSION_INTRADAY: TradeSessionKind = TradeSessionKind(0);
81pub const TRADE_SESSION_PRE: TradeSessionKind = TradeSessionKind(1);
82pub const TRADE_SESSION_POST: TradeSessionKind = TradeSessionKind(2);
83pub const TRADE_SESSION_OVERNIGHT: TradeSessionKind = TradeSessionKind(3);
84
85#[derive(Debug, Copy, Clone, Eq, PartialEq)]
86pub struct Market {
87 pub timezone: &'static Tz,
88 pub trade_sessions: &'static [&'static [TradeSession]],
89 pub half_trade_sessions: &'static [&'static [TradeSession]],
90 pub lot_size: i32,
91}
92
93#[derive(Debug, Copy, Clone, Eq, PartialEq)]
94pub enum UpdateAction<T> {
95 UpdateLast(T),
96 AppendNew { confirmed: Option<T>, new: T },
97 None,
98}
99
100pub trait Days: std::fmt::Debug + Copy {
101 fn contains(&self, date: Date) -> bool;
102}
103
104impl Days for bool {
105 #[inline]
106 fn contains(&self, _date: Date) -> bool {
107 *self
108 }
109}
110
111impl Days for &HashSet<Date> {
112 #[inline]
113 fn contains(&self, date: Date) -> bool {
114 HashSet::contains(self, &date)
115 }
116}
117
118impl Market {
119 pub fn candlestick_time<H, TS>(
120 &self,
121 ts: TS,
122 half_days: H,
123 period: Period,
124 t: OffsetDateTime,
125 ) -> Option<OffsetDateTime>
126 where
127 H: Days,
128 TS: TradeSessionType,
129 {
130 use Period::*;
131
132 if !ts.is_intraday() && !period.is_minute() {
133 return None;
134 }
135
136 let ts = ts.kind();
137
138 let t = t.to_timezone(self.timezone);
139 let time = t.time();
140 let trade_sessions = if !half_days.contains(t.date()) {
141 self.trade_sessions.get(ts.0)?
142 } else {
143 self.half_trade_sessions.get(ts.0)?
144 };
145 let res = trade_sessions.find_session(time);
146 let (time, n) = match res {
147 FindSessionResult::BeforeFirst => return None,
148 FindSessionResult::Between(n) => Some((time, n)),
149 FindSessionResult::After(n) => {
150 if time >= trade_sessions[n].end + trade_sessions[n].timeout {
151 return None;
152 } else {
153 Some((trade_sessions[n].end, n))
154 }
155 }
156 }?;
157
158 Some(match period {
159 Min_1 => t.replace_time(Time::from_hms(time.hour(), time.minute(), 0).ok()?),
160 Min_2 | Min_3 | Min_5 | Min_10 | Min_15 | Min_20 | Min_30 | Min_45 | Min_60
161 | Min_120 | Min_180 | Min_240 => {
162 let minutes = period.minutes() as i64;
163 let TradeSession { start, .. } = &trade_sessions[n];
164 let start_minutes = start.hour() as i64 * 60 + start.minute() as i64;
165 let current_minutes = time.hour() as i64 * 60 + time.minute() as i64;
166 let offset_minutes = ((current_minutes - start_minutes) / minutes) * minutes;
167 t.replace_time(*start + Duration::minutes(offset_minutes))
168 }
169 Day => t.replace_time(time!(00:00:00)),
170 Week => {
171 let week = t.iso_week();
172 Date::from_iso_week_date(t.year(), week, Weekday::Monday)
173 .ok()?
174 .with_hms(0, 0, 0)
175 .ok()?
176 .assume_timezone(self.timezone)
177 .take_first()?
178 }
179 Month => PrimitiveDateTime::new(
180 Date::from_calendar_date(t.year(), t.month(), 1).ok()?,
181 time!(00:00:00),
182 )
183 .assume_timezone(self.timezone)
184 .take_first()?,
185 Quarter => {
186 let month = t.month();
187 let quarter = (month as u8 - 1) / 3;
188 let date = Date::from_calendar_date(
189 t.year(),
190 time::Month::try_from(quarter * 3 + 1).ok()?,
191 1,
192 )
193 .ok()?;
194 PrimitiveDateTime::new(date, time!(00:00:00))
195 .assume_timezone(self.timezone)
196 .take_first()?
197 }
198 Year => PrimitiveDateTime::new(
199 Date::from_calendar_date(t.year(), time::Month::January, 1).ok()?,
200 time!(00:00:00),
201 )
202 .assume_timezone(self.timezone)
203 .take_first()?,
204 })
205 }
206
207 #[must_use]
208 pub fn merge_trade<H, TS, C, T, P, V, R>(
209 &self,
210 half_days: H,
211 period: Period,
212 input: Option<C>,
213 trade: &T,
214 update_fields: UpdateFields,
215 ) -> UpdateAction<C>
216 where
217 H: Days,
218 TS: TradeSessionType + Eq,
219 C: CandlestickType<PriceType = P, VolumeType = V, TurnoverType = R, TradeSessionType = TS>,
220 T: TradeType<PriceType = P, VolumeType = V, TurnoverType = R, TradeSessionType = TS>,
221 P: PartialOrd + Add<Output = P>,
222 V: Add<Output = V> + Zero,
223 R: Add<Output = R> + Zero,
224 {
225 let trade_session = trade.trade_session();
226
227 if let Some(input_trade_session) = input.as_ref().map(|c| c.trade_session()) {
228 debug_assert!(input_trade_session == trade_session);
229 }
230
231 let Some(time) = self.candlestick_time(
232 trade_session,
233 half_days,
234 period,
235 trade.time().to_timezone(self.timezone),
236 ) else {
237 return UpdateAction::None;
238 };
239
240 match input {
241 Some(prev) if time == prev.time() => {
242 let mut candlestick = prev;
243
244 if update_fields.contains(UpdateFields::PRICE) {
245 if !candlestick.open_updated() {
246 candlestick.set_open(trade.price());
247 candlestick.set_open_updated(true);
248 }
249
250 candlestick.set_high(if trade.price() > candlestick.high() {
251 trade.price()
252 } else {
253 candlestick.high()
254 });
255
256 candlestick.set_low(if trade.price() < candlestick.low() {
257 trade.price()
258 } else {
259 candlestick.low()
260 });
261
262 candlestick.set_close(trade.price());
263 }
264
265 if update_fields.contains(UpdateFields::VOLUME) {
266 candlestick.set_volume(candlestick.volume() + trade.volume());
267 candlestick
268 .set_turnover(candlestick.turnover() + trade.turnover(self.lot_size));
269 }
270
271 UpdateAction::UpdateLast(candlestick)
272 }
273 None => {
274 if update_fields.contains(UpdateFields::PRICE) {
275 let new_candlestick = C::new(CandlestickComponents {
276 time: time.to_timezone(time_tz::timezones::db::UTC),
277 open: trade.price(),
278 high: trade.price(),
279 low: trade.price(),
280 close: trade.price(),
281 volume: trade.volume(),
282 turnover: trade.turnover(self.lot_size),
283 trade_session,
284 open_updated: true,
285 });
286 UpdateAction::AppendNew {
287 confirmed: None,
288 new: new_candlestick,
289 }
290 } else {
291 UpdateAction::None
292 }
293 }
294 Some(prev) if time > prev.time() => {
295 let mut new_candlestick = C::new(CandlestickComponents {
296 time: time.to_timezone(time_tz::timezones::db::UTC),
297 open: prev.close(),
298 high: prev.close(),
299 low: prev.close(),
300 close: prev.close(),
301 volume: V::zero(),
302 turnover: R::zero(),
303 trade_session,
304 open_updated: false,
305 });
306
307 if update_fields.contains(UpdateFields::PRICE) {
308 new_candlestick.set_open(trade.price());
309 new_candlestick.set_high(trade.price());
310 new_candlestick.set_low(trade.price());
311 new_candlestick.set_close(trade.price());
312 new_candlestick.set_open_updated(true);
313 }
314
315 if update_fields.contains(UpdateFields::VOLUME) {
316 new_candlestick.set_volume(trade.volume());
317 new_candlestick.set_turnover(trade.turnover(self.lot_size));
318 }
319
320 UpdateAction::AppendNew {
321 confirmed: Some(prev),
322 new: new_candlestick,
323 }
324 }
325 _ => UpdateAction::None,
326 }
327 }
328
329 #[must_use]
330 pub fn merge_quote_day<TS, C, Q, P, V, R>(&self, input: Option<C>, quote: &Q) -> UpdateAction<C>
331 where
332 TS: TradeSessionType + Eq,
333 C: CandlestickType<PriceType = P, VolumeType = V, TurnoverType = R, TradeSessionType = TS>,
334 Q: QuoteType<PriceType = P, VolumeType = V, TurnoverType = R, TradeSessionType = TS>,
335 {
336 let trade_session = quote.trade_session();
337
338 if !trade_session.is_intraday() {
339 return UpdateAction::None;
340 }
341
342 if let Some(input_trade_session) = input.as_ref().map(|c| c.trade_session()) {
343 debug_assert!(input_trade_session == trade_session);
344 }
345
346 let tz = self.timezone;
347 let time = quote.time().to_timezone(tz).replace_time(Time::MIDNIGHT);
348
349 match input {
350 Some(prev) if time == prev.time() => {
351 UpdateAction::UpdateLast(C::new(CandlestickComponents {
352 time: time.to_timezone(time_tz::timezones::db::UTC),
353 open: quote.open(),
354 high: quote.high(),
355 low: quote.low(),
356 close: quote.last_done(),
357 volume: quote.volume(),
358 turnover: quote.turnover(),
359 trade_session,
360 open_updated: true,
361 }))
362 }
363 None => UpdateAction::AppendNew {
364 confirmed: None,
365 new: C::new(CandlestickComponents {
366 time: time.to_timezone(time_tz::timezones::db::UTC),
367 open: quote.open(),
368 high: quote.high(),
369 low: quote.low(),
370 close: quote.last_done(),
371 volume: quote.volume(),
372 turnover: quote.turnover(),
373 trade_session,
374 open_updated: true,
375 }),
376 },
377 Some(prev) if time > prev.time() => UpdateAction::AppendNew {
378 confirmed: Some(prev),
379 new: C::new(CandlestickComponents {
380 time: time.to_timezone(time_tz::timezones::db::UTC),
381 open: quote.open(),
382 high: quote.high(),
383 low: quote.low(),
384 close: quote.last_done(),
385 volume: quote.volume(),
386 turnover: quote.turnover(),
387 trade_session,
388 open_updated: true,
389 }),
390 },
391 _ => UpdateAction::None,
392 }
393 }
394
395 pub fn trade_session(&self, candlestick_time: OffsetDateTime) -> Option<TradeSessionKind> {
396 let candlestick_time = candlestick_time.to_timezone(self.timezone);
397 for (idx, trade_sessions) in self.trade_sessions.iter().enumerate() {
398 for TradeSession {
399 start,
400 end,
401 inclusive,
402 timeout,
403 ..
404 } in trade_sessions.iter()
405 {
406 let time = candlestick_time.time();
407 if !*inclusive && timeout.is_zero() {
408 if time >= *start && time < *end {
409 return Some(TradeSessionKind(idx));
410 }
411 } else if time >= *start && time <= *end {
412 return Some(TradeSessionKind(idx));
413 }
414 }
415 }
416 None
417 }
418
419 pub fn is_first<H, TS>(
420 &self,
421 half_days: H,
422 period: Period,
423 ts: TS,
424 candlestick_time: OffsetDateTime,
425 ) -> bool
426 where
427 H: Days,
428 TS: TradeSessionType,
429 {
430 assert!(period.is_minute());
431 if !half_days.contains(candlestick_time.date()) {
432 self.trade_sessions
433 } else {
434 self.half_trade_sessions
435 }
436 .get(ts.kind().0)
437 .and_then(|sessions| sessions.first())
438 .map(|session| session.start)
439 == Some(candlestick_time.to_timezone(self.timezone).time())
440 }
441
442 pub fn is_last<H, TS>(
443 &self,
444 half_days: H,
445 period: Period,
446 ts: TS,
447 candlestick_time: OffsetDateTime,
448 ) -> bool
449 where
450 H: Days,
451 TS: TradeSessionType,
452 {
453 assert!(period.is_minute());
454
455 let Some(mut end) = if !half_days.contains(candlestick_time.date()) {
456 self.trade_sessions
457 } else {
458 self.half_trade_sessions
459 }
460 .get(ts.kind().0)
461 .and_then(|sessions| sessions.last())
462 .map(|session| session.end) else {
463 return false;
464 };
465 end -= Duration::seconds(1);
466 let Some(last_time) = PrimitiveDateTime::new(candlestick_time.date(), end)
467 .assume_timezone(self.timezone)
468 .take_first()
469 else {
470 return false;
471 };
472 let Some(last_time) = self.candlestick_time(ts, half_days, period, last_time) else {
473 return false;
474 };
475 last_time.time() == candlestick_time.to_timezone(self.timezone).time()
476 }
477}