kiteticker_async/
ticker.rs

1use futures_util::{stream::iter, SinkExt, StreamExt};
2use serde_json::json;
3use std::{collections::HashMap, sync::Arc};
4use tokio::net::TcpStream;
5use tokio::sync::Mutex;
6use tokio_tungstenite::{
7  connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream,
8};
9
10use crate::models::{
11  packet_length, Mode, Request, TextMessage, Tick, TickMessage, TickerMessage,
12};
13
14#[derive(Debug, Clone)]
15///
16/// The WebSocket client for connecting to Kite Connect's streaming quotes service.
17///
18pub struct KiteTickerAsync {
19  #[allow(dead_code)]
20  api_key: String,
21  #[allow(dead_code)]
22  access_token: String,
23  ws_stream: Arc<Mutex<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
24}
25
26impl KiteTickerAsync {
27  /// Establish a connection with the Kite WebSocket server
28  pub async fn connect(
29    api_key: &str,
30    access_token: &str,
31  ) -> Result<Self, String> {
32    let socket_url = format!(
33      "wss://{}?api_key={}&access_token={}",
34      "ws.kite.trade", api_key, access_token
35    );
36    let url = url::Url::parse(socket_url.as_str()).unwrap();
37
38    let (ws_stream, _) = connect_async(url).await.map_err(|e| e.to_string())?;
39
40    Ok(KiteTickerAsync {
41      api_key: api_key.to_string(),
42      access_token: access_token.to_string(),
43      ws_stream: Arc::new(Mutex::new(ws_stream)),
44    })
45  }
46
47  /// Subscribes the client to a list of instruments
48  pub async fn subscribe(
49    mut self,
50    instrument_tokens: &[u32],
51    mode: Option<Mode>,
52  ) -> Result<KiteTickerSubscriber, String> {
53    self
54      .subscribe_cmd(instrument_tokens, mode.clone())
55      .await
56      .expect("failed to subscribe");
57    let st = instrument_tokens
58      .to_vec()
59      .iter()
60      .map(|t| (t.clone(), mode.to_owned().unwrap_or_default()))
61      .collect();
62
63    Ok(KiteTickerSubscriber {
64      ticker: self,
65      subscribed_tokens: st,
66    })
67  }
68
69  /// Close the websocket connection
70  pub async fn close(&mut self) -> Result<(), String> {
71    let mut ws_stream = self.ws_stream.lock().await;
72    ws_stream.close(None).await.map_err(|x| x.to_string())?;
73    Ok(())
74  }
75
76  async fn subscribe_cmd(
77    &mut self,
78    instrument_tokens: &[u32],
79    mode: Option<Mode>,
80  ) -> Result<(), String> {
81    let mut msgs = iter(vec![
82      Ok(Message::Text(
83        Request::subscribe(instrument_tokens.to_vec()).to_string(),
84      )),
85      Ok(Message::Text(
86        Request::mode(mode.unwrap_or_default(), instrument_tokens.to_vec())
87          .to_string(),
88      )),
89    ]);
90
91    let mut ws_stream = self.ws_stream.lock().await;
92
93    ws_stream
94      .send_all(msgs.by_ref())
95      .await
96      .expect("failed to send subscription message");
97
98    Ok(())
99  }
100
101  async fn unsubscribe_cmd(
102    &mut self,
103    instrument_tokens: &[u32],
104  ) -> Result<(), String> {
105    let mut ws_stream = self.ws_stream.lock().await;
106    ws_stream
107      .send(Message::Text(
108        Request::unsubscribe(instrument_tokens.to_vec()).to_string(),
109      ))
110      .await
111      .expect("failed to send unsubscribe message");
112    Ok(())
113  }
114
115  async fn set_mode_cmd(
116    &mut self,
117    instrument_tokens: &[u32],
118    mode: Mode,
119  ) -> Result<(), String> {
120    let mut ws_stream = self.ws_stream.lock().await;
121    ws_stream
122      .send(Message::Text(
123        Request::mode(mode, instrument_tokens.to_vec()).to_string(),
124      ))
125      .await
126      .expect("failed to send set mode message");
127    Ok(())
128  }
129}
130
131#[derive(Debug, Clone)]
132///
133/// The Websocket client that entered in a pub/sub mode once the client subscribed to a list of instruments
134///
135pub struct KiteTickerSubscriber {
136  ticker: KiteTickerAsync,
137  subscribed_tokens: HashMap<u32, Mode>,
138}
139
140impl KiteTickerSubscriber {
141  /// Get the list of subscribed instruments
142  pub fn get_subscribed(&self) -> Vec<u32> {
143    self
144      .subscribed_tokens
145      .clone()
146      .into_keys()
147      .collect::<Vec<_>>()
148  }
149
150  /// get all tokens common between subscribed tokens and input tokens
151  /// and if the input is empty then all subscribed tokens will be unsubscribed
152  fn get_subscribed_or(&self, tokens: &[u32]) -> Vec<u32> {
153    if tokens.len() == 0 {
154      self.get_subscribed()
155    } else {
156      tokens
157        .iter()
158        .filter(|t| self.subscribed_tokens.contains_key(t))
159        .map(|t| t.clone())
160        .collect::<Vec<_>>()
161    }
162  }
163
164  /// Subscribe to new tokens
165  pub async fn subscribe(
166    &mut self,
167    tokens: &[u32],
168    mode: Option<Mode>,
169  ) -> Result<(), String> {
170    self.subscribed_tokens.extend(
171      tokens
172        .iter()
173        .map(|t| (t.clone(), mode.clone().unwrap_or_default())),
174    );
175    let tks = self.get_subscribed();
176    self.ticker.subscribe_cmd(tks.as_slice(), None).await?;
177    Ok(())
178  }
179
180  /// Change the mode of the subscribed instrument tokens
181  pub async fn set_mode(
182    &mut self,
183    instrument_tokens: &[u32],
184    mode: Mode,
185  ) -> Result<(), String> {
186    let tokens = self.get_subscribed_or(instrument_tokens);
187    self.ticker.set_mode_cmd(tokens.as_slice(), mode).await
188  }
189
190  /// Unsubscribe provided subscribed tokens, if input is empty then all subscribed tokens will unsubscribed
191  ///
192  /// Tokens in the input which are not part of the subscribed tokens will be ignored.
193  pub async fn unsubscribe(
194    &mut self,
195    instrument_tokens: &[u32],
196  ) -> Result<(), String> {
197    let tokens = self.get_subscribed_or(instrument_tokens);
198    self.ticker.unsubscribe_cmd(tokens.as_slice()).await
199  }
200
201  /// Get the next message from the server, waiting if necessary.
202  /// If the result is None then server is terminated
203  pub async fn next_message(&mut self) -> Result<Option<TickerMessage>, String> {
204    let mut ws_stream = self.ticker.ws_stream.lock().await;
205    match ws_stream.next().await {
206      Some(message) => match message {
207        Ok(msg) => Ok(self.process_message(msg)),
208        Err(e) => Err(e.to_string()),
209      },
210      None => Ok(None),
211    }
212  }
213
214  fn process_message(&self, message: Message) -> Option<TickerMessage> {
215    match message {
216      Message::Text(text_message) => self.process_text_message(text_message),
217      Message::Binary(ref binary_message) => {
218        if binary_message.len() < 2 {
219          return Some(TickerMessage::Ticks(vec![]));
220        } else {
221          self.process_binary(binary_message.as_slice())
222        }
223      }
224      Message::Close(closing_message) => closing_message.map(|c| {
225        TickerMessage::ClosingMessage(json!({
226          "code": c.code.to_string(),
227          "reason": c.reason.to_string()
228        }))
229      }),
230      Message::Ping(_) => unimplemented!(),
231      Message::Pong(_) => unimplemented!(),
232      Message::Frame(_) => unimplemented!(),
233    }
234  }
235
236  fn process_binary(&self, binary_message: &[u8]) -> Option<TickerMessage> {
237    // 0 - 2 : number of packets in the message
238    let num_packets =
239      i16::from_be_bytes(binary_message[0..=1].try_into().unwrap()) as usize;
240    if num_packets > 0 {
241      Some(TickerMessage::Ticks(
242        (0..num_packets)
243          .into_iter()
244          .fold((vec![], 2), |(mut acc, start), _| {
245            // start - start + 2 : length of the packet
246            let packet_len = packet_length(&binary_message[start..start + 2]);
247            let next_start = start + 2 + packet_len;
248            let tick = Tick::from(&binary_message[start + 2..next_start]);
249            acc.push(TickMessage::new(tick.instrument_token, tick));
250            (acc, next_start)
251          })
252          .0,
253      ))
254    } else {
255      None
256    }
257  }
258
259  fn process_text_message(
260    &self,
261    text_message: String,
262  ) -> Option<TickerMessage> {
263    serde_json::from_str::<TextMessage>(&text_message)
264      .map(|x| x.into())
265      .ok()
266  }
267}
268
269#[cfg(test)]
270mod tests {
271  use super::*;
272  use tokio::select;
273
274  async fn check<F>(
275    mode: Mode,
276    token: u32,
277    sb: &mut KiteTickerSubscriber,
278    assertions: Option<F>,
279  ) where
280    F: Fn(Vec<TickMessage>) -> (),
281  {
282    loop {
283      match sb.next_message().await {
284        Ok(message) => match message {
285          Some(TickerMessage::Ticks(xs)) => {
286            if xs.len() == 0 {
287              continue;
288            }
289            assertions.map(|f| f(xs.clone())).or_else(|| {
290              let tick_message = xs.first().unwrap();
291              assert!(tick_message.instrument_token == token);
292              assert_eq!(tick_message.content.mode, mode);
293              Some(())
294            });
295            break;
296          }
297          _ => {
298            continue;
299          }
300        },
301        _ => {
302          assert!(false);
303          break;
304        }
305      }
306    }
307  }
308
309  #[tokio::test]
310  async fn test_ticker() {
311    let api_key = std::env::var("KITE_API_KEY").unwrap();
312    let access_token = std::env::var("KITE_ACCESS_TOKEN").unwrap();
313    let ticker = KiteTickerAsync::connect(&api_key, &access_token).await;
314
315    assert_eq!(ticker.is_ok(), true);
316
317    let ticker = ticker.unwrap();
318    let token = 94977; // bata
319    let mode = Mode::Full;
320    let sb = ticker.subscribe(&[token], Some(mode.clone())).await;
321    assert_eq!(sb.is_ok(), true);
322    let mut sb = sb.unwrap();
323    assert_eq!(sb.subscribed_tokens.len(), 1);
324    let mut loop_cnt = 0;
325    loop {
326      loop_cnt += 1;
327      select! {
328        Ok(n) = sb.next_message() => {
329          match n.to_owned() {
330            Some(message) => {
331              match message {
332                TickerMessage::Ticks(xs) => {
333                  if xs.len() == 0 {
334                    if loop_cnt > 5 {
335                      break;
336                    }else {
337                      continue;
338                    }
339                  }
340                  assert_eq!(xs.len(), 1);
341                  let tick_message = xs.first().unwrap();
342                  assert!(tick_message.instrument_token == token);
343                  assert_eq!(tick_message.content.mode, mode);
344                  if loop_cnt > 5 {
345                    break;
346                  }
347                },
348                _ => {
349                  if loop_cnt > 5 {
350                    break;
351                  }
352                }
353              }
354            },
355            _ => {
356              if loop_cnt > 5 {
357                assert!(false);
358                break;
359              }
360            }
361          }
362        },
363        else => {
364          assert!(false);
365          break;
366        }
367      }
368    }
369
370    sb.ticker.close().await.unwrap();
371  }
372
373  #[tokio::test]
374  async fn test_unsubscribe() {
375    // create a ticker
376    let api_key = std::env::var("KITE_API_KEY").unwrap();
377    let access_token = std::env::var("KITE_ACCESS_TOKEN").unwrap();
378    let ticker = KiteTickerAsync::connect(&api_key, &access_token).await;
379
380    let ticker = ticker.unwrap();
381    let token = 94977; // bata
382    let mode = Mode::Full;
383    let mut sb = ticker
384      .subscribe(&[token], Some(mode.clone()))
385      .await
386      .unwrap();
387
388    let mut loop_cnt = 0;
389
390    loop {
391      match sb.next_message().await {
392        Ok(message) => match message {
393          Some(TickerMessage::Ticks(xs)) => {
394            if xs.len() == 0 {
395              if loop_cnt > 4 {
396                assert!(true);
397                break;
398              } else {
399                loop_cnt += 1;
400                continue;
401              }
402            }
403            assert_eq!(xs.len(), 1);
404            let tick_message = xs.first().unwrap();
405            assert!(tick_message.instrument_token == token);
406            sb.unsubscribe(&[]).await.unwrap();
407            loop_cnt += 1;
408            if loop_cnt > 5 {
409              assert!(false);
410              break;
411            }
412          }
413          _ => {
414            continue;
415          }
416        },
417        _ => {
418          assert!(false);
419          break;
420        }
421      }
422    }
423    sb.ticker.close().await.unwrap();
424  }
425
426  async fn create_ticker() -> KiteTickerAsync {
427    // create a ticker
428    let api_key = std::env::var("KITE_API_KEY").unwrap();
429    let access_token = std::env::var("KITE_ACCESS_TOKEN").unwrap();
430    let ticker = KiteTickerAsync::connect(&api_key, &access_token).await;
431    ticker.expect("failed to create ticker")
432  }
433
434  #[tokio::test]
435  async fn test_set_mode() {
436    let ticker = create_ticker().await;
437    let token = 94977; // bata
438    let mode = Mode::LTP;
439    let new_mode = Mode::Quote;
440    let mut sb = ticker
441      .subscribe(&[token], Some(mode.clone()))
442      .await
443      .unwrap();
444
445    let f1: Option<Box<dyn Fn(Vec<TickMessage>) -> ()>> = None;
446    let f2: Option<Box<dyn Fn(Vec<TickMessage>) -> ()>> = None;
447    check(mode, token, &mut sb, f1).await;
448    sb.set_mode(&[], new_mode.clone()).await.unwrap();
449    check(new_mode, token, &mut sb, f2).await;
450
451    sb.ticker.close().await.unwrap();
452  }
453
454  #[tokio::test]
455  async fn test_new_sub() {
456    let ticker = create_ticker().await;
457    let token = 94977; // bata
458    let mode = Mode::LTP;
459    let mut sb = ticker
460      .subscribe(&[token], Some(mode.clone()))
461      .await
462      .unwrap();
463    tokio::spawn(async move {
464      sb.subscribe(&[2953217], None).await.unwrap();
465    })
466    .await
467    .unwrap();
468  }
469}