longport/quote/
core.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4};
5
6use comfy_table::Table;
7use itertools::Itertools;
8use longport_candlesticks::UpdateAction;
9use longport_httpcli::HttpClient;
10use longport_proto::quote::{
11    self, AdjustType, MarketTradeDayRequest, MarketTradeDayResponse, MultiSecurityRequest, Period,
12    PushQuoteTag, SecurityCandlestickRequest, SecurityCandlestickResponse,
13    SecurityStaticInfoResponse, SubscribeRequest, TradeSession, UnsubscribeRequest,
14};
15use longport_wscli::{
16    CodecType, Platform, ProtocolVersion, RateLimit, WsClient, WsClientError, WsEvent, WsSession,
17};
18use time::{Date, OffsetDateTime};
19use tokio::{
20    sync::{mpsc, oneshot},
21    time::{Duration, Instant},
22};
23
24use crate::{
25    config::PushCandlestickMode,
26    quote::{
27        cmd_code,
28        store::{Candlesticks, Store},
29        sub_flags::SubFlags,
30        types::QuotePackageDetail,
31        utils::{format_date, parse_date},
32        Candlestick, PushCandlestick, PushEvent, PushEventDetail, PushQuote, PushTrades,
33        RealtimeQuote, SecurityBoard, SecurityBrokers, SecurityDepth, Subscription, Trade,
34    },
35    Config, Error, Market, Result,
36};
37
38const RECONNECT_DELAY: Duration = Duration::from_secs(2);
39const MAX_CANDLESTICKS: usize = 500;
40
41pub(crate) enum Command {
42    Request {
43        command_code: u8,
44        body: Vec<u8>,
45        reply_tx: oneshot::Sender<Result<Vec<u8>>>,
46    },
47    Subscribe {
48        symbols: Vec<String>,
49        sub_types: SubFlags,
50        is_first_push: bool,
51        reply_tx: oneshot::Sender<Result<()>>,
52    },
53    Unsubscribe {
54        symbols: Vec<String>,
55        sub_types: SubFlags,
56        reply_tx: oneshot::Sender<Result<()>>,
57    },
58    SubscribeCandlesticks {
59        symbol: String,
60        period: Period,
61        reply_tx: oneshot::Sender<Result<Vec<Candlestick>>>,
62    },
63    UnsubscribeCandlesticks {
64        symbol: String,
65        period: Period,
66        reply_tx: oneshot::Sender<Result<()>>,
67    },
68    Subscriptions {
69        reply_tx: oneshot::Sender<Vec<Subscription>>,
70    },
71    GetRealtimeQuote {
72        symbols: Vec<String>,
73        reply_tx: oneshot::Sender<Vec<RealtimeQuote>>,
74    },
75    GetRealtimeDepth {
76        symbol: String,
77        reply_tx: oneshot::Sender<SecurityDepth>,
78    },
79    GetRealtimeTrade {
80        symbol: String,
81        count: usize,
82        reply_tx: oneshot::Sender<Vec<Trade>>,
83    },
84    GetRealtimeBrokers {
85        symbol: String,
86        reply_tx: oneshot::Sender<SecurityBrokers>,
87    },
88    GetRealtimeCandlesticks {
89        symbol: String,
90        period: Period,
91        count: usize,
92        reply_tx: oneshot::Sender<Vec<Candlestick>>,
93    },
94}
95
96#[derive(Debug, Default)]
97struct TradingDays {
98    normal_days: HashMap<Market, HashSet<Date>>,
99    half_days: HashMap<Market, HashSet<Date>>,
100}
101
102impl TradingDays {
103    #[inline]
104    fn normal_days(&self, market: Market) -> Days {
105        Days(self.normal_days.get(&market))
106    }
107
108    #[inline]
109    fn half_days(&self, market: Market) -> Days {
110        Days(self.half_days.get(&market))
111    }
112}
113
114#[derive(Debug, Copy, Clone)]
115struct Days<'a>(Option<&'a HashSet<Date>>);
116
117impl longport_candlesticks::Days for Days<'_> {
118    #[inline]
119    fn contains(&self, date: Date) -> bool {
120        match self.0 {
121            Some(days) => days.contains(&date),
122            None => false,
123        }
124    }
125}
126
127#[derive(Debug)]
128pub(crate) struct MarketPackageDetail {
129    pub(crate) market: String,
130    pub(crate) packages: Vec<QuotePackageDetail>,
131    pub(crate) warning: String,
132}
133
134pub(crate) struct Core {
135    config: Arc<Config>,
136    rate_limit: Vec<(u8, RateLimit)>,
137    command_rx: mpsc::UnboundedReceiver<Command>,
138    push_tx: mpsc::UnboundedSender<PushEvent>,
139    event_tx: mpsc::UnboundedSender<WsEvent>,
140    event_rx: mpsc::UnboundedReceiver<WsEvent>,
141    http_cli: HttpClient,
142    ws_cli: WsClient,
143    session: Option<WsSession>,
144    close: bool,
145    subscriptions: HashMap<String, SubFlags>,
146    trading_days: TradingDays,
147    store: Store,
148    member_id: i64,
149    quote_level: String,
150    quote_package_details: Vec<QuotePackageDetail>,
151    push_candlestick_mode: PushCandlestickMode,
152}
153
154impl Core {
155    pub(crate) async fn try_new(
156        config: Arc<Config>,
157        command_rx: mpsc::UnboundedReceiver<Command>,
158        push_tx: mpsc::UnboundedSender<PushEvent>,
159    ) -> Result<Self> {
160        let http_cli = config.create_http_client();
161        let otp = http_cli.get_otp_v2().await?;
162
163        let (event_tx, event_rx) = mpsc::unbounded_channel();
164
165        tracing::info!("connecting to quote server");
166        let (url, res) = config.create_quote_ws_request().await;
167        let request = res.map_err(WsClientError::from)?;
168
169        let mut ws_cli = WsClient::open(
170            request,
171            ProtocolVersion::Version1,
172            CodecType::Protobuf,
173            Platform::OpenAPI,
174            event_tx.clone(),
175            vec![],
176        )
177        .await?;
178
179        tracing::info!(url = url, "quote server connected");
180
181        let session = ws_cli.request_auth(otp, config.create_metadata()).await?;
182
183        // fetch user profile
184        let resp = ws_cli
185            .request::<_, quote::UserQuoteProfileResponse>(
186                cmd_code::QUERY_USER_QUOTE_PROFILE,
187                None,
188                quote::UserQuoteProfileRequest {
189                    language: config.language.unwrap_or_default().to_string(),
190                },
191            )
192            .await?;
193        let member_id = resp.member_id;
194        let quote_level = resp.quote_level;
195        let (quote_package_details, quote_package_details_by_market) = resp
196            .quote_level_detail
197            .map(|details| {
198                Ok::<_, Error>((
199                    details
200                        .by_package_key
201                        .into_values()
202                        .map(TryInto::try_into)
203                        .collect::<Result<Vec<_>>>()?,
204                    details
205                        .by_market_code
206                        .into_iter()
207                        .map(|(market, market_packages)| {
208                            Ok(MarketPackageDetail {
209                                market,
210                                packages: market_packages
211                                    .packages
212                                    .into_iter()
213                                    .map(TryInto::try_into)
214                                    .collect::<Result<Vec<_>>>()?,
215                                warning: market_packages.warning_msg,
216                            })
217                        })
218                        .collect::<Result<Vec<_>>>()?,
219                ))
220            })
221            .transpose()?
222            .unwrap_or_default();
223        let rate_limit: Vec<(u8, RateLimit)> = resp
224            .rate_limit
225            .iter()
226            .map(|config| {
227                (
228                    config.command as u8,
229                    RateLimit {
230                        interval: Duration::from_secs(1),
231                        initial: config.burst as usize,
232                        max: config.burst as usize,
233                        refill: config.limit as usize,
234                    },
235                )
236            })
237            .collect();
238        ws_cli.set_rate_limit(rate_limit.clone());
239
240        let current_trade_days = fetch_trading_days(&ws_cli).await?;
241        let push_candlestick_mode = config.push_candlestick_mode.unwrap_or_default();
242
243        let mut table = Table::new();
244        for market_packages in quote_package_details_by_market {
245            if market_packages.warning.is_empty() {
246                table.add_row(vec![
247                    market_packages.market,
248                    market_packages
249                        .packages
250                        .into_iter()
251                        .map(|package| package.name)
252                        .join(", "),
253                ]);
254            } else {
255                table.add_row(vec![market_packages.market, market_packages.warning]);
256            }
257        }
258
259        if config.enable_print_quote_packages {
260            println!("{}", table);
261        }
262
263        tracing::info!(
264            member_id = member_id,
265            quote_level = quote_level,
266            quote_package_details = ?quote_package_details,
267            "quote context initialized",
268        );
269
270        Ok(Self {
271            config,
272            rate_limit,
273            command_rx,
274            push_tx,
275            event_tx,
276            event_rx,
277            http_cli,
278            ws_cli,
279            session: Some(session),
280            close: false,
281            subscriptions: HashMap::new(),
282            trading_days: current_trade_days,
283            store: Store::default(),
284            member_id,
285            quote_level,
286            quote_package_details,
287            push_candlestick_mode,
288        })
289    }
290
291    #[inline]
292    pub(crate) fn member_id(&self) -> i64 {
293        self.member_id
294    }
295
296    #[inline]
297    pub(crate) fn quote_level(&self) -> &str {
298        &self.quote_level
299    }
300
301    #[inline]
302    pub(crate) fn quote_package_details(&self) -> &[QuotePackageDetail] {
303        &self.quote_package_details
304    }
305
306    pub(crate) async fn run(mut self) {
307        while !self.close {
308            match self.main_loop().await {
309                Ok(()) => return,
310                Err(err) => tracing::error!(error = %err, "quote disconnected"),
311            }
312
313            loop {
314                // reconnect
315                tokio::time::sleep(RECONNECT_DELAY).await;
316
317                tracing::info!("connecting to quote server");
318                let (url, res) = self.config.create_quote_ws_request().await;
319                let request = res.expect("BUG: failed to create quote ws request");
320
321                match WsClient::open(
322                    request,
323                    ProtocolVersion::Version1,
324                    CodecType::Protobuf,
325                    Platform::OpenAPI,
326                    self.event_tx.clone(),
327                    self.rate_limit.clone(),
328                )
329                .await
330                {
331                    Ok(ws_cli) => self.ws_cli = ws_cli,
332                    Err(err) => {
333                        tracing::error!(error = %err, "failed to connect quote server");
334                        continue;
335                    }
336                }
337
338                tracing::info!(url = url, "quote server connected");
339
340                // request new session
341                match &self.session {
342                    Some(session) if !session.is_expired() => {
343                        match self
344                            .ws_cli
345                            .request_reconnect(&session.session_id, self.config.create_metadata())
346                            .await
347                        {
348                            Ok(new_session) => self.session = Some(new_session),
349                            Err(err) => {
350                                self.session = None; // invalid session
351                                tracing::error!(error = %err, "failed to request session id");
352                                continue;
353                            }
354                        }
355                    }
356                    _ => {
357                        let otp = match self.http_cli.get_otp_v2().await {
358                            Ok(otp) => otp,
359                            Err(err) => {
360                                tracing::error!(error = %err, "failed to request otp");
361                                continue;
362                            }
363                        };
364
365                        match self
366                            .ws_cli
367                            .request_auth(otp, self.config.create_metadata())
368                            .await
369                        {
370                            Ok(new_session) => self.session = Some(new_session),
371                            Err(err) => {
372                                tracing::error!(error = %err, "failed to request session id");
373                                continue;
374                            }
375                        }
376                    }
377                }
378
379                // handle reconnect
380                match self.resubscribe().await {
381                    Ok(()) => break,
382                    Err(err) => {
383                        tracing::error!(error = %err, "failed to subscribe topics");
384                        continue;
385                    }
386                }
387            }
388        }
389    }
390
391    async fn main_loop(&mut self) -> Result<()> {
392        let mut update_trading_days_interval = tokio::time::interval_at(
393            Instant::now() + Duration::from_secs(60 * 60 * 24),
394            Duration::from_secs(60 * 60 * 24),
395        );
396        let mut ticker = tokio::time::interval(Duration::from_secs(1));
397
398        loop {
399            tokio::select! {
400                item = self.event_rx.recv() => {
401                    match item {
402                        Some(event) => self.handle_ws_event(event).await?,
403                        None => unreachable!(),
404                    }
405                }
406                item = self.command_rx.recv() => {
407                    match item {
408                        Some(command) => self.handle_command(command).await?,
409                        None => {
410                            self.close = true;
411                            return Ok(());
412                        }
413                    }
414                }
415                _ = ticker.tick() => self.tick(),
416                _ = update_trading_days_interval.tick() => {
417                    if let Ok(days) = fetch_trading_days(&self.ws_cli).await {
418                        self.trading_days = days;
419                    }
420                }
421            }
422        }
423    }
424
425    async fn handle_command(&mut self, command: Command) -> Result<()> {
426        match command {
427            Command::Request {
428                command_code,
429                body,
430                reply_tx,
431            } => self.handle_request(command_code, body, reply_tx).await,
432            Command::Subscribe {
433                symbols,
434                sub_types,
435                is_first_push,
436                reply_tx,
437            } => {
438                let res = self
439                    .handle_subscribe(symbols, sub_types, is_first_push)
440                    .await;
441                let _ = reply_tx.send(res);
442                Ok(())
443            }
444            Command::Unsubscribe {
445                symbols,
446                sub_types,
447                reply_tx,
448            } => {
449                let _ = reply_tx.send(self.handle_unsubscribe(symbols, sub_types).await);
450                Ok(())
451            }
452            Command::SubscribeCandlesticks {
453                symbol,
454                period,
455                reply_tx,
456            } => {
457                let _ = reply_tx.send(self.handle_subscribe_candlesticks(symbol, period).await);
458                Ok(())
459            }
460            Command::UnsubscribeCandlesticks {
461                symbol,
462                period,
463                reply_tx,
464            } => {
465                let _ = reply_tx.send(self.handle_unsubscribe_candlesticks(symbol, period).await);
466                Ok(())
467            }
468            Command::Subscriptions { reply_tx } => {
469                let res = self.handle_subscriptions().await;
470                let _ = reply_tx.send(res);
471                Ok(())
472            }
473            Command::GetRealtimeQuote { symbols, reply_tx } => {
474                let _ = reply_tx.send(self.handle_get_realtime_quote(symbols));
475                Ok(())
476            }
477            Command::GetRealtimeDepth { symbol, reply_tx } => {
478                let _ = reply_tx.send(self.handle_get_realtime_depth(symbol));
479                Ok(())
480            }
481            Command::GetRealtimeTrade {
482                symbol,
483                count,
484                reply_tx,
485            } => {
486                let _ = reply_tx.send(self.handle_get_realtime_trades(symbol, count));
487                Ok(())
488            }
489            Command::GetRealtimeBrokers { symbol, reply_tx } => {
490                let _ = reply_tx.send(self.handle_get_realtime_brokers(symbol));
491                Ok(())
492            }
493            Command::GetRealtimeCandlesticks {
494                symbol,
495                period,
496                count,
497                reply_tx,
498            } => {
499                let _ = reply_tx.send(self.handle_get_realtime_candlesticks(symbol, period, count));
500                Ok(())
501            }
502        }
503    }
504
505    async fn handle_request(
506        &mut self,
507        command_code: u8,
508        body: Vec<u8>,
509        reply_tx: oneshot::Sender<Result<Vec<u8>>>,
510    ) -> Result<()> {
511        let res = self.ws_cli.request_raw(command_code, None, body).await;
512        let _ = reply_tx.send(res.map_err(Into::into));
513        Ok(())
514    }
515
516    async fn handle_subscribe(
517        &mut self,
518        symbols: Vec<String>,
519        sub_types: SubFlags,
520        is_first_push: bool,
521    ) -> Result<()> {
522        // send request
523        let req = SubscribeRequest {
524            symbol: symbols.clone(),
525            sub_type: sub_types.into(),
526            is_first_push,
527        };
528        self.ws_cli
529            .request::<_, ()>(cmd_code::SUBSCRIBE, None, req)
530            .await?;
531
532        // update subscriptions
533        for symbol in symbols {
534            self.subscriptions
535                .entry(symbol)
536                .and_modify(|flags| *flags |= sub_types)
537                .or_insert(sub_types);
538        }
539
540        Ok(())
541    }
542
543    async fn handle_unsubscribe(
544        &mut self,
545        symbols: Vec<String>,
546        sub_types: SubFlags,
547    ) -> Result<()> {
548        tracing::info!(symbols = ?symbols, sub_types = ?sub_types, "unsubscribe");
549
550        // send requests
551        let mut st_group: HashMap<SubFlags, Vec<&str>> = HashMap::new();
552
553        for symbol in &symbols {
554            let mut st = sub_types;
555
556            if let Some(candlesticks) = self
557                .store
558                .securities
559                .get(symbol)
560                .map(|data| &data.candlesticks)
561            {
562                for period in candlesticks.keys() {
563                    if period == &Period::Day {
564                        st.remove(SubFlags::QUOTE);
565                    } else {
566                        st.remove(SubFlags::TRADE);
567                    }
568                }
569            }
570
571            if !st.is_empty() {
572                st_group.entry(st).or_default().push(symbol.as_ref());
573            }
574        }
575
576        let requests = st_group
577            .iter()
578            .map(|(st, symbols)| UnsubscribeRequest {
579                symbol: symbols.iter().map(ToString::to_string).collect(),
580                sub_type: (*st).into(),
581                unsub_all: false,
582            })
583            .collect::<Vec<_>>();
584
585        for req in requests {
586            self.ws_cli
587                .request::<_, ()>(cmd_code::UNSUBSCRIBE, None, req)
588                .await?;
589        }
590
591        // update subscriptions
592        let mut remove_symbols = Vec::new();
593        for symbol in &symbols {
594            if let Some(cur_flags) = self.subscriptions.get_mut(symbol) {
595                *cur_flags &= !sub_types;
596                if cur_flags.is_empty() {
597                    remove_symbols.push(symbol);
598                }
599            }
600        }
601
602        for symbol in remove_symbols {
603            self.subscriptions.remove(symbol);
604        }
605        Ok(())
606    }
607
608    async fn handle_subscribe_candlesticks(
609        &mut self,
610        symbol: String,
611        period: Period,
612    ) -> Result<Vec<Candlestick>> {
613        tracing::info!(symbol = symbol, period = ?period, "subscribe candlesticks");
614
615        if let Some(candlesticks) = self
616            .store
617            .securities
618            .get(&symbol)
619            .and_then(|data| data.candlesticks.get(&period))
620        {
621            tracing::info!(symbol = symbol, period = ?period, "subscribed, returns candlesticks in memory");
622            return Ok(candlesticks.candlesticks.clone());
623        }
624
625        tracing::info!(symbol = symbol, "fetch symbol board");
626
627        let security_data = self.store.securities.entry(symbol.clone()).or_default();
628        if security_data.board != SecurityBoard::Unknown {
629            // update board
630            let resp: SecurityStaticInfoResponse = self
631                .ws_cli
632                .request(
633                    cmd_code::GET_BASIC_INFO,
634                    None,
635                    MultiSecurityRequest {
636                        symbol: vec![symbol.clone()],
637                    },
638                )
639                .await?;
640            if resp.secu_static_info.is_empty() {
641                return Err(Error::InvalidSecuritySymbol {
642                    symbol: symbol.clone(),
643                });
644            }
645            security_data.board = resp.secu_static_info[0].board.parse().unwrap_or_default();
646        }
647
648        tracing::info!(symbol = symbol, board = ?security_data.board, "got the symbol board");
649
650        // pull candlesticks
651        tracing::info!(symbol = symbol, period = ?period, "pull history candlesticks");
652        let resp: SecurityCandlestickResponse = self
653            .ws_cli
654            .request(
655                cmd_code::GET_SECURITY_CANDLESTICKS,
656                None,
657                SecurityCandlestickRequest {
658                    symbol: symbol.clone(),
659                    period: period.into(),
660                    count: 1000,
661                    adjust_type: AdjustType::NoAdjust.into(),
662                },
663            )
664            .await?;
665        tracing::info!(symbol = symbol, period = ?period, len = resp.candlesticks.len(), "got history candlesticks");
666
667        let candlesticks = resp
668            .candlesticks
669            .into_iter()
670            .map(TryInto::try_into)
671            .collect::<Result<Vec<_>>>()?;
672        security_data
673            .candlesticks
674            .entry(period)
675            .or_insert_with(|| Candlesticks {
676                candlesticks: candlesticks.clone(),
677                confirmed: false,
678            });
679
680        let sub_flags = if period == Period::Day {
681            SubFlags::QUOTE
682        } else {
683            SubFlags::TRADE
684        };
685
686        // subscribe
687        if self
688            .subscriptions
689            .get(&symbol)
690            .copied()
691            .unwrap_or_else(SubFlags::empty)
692            .contains(sub_flags)
693        {
694            return Ok(candlesticks);
695        }
696
697        tracing::info!(symbol = symbol, period = ?period, sub_flags = ?sub_flags, "subscribe for candlesticks");
698
699        let req = SubscribeRequest {
700            symbol: vec![symbol.clone()],
701            sub_type: sub_flags.into(),
702            is_first_push: true,
703        };
704        self.ws_cli
705            .request::<_, ()>(cmd_code::SUBSCRIBE, None, req)
706            .await?;
707
708        tracing::info!(symbol = symbol, period = ?period, sub_flags = ?sub_flags, "subscribed for candlesticks");
709        Ok(candlesticks)
710    }
711
712    async fn handle_unsubscribe_candlesticks(
713        &mut self,
714        symbol: String,
715        period: Period,
716    ) -> Result<()> {
717        let mut unsubscribe_sub_flags = if period == Period::Day {
718            SubFlags::QUOTE
719        } else {
720            SubFlags::TRADE
721        };
722
723        if let Some(periods) = self
724            .store
725            .securities
726            .get_mut(&symbol)
727            .map(|data| &mut data.candlesticks)
728        {
729            periods.remove(&period);
730
731            for period in periods.keys() {
732                if period == &Period::Day {
733                    unsubscribe_sub_flags.remove(SubFlags::QUOTE);
734                } else {
735                    unsubscribe_sub_flags.remove(SubFlags::TRADE);
736                }
737            }
738
739            if !unsubscribe_sub_flags.is_empty()
740                && !self
741                    .subscriptions
742                    .get(&symbol)
743                    .copied()
744                    .unwrap_or_else(SubFlags::empty)
745                    .contains(unsubscribe_sub_flags)
746            {
747                self.ws_cli
748                    .request::<_, ()>(
749                        cmd_code::UNSUBSCRIBE,
750                        None,
751                        UnsubscribeRequest {
752                            symbol: vec![symbol],
753                            sub_type: unsubscribe_sub_flags.into(),
754                            unsub_all: false,
755                        },
756                    )
757                    .await?;
758            }
759        }
760
761        Ok(())
762    }
763
764    async fn handle_subscriptions(&mut self) -> Vec<Subscription> {
765        self.subscriptions
766            .iter()
767            .map(|(symbol, sub_flags)| Subscription {
768                symbol: symbol.clone(),
769                sub_types: *sub_flags,
770                candlesticks: self
771                    .store
772                    .securities
773                    .get(symbol)
774                    .map(|data| &data.candlesticks)
775                    .map(|periods| periods.keys().copied().collect())
776                    .unwrap_or_default(),
777            })
778            .collect()
779    }
780
781    async fn handle_ws_event(&mut self, event: WsEvent) -> Result<()> {
782        match event {
783            WsEvent::Error(err) => Err(err.into()),
784            WsEvent::Push { command_code, body } => self.handle_push(command_code, body),
785        }
786    }
787
788    async fn resubscribe(&mut self) -> Result<()> {
789        let mut subscriptions: HashMap<SubFlags, HashSet<String>> = HashMap::new();
790
791        for (symbol, flags) in &self.subscriptions {
792            subscriptions
793                .entry(*flags)
794                .or_default()
795                .insert(symbol.clone());
796        }
797
798        for (symbol, data) in &self.store.securities {
799            for period in data.candlesticks.keys() {
800                subscriptions
801                    .entry(if *period == Period::Day {
802                        SubFlags::QUOTE
803                    } else {
804                        SubFlags::TRADE
805                    })
806                    .or_default()
807                    .insert(symbol.clone());
808            }
809        }
810
811        tracing::info!(subscriptions = ?subscriptions, "resubscribe");
812
813        for (flags, symbols) in subscriptions {
814            self.ws_cli
815                .request::<_, ()>(
816                    cmd_code::SUBSCRIBE,
817                    None,
818                    SubscribeRequest {
819                        symbol: symbols.into_iter().collect(),
820                        sub_type: flags.into(),
821                        is_first_push: false,
822                    },
823                )
824                .await?;
825        }
826        Ok(())
827    }
828
829    fn tick(&mut self) {
830        let now = OffsetDateTime::now_utc();
831
832        for (symbol, security_data) in &mut self.store.securities {
833            let Some(market_type) = parse_market_from_symbol(symbol) else {
834                continue;
835            };
836            let normal_days = self.trading_days.normal_days(market_type);
837            let half_days = self.trading_days.half_days(market_type);
838
839            for (period, candlesticks) in &mut security_data.candlesticks {
840                let action = candlesticks.tick(
841                    market_type,
842                    normal_days,
843                    half_days,
844                    security_data.board,
845                    *period,
846                    now,
847                );
848                update_and_push_candlestick(
849                    candlesticks,
850                    symbol,
851                    *period,
852                    action,
853                    self.push_candlestick_mode,
854                    &mut self.push_tx,
855                );
856            }
857        }
858    }
859
860    fn merge_trades(&mut self, symbol: &str, trades: &PushTrades) {
861        let Some(market_type) = parse_market_from_symbol(symbol) else {
862            return;
863        };
864        let Some(security_data) = self.store.securities.get_mut(symbol) else {
865            return;
866        };
867        let half_days = self.trading_days.half_days(market_type);
868
869        for (period, candlesticks) in &mut security_data.candlesticks {
870            if period == &Period::Day {
871                continue;
872            }
873
874            for trade in &trades.trades {
875                if trade.trade_session != TradeSession::NormalTrade {
876                    continue;
877                }
878
879                let action = candlesticks.merge_trade(
880                    market_type,
881                    half_days,
882                    security_data.board,
883                    *period,
884                    trade,
885                );
886                update_and_push_candlestick(
887                    candlesticks,
888                    symbol,
889                    *period,
890                    action,
891                    self.push_candlestick_mode,
892                    &mut self.push_tx,
893                );
894            }
895        }
896    }
897
898    fn merge_quote(&mut self, symbol: &str, push_quote: &PushQuote) {
899        if push_quote.trade_session != TradeSession::NormalTrade {
900            return;
901        }
902
903        let Some(market_type) = parse_market_from_symbol(symbol) else {
904            return;
905        };
906        let Some(security_data) = self.store.securities.get_mut(symbol) else {
907            return;
908        };
909        let half_days = self.trading_days.half_days(market_type);
910        let Some(candlesticks) = security_data.candlesticks.get_mut(&Period::Day) else {
911            return;
912        };
913
914        let action = candlesticks.merge_quote(
915            market_type,
916            half_days,
917            security_data.board,
918            Period::Day,
919            push_quote,
920        );
921        update_and_push_candlestick(
922            candlesticks,
923            symbol,
924            Period::Day,
925            action,
926            self.push_candlestick_mode,
927            &mut self.push_tx,
928        )
929    }
930
931    fn handle_push(&mut self, command_code: u8, body: Vec<u8>) -> Result<()> {
932        match PushEvent::parse(command_code, &body) {
933            Ok((mut event, tag)) => {
934                tracing::info!(event = ?event, tag = ?tag, "push event");
935
936                if tag != Some(PushQuoteTag::Eod) {
937                    self.store.handle_push(&mut event);
938                }
939
940                if let PushEventDetail::Trade(trades) = &event.detail {
941                    // merge candlesticks
942                    self.merge_trades(&event.symbol, trades);
943
944                    if !self
945                        .subscriptions
946                        .get(&event.symbol)
947                        .map(|sub_flags| sub_flags.contains(SubFlags::TRADE))
948                        .unwrap_or_default()
949                    {
950                        return Ok(());
951                    }
952                } else if let PushEventDetail::Quote(push_quote) = &event.detail {
953                    self.merge_quote(&event.symbol, push_quote);
954
955                    if !self
956                        .subscriptions
957                        .get(&event.symbol)
958                        .map(|sub_flags| sub_flags.contains(SubFlags::QUOTE))
959                        .unwrap_or_default()
960                    {
961                        return Ok(());
962                    }
963                }
964
965                if tag == Some(PushQuoteTag::Eod) {
966                    return Ok(());
967                }
968
969                let _ = self.push_tx.send(event);
970            }
971            Err(err) => {
972                tracing::error!(error = %err, "failed to parse push message");
973            }
974        }
975        Ok(())
976    }
977
978    fn handle_get_realtime_quote(&self, symbols: Vec<String>) -> Vec<RealtimeQuote> {
979        let mut result = Vec::new();
980
981        for symbol in symbols {
982            if let Some(data) = self.store.securities.get(&symbol) {
983                result.push(RealtimeQuote {
984                    symbol,
985                    last_done: data.quote.last_done,
986                    open: data.quote.open,
987                    high: data.quote.high,
988                    low: data.quote.low,
989                    timestamp: data.quote.timestamp,
990                    volume: data.quote.volume,
991                    turnover: data.quote.turnover,
992                    trade_status: data.quote.trade_status,
993                });
994            }
995        }
996
997        result
998    }
999
1000    fn handle_get_realtime_depth(&self, symbol: String) -> SecurityDepth {
1001        let mut result = SecurityDepth::default();
1002        if let Some(data) = self.store.securities.get(&symbol) {
1003            result.asks.clone_from(&data.asks);
1004            result.bids.clone_from(&data.bids);
1005        }
1006        result
1007    }
1008
1009    fn handle_get_realtime_trades(&self, symbol: String, count: usize) -> Vec<Trade> {
1010        let mut res = Vec::new();
1011
1012        if let Some(data) = self.store.securities.get(&symbol) {
1013            let trades = if data.trades.len() >= count {
1014                &data.trades[data.trades.len() - count..]
1015            } else {
1016                &data.trades
1017            };
1018            res = trades.to_vec();
1019        }
1020        res
1021    }
1022
1023    fn handle_get_realtime_brokers(&self, symbol: String) -> SecurityBrokers {
1024        let mut result = SecurityBrokers::default();
1025        if let Some(data) = self.store.securities.get(&symbol) {
1026            result.ask_brokers.clone_from(&data.ask_brokers);
1027            result.bid_brokers.clone_from(&data.bid_brokers);
1028        }
1029        result
1030    }
1031
1032    fn handle_get_realtime_candlesticks(
1033        &self,
1034        symbol: String,
1035        period: Period,
1036        count: usize,
1037    ) -> Vec<Candlestick> {
1038        self.store
1039            .securities
1040            .get(&symbol)
1041            .map(|data| &data.candlesticks)
1042            .and_then(|periods| periods.get(&period))
1043            .map(|candlesticks| {
1044                let candlesticks = if candlesticks.candlesticks.len() >= count {
1045                    &candlesticks.candlesticks[candlesticks.candlesticks.len() - count..]
1046                } else {
1047                    &candlesticks.candlesticks
1048                };
1049                candlesticks.to_vec()
1050            })
1051            .unwrap_or_default()
1052    }
1053}
1054
1055async fn fetch_trading_days(cli: &WsClient) -> Result<TradingDays> {
1056    let mut days = TradingDays::default();
1057    let begin_day = OffsetDateTime::now_utc().date() - time::Duration::days(5);
1058    let end_day = begin_day + time::Duration::days(30);
1059
1060    for market in [Market::HK, Market::US, Market::SG, Market::CN] {
1061        let resp = cli
1062            .request::<_, MarketTradeDayResponse>(
1063                cmd_code::GET_TRADING_DAYS,
1064                None,
1065                MarketTradeDayRequest {
1066                    market: market.to_string(),
1067                    beg_day: format_date(begin_day),
1068                    end_day: format_date(end_day),
1069                },
1070            )
1071            .await?;
1072
1073        days.normal_days.insert(
1074            market,
1075            resp.trade_day
1076                .iter()
1077                .map(|value| {
1078                    parse_date(value).map_err(|err| Error::parse_field_error("half_trade_day", err))
1079                })
1080                .collect::<Result<HashSet<_>>>()?,
1081        );
1082
1083        days.half_days.insert(
1084            market,
1085            resp.half_trade_day
1086                .iter()
1087                .map(|value| {
1088                    parse_date(value).map_err(|err| Error::parse_field_error("half_trade_day", err))
1089                })
1090                .collect::<Result<HashSet<_>>>()?,
1091        );
1092    }
1093
1094    Ok(days)
1095}
1096
1097fn update_and_push_candlestick(
1098    candlesticks: &mut Candlesticks,
1099    symbol: &str,
1100    period: Period,
1101    action: UpdateAction,
1102    push_candlestick_mode: PushCandlestickMode,
1103    tx: &mut mpsc::UnboundedSender<PushEvent>,
1104) {
1105    let mut push_candlesticks = Vec::new();
1106
1107    match action {
1108        UpdateAction::UpdateLast(candlestick) => {
1109            *candlesticks.candlesticks.last_mut().unwrap() = candlestick.into();
1110            if push_candlestick_mode == PushCandlestickMode::Realtime {
1111                push_candlesticks.push((candlestick.into(), false));
1112            }
1113        }
1114        UpdateAction::AppendNew { confirmed, new } => {
1115            candlesticks.candlesticks.push(new.into());
1116            candlesticks.confirmed = false;
1117            if candlesticks.candlesticks.len() > MAX_CANDLESTICKS * 2 {
1118                candlesticks.candlesticks.drain(..MAX_CANDLESTICKS);
1119            }
1120
1121            match push_candlestick_mode {
1122                PushCandlestickMode::Realtime => {
1123                    if let Some(confirmed) = confirmed {
1124                        push_candlesticks.push((confirmed.into(), true));
1125                    }
1126                    push_candlesticks.push((new.into(), false));
1127                }
1128                PushCandlestickMode::Confirmed => {
1129                    if let Some(confirmed) = confirmed {
1130                        push_candlesticks.push((confirmed.into(), true));
1131                    }
1132                }
1133            }
1134        }
1135        UpdateAction::Confirm(candlestick) => {
1136            candlesticks.confirmed = true;
1137            if push_candlestick_mode == PushCandlestickMode::Confirmed {
1138                push_candlesticks.push((candlestick.into(), true));
1139            }
1140        }
1141        UpdateAction::None => {}
1142    };
1143
1144    for (candlestick, is_confirmed) in push_candlesticks {
1145        tracing::info!(
1146            symbol = symbol,
1147            period = ?period,
1148            is_confirmed = is_confirmed,
1149            candlestick = ?candlestick,
1150            "push candlestick"
1151        );
1152        let _ = tx.send(PushEvent {
1153            sequence: 0,
1154            symbol: symbol.to_string(),
1155            detail: PushEventDetail::Candlestick(PushCandlestick {
1156                period,
1157                candlestick,
1158                is_confirmed,
1159            }),
1160        });
1161    }
1162}
1163
1164fn parse_market_from_symbol(symbol: &str) -> Option<Market> {
1165    let market = symbol.rfind('.').map(|idx| &symbol[idx + 1..])?;
1166    Some(match market {
1167        "US" => Market::US,
1168        "HK" => Market::HK,
1169        "SG" => Market::SG,
1170        "SH" | "SZ" => Market::CN,
1171        _ => return None,
1172    })
1173}
1174
1175#[cfg(test)]
1176mod tests {
1177    use super::*;
1178
1179    #[test]
1180    fn test_parse_market_from_symbol() {
1181        assert_eq!(parse_market_from_symbol("AAPL.US"), Some(Market::US));
1182        assert_eq!(parse_market_from_symbol("BRK.A.US"), Some(Market::US));
1183    }
1184}