use crate::manager::{
ChannelId, ConnectionStats, HealthMonitor, HealthSummary, KiteManagerConfig,
ManagedConnection, ManagerStats, MessageProcessor, ProcessorStats,
};
use crate::models::{Mode, TickerMessage};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, mpsc, RwLock};
#[derive(Debug)]
pub struct KiteTickerManager {
config: KiteManagerConfig,
api_key: String,
access_token: String,
connections: Vec<ManagedConnection>,
processors: Vec<MessageProcessor>,
output_channels: Vec<broadcast::Receiver<TickerMessage>>,
symbol_mapping: HashMap<u32, ChannelId>,
health_monitor: Option<HealthMonitor>,
next_connection_index: usize,
#[allow(dead_code)]
start_time: Instant,
raw_only: bool,
}
#[derive(Debug, Clone)]
pub struct KiteTickerManagerBuilder {
api_key: String,
access_token: String,
config: KiteManagerConfig,
raw_only: bool,
}
impl KiteTickerManagerBuilder {
pub fn new(
api_key: impl Into<String>,
access_token: impl Into<String>,
) -> Self {
Self {
api_key: api_key.into(),
access_token: access_token.into(),
config: KiteManagerConfig::default(),
raw_only: false,
}
}
pub fn max_connections(mut self, n: usize) -> Self {
self.config.max_connections = n;
self
}
pub fn max_symbols_per_connection(mut self, n: usize) -> Self {
self.config.max_symbols_per_connection = n;
self
}
pub fn connection_timeout(mut self, d: std::time::Duration) -> Self {
self.config.connection_timeout = d;
self
}
pub fn health_check_interval(mut self, d: std::time::Duration) -> Self {
self.config.health_check_interval = d;
self
}
pub fn reconnect_attempts(mut self, attempts: usize) -> Self {
self.config.max_reconnect_attempts = attempts;
self
}
pub fn reconnect_delay(mut self, d: std::time::Duration) -> Self {
self.config.reconnect_delay = d;
self
}
pub fn enable_dedicated_parsers(mut self, enable: bool) -> Self {
self.config.enable_dedicated_parsers = enable;
self
}
pub fn default_mode(mut self, mode: Mode) -> Self {
self.config.default_mode = mode;
self
}
pub fn heartbeat_liveness_threshold(
mut self,
d: std::time::Duration,
) -> Self {
self.config.heartbeat_liveness_threshold = d;
self
}
pub fn connection_buffer_size(mut self, sz: usize) -> Self {
self.config.connection_buffer_size = sz;
self
}
pub fn parser_buffer_size(mut self, sz: usize) -> Self {
self.config.parser_buffer_size = sz;
self
}
pub fn raw_only(mut self, raw: bool) -> Self {
self.raw_only = raw;
self
}
pub fn config(mut self, config: KiteManagerConfig) -> Self {
self.config = config;
self
}
pub fn build(self) -> KiteTickerManager {
KiteTickerManager::new(self.api_key, self.access_token, self.config)
.with_raw_only(self.raw_only)
}
}
impl KiteTickerManager {
pub fn new(
api_key: String,
access_token: String,
config: KiteManagerConfig,
) -> Self {
Self {
config,
api_key,
access_token,
connections: Vec::new(),
processors: Vec::new(),
output_channels: Vec::new(),
symbol_mapping: HashMap::new(),
health_monitor: None,
next_connection_index: 0,
start_time: Instant::now(),
raw_only: false,
}
}
pub fn with_raw_only(mut self, raw: bool) -> Self {
self.raw_only = raw;
self
}
pub async fn start(&mut self) -> Result<(), String> {
log::info!(
"Starting KiteTickerManager with {} connections",
self.config.max_connections
);
for i in 0..self.config.max_connections {
let channel_id = ChannelId::from_index(i)
.ok_or_else(|| format!("Invalid connection index: {}", i))?;
let (connection_sender, processor_receiver) = mpsc::unbounded_channel();
let mut connection =
ManagedConnection::new(channel_id, connection_sender);
if self.raw_only {
connection
.connect_with_raw(
&self.api_key,
&self.access_token,
&self.config,
true,
)
.await
.map_err(|e| format!("Failed to connect WebSocket {}: {}", i, e))?;
} else {
connection
.connect(&self.api_key, &self.access_token, &self.config)
.await
.map_err(|e| format!("Failed to connect WebSocket {}: {}", i, e))?;
}
let (mut processor, output_receiver) = MessageProcessor::new(
channel_id,
processor_receiver,
self.config.parser_buffer_size,
);
if self.config.enable_dedicated_parsers {
processor.start();
log::info!("Started dedicated parser for connection {}", i);
}
self.connections.push(connection);
self.processors.push(processor);
self.output_channels.push(output_receiver);
}
if self.config.health_check_interval.as_secs() > 0 {
let connection_stats: Vec<Arc<RwLock<ConnectionStats>>> = self
.connections
.iter()
.map(|c| Arc::clone(&c.stats))
.collect();
let mut health_monitor =
HealthMonitor::new(connection_stats, self.config.health_check_interval);
health_monitor.start();
self.health_monitor = Some(health_monitor);
log::info!("Started health monitor");
}
log::info!(
"KiteTickerManager started successfully with {} connections",
self.connections.len()
);
Ok(())
}
pub async fn subscribe_symbols(
&mut self,
symbols: &[u32],
mode: Option<Mode>,
) -> Result<(), String> {
let mode = mode.unwrap_or(self.config.default_mode);
log::info!(
"Subscribing to {} symbols with mode: {:?}",
symbols.len(),
mode
);
let mut connection_symbols: HashMap<ChannelId, Vec<u32>> = HashMap::new();
for &symbol in symbols {
if self.symbol_mapping.contains_key(&symbol) {
log::debug!("Symbol {} already subscribed", symbol);
continue;
}
let connection_id = self.find_available_connection()?;
self.symbol_mapping.insert(symbol, connection_id);
connection_symbols
.entry(connection_id)
.or_default()
.push(symbol);
}
for (connection_id, symbols) in connection_symbols {
let connection = &mut self.connections[connection_id.to_index()];
let mode_clone = mode;
if !symbols.is_empty() {
if connection.subscribed_symbols.is_empty() {
connection
.subscribe_symbols(&symbols, mode_clone)
.await
.map_err(|e| {
format!(
"Failed to subscribe on connection {:?}: {}",
connection_id, e
)
})?;
connection.start_message_processing().await.map_err(|e| {
format!(
"Failed to start message processing on connection {:?}: {}",
connection_id, e
)
})?;
} else {
connection
.add_symbols(&symbols, mode_clone)
.await
.map_err(|e| {
format!(
"Failed to add symbols on connection {:?}: {}",
connection_id, e
)
})?;
}
log::info!(
"Subscribed {} symbols on connection {:?}",
symbols.len(),
connection_id
);
}
}
log::info!("Successfully subscribed to {} new symbols", symbols.len());
Ok(())
}
fn find_available_connection(&mut self) -> Result<ChannelId, String> {
let _start_index = self.next_connection_index;
for _ in 0..self.config.max_connections {
let connection = &self.connections[self.next_connection_index];
if connection
.can_accept_symbols(1, self.config.max_symbols_per_connection)
{
let channel_id = connection.id;
self.next_connection_index =
(self.next_connection_index + 1) % self.config.max_connections;
return Ok(channel_id);
}
self.next_connection_index =
(self.next_connection_index + 1) % self.config.max_connections;
}
Err("All connections are at capacity".to_string())
}
pub fn get_channel(
&mut self,
channel_id: ChannelId,
) -> Option<broadcast::Receiver<TickerMessage>> {
if channel_id.to_index() < self.output_channels.len() {
Some(self.output_channels[channel_id.to_index()].resubscribe())
} else {
None
}
}
pub fn get_all_channels(
&mut self,
) -> Vec<(ChannelId, broadcast::Receiver<TickerMessage>)> {
let mut channels = Vec::new();
for (i, channel) in self.output_channels.iter().enumerate() {
if let Some(channel_id) = ChannelId::from_index(i) {
channels.push((channel_id, channel.resubscribe()));
}
}
channels
}
pub fn get_raw_frame_channel(
&self,
channel_id: ChannelId,
) -> Option<tokio::sync::broadcast::Receiver<bytes::Bytes>> {
self
.connections
.get(channel_id.to_index())
.and_then(|mc| mc.ticker.as_ref())
.map(|t| t.subscribe_raw_frames())
}
pub fn get_full_raw_subscriber(
&self,
channel_id: ChannelId,
) -> Option<crate::KiteTickerRawSubscriber184> {
self
.connections
.get(channel_id.to_index())
.and_then(|mc| mc.ticker.as_ref())
.map(|t| t.subscribe_full_raw())
}
pub fn get_all_raw_frame_channels(
&self,
) -> Vec<(ChannelId, tokio::sync::broadcast::Receiver<bytes::Bytes>)> {
let mut out = Vec::with_capacity(self.connections.len());
for (i, mc) in self.connections.iter().enumerate() {
if let Some(ch) = ChannelId::from_index(i) {
if let Some(t) = mc.ticker.as_ref() {
out.push((ch, t.subscribe_raw_frames()));
}
}
}
out
}
pub async fn get_stats(&self) -> Result<ManagerStats, String> {
if let Some(health_monitor) = &self.health_monitor {
Ok(health_monitor.get_manager_stats().await)
} else {
Err("Health monitor not available".to_string())
}
}
pub async fn get_health(&self) -> Result<HealthSummary, String> {
if let Some(health_monitor) = &self.health_monitor {
Ok(health_monitor.get_health_summary().await)
} else {
Err("Health monitor not available".to_string())
}
}
pub async fn get_processor_stats(&self) -> Vec<(ChannelId, ProcessorStats)> {
let mut stats = Vec::new();
for processor in &self.processors {
let processor_stats = processor.get_stats().await;
stats.push((processor.channel_id, processor_stats));
}
stats
}
pub fn get_symbol_distribution(&self) -> HashMap<ChannelId, Vec<u32>> {
let mut distribution: HashMap<ChannelId, Vec<u32>> = HashMap::new();
for (&symbol, &channel_id) in &self.symbol_mapping {
distribution.entry(channel_id).or_default().push(symbol);
}
distribution
}
pub async fn unsubscribe_symbols(
&mut self,
symbols: &[u32],
) -> Result<(), String> {
log::info!("Unsubscribing from {} symbols", symbols.len());
let mut connection_symbols: HashMap<ChannelId, Vec<u32>> = HashMap::new();
for &symbol in symbols {
if let Some(&channel_id) = self.symbol_mapping.get(&symbol) {
connection_symbols
.entry(channel_id)
.or_default()
.push(symbol);
self.symbol_mapping.remove(&symbol);
} else {
log::debug!("Symbol {} not found in subscriptions", symbol);
}
}
for (channel_id, symbols) in connection_symbols {
let connection = &mut self.connections[channel_id.to_index()];
if !symbols.is_empty() {
connection.remove_symbols(&symbols).await.map_err(|e| {
format!(
"Failed to unsubscribe from connection {:?}: {}",
channel_id, e
)
})?;
log::info!(
"Unsubscribed {} symbols from connection {:?}",
symbols.len(),
channel_id
);
}
}
log::info!("Successfully unsubscribed from {} symbols", symbols.len());
Ok(())
}
pub async fn change_mode(
&mut self,
symbols: &[u32],
mode: Mode,
) -> Result<(), String> {
log::info!("Changing mode for {} symbols to {:?}", symbols.len(), mode);
let mut connection_symbols: HashMap<ChannelId, Vec<u32>> = HashMap::new();
for &symbol in symbols {
if let Some(&channel_id) = self.symbol_mapping.get(&symbol) {
connection_symbols
.entry(channel_id)
.or_default()
.push(symbol);
} else {
log::debug!("Symbol {} not found in subscriptions", symbol);
}
}
for (channel_id, symbols) in connection_symbols {
let connection = &mut self.connections[channel_id.to_index()];
if symbols.is_empty() {
continue;
}
if let Some(ref cmd) = connection.cmd_tx {
let mode_req = crate::models::Request::mode(mode, &symbols).to_string();
let _ = cmd.send(tokio_tungstenite::tungstenite::Message::Text(
mode_req.into(),
));
for &s in &symbols {
connection.subscribed_symbols.insert(s, mode);
}
log::info!(
"Changed mode for {} symbols on connection {:?}",
symbols.len(),
channel_id
);
} else if let Some(subscriber) = &mut connection.subscriber {
subscriber.set_mode(&symbols, mode).await.map_err(|e| {
format!(
"Failed to change mode on connection {:?}: {}",
channel_id, e
)
})?;
for &s in &symbols {
connection.subscribed_symbols.insert(s, mode);
}
}
}
log::info!("Successfully changed mode for {} symbols", symbols.len());
Ok(())
}
pub async fn stop(&mut self) -> Result<(), String> {
log::info!("Stopping KiteTickerManager");
if let Some(health_monitor) = &mut self.health_monitor {
health_monitor.stop().await;
}
for processor in &mut self.processors {
processor.stop().await;
}
for connection in &mut self.connections {
if let Some(h) = connection.heartbeat_handle.take() {
h.abort();
let _ = h.await;
}
if let Some(handle) = connection.task_handle.take() {
handle.abort();
let _ = handle.await;
}
}
log::info!("KiteTickerManager stopped");
Ok(())
}
}