1use std::collections::{BTreeSet, HashMap, VecDeque};
10
11use futures_util::{Sink, Stream};
12use tokio_tungstenite::tungstenite::Message;
13
14use crate::client::FyersClient;
15use crate::error::{FyersError, Result};
16use crate::models::ws::{
17 DataControlEvent, DataSocketConfig, DataSocketEvent, DataSubscribeRequest,
18 DataUnsubscribeRequest,
19};
20use crate::ws::data_protocol::{
21 self, ScripFeed, ack_count_from_auth_envelope, build_ack_message, build_auth_message,
22 build_channel_bitmap_message, build_channel_bitmap_message_with_marker,
23 build_subscribe_message, build_unsubscribe_message, data_type, datafeed_message_num,
24 depth_update_from_feed, extract_hsm_key, index_update_from_feed, mode, parse_datafeed,
25 parse_envelope, req_type, symbol_update_from_feed,
26};
27use crate::ws::data_symbols;
28use crate::ws::manager::{
29 LiveWebSocket, ManagedSocket, ReconnectPolicy, connect_live_socket_no_auth_header,
30};
31use crate::ws::protocol::SocketKind;
32
33const MAX_DATA_SOCKET_SYMBOLS: usize = 5000;
34const DEFAULT_CHANNEL: u8 = 11;
35const DEFAULT_SOURCE_ID: &str = concat!("fyers-rs/", env!("CARGO_PKG_VERSION"));
36
37pub type LiveDataSocketConnection = DataSocketConnection<LiveWebSocket>;
39
40#[derive(Debug, Clone, Copy)]
42pub struct DataSocketService<'a> {
43 client: &'a FyersClient,
44}
45
46impl<'a> DataSocketService<'a> {
47 pub(crate) const fn new(client: &'a FyersClient) -> Self {
49 Self { client }
50 }
51
52 pub const fn client(&self) -> &'a FyersClient {
54 self.client
55 }
56
57 pub async fn connect(&self) -> Result<LiveDataSocketConnection> {
59 self.connect_with_config(DataSocketConfig::default()).await
60 }
61
62 pub async fn connect_with_config(
64 &self,
65 config: DataSocketConfig,
66 ) -> Result<LiveDataSocketConnection> {
67 let stream = connect_live_socket_no_auth_header(self.client.config(), SocketKind::Data)
68 .await?;
69 let mut connection =
70 DataSocketConnection::from_stream(stream, self.client.clone(), config)?;
71 connection.handshake().await?;
72 Ok(connection)
73 }
74
75 pub fn connect_with_stream<S>(&self, stream: S) -> Result<DataSocketConnection<S>>
81 where
82 S: Stream<Item = std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>
83 + Sink<Message, Error = tokio_tungstenite::tungstenite::Error>
84 + Unpin,
85 {
86 DataSocketConnection::from_stream(stream, self.client.clone(), DataSocketConfig::default())
87 }
88}
89
90#[derive(Debug)]
126pub struct DataSocketConnection<S = LiveWebSocket> {
127 socket: ManagedSocket<S, DataSocketEvent>,
128 config: DataSocketConfig,
129 client: FyersClient,
130 hsm_key: String,
131 access_token: String,
132 channel_num: u8,
133 source_id: String,
134 subscriptions: Vec<DataSubscribeRequest>,
135 topic_to_input: HashMap<String, String>,
136 pending_events: VecDeque<DataSocketEvent>,
137 ack_count: u32,
138 update_count: u32,
139 last_message_num: u32,
140 pending_ack: Option<u32>,
141}
142
143impl<S> DataSocketConnection<S>
144where
145 S: Stream<Item = std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>
146 + Sink<Message, Error = tokio_tungstenite::tungstenite::Error>
147 + Unpin,
148{
149 pub fn from_stream(stream: S, client: FyersClient, config: DataSocketConfig) -> Result<Self> {
155 let access_token = client
156 .config()
157 .access_token()
158 .ok_or(FyersError::MissingConfig {
159 field: "access_token",
160 })?
161 .expose_secret()
162 .to_owned();
163 let hsm_key = extract_hsm_key(&access_token)?;
164 let reconnect_policy = ReconnectPolicy::new(
165 config.reconnect,
166 config.reconnect_retry,
167 config.queue_process_interval.as_duration(),
168 );
169 Ok(Self {
170 socket: ManagedSocket::from_stream(
171 SocketKind::Data,
172 stream,
173 noop_parser,
174 reconnect_policy,
175 ),
176 config,
177 client,
178 hsm_key,
179 access_token,
180 channel_num: DEFAULT_CHANNEL,
181 source_id: DEFAULT_SOURCE_ID.to_owned(),
182 subscriptions: Vec::new(),
183 topic_to_input: HashMap::new(),
184 pending_events: VecDeque::new(),
185 ack_count: 0,
186 update_count: 0,
187 last_message_num: 0,
188 pending_ack: None,
189 })
190 }
191
192 pub const fn config(&self) -> &DataSocketConfig {
194 &self.config
195 }
196
197 pub const fn socket(&self) -> &ManagedSocket<S, DataSocketEvent> {
199 &self.socket
200 }
201
202 pub const fn socket_mut(&mut self) -> &mut ManagedSocket<S, DataSocketEvent> {
204 &mut self.socket
205 }
206
207 pub async fn handshake(&mut self) -> Result<()> {
214 let channel_mode = if self.config.lite_mode {
215 mode::LITE
216 } else {
217 mode::FULL
218 };
219 let auth_msg = build_auth_message(&self.hsm_key, channel_mode, &self.source_id);
220 self.send_binary(auth_msg).await?;
221
222 let mode_marker = if self.config.lite_mode {
223 mode::LITE_HEADER
224 } else {
225 mode::FULL_HEADER
226 };
227 let mode_msg = build_channel_bitmap_message_with_marker(
228 req_type::FULL_MODE,
229 self.channel_num,
230 mode_marker,
231 );
232 self.send_binary(mode_msg).await?;
233
234 let resume_msg =
235 build_channel_bitmap_message(req_type::CHANNEL_RESUME, self.channel_num);
236 self.send_binary(resume_msg).await?;
237 Ok(())
238 }
239
240 pub async fn subscribe(&mut self, request: &DataSubscribeRequest) -> Result<()> {
246 if request.symbols.is_empty() {
247 return Ok(());
248 }
249 if request.symbols.len() > MAX_DATA_SOCKET_SYMBOLS
250 || active_symbol_count_after(&self.subscriptions, request) > MAX_DATA_SOCKET_SYMBOLS
251 {
252 return Err(FyersError::Validation(
253 "data WebSocket subscriptions cannot exceed 5000 symbols".to_owned(),
254 ));
255 }
256
257 let resolved =
258 data_symbols::resolve_hsm_symbols(&self.client, &request.symbols, request.data_type)
259 .await?;
260 if !resolved.invalid.is_empty() {
261 return Err(FyersError::Validation(format!(
262 "data-socket subscribe: invalid symbols {:?}",
263 resolved.invalid
264 )));
265 }
266 if resolved.resolved.is_empty() {
267 return Err(FyersError::Validation(
268 "data-socket subscribe: symbol-token API returned no usable HSM tokens"
269 .to_owned(),
270 ));
271 }
272
273 for r in &resolved.resolved {
274 self.topic_to_input
275 .insert(r.hsm_topic.clone(), r.input_symbol.clone());
276 }
277
278 let topics = resolved.hsm_topics();
279 let frame = build_subscribe_message(
280 &topics,
281 self.channel_num,
282 &self.access_token,
283 &self.source_id,
284 );
285 self.send_binary(frame).await?;
286
287 if !self.subscriptions.contains(request) {
288 self.subscriptions.push(request.clone());
289 }
290 Ok(())
291 }
292
293 pub async fn unsubscribe(&mut self, request: &DataUnsubscribeRequest) -> Result<()> {
295 if request.symbols.is_empty() {
296 return Ok(());
297 }
298 let resolved =
299 data_symbols::resolve_hsm_symbols(&self.client, &request.symbols, request.data_type)
300 .await?;
301 let topics = resolved.hsm_topics();
302 if topics.is_empty() {
303 return Ok(());
304 }
305 let frame = build_unsubscribe_message(
306 &topics,
307 self.channel_num,
308 &self.access_token,
309 &self.source_id,
310 );
311 self.send_binary(frame).await?;
312
313 for topic in &topics {
314 self.topic_to_input.remove(topic);
315 }
316 self.subscriptions.retain(|existing| existing != request);
317 Ok(())
318 }
319
320 pub fn resubscribe_frames(&self) -> Result<Vec<String>> {
327 self.subscriptions
328 .iter()
329 .map(serde_json::to_string)
330 .collect::<std::result::Result<Vec<_>, _>>()
331 .map_err(FyersError::from)
332 }
333
334 pub async fn next_event(&mut self) -> Result<Option<DataSocketEvent>> {
336 loop {
337 if let Some(message_num) = self.pending_ack.take() {
338 let ack = build_ack_message(message_num);
339 self.send_binary(ack).await?;
340 }
341 if let Some(event) = self.pending_events.pop_front() {
342 return Ok(Some(event));
343 }
344 let Some(message) = self.socket.next_raw_frame().await? else {
345 return Ok(None);
346 };
347 match message {
348 Message::Binary(bytes) => {
349 self.handle_binary_frame(&bytes)?;
350 }
351 Message::Text(text) => {
352 return Err(FyersError::Validation(format!(
353 "data socket received unexpected text frame ({} bytes)",
354 text.len()
355 )));
356 }
357 _ => {}
358 }
359 }
360 }
361
362 pub const fn ack_count(&self) -> u32 {
365 self.ack_count
366 }
367
368 pub async fn close(&mut self) -> Result<()> {
370 self.socket.close().await
371 }
372
373 async fn send_binary(&mut self, bytes: Vec<u8>) -> Result<()> {
374 self.socket.send_binary(bytes).await
375 }
376
377 fn handle_binary_frame(&mut self, bytes: &[u8]) -> Result<()> {
378 if bytes.len() < 4 {
379 return Ok(());
380 }
381 let req = bytes[2];
382 match req {
383 req_type::DATAFEED => {
384 if let Some(num) = datafeed_message_num(bytes) {
385 self.last_message_num = num;
386 }
387 let feeds = parse_datafeed(bytes)?;
388 let saw_market_payload = feeds.iter().any(|f| {
389 matches!(
390 f.data_type,
391 data_type::SNAPSHOT | data_type::UPDATE | data_type::LITE
392 )
393 });
394 for feed in &feeds {
395 if let Some(event) = self.feed_to_event(feed) {
396 self.pending_events.push_back(event);
397 }
398 }
399 if saw_market_payload && self.ack_count > 0 {
400 self.update_count = self.update_count.saturating_add(1);
401 if self.update_count >= self.ack_count {
402 self.pending_ack = Some(self.last_message_num);
403 self.update_count = 0;
404 }
405 }
406 }
407 req_type::CHANNEL_BUFFER => {
408 }
410 req_type::AUTH => {
411 let env = parse_envelope(bytes)?;
412 if let Some(count) = ack_count_from_auth_envelope(&env) {
413 self.ack_count = count;
414 }
415 self.pending_events
416 .push_back(DataSocketEvent::Connected(envelope_to_control(
417 &env, "cn", "Authentication done",
418 )));
419 }
420 req_type::SUBSCRIBE => {
421 let env = parse_envelope(bytes)?;
422 self.pending_events
423 .push_back(DataSocketEvent::Subscribed(envelope_to_control(
424 &env, "sub", "Subscribed",
425 )));
426 }
427 req_type::UNSUBSCRIBE => {
428 let env = parse_envelope(bytes)?;
429 self.pending_events
430 .push_back(DataSocketEvent::Unsubscribed(envelope_to_control(
431 &env, "unsub", "Unsubscribed",
432 )));
433 }
434 req_type::FULL_MODE | req_type::CHANNEL_RESUME | req_type::CHANNEL_PAUSE => {
435 let env = parse_envelope(bytes)?;
436 let event_type = match req {
437 req_type::FULL_MODE => "ful",
438 req_type::CHANNEL_RESUME => "cr",
439 req_type::CHANNEL_PAUSE => "cp",
440 _ => unreachable!(),
441 };
442 self.pending_events
443 .push_back(DataSocketEvent::Mode(envelope_to_control(
444 &env, event_type, "Mode",
445 )));
446 }
447 _ => {
448 let env = parse_envelope(bytes).unwrap_or(data_protocol::Envelope {
449 req_type: req,
450 fields: Vec::new(),
451 });
452 self.pending_events
453 .push_back(DataSocketEvent::Error(envelope_to_control(
454 &env,
455 "error",
456 &format!("unhandled data-socket frame type 0x{req:02x}"),
457 )));
458 }
459 }
460 Ok(())
461 }
462
463 fn feed_to_event(&self, feed: &ScripFeed<'_>) -> Option<DataSocketEvent> {
464 if !matches!(
465 feed.data_type,
466 data_type::SNAPSHOT | data_type::UPDATE | data_type::LITE
467 ) {
468 return None;
469 }
470 let user_symbol = self
471 .topic_to_input
472 .get(feed.topic_name)
473 .cloned()
474 .unwrap_or_else(|| feed.topic_name.to_owned());
475
476 match topic_kind(feed.topic_name) {
477 TopicKind::Index => {
478 let mut event = index_update_from_feed(feed);
479 event.symbol = user_symbol;
480 Some(DataSocketEvent::IndexUpdate(event))
481 }
482 TopicKind::Depth => {
483 let mut event = depth_update_from_feed(feed);
484 event.symbol = user_symbol;
485 Some(DataSocketEvent::DepthUpdate(event))
486 }
487 TopicKind::Symbol => {
488 let mut event = symbol_update_from_feed(feed);
489 event.symbol = user_symbol;
490 Some(DataSocketEvent::SymbolUpdate(event))
491 }
492 TopicKind::Other => None,
493 }
494 }
495}
496
497#[derive(Debug, Clone, Copy)]
498enum TopicKind {
499 Symbol,
500 Index,
501 Depth,
502 Other,
503}
504
505fn topic_kind(topic: &str) -> TopicKind {
506 match topic.split('|').next() {
507 Some("sf") => TopicKind::Symbol,
508 Some("if") => TopicKind::Index,
509 Some("dp") => TopicKind::Depth,
510 _ => TopicKind::Other,
511 }
512}
513
514fn envelope_to_control(env: &data_protocol::Envelope<'_>, kind: &str, default_msg: &str) -> DataControlEvent {
515 let s = env.status_text().unwrap_or("ok").to_owned();
516 let code = if s == "K" { 200 } else { -1 };
517 DataControlEvent {
518 event_type: kind.to_owned(),
519 code,
520 message: if s == "K" { default_msg.to_owned() } else { s.clone() },
521 s: if s == "K" { "ok".to_owned() } else { "error".to_owned() },
522 }
523}
524
525fn noop_parser(_message: Message) -> Result<Option<DataSocketEvent>> {
526 Ok(None)
527}
528
529fn active_symbol_count_after(
530 subscriptions: &[DataSubscribeRequest],
531 request: &DataSubscribeRequest,
532) -> usize {
533 subscriptions
534 .iter()
535 .flat_map(|subscription| subscription.symbols.iter())
536 .chain(request.symbols.iter())
537 .map(String::as_str)
538 .collect::<BTreeSet<_>>()
539 .len()
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545 use crate::FyersClient;
546
547 fn dummy_client() -> FyersClient {
548 let header = "eyJhbGciOiJub25lIn0";
552 let payload = "eyJzdWIiOiJhY2Nlc3NfdG9rZW4iLCJoc21fa2V5IjoiZGVhZGJlZWYifQ";
553 let token = format!("{header}.{payload}.sig");
554 FyersClient::builder()
555 .client_id("APPID-100")
556 .access_token(token)
557 .build()
558 .unwrap()
559 }
560
561 #[test]
562 fn topic_kind_classifier_matches_documented_prefixes() {
563 assert!(matches!(topic_kind("sf|nse_cm|3045"), TopicKind::Symbol));
564 assert!(matches!(topic_kind("if|nse_cm|26000"), TopicKind::Index));
565 assert!(matches!(topic_kind("dp|nse_cm|3045"), TopicKind::Depth));
566 assert!(matches!(topic_kind("nope"), TopicKind::Other));
567 }
568
569 #[test]
570 fn envelope_to_control_maps_k_status_to_ok() {
571 let env = data_protocol::Envelope {
572 req_type: req_type::SUBSCRIBE,
573 fields: vec![data_protocol::EnvelopeField {
574 id: 1,
575 value: b"K",
576 }],
577 };
578 let control = envelope_to_control(&env, "sub", "Subscribed");
579 assert_eq!(control.s, "ok");
580 assert_eq!(control.code, 200);
581 assert_eq!(control.message, "Subscribed");
582 }
583
584 #[test]
585 fn from_stream_extracts_hsm_key_from_client_token() {
586 let client = dummy_client();
590 let token = client.config().access_token().unwrap().expose_secret().to_owned();
591 let hsm = extract_hsm_key(&token).unwrap();
592 assert_eq!(hsm, "deadbeef");
593 }
594}