1use futures_util::{stream::iter, SinkExt, StreamExt};
2use serde_json::json;
3use std::{collections::HashMap, sync::Arc};
4use tokio::net::TcpStream;
5use tokio::sync::Mutex;
6use tokio_tungstenite::{
7 connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream,
8};
9
10use crate::models::{
11 packet_length, Mode, Request, TextMessage, Tick, TickMessage, TickerMessage,
12};
13
14#[derive(Debug, Clone)]
15pub struct KiteTickerAsync {
19 #[allow(dead_code)]
20 api_key: String,
21 #[allow(dead_code)]
22 access_token: String,
23 ws_stream: Arc<Mutex<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
24}
25
26impl KiteTickerAsync {
27 pub async fn connect(
29 api_key: &str,
30 access_token: &str,
31 ) -> Result<Self, String> {
32 let socket_url = format!(
33 "wss://{}?api_key={}&access_token={}",
34 "ws.kite.trade", api_key, access_token
35 );
36 let url = url::Url::parse(socket_url.as_str()).unwrap();
37
38 let (ws_stream, _) = connect_async(url).await.map_err(|e| e.to_string())?;
39
40 Ok(KiteTickerAsync {
41 api_key: api_key.to_string(),
42 access_token: access_token.to_string(),
43 ws_stream: Arc::new(Mutex::new(ws_stream)),
44 })
45 }
46
47 pub async fn subscribe(
49 mut self,
50 instrument_tokens: &[u32],
51 mode: Option<Mode>,
52 ) -> Result<KiteTickerSubscriber, String> {
53 self
54 .subscribe_cmd(instrument_tokens, mode.clone())
55 .await
56 .expect("failed to subscribe");
57 let st = instrument_tokens
58 .to_vec()
59 .iter()
60 .map(|t| (t.clone(), mode.to_owned().unwrap_or_default()))
61 .collect();
62
63 Ok(KiteTickerSubscriber {
64 ticker: self,
65 subscribed_tokens: st,
66 })
67 }
68
69 pub async fn close(&mut self) -> Result<(), String> {
71 let mut ws_stream = self.ws_stream.lock().await;
72 ws_stream.close(None).await.map_err(|x| x.to_string())?;
73 Ok(())
74 }
75
76 async fn subscribe_cmd(
77 &mut self,
78 instrument_tokens: &[u32],
79 mode: Option<Mode>,
80 ) -> Result<(), String> {
81 let mut msgs = iter(vec![
82 Ok(Message::Text(
83 Request::subscribe(instrument_tokens.to_vec()).to_string(),
84 )),
85 Ok(Message::Text(
86 Request::mode(mode.unwrap_or_default(), instrument_tokens.to_vec())
87 .to_string(),
88 )),
89 ]);
90
91 let mut ws_stream = self.ws_stream.lock().await;
92
93 ws_stream
94 .send_all(msgs.by_ref())
95 .await
96 .expect("failed to send subscription message");
97
98 Ok(())
99 }
100
101 async fn unsubscribe_cmd(
102 &mut self,
103 instrument_tokens: &[u32],
104 ) -> Result<(), String> {
105 let mut ws_stream = self.ws_stream.lock().await;
106 ws_stream
107 .send(Message::Text(
108 Request::unsubscribe(instrument_tokens.to_vec()).to_string(),
109 ))
110 .await
111 .expect("failed to send unsubscribe message");
112 Ok(())
113 }
114
115 async fn set_mode_cmd(
116 &mut self,
117 instrument_tokens: &[u32],
118 mode: Mode,
119 ) -> Result<(), String> {
120 let mut ws_stream = self.ws_stream.lock().await;
121 ws_stream
122 .send(Message::Text(
123 Request::mode(mode, instrument_tokens.to_vec()).to_string(),
124 ))
125 .await
126 .expect("failed to send set mode message");
127 Ok(())
128 }
129}
130
131#[derive(Debug, Clone)]
132pub struct KiteTickerSubscriber {
136 ticker: KiteTickerAsync,
137 subscribed_tokens: HashMap<u32, Mode>,
138}
139
140impl KiteTickerSubscriber {
141 pub fn get_subscribed(&self) -> Vec<u32> {
143 self
144 .subscribed_tokens
145 .clone()
146 .into_keys()
147 .collect::<Vec<_>>()
148 }
149
150 fn get_subscribed_or(&self, tokens: &[u32]) -> Vec<u32> {
153 if tokens.len() == 0 {
154 self.get_subscribed()
155 } else {
156 tokens
157 .iter()
158 .filter(|t| self.subscribed_tokens.contains_key(t))
159 .map(|t| t.clone())
160 .collect::<Vec<_>>()
161 }
162 }
163
164 pub async fn subscribe(
166 &mut self,
167 tokens: &[u32],
168 mode: Option<Mode>,
169 ) -> Result<(), String> {
170 self.subscribed_tokens.extend(
171 tokens
172 .iter()
173 .map(|t| (t.clone(), mode.clone().unwrap_or_default())),
174 );
175 let tks = self.get_subscribed();
176 self.ticker.subscribe_cmd(tks.as_slice(), None).await?;
177 Ok(())
178 }
179
180 pub async fn set_mode(
182 &mut self,
183 instrument_tokens: &[u32],
184 mode: Mode,
185 ) -> Result<(), String> {
186 let tokens = self.get_subscribed_or(instrument_tokens);
187 self.ticker.set_mode_cmd(tokens.as_slice(), mode).await
188 }
189
190 pub async fn unsubscribe(
194 &mut self,
195 instrument_tokens: &[u32],
196 ) -> Result<(), String> {
197 let tokens = self.get_subscribed_or(instrument_tokens);
198 self.ticker.unsubscribe_cmd(tokens.as_slice()).await
199 }
200
201 pub async fn next_message(&mut self) -> Result<Option<TickerMessage>, String> {
204 let mut ws_stream = self.ticker.ws_stream.lock().await;
205 match ws_stream.next().await {
206 Some(message) => match message {
207 Ok(msg) => Ok(self.process_message(msg)),
208 Err(e) => Err(e.to_string()),
209 },
210 None => Ok(None),
211 }
212 }
213
214 fn process_message(&self, message: Message) -> Option<TickerMessage> {
215 match message {
216 Message::Text(text_message) => self.process_text_message(text_message),
217 Message::Binary(ref binary_message) => {
218 if binary_message.len() < 2 {
219 return Some(TickerMessage::Ticks(vec![]));
220 } else {
221 self.process_binary(binary_message.as_slice())
222 }
223 }
224 Message::Close(closing_message) => closing_message.map(|c| {
225 TickerMessage::ClosingMessage(json!({
226 "code": c.code.to_string(),
227 "reason": c.reason.to_string()
228 }))
229 }),
230 Message::Ping(_) => unimplemented!(),
231 Message::Pong(_) => unimplemented!(),
232 Message::Frame(_) => unimplemented!(),
233 }
234 }
235
236 fn process_binary(&self, binary_message: &[u8]) -> Option<TickerMessage> {
237 let num_packets =
239 i16::from_be_bytes(binary_message[0..=1].try_into().unwrap()) as usize;
240 if num_packets > 0 {
241 Some(TickerMessage::Ticks(
242 (0..num_packets)
243 .into_iter()
244 .fold((vec![], 2), |(mut acc, start), _| {
245 let packet_len = packet_length(&binary_message[start..start + 2]);
247 let next_start = start + 2 + packet_len;
248 let tick = Tick::from(&binary_message[start + 2..next_start]);
249 acc.push(TickMessage::new(tick.instrument_token, tick));
250 (acc, next_start)
251 })
252 .0,
253 ))
254 } else {
255 None
256 }
257 }
258
259 fn process_text_message(
260 &self,
261 text_message: String,
262 ) -> Option<TickerMessage> {
263 serde_json::from_str::<TextMessage>(&text_message)
264 .map(|x| x.into())
265 .ok()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use tokio::select;
273
274 async fn check<F>(
275 mode: Mode,
276 token: u32,
277 sb: &mut KiteTickerSubscriber,
278 assertions: Option<F>,
279 ) where
280 F: Fn(Vec<TickMessage>) -> (),
281 {
282 loop {
283 match sb.next_message().await {
284 Ok(message) => match message {
285 Some(TickerMessage::Ticks(xs)) => {
286 if xs.len() == 0 {
287 continue;
288 }
289 assertions.map(|f| f(xs.clone())).or_else(|| {
290 let tick_message = xs.first().unwrap();
291 assert!(tick_message.instrument_token == token);
292 assert_eq!(tick_message.content.mode, mode);
293 Some(())
294 });
295 break;
296 }
297 _ => {
298 continue;
299 }
300 },
301 _ => {
302 assert!(false);
303 break;
304 }
305 }
306 }
307 }
308
309 #[tokio::test]
310 async fn test_ticker() {
311 let api_key = std::env::var("KITE_API_KEY").unwrap();
312 let access_token = std::env::var("KITE_ACCESS_TOKEN").unwrap();
313 let ticker = KiteTickerAsync::connect(&api_key, &access_token).await;
314
315 assert_eq!(ticker.is_ok(), true);
316
317 let ticker = ticker.unwrap();
318 let token = 94977; let mode = Mode::Full;
320 let sb = ticker.subscribe(&[token], Some(mode.clone())).await;
321 assert_eq!(sb.is_ok(), true);
322 let mut sb = sb.unwrap();
323 assert_eq!(sb.subscribed_tokens.len(), 1);
324 let mut loop_cnt = 0;
325 loop {
326 loop_cnt += 1;
327 select! {
328 Ok(n) = sb.next_message() => {
329 match n.to_owned() {
330 Some(message) => {
331 match message {
332 TickerMessage::Ticks(xs) => {
333 if xs.len() == 0 {
334 if loop_cnt > 5 {
335 break;
336 }else {
337 continue;
338 }
339 }
340 assert_eq!(xs.len(), 1);
341 let tick_message = xs.first().unwrap();
342 assert!(tick_message.instrument_token == token);
343 assert_eq!(tick_message.content.mode, mode);
344 if loop_cnt > 5 {
345 break;
346 }
347 },
348 _ => {
349 if loop_cnt > 5 {
350 break;
351 }
352 }
353 }
354 },
355 _ => {
356 if loop_cnt > 5 {
357 assert!(false);
358 break;
359 }
360 }
361 }
362 },
363 else => {
364 assert!(false);
365 break;
366 }
367 }
368 }
369
370 sb.ticker.close().await.unwrap();
371 }
372
373 #[tokio::test]
374 async fn test_unsubscribe() {
375 let api_key = std::env::var("KITE_API_KEY").unwrap();
377 let access_token = std::env::var("KITE_ACCESS_TOKEN").unwrap();
378 let ticker = KiteTickerAsync::connect(&api_key, &access_token).await;
379
380 let ticker = ticker.unwrap();
381 let token = 94977; let mode = Mode::Full;
383 let mut sb = ticker
384 .subscribe(&[token], Some(mode.clone()))
385 .await
386 .unwrap();
387
388 let mut loop_cnt = 0;
389
390 loop {
391 match sb.next_message().await {
392 Ok(message) => match message {
393 Some(TickerMessage::Ticks(xs)) => {
394 if xs.len() == 0 {
395 if loop_cnt > 4 {
396 assert!(true);
397 break;
398 } else {
399 loop_cnt += 1;
400 continue;
401 }
402 }
403 assert_eq!(xs.len(), 1);
404 let tick_message = xs.first().unwrap();
405 assert!(tick_message.instrument_token == token);
406 sb.unsubscribe(&[]).await.unwrap();
407 loop_cnt += 1;
408 if loop_cnt > 5 {
409 assert!(false);
410 break;
411 }
412 }
413 _ => {
414 continue;
415 }
416 },
417 _ => {
418 assert!(false);
419 break;
420 }
421 }
422 }
423 sb.ticker.close().await.unwrap();
424 }
425
426 async fn create_ticker() -> KiteTickerAsync {
427 let api_key = std::env::var("KITE_API_KEY").unwrap();
429 let access_token = std::env::var("KITE_ACCESS_TOKEN").unwrap();
430 let ticker = KiteTickerAsync::connect(&api_key, &access_token).await;
431 ticker.expect("failed to create ticker")
432 }
433
434 #[tokio::test]
435 async fn test_set_mode() {
436 let ticker = create_ticker().await;
437 let token = 94977; let mode = Mode::LTP;
439 let new_mode = Mode::Quote;
440 let mut sb = ticker
441 .subscribe(&[token], Some(mode.clone()))
442 .await
443 .unwrap();
444
445 let f1: Option<Box<dyn Fn(Vec<TickMessage>) -> ()>> = None;
446 let f2: Option<Box<dyn Fn(Vec<TickMessage>) -> ()>> = None;
447 check(mode, token, &mut sb, f1).await;
448 sb.set_mode(&[], new_mode.clone()).await.unwrap();
449 check(new_mode, token, &mut sb, f2).await;
450
451 sb.ticker.close().await.unwrap();
452 }
453
454 #[tokio::test]
455 async fn test_new_sub() {
456 let ticker = create_ticker().await;
457 let token = 94977; let mode = Mode::LTP;
459 let mut sb = ticker
460 .subscribe(&[token], Some(mode.clone()))
461 .await
462 .unwrap();
463 tokio::spawn(async move {
464 sb.subscribe(&[2953217], None).await.unwrap();
465 })
466 .await
467 .unwrap();
468 }
469}