kiteticker_async_manager/manager/
connection_pool.rs

1use crate::manager::{ChannelId, ConnectionStats, KiteManagerConfig};
2use crate::models::{Mode, TickerMessage};
3use crate::ticker::KiteTickerAsync;
4use std::collections::HashMap;
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use tokio::sync::{mpsc, RwLock};
9use tokio::task::JoinHandle;
10use tokio::time::timeout;
11
12/// Represents a single WebSocket connection with its metadata
13#[derive(Debug)]
14pub struct ManagedConnection {
15  pub id: ChannelId,
16  pub ticker: Option<KiteTickerAsync>,
17  pub subscriber: Option<crate::ticker::KiteTickerSubscriber>,
18  pub subscribed_symbols: HashMap<u32, Mode>,
19  pub stats: Arc<RwLock<ConnectionStats>>,
20  pub is_healthy: Arc<AtomicBool>,
21  pub last_ping: Arc<AtomicU64>, // Unix timestamp
22  pub task_handle: Option<JoinHandle<()>>,
23  // Background watcher to update last_ping on any inbound frame (including heartbeats)
24  pub heartbeat_handle: Option<JoinHandle<()>>,
25  pub message_sender: mpsc::UnboundedSender<TickerMessage>,
26  // Store credentials for dynamic operations
27  api_key: String,
28  access_token: String,
29  pub(crate) cmd_tx:
30    Option<mpsc::UnboundedSender<tokio_tungstenite::tungstenite::Message>>,
31  // Liveness threshold for heartbeats/frames
32  heartbeat_liveness_threshold: Duration,
33}
34
35impl ManagedConnection {
36  pub fn new(
37    id: ChannelId,
38    message_sender: mpsc::UnboundedSender<TickerMessage>,
39  ) -> Self {
40    let stats = ConnectionStats {
41      connection_id: id.to_index(),
42      ..Default::default()
43    };
44
45    Self {
46      id,
47      ticker: None,
48      subscriber: None,
49      subscribed_symbols: HashMap::new(),
50      stats: Arc::new(RwLock::new(stats)),
51      is_healthy: Arc::new(AtomicBool::new(false)),
52      last_ping: Arc::new(AtomicU64::new(0)),
53      task_handle: None,
54      heartbeat_handle: None,
55      message_sender,
56      api_key: String::new(),
57      access_token: String::new(),
58      cmd_tx: None,
59      heartbeat_liveness_threshold: Duration::from_secs(10),
60    }
61  }
62
63  /// Connect to WebSocket and start message processing
64  pub async fn connect(
65    &mut self,
66    api_key: &str,
67    access_token: &str,
68    config: &KiteManagerConfig,
69  ) -> Result<(), String> {
70    // Store credentials for dynamic operations
71    self.api_key = api_key.to_string();
72    self.access_token = access_token.to_string();
73
74    // Connect to WebSocket
75    let ticker = timeout(
76      config.connection_timeout,
77      KiteTickerAsync::connect(api_key, access_token),
78    )
79    .await
80    .map_err(|_| "Connection timeout".to_string())?
81    .map_err(|e| format!("Connection failed: {}", e))?;
82
83    self.cmd_tx = ticker.command_sender();
84    // Initialize last_ping to now and start heartbeat watcher
85    let now_sec = std::time::SystemTime::now()
86      .duration_since(std::time::UNIX_EPOCH)
87      .unwrap_or_default()
88      .as_secs();
89    self.last_ping.store(now_sec, Ordering::Relaxed);
90    self.ticker = Some(ticker);
91    self.start_heartbeat_watcher();
92    // Set configured liveness threshold
93    self.heartbeat_liveness_threshold = config.heartbeat_liveness_threshold;
94    self.is_healthy.store(true, Ordering::Relaxed);
95
96    // Update stats
97    {
98      let mut stats = self.stats.write().await;
99      stats.is_connected = true;
100      stats.connection_uptime = Duration::ZERO;
101    }
102
103    Ok(())
104  }
105
106  /// Start a background watcher that listens to raw frames and updates `last_ping`.
107  fn start_heartbeat_watcher(&mut self) {
108    // Drop existing watcher if any
109    if let Some(h) = self.heartbeat_handle.take() {
110      h.abort();
111    }
112    let Some(ticker) = self.ticker.as_ref() else {
113      return;
114    };
115    let mut rx = ticker.subscribe_raw_frames();
116    let last_ping = Arc::clone(&self.last_ping);
117    let id = self.id;
118    let handle = tokio::spawn(async move {
119      loop {
120        match rx.recv().await {
121          Ok(_frame) => {
122            let now = std::time::SystemTime::now()
123              .duration_since(std::time::UNIX_EPOCH)
124              .unwrap_or_default()
125              .as_secs();
126            last_ping.store(now, Ordering::Relaxed);
127          }
128          Err(tokio::sync::broadcast::error::RecvError::Closed) => {
129            log::debug!(
130              "Heartbeat watcher closed for connection {}",
131              id.to_index()
132            );
133            break;
134          }
135          Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
136            // If lagging, just continue; we'll get a fresher frame soon
137            continue;
138          }
139        }
140      }
141    });
142    self.heartbeat_handle = Some(handle);
143  }
144
145  /// Connect with explicit raw_only flag
146  pub async fn connect_with_raw(
147    &mut self,
148    api_key: &str,
149    access_token: &str,
150    config: &KiteManagerConfig,
151    raw_only: bool,
152  ) -> Result<(), String> {
153    self.api_key = api_key.to_string();
154    self.access_token = access_token.to_string();
155    let ticker = tokio::time::timeout(
156      config.connection_timeout,
157      crate::ticker::KiteTickerAsync::connect_with_options(
158        api_key,
159        access_token,
160        raw_only,
161      ),
162    )
163    .await
164    .map_err(|_| "Connection timeout".to_string())?
165    .map_err(|e| format!("Connection failed: {}", e))?;
166
167    self.cmd_tx = ticker.command_sender();
168    // Initialize last_ping to now and start heartbeat watcher
169    let now_sec = std::time::SystemTime::now()
170      .duration_since(std::time::UNIX_EPOCH)
171      .unwrap_or_default()
172      .as_secs();
173    self
174      .last_ping
175      .store(now_sec, std::sync::atomic::Ordering::Relaxed);
176    self.ticker = Some(ticker);
177    self.start_heartbeat_watcher();
178    self
179      .is_healthy
180      .store(true, std::sync::atomic::Ordering::Relaxed);
181    // Set configured liveness threshold
182    self.heartbeat_liveness_threshold = config.heartbeat_liveness_threshold;
183    {
184      let mut stats = self.stats.write().await;
185      stats.is_connected = true;
186      stats.connection_uptime = Duration::ZERO;
187    }
188    Ok(())
189  }
190
191  /// Subscribe to symbols on this connection
192  pub async fn subscribe_symbols(
193    &mut self,
194    symbols: &[u32],
195    mode: Mode,
196  ) -> Result<(), String> {
197    if let Some(ticker) = self.ticker.as_mut() {
198      // Use the existing ticker directly
199      let subscriber = ticker.subscribe(symbols, Some(mode)).await?;
200      // Track symbols
201      for &symbol in symbols {
202        self.subscribed_symbols.insert(symbol, mode);
203      }
204      self.subscriber = Some(subscriber);
205
206      // Update stats
207      {
208        let mut stats = self.stats.write().await;
209        stats.symbol_count = self.subscribed_symbols.len();
210      }
211
212      Ok(())
213    } else {
214      Err("Connection not established".to_string())
215    }
216  }
217
218  /// Dynamically add new symbols to existing subscription
219  pub async fn add_symbols(
220    &mut self,
221    symbols: &[u32],
222    mode: Mode,
223  ) -> Result<(), String> {
224    if self.subscriber.is_some() {
225      // Filter to truly new symbols
226      let new: Vec<u32> = symbols
227        .iter()
228        .copied()
229        .filter(|s| !self.subscribed_symbols.contains_key(s))
230        .collect();
231      if new.is_empty() {
232        return Ok(());
233      }
234      if let Some(tx) = &self.cmd_tx {
235        // send subscribe + mode
236        let sub = crate::models::Request::subscribe(&new).to_string();
237        let mode_msg = crate::models::Request::mode(mode, &new).to_string();
238        let _ =
239          tx.send(tokio_tungstenite::tungstenite::Message::Text(sub.into()));
240        let _ = tx.send(tokio_tungstenite::tungstenite::Message::Text(
241          mode_msg.into(),
242        ));
243      }
244      for &s in &new {
245        self.subscribed_symbols.insert(s, mode);
246      }
247      let mut stats = self.stats.write().await;
248      stats.symbol_count = self.subscribed_symbols.len();
249      log::info!(
250        "Incrementally subscribed {} symbols on connection {}",
251        new.len(),
252        self.id.to_index()
253      );
254      Ok(())
255    } else {
256      self.subscribe_symbols(symbols, mode).await
257    }
258  }
259
260  /// Dynamically remove symbols from existing subscription
261  pub async fn remove_symbols(
262    &mut self,
263    symbols: &[u32],
264  ) -> Result<(), String> {
265    if self.subscriber.is_some() {
266      // Only symbols currently subscribed
267      let existing: Vec<u32> = symbols
268        .iter()
269        .copied()
270        .filter(|s| self.subscribed_symbols.contains_key(s))
271        .collect();
272      if existing.is_empty() {
273        return Ok(());
274      }
275      if let Some(tx) = &self.cmd_tx {
276        let unsub = crate::models::Request::unsubscribe(&existing).to_string();
277        let _ =
278          tx.send(tokio_tungstenite::tungstenite::Message::Text(unsub.into()));
279      }
280      for s in &existing {
281        self.subscribed_symbols.remove(s);
282      }
283      let mut stats = self.stats.write().await;
284      stats.symbol_count = self.subscribed_symbols.len();
285      log::info!(
286        "Incrementally unsubscribed {} symbols on connection {}",
287        existing.len(),
288        self.id.to_index()
289      );
290      Ok(())
291    } else {
292      Err("No active subscription to remove symbols from".to_string())
293    }
294  }
295
296  /// Start message processing for the subscriber
297  pub async fn start_message_processing(&mut self) -> Result<(), String> {
298    if let Some(subscriber) = self.subscriber.take() {
299      let message_sender = self.message_sender.clone();
300      let stats = Arc::clone(&self.stats);
301      let is_healthy = Arc::clone(&self.is_healthy);
302      let last_ping = Arc::clone(&self.last_ping);
303      let connection_id = self.id;
304      let threshold = self.heartbeat_liveness_threshold;
305
306      let handle = tokio::spawn(async move {
307        Self::message_processing_loop(
308          subscriber,
309          message_sender,
310          stats,
311          is_healthy,
312          connection_id,
313          last_ping,
314          threshold,
315        )
316        .await;
317      });
318
319      self.task_handle = Some(handle);
320      Ok(())
321    } else {
322      Err("No subscriber available for message processing".to_string())
323    }
324  }
325
326  /// Message processing loop for this connection
327  async fn message_processing_loop(
328    mut subscriber: crate::ticker::KiteTickerSubscriber,
329    message_sender: mpsc::UnboundedSender<TickerMessage>,
330    stats: Arc<RwLock<ConnectionStats>>,
331    is_healthy: Arc<AtomicBool>,
332    connection_id: ChannelId,
333    last_ping: Arc<AtomicU64>,
334    heartbeat_threshold: Duration,
335  ) {
336    let mut last_message_time = Instant::now();
337    let mut last_stats_flush = Instant::now();
338    let mut pending_messages: u64 = 0;
339
340    log::info!(
341      "Starting message processing loop for connection {}",
342      connection_id.to_index()
343    );
344
345    loop {
346      match timeout(Duration::from_secs(30), subscriber.next_message()).await {
347        Ok(Ok(Some(message))) => {
348          last_message_time = Instant::now();
349
350          // Debug: Print incoming message
351          if log::log_enabled!(log::Level::Debug) {
352            match &message {
353              TickerMessage::Ticks(ticks) => {
354                log::debug!(
355                  "Connection {}: Received {} ticks",
356                  connection_id.to_index(),
357                  ticks.len()
358                );
359                for (i, tick) in ticks.iter().take(3).enumerate() {
360                  log::debug!(
361                    "  Tick {}: Symbol {}, Mode {:?}, LTP {:?}",
362                    i + 1,
363                    tick.instrument_token,
364                    tick.content.mode,
365                    tick.content.last_price
366                  );
367                }
368              }
369              TickerMessage::Error(err) => {
370                log::debug!(
371                  "Connection {}: Received error message: {}",
372                  connection_id.to_index(),
373                  err
374                );
375              }
376              _ => {
377                log::debug!(
378                  "Connection {}: Received other message: {:?}",
379                  connection_id.to_index(),
380                  message
381                );
382              }
383            }
384          }
385
386          // Update stats
387          pending_messages += 1;
388          if last_stats_flush.elapsed() >= Duration::from_millis(1000) {
389            let mut stats = stats.write().await;
390            stats.messages_received += pending_messages;
391            stats.last_message_time = Some(last_message_time);
392            pending_messages = 0;
393            last_stats_flush = Instant::now();
394          }
395
396          // Forward message to parser (non-blocking)
397          if message_sender.send(message).is_err() {
398            log::warn!(
399              "Connection {}: Parser channel full, dropping message",
400              connection_id.to_index()
401            );
402
403            // Update error stats
404            let mut stats = stats.write().await;
405            stats.errors_count += 1;
406          }
407        }
408        Ok(Ok(None)) => {
409          log::info!("Connection {} closed", connection_id.to_index());
410          is_healthy.store(false, Ordering::Relaxed);
411          break;
412        }
413        Ok(Err(e)) => {
414          log::error!("Connection {} error: {}", connection_id.to_index(), e);
415
416          // Update error stats
417          if last_stats_flush.elapsed() >= Duration::from_millis(250) {
418            let mut stats = stats.write().await;
419            stats.errors_count += 1;
420            last_stats_flush = Instant::now();
421          }
422
423          // Continue trying to receive messages
424        }
425        Err(_) => {
426          // Timeout waiting for parsed messages; consult heartbeat/frames
427          let now_sec = std::time::SystemTime::now()
428            .duration_since(std::time::UNIX_EPOCH)
429            .unwrap_or_default()
430            .as_secs();
431          let last = last_ping.load(Ordering::Relaxed);
432          // If we've seen any frame within threshold, consider connection alive
433          if last > 0
434            && now_sec.saturating_sub(last) <= heartbeat_threshold.as_secs()
435          {
436            continue;
437          }
438          // Fallback to parsed message timer if heartbeat missed
439          if last_message_time.elapsed() > heartbeat_threshold {
440            log::warn!(
441              "Connection {} timeout - no frames/heartbeats within {:?}",
442              connection_id.to_index(),
443              heartbeat_threshold,
444            );
445            is_healthy.store(false, Ordering::Relaxed);
446            break;
447          }
448        }
449      }
450    }
451
452    // Update connection status
453    {
454      let mut stats = stats.write().await;
455      stats.is_connected = false;
456    }
457    is_healthy.store(false, Ordering::Relaxed);
458  }
459
460  /// Check if connection can accept more symbols
461  pub fn can_accept_symbols(
462    &self,
463    count: usize,
464    max_per_connection: usize,
465  ) -> bool {
466    self.subscribed_symbols.len() + count <= max_per_connection
467  }
468
469  /// Get current symbol count
470  pub fn symbol_count(&self) -> usize {
471    self.subscribed_symbols.len()
472  }
473
474  /// Check if connection is healthy
475  pub fn is_healthy(&self) -> bool {
476    self.is_healthy.load(Ordering::Relaxed)
477  }
478}