1use crate::manager::{
16 ApiCredentials, ApiKeyId, ApiKeyStats, ChannelId,
17 DistributionStrategy, KiteManagerConfig, ManagedConnection, MessageProcessor,
18 MultiApiConfig, MultiApiStats,
19};
20use crate::models::{Mode, TickerMessage};
21use std::collections::HashMap;
22use std::time::Instant;
23use tokio::sync::{broadcast, mpsc};
24
25#[derive(Debug)]
27struct ApiConnectionGroup {
28 api_key_id: ApiKeyId,
29 credentials: ApiCredentials,
30 connections: Vec<ManagedConnection>,
31 processors: Vec<MessageProcessor>,
32 subscribed_symbols: HashMap<u32, (usize, Mode)>, next_connection_index: usize,
34}
35
36impl ApiConnectionGroup {
37 fn new(api_key_id: ApiKeyId, credentials: ApiCredentials) -> Self {
38 Self {
39 api_key_id,
40 credentials,
41 connections: Vec::new(),
42 processors: Vec::new(),
43 subscribed_symbols: HashMap::new(),
44 next_connection_index: 0,
45 }
46 }
47
48 fn find_available_connection(
50 &mut self,
51 max_symbols_per_connection: usize,
52 ) -> Option<usize> {
53 let start_index = self.next_connection_index;
54
55 for _ in 0..self.connections.len() {
56 let connection = &self.connections[self.next_connection_index];
57
58 if connection.can_accept_symbols(1, max_symbols_per_connection) {
59 let result = self.next_connection_index;
60 self.next_connection_index =
61 (self.next_connection_index + 1) % self.connections.len();
62 return Some(result);
63 }
64
65 self.next_connection_index =
66 (self.next_connection_index + 1) % self.connections.len();
67 }
68
69 self.next_connection_index = start_index;
71 None
72 }
73
74 fn total_symbols(&self) -> usize {
76 self.subscribed_symbols.len()
77 }
78
79 async fn get_stats(&self) -> ApiKeyStats {
81 let mut stats = ApiKeyStats {
82 api_key_id: self.api_key_id.0.clone(),
83 active_connections: 0,
84 total_symbols: self.total_symbols(),
85 total_messages_received: 0,
86 total_messages_parsed: 0,
87 total_errors: 0,
88 connection_stats: Vec::new(),
89 };
90
91 for connection in &self.connections {
92 let conn_stats = connection.stats.read().await;
93 stats.connection_stats.push(conn_stats.clone());
94
95 if conn_stats.is_connected {
96 stats.active_connections += 1;
97 }
98
99 stats.total_messages_received += conn_stats.messages_received;
100 stats.total_messages_parsed += conn_stats.messages_parsed;
101 stats.total_errors += conn_stats.errors_count;
102 }
103
104 stats
105 }
106}
107
108#[derive(Debug)]
141pub struct MultiApiKiteTickerManager {
142 config: MultiApiConfig,
143 api_groups: HashMap<ApiKeyId, ApiConnectionGroup>,
144
145 unified_output_tx: broadcast::Sender<(ApiKeyId, TickerMessage)>,
147
148 symbol_to_api: HashMap<u32, ApiKeyId>,
150
151 next_api_index: usize,
153 api_key_order: Vec<ApiKeyId>, start_time: Instant,
156}
157
158#[derive(Debug, Clone)]
160pub struct MultiApiKiteTickerManagerBuilder {
161 api_credentials: HashMap<ApiKeyId, ApiCredentials>,
162 config: MultiApiConfig,
163}
164
165impl MultiApiKiteTickerManagerBuilder {
166 pub fn new() -> Self {
168 Self {
169 api_credentials: HashMap::new(),
170 config: MultiApiConfig::default(),
171 }
172 }
173
174 pub fn add_api_key(
176 mut self,
177 id: impl Into<ApiKeyId>,
178 api_key: impl Into<String>,
179 access_token: impl Into<String>,
180 ) -> Self {
181 let api_key_id = id.into();
182 let credentials = ApiCredentials::new(api_key, access_token);
183 self.api_credentials.insert(api_key_id, credentials);
184 self
185 }
186
187 pub fn max_connections_per_api(mut self, n: usize) -> Self {
189 self.config.max_connections_per_api = n.min(3); self
191 }
192
193 pub fn distribution_strategy(mut self, strategy: DistributionStrategy) -> Self {
195 self.config.distribution_strategy = strategy;
196 self
197 }
198
199 pub fn base_config(mut self, config: KiteManagerConfig) -> Self {
201 self.config.base_config = config;
202 self
203 }
204
205 pub fn max_symbols_per_connection(mut self, n: usize) -> Self {
207 self.config.base_config.max_symbols_per_connection = n;
208 self
209 }
210
211 pub fn connection_timeout(mut self, d: std::time::Duration) -> Self {
213 self.config.base_config.connection_timeout = d;
214 self
215 }
216
217 pub fn health_check_interval(mut self, d: std::time::Duration) -> Self {
219 self.config.base_config.health_check_interval = d;
220 self
221 }
222
223 pub fn enable_health_monitoring(mut self, enable: bool) -> Self {
225 self.config.enable_health_monitoring = enable;
226 self
227 }
228
229 pub fn default_mode(mut self, mode: Mode) -> Self {
231 self.config.base_config.default_mode = mode;
232 self
233 }
234
235 pub fn build(self) -> MultiApiKiteTickerManager {
237 MultiApiKiteTickerManager::new(self.api_credentials, self.config)
238 }
239}
240
241impl Default for MultiApiKiteTickerManagerBuilder {
242 fn default() -> Self {
243 Self::new()
244 }
245}
246
247impl MultiApiKiteTickerManager {
248 pub fn builder() -> MultiApiKiteTickerManagerBuilder {
250 MultiApiKiteTickerManagerBuilder::new()
251 }
252
253 fn new(
255 api_credentials: HashMap<ApiKeyId, ApiCredentials>,
256 config: MultiApiConfig,
257 ) -> Self {
258 let (unified_output_tx, _) =
259 broadcast::channel(config.base_config.parser_buffer_size);
260
261 let api_key_order: Vec<ApiKeyId> = api_credentials.keys().cloned().collect();
262
263 let mut api_groups = HashMap::new();
264 for (id, creds) in api_credentials {
265 api_groups.insert(id.clone(), ApiConnectionGroup::new(id, creds));
266 }
267
268 Self {
269 config,
270 api_groups,
271 unified_output_tx,
272 symbol_to_api: HashMap::new(),
273 next_api_index: 0,
274 api_key_order,
275 start_time: Instant::now(),
276 }
277 }
278
279 pub async fn start(&mut self) -> Result<(), String> {
281 if self.api_groups.is_empty() {
282 return Err("No API keys configured".to_string());
283 }
284
285 log::info!(
286 "Starting MultiApiKiteTickerManager with {} API keys",
287 self.api_groups.len()
288 );
289
290 let unified_tx = self.unified_output_tx.clone();
292
293 for (api_key_id, group) in &mut self.api_groups {
294 log::info!(
295 "Initializing {} connections for API key: {}",
296 self.config.max_connections_per_api,
297 api_key_id.0
298 );
299
300 for i in 0..self.config.max_connections_per_api {
301 let channel_id = ChannelId::from_index(i)
302 .ok_or_else(|| format!("Invalid connection index: {}", i))?;
303
304 let (connection_sender, processor_receiver) = mpsc::unbounded_channel();
306
307 let mut connection =
309 ManagedConnection::new(channel_id, connection_sender);
310
311 connection
313 .connect(
314 &group.credentials.api_key,
315 &group.credentials.access_token,
316 &self.config.base_config,
317 )
318 .await
319 .map_err(|e| {
320 format!(
321 "Failed to connect WebSocket {} for API key {}: {}",
322 i, api_key_id.0, e
323 )
324 })?;
325
326 let (mut processor, output_receiver) = MessageProcessor::new(
328 channel_id,
329 processor_receiver,
330 self.config.base_config.parser_buffer_size,
331 );
332
333 if self.config.base_config.enable_dedicated_parsers {
335 processor.start();
336 log::info!(
337 "Started dedicated parser for API key {} connection {}",
338 api_key_id.0,
339 i
340 );
341 }
342
343 Self::spawn_message_forwarder_static(
345 unified_tx.clone(),
346 api_key_id.clone(),
347 output_receiver,
348 );
349
350 group.connections.push(connection);
351 group.processors.push(processor);
352 }
353
354 log::info!(
355 "Initialized {} connections for API key: {}",
356 group.connections.len(),
357 api_key_id.0
358 );
359 }
360
361 log::info!(
362 "MultiApiKiteTickerManager started successfully with {} API keys",
363 self.api_groups.len()
364 );
365
366 Ok(())
367 }
368
369 fn spawn_message_forwarder_static(
371 tx: broadcast::Sender<(ApiKeyId, TickerMessage)>,
372 api_key_id: ApiKeyId,
373 mut receiver: broadcast::Receiver<TickerMessage>,
374 ) {
375 tokio::spawn(async move {
376 loop {
377 match receiver.recv().await {
378 Ok(msg) => {
379 let _ = tx.send((api_key_id.clone(), msg));
381 }
382 Err(broadcast::error::RecvError::Closed) => {
383 log::debug!(
384 "Message forwarder closed for API key: {}",
385 api_key_id.0
386 );
387 break;
388 }
389 Err(broadcast::error::RecvError::Lagged(n)) => {
390 log::warn!(
391 "Message forwarder lagged by {} messages for API key: {}",
392 n,
393 api_key_id.0
394 );
395 continue;
396 }
397 }
398 }
399 });
400 }
401
402 pub async fn subscribe_symbols(
404 &mut self,
405 symbols: &[u32],
406 mode: Option<Mode>,
407 ) -> Result<(), String> {
408 if self.config.distribution_strategy == DistributionStrategy::Manual {
409 return Err(
410 "Cannot use auto-subscribe with Manual distribution strategy. Use subscribe_symbols_to_api instead.".to_string()
411 );
412 }
413
414 let mode = mode.unwrap_or(self.config.base_config.default_mode);
415
416 log::info!(
417 "Subscribing to {} symbols with mode: {:?} using round-robin distribution",
418 symbols.len(),
419 mode
420 );
421
422 for &symbol in symbols {
424 if self.symbol_to_api.contains_key(&symbol) {
426 log::debug!("Symbol {} already subscribed", symbol);
427 continue;
428 }
429
430 let api_key_id = self.find_available_api_key()?;
432
433 self
435 .subscribe_symbol_to_api(&api_key_id, symbol, mode)
436 .await?;
437 }
438
439 log::info!("Successfully subscribed to {} symbols", symbols.len());
440 Ok(())
441 }
442
443 pub async fn subscribe_symbols_to_api(
445 &mut self,
446 api_key_id: impl Into<ApiKeyId>,
447 symbols: &[u32],
448 mode: Option<Mode>,
449 ) -> Result<(), String> {
450 let api_key_id = api_key_id.into();
451 let mode = mode.unwrap_or(self.config.base_config.default_mode);
452
453 log::info!(
454 "Subscribing {} symbols to API key: {} with mode: {:?}",
455 symbols.len(),
456 api_key_id.0,
457 mode
458 );
459
460 for &symbol in symbols {
461 self
462 .subscribe_symbol_to_api(&api_key_id, symbol, mode)
463 .await?;
464 }
465
466 log::info!(
467 "Successfully subscribed {} symbols to API key: {}",
468 symbols.len(),
469 api_key_id.0
470 );
471 Ok(())
472 }
473
474 async fn subscribe_symbol_to_api(
476 &mut self,
477 api_key_id: &ApiKeyId,
478 symbol: u32,
479 mode: Mode,
480 ) -> Result<(), String> {
481 let group = self
482 .api_groups
483 .get_mut(api_key_id)
484 .ok_or_else(|| format!("API key not found: {}", api_key_id.0))?;
485
486 let connection_index = group
488 .find_available_connection(
489 self.config.base_config.max_symbols_per_connection,
490 )
491 .ok_or_else(|| {
492 format!(
493 "All connections at capacity for API key: {}",
494 api_key_id.0
495 )
496 })?;
497
498 let connection = &mut group.connections[connection_index];
499
500 if connection.subscribed_symbols.is_empty() {
502 connection
504 .subscribe_symbols(&[symbol], mode)
505 .await
506 .map_err(|e| {
507 format!(
508 "Failed to subscribe symbol {} on API key {}: {}",
509 symbol, api_key_id.0, e
510 )
511 })?;
512
513 connection.start_message_processing().await.map_err(|e| {
514 format!(
515 "Failed to start message processing on API key {}: {}",
516 api_key_id.0, e
517 )
518 })?;
519 } else {
520 connection.add_symbols(&[symbol], mode).await.map_err(|e| {
522 format!(
523 "Failed to add symbol {} on API key {}: {}",
524 symbol, api_key_id.0, e
525 )
526 })?;
527 }
528
529 group
531 .subscribed_symbols
532 .insert(symbol, (connection_index, mode));
533 self.symbol_to_api.insert(symbol, api_key_id.clone());
534
535 Ok(())
536 }
537
538 fn find_available_api_key(&mut self) -> Result<ApiKeyId, String> {
540 if self.api_key_order.is_empty() {
541 return Err("No API keys configured".to_string());
542 }
543
544 let start_index = self.next_api_index;
545
546 for _ in 0..self.api_key_order.len() {
547 let api_key_id = &self.api_key_order[self.next_api_index];
548
549 if let Some(group) = self.api_groups.get_mut(api_key_id) {
550 let has_capacity = group
552 .find_available_connection(
553 self.config.base_config.max_symbols_per_connection,
554 )
555 .is_some();
556
557 if has_capacity {
558 let result = api_key_id.clone();
559 self.next_api_index =
560 (self.next_api_index + 1) % self.api_key_order.len();
561 return Ok(result);
562 }
563 }
564
565 self.next_api_index =
566 (self.next_api_index + 1) % self.api_key_order.len();
567 }
568
569 self.next_api_index = start_index;
571 Err("All API keys are at capacity".to_string())
572 }
573
574 pub async fn unsubscribe_symbols(
576 &mut self,
577 symbols: &[u32],
578 ) -> Result<(), String> {
579 log::info!("Unsubscribing from {} symbols", symbols.len());
580
581 let mut api_symbols: HashMap<ApiKeyId, Vec<u32>> = HashMap::new();
583
584 for &symbol in symbols {
585 if let Some(api_key_id) = self.symbol_to_api.get(&symbol) {
586 api_symbols
587 .entry(api_key_id.clone())
588 .or_default()
589 .push(symbol);
590 }
591 }
592
593 for (api_key_id, symbols) in api_symbols {
595 if let Some(group) = self.api_groups.get_mut(&api_key_id) {
596 let mut conn_symbols: HashMap<usize, Vec<u32>> = HashMap::new();
598
599 for symbol in symbols {
600 if let Some((conn_idx, _)) = group.subscribed_symbols.get(&symbol) {
601 conn_symbols.entry(*conn_idx).or_default().push(symbol);
602 }
603 }
604
605 for (conn_idx, symbols) in conn_symbols {
607 if let Some(connection) = group.connections.get_mut(conn_idx) {
608 connection.remove_symbols(&symbols).await.map_err(|e| {
609 format!(
610 "Failed to unsubscribe from API key {}: {}",
611 api_key_id.0, e
612 )
613 })?;
614 }
615
616 for symbol in symbols {
618 group.subscribed_symbols.remove(&symbol);
619 self.symbol_to_api.remove(&symbol);
620 }
621 }
622 }
623 }
624
625 log::info!("Successfully unsubscribed from symbols");
626 Ok(())
627 }
628
629 pub async fn change_mode(
631 &mut self,
632 symbols: &[u32],
633 mode: Mode,
634 ) -> Result<(), String> {
635 log::info!("Changing mode for {} symbols to {:?}", symbols.len(), mode);
636
637 let mut api_symbols: HashMap<ApiKeyId, HashMap<usize, Vec<u32>>> =
639 HashMap::new();
640
641 for &symbol in symbols {
642 if let Some(api_key_id) = self.symbol_to_api.get(&symbol) {
643 if let Some(group) = self.api_groups.get(api_key_id) {
644 if let Some((conn_idx, _)) = group.subscribed_symbols.get(&symbol) {
645 api_symbols
646 .entry(api_key_id.clone())
647 .or_default()
648 .entry(*conn_idx)
649 .or_default()
650 .push(symbol);
651 }
652 }
653 }
654 }
655
656 for (api_key_id, conn_symbols) in api_symbols {
658 if let Some(group) = self.api_groups.get_mut(&api_key_id) {
659 for (conn_idx, symbols) in conn_symbols {
660 if let Some(connection) = group.connections.get_mut(conn_idx) {
661 if let Some(ref cmd) = connection.cmd_tx {
662 let mode_req =
663 crate::models::Request::mode(mode, &symbols).to_string();
664 let _ = cmd.send(
665 tokio_tungstenite::tungstenite::Message::Text(mode_req.into()),
666 );
667
668 for &symbol in &symbols {
670 connection.subscribed_symbols.insert(symbol, mode);
671 group.subscribed_symbols.insert(symbol, (conn_idx, mode));
672 }
673 }
674 }
675 }
676 }
677 }
678
679 log::info!("Successfully changed mode for symbols");
680 Ok(())
681 }
682
683 pub fn get_unified_channel(
687 &self,
688 ) -> broadcast::Receiver<(ApiKeyId, TickerMessage)> {
689 self.unified_output_tx.subscribe()
690 }
691
692 pub fn get_channel(
694 &mut self,
695 api_key_id: impl Into<ApiKeyId>,
696 channel_id: ChannelId,
697 ) -> Option<broadcast::Receiver<TickerMessage>> {
698 let api_key_id = api_key_id.into();
699 self.api_groups.get_mut(&api_key_id).and_then(|group| {
700 group
701 .processors
702 .get_mut(channel_id.to_index())
703 .map(|p| p.output_sender.subscribe())
704 })
705 }
706
707 pub async fn get_stats(&self) -> MultiApiStats {
709 let mut stats = MultiApiStats {
710 total_api_keys: self.api_groups.len(),
711 total_connections: 0,
712 total_symbols: self.symbol_to_api.len(),
713 total_messages_received: 0,
714 total_messages_parsed: 0,
715 total_errors: 0,
716 uptime: self.start_time.elapsed(),
717 per_api_stats: Vec::new(),
718 };
719
720 for group in self.api_groups.values() {
721 let api_stats = group.get_stats().await;
722
723 stats.total_connections += api_stats.active_connections;
724 stats.total_messages_received += api_stats.total_messages_received;
725 stats.total_messages_parsed += api_stats.total_messages_parsed;
726 stats.total_errors += api_stats.total_errors;
727
728 stats.per_api_stats.push(api_stats);
729 }
730
731 stats
732 }
733
734 pub async fn get_api_stats(
736 &self,
737 api_key_id: impl Into<ApiKeyId>,
738 ) -> Result<ApiKeyStats, String> {
739 let api_key_id = api_key_id.into();
740 self
741 .api_groups
742 .get(&api_key_id)
743 .ok_or_else(|| format!("API key not found: {}", api_key_id.0))?
744 .get_stats()
745 .await
746 .pipe(Ok)
747 }
748
749 pub fn get_symbol_distribution(
751 &self,
752 ) -> HashMap<ApiKeyId, HashMap<usize, Vec<u32>>> {
753 let mut distribution: HashMap<ApiKeyId, HashMap<usize, Vec<u32>>> =
754 HashMap::new();
755
756 for (api_key_id, group) in &self.api_groups {
757 let mut api_dist: HashMap<usize, Vec<u32>> = HashMap::new();
758
759 for (&symbol, &(conn_idx, _)) in &group.subscribed_symbols {
760 api_dist.entry(conn_idx).or_default().push(symbol);
761 }
762
763 distribution.insert(api_key_id.clone(), api_dist);
764 }
765
766 distribution
767 }
768
769 pub fn get_api_keys(&self) -> Vec<ApiKeyId> {
771 self.api_key_order.clone()
772 }
773
774 pub async fn stop(&mut self) -> Result<(), String> {
776 log::info!("Stopping MultiApiKiteTickerManager");
777
778 for (api_key_id, group) in &mut self.api_groups {
779 log::info!("Stopping connections for API key: {}", api_key_id.0);
780
781 for processor in &mut group.processors {
783 processor.stop().await;
784 }
785
786 for connection in &mut group.connections {
788 if let Some(h) = connection.heartbeat_handle.take() {
789 h.abort();
790 let _ = h.await;
791 }
792 if let Some(handle) = connection.task_handle.take() {
793 handle.abort();
794 let _ = handle.await;
795 }
796 }
797 }
798
799 log::info!("MultiApiKiteTickerManager stopped");
800 Ok(())
801 }
802}
803
804trait Pipe: Sized {
806 fn pipe<F, R>(self, f: F) -> R
807 where
808 F: FnOnce(Self) -> R,
809 {
810 f(self)
811 }
812}
813
814impl<T> Pipe for T {}