1use crate::models::{Mode, Request, TextMessage, Tick, TickMessage, TickerMessage};
2use crate::parser::packet_length;
3use byteorder::{BigEndian, ByteOrder};
4use futures_util::{SinkExt, StreamExt};
5use serde_json::json;
6use std::collections::HashMap;
7use tokio::sync::{broadcast, mpsc};
8use tokio::task::JoinHandle;
9use tokio_tungstenite::{connect_async, tungstenite::Message};
10
11#[derive(Debug)]
12pub struct KiteTickerAsync {
16 #[allow(dead_code)]
17 api_key: String,
18 #[allow(dead_code)]
19 access_token: String,
20 cmd_tx: Option<mpsc::UnboundedSender<Message>>,
21 msg_tx: broadcast::Sender<TickerMessage>,
22 writer_handle: Option<JoinHandle<()>>,
23 reader_handle: Option<JoinHandle<()>>,
24}
25
26impl KiteTickerAsync {
27 pub async fn connect(
29 api_key: &str,
30 access_token: &str,
31 ) -> Result<Self, String> {
32 let socket_url = format!(
33 "wss://{}?api_key={}&access_token={}",
34 "ws.kite.trade", api_key, access_token
35 );
36 let url = url::Url::parse(socket_url.as_str()).unwrap();
37
38 let (ws_stream, _) = connect_async(url).await.map_err(|e| e.to_string())?;
39
40 let (write_half, mut read_half) = ws_stream.split();
41
42 let (cmd_tx, mut cmd_rx) = mpsc::unbounded_channel::<Message>();
43 let (msg_tx, _) = broadcast::channel(1000);
45 let mut write = write_half;
46 let writer_handle = tokio::spawn(async move {
47 while let Some(msg) = cmd_rx.recv().await {
48 if write.send(msg).await.is_err() {
49 break;
50 }
51 }
52 });
53
54 let msg_sender = msg_tx.clone();
55 let reader_handle = tokio::spawn(async move {
56 while let Some(message) = read_half.next().await {
57 match message {
58 Ok(msg) => {
59 if let Some(processed_msg) = process_message(msg) {
61 let _ = msg_sender.send(processed_msg);
64 }
65 }
66 Err(e) => {
67 let error_msg = TickerMessage::Error(format!("WebSocket error: {}", e));
69 let _ = msg_sender.send(error_msg);
70
71 if matches!(e, tokio_tungstenite::tungstenite::Error::ConnectionClosed |
73 tokio_tungstenite::tungstenite::Error::AlreadyClosed) {
74 break;
75 }
76 }
77 }
78 }
79 });
80
81 Ok(KiteTickerAsync {
82 api_key: api_key.to_string(),
83 access_token: access_token.to_string(),
84 cmd_tx: Some(cmd_tx),
85 msg_tx,
86 writer_handle: Some(writer_handle),
87 reader_handle: Some(reader_handle),
88 })
89 }
90
91 pub async fn subscribe(
93 mut self,
94 instrument_tokens: &[u32],
95 mode: Option<Mode>,
96 ) -> Result<KiteTickerSubscriber, String> {
97 self
98 .subscribe_cmd(instrument_tokens, mode.as_ref())
99 .await
100 .expect("failed to subscribe");
101 let default_mode = mode.unwrap_or_default();
102 let st = instrument_tokens
103 .iter()
104 .map(|&t| (t, default_mode.clone()))
105 .collect();
106
107 let rx = self.msg_tx.subscribe();
108 Ok(KiteTickerSubscriber {
109 ticker: self,
110 subscribed_tokens: st,
111 rx,
112 })
113 }
114
115 pub async fn close(&mut self) -> Result<(), String> {
117 if let Some(tx) = self.cmd_tx.take() {
118 let _ = tx.send(Message::Close(None));
119 }
120 if let Some(handle) = self.writer_handle.take() {
121 let _ = handle.await.map_err(|e| e.to_string())?;
122 }
123 if let Some(handle) = self.reader_handle.take() {
124 let _ = handle.await.map_err(|e| e.to_string())?;
125 }
126 Ok(())
127 }
128
129 async fn subscribe_cmd(
130 &mut self,
131 instrument_tokens: &[u32],
132 mode: Option<&Mode>,
133 ) -> Result<(), String> {
134 let mode_value = mode.cloned().unwrap_or_default();
135 let msgs = vec![
136 Message::Text(Request::subscribe(instrument_tokens.to_vec()).to_string()),
137 Message::Text(
138 Request::mode(mode_value, instrument_tokens.to_vec())
139 .to_string(),
140 ),
141 ];
142
143 for msg in msgs {
144 if let Some(tx) = &self.cmd_tx {
145 tx.send(msg).map_err(|e| e.to_string())?;
146 }
147 }
148
149 Ok(())
150 }
151
152 async fn unsubscribe_cmd(
153 &mut self,
154 instrument_tokens: &[u32],
155 ) -> Result<(), String> {
156 if let Some(tx) = &self.cmd_tx {
157 tx.send(Message::Text(
158 Request::unsubscribe(instrument_tokens.to_vec()).to_string(),
159 ))
160 .map_err(|e| e.to_string())?;
161 }
162 Ok(())
163 }
164
165 async fn set_mode_cmd(
166 &mut self,
167 instrument_tokens: &[u32],
168 mode: Mode,
169 ) -> Result<(), String> {
170 if let Some(tx) = &self.cmd_tx {
171 tx.send(Message::Text(
172 Request::mode(mode, instrument_tokens.to_vec()).to_string(),
173 ))
174 .map_err(|e| e.to_string())?;
175 }
176 Ok(())
177 }
178
179 pub fn is_connected(&self) -> bool {
181 self.cmd_tx.is_some() &&
182 self.writer_handle.as_ref().map_or(false, |h| !h.is_finished()) &&
183 self.reader_handle.as_ref().map_or(false, |h| !h.is_finished())
184 }
185
186 pub async fn ping(&mut self) -> Result<(), String> {
188 if let Some(tx) = &self.cmd_tx {
189 tx.send(Message::Ping(vec![])).map_err(|e| e.to_string())?;
190 Ok(())
191 } else {
192 Err("Connection is closed".to_string())
193 }
194 }
195
196 pub fn receiver_count(&self) -> usize {
198 self.msg_tx.receiver_count()
199 }
200
201 pub fn channel_capacity(&self) -> usize {
203 1000 }
207}
208
209#[derive(Debug)]
210pub struct KiteTickerSubscriber {
214 ticker: KiteTickerAsync,
215 subscribed_tokens: HashMap<u32, Mode>,
216 rx: broadcast::Receiver<TickerMessage>,
217}
218
219impl KiteTickerSubscriber {
220 pub fn get_subscribed(&self) -> Vec<u32> {
222 self
223 .subscribed_tokens
224 .clone()
225 .into_keys()
226 .collect::<Vec<_>>()
227 }
228
229 fn get_subscribed_or(&self, tokens: &[u32]) -> Vec<u32> {
232 if tokens.len() == 0 {
233 self.get_subscribed()
234 } else {
235 tokens
236 .iter()
237 .filter(|t| self.subscribed_tokens.contains_key(t))
238 .map(|t| t.clone())
239 .collect::<Vec<_>>()
240 }
241 }
242
243 pub async fn subscribe(
245 &mut self,
246 tokens: &[u32],
247 mode: Option<Mode>,
248 ) -> Result<(), String> {
249 self.subscribed_tokens.extend(
250 tokens
251 .iter()
252 .map(|t| (t.clone(), mode.clone().unwrap_or_default())),
253 );
254 let tks = self.get_subscribed();
255 self.ticker.subscribe_cmd(tks.as_slice(), None).await?;
256 Ok(())
257 }
258
259 pub async fn set_mode(
261 &mut self,
262 instrument_tokens: &[u32],
263 mode: Mode,
264 ) -> Result<(), String> {
265 let tokens = self.get_subscribed_or(instrument_tokens);
266 self.ticker.set_mode_cmd(tokens.as_slice(), mode).await
267 }
268
269 pub async fn unsubscribe(
273 &mut self,
274 instrument_tokens: &[u32],
275 ) -> Result<(), String> {
276 let tokens = self.get_subscribed_or(instrument_tokens);
277 match self.ticker.unsubscribe_cmd(tokens.as_slice()).await {
278 Ok(_) => {
279 self.subscribed_tokens.retain(|k, _| !tokens.contains(k));
280 Ok(())
281 }
282 Err(e) => Err(e),
283 }
284 }
285
286 pub async fn next_message(
289 &mut self,
290 ) -> Result<Option<TickerMessage>, String> {
291 match self.rx.recv().await {
292 Ok(msg) => Ok(Some(msg)),
293 Err(broadcast::error::RecvError::Closed) => Ok(None),
294 Err(e) => Err(e.to_string()),
295 }
296 }
297
298 pub async fn close(&mut self) -> Result<(), String> {
299 self.ticker.close().await
300 }
301}
302
303fn process_message(message: Message) -> Option<TickerMessage> {
304 match message {
305 Message::Text(text_message) => process_text_message(text_message),
306 Message::Binary(ref binary_message) => {
307 if binary_message.len() < 2 {
308 return Some(TickerMessage::Ticks(vec![]));
309 } else {
310 process_binary(binary_message.as_slice())
311 }
312 }
313 Message::Close(closing_message) => closing_message.map(|c| {
314 TickerMessage::ClosingMessage(json!({
315 "code": c.code.to_string(),
316 "reason": c.reason.to_string()
317 }))
318 }),
319 Message::Ping(_) => None, Message::Pong(_) => None, Message::Frame(_) => None, }
323}
324
325fn process_binary(binary_message: &[u8]) -> Option<TickerMessage> {
326 if binary_message.len() < 2 {
327 return None;
328 }
329 let num_packets = BigEndian::read_u16(&binary_message[0..2]) as usize;
330 if num_packets > 0 {
331 let mut start = 2;
332 let mut ticks = Vec::with_capacity(num_packets);
333 for _ in 0..num_packets {
334 if start + 2 > binary_message.len() {
335 return Some(TickerMessage::Error("Invalid packet structure".to_string()));
336 }
337 let packet_len = packet_length(&binary_message[start..start + 2]);
338 let next_start = start + 2 + packet_len;
339 if next_start > binary_message.len() {
340 return Some(TickerMessage::Error("Packet length exceeds message size".to_string()));
341 }
342 match Tick::try_from(&binary_message[start + 2..next_start]) {
343 Ok(tick) => ticks.push(TickMessage::new(tick.instrument_token, tick)),
344 Err(e) => return Some(TickerMessage::Error(e.to_string())),
345 }
346 start = next_start;
347 }
348 Some(TickerMessage::Ticks(ticks))
349 } else {
350 None
351 }
352}
353
354fn process_text_message(text_message: String) -> Option<TickerMessage> {
355 serde_json::from_str::<TextMessage>(&text_message)
356 .map(|x| x.into())
357 .ok()
358}
359
360#[cfg(test)]
361mod tests {
362 use std::time::Duration;
363
364 use base64::{engine::general_purpose, Engine};
365
366 use crate::{DepthItem, Mode, Tick, OHLC};
367
368 fn load_packet(name: &str) -> Vec<u8> {
369 let str =
370 std::fs::read_to_string(format!("kiteconnect-mocks/{}.packet", name))
371 .map(|s| s.trim().to_string())
372 .expect("could not read file");
373 let ret = general_purpose::STANDARD
374 .decode(str)
375 .expect("could not decode");
376 ret
377 }
378
379 fn setup() -> Vec<(&'static str, Vec<u8>, Tick)> {
380 vec![
381 (
382 "quote packet",
383 load_packet("ticker_quote"),
384 Tick {
385 mode: Mode::Quote,
386 exchange: crate::Exchange::NSE,
387 instrument_token: 408065,
388 is_tradable: true,
389 is_index: false,
390 last_traded_timestamp: None,
391 exchange_timestamp: None,
392 last_price: Some(1573.15),
393 avg_traded_price: Some(1570.33),
394 last_traded_qty: Some(1),
395 total_buy_qty: Some(256511),
396 total_sell_qty: Some(360503),
397 volume_traded: Some(1175986),
398 ohlc: Some(OHLC {
399 open: 1569.15,
400 high: 1575.0,
401 low: 1561.05,
402 close: 1567.8,
403 }),
404 oi_day_high: None,
405 oi_day_low: None,
406 oi: None,
407 net_change: None,
408 depth: None,
409 },
410 ),
411 (
412 "full packet",
413 load_packet("ticker_full"),
414 Tick {
415 mode: Mode::Full,
416 exchange: crate::Exchange::NSE,
417 instrument_token: 408065,
418 is_tradable: true,
419 is_index: false,
420 last_traded_timestamp: Some(Duration::from_secs(
421 chrono::DateTime::parse_from_rfc3339("2021-07-05T10:41:27+05:30")
422 .unwrap()
423 .timestamp() as u64,
424 )),
425 exchange_timestamp: Some(Duration::from_secs(
426 chrono::DateTime::parse_from_rfc3339("2021-07-05T10:41:27+05:30")
427 .unwrap()
428 .timestamp() as u64,
429 )),
430 last_price: Some(1573.7),
431 avg_traded_price: Some(1570.37),
432 last_traded_qty: Some(7),
433 total_buy_qty: Some(256443),
434 total_sell_qty: Some(363009),
435 volume_traded: Some(1192471),
436 ohlc: Some(OHLC {
437 open: 1569.15,
438 high: 1575.0,
439 low: 1561.05,
440 close: 1567.8,
441 }),
442 oi_day_high: Some(0),
443 oi_day_low: Some(0),
444 oi: Some(0),
445 net_change: Some(5.900000000000091),
446 depth: Some(crate::Depth {
447 buy: [
448 DepthItem {
449 qty: 5,
450 price: 1573.4,
451 orders: 1,
452 },
453 DepthItem {
454 qty: 140,
455 price: 1573.0,
456 orders: 2,
457 },
458 DepthItem {
459 qty: 2,
460 price: 1572.95,
461 orders: 1,
462 },
463 DepthItem {
464 qty: 219,
465 price: 1572.9,
466 orders: 7,
467 },
468 DepthItem {
469 qty: 50,
470 price: 1572.85,
471 orders: 1,
472 },
473 ],
474 sell: [
475 DepthItem {
476 qty: 172,
477 price: 1573.7,
478 orders: 3,
479 },
480 DepthItem {
481 qty: 44,
482 price: 1573.75,
483 orders: 3,
484 },
485 DepthItem {
486 qty: 302,
487 price: 1573.85,
488 orders: 3,
489 },
490 DepthItem {
491 qty: 141,
492 price: 1573.9,
493 orders: 2,
494 },
495 DepthItem {
496 qty: 724,
497 price: 1573.95,
498 orders: 5,
499 },
500 ],
501 }),
502 },
503 ),
504 ]
505 }
506
507 #[test]
508 fn test_quotes() {
509 let data = setup();
510 for (name, packet, expected) in data {
511 let tick = Tick::try_from(packet.as_slice());
512 assert_eq!(tick.is_err(), false);
513 assert_eq!(tick.unwrap(), expected, "Testing {}", name);
514 }
515 }
516}