kalshi_rust/websocket/
subscription.rs1use super::{Channel, CommandResponse, KalshiWebSocket};
2use crate::kalshi_error::KalshiError;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone)]
7pub struct Subscription {
8 pub sid: i32,
9 pub channel: Channel,
10 pub market_tickers: Vec<String>,
11}
12
13#[derive(Debug, Clone, Deserialize, Serialize)]
14pub struct SubscribeResponse {
15 pub sid: i32,
16 pub channel: String,
17}
18
19#[derive(Debug, Serialize)]
20#[serde(rename_all = "snake_case")]
21pub enum UpdateAction {
22 AddMarkets,
23 DeleteMarkets,
24}
25
26fn parse_channel(channel_str: &str) -> Option<Channel> {
28 match channel_str {
29 "orderbook_delta" => Some(Channel::OrderbookDelta),
30 "ticker" => Some(Channel::Ticker),
31 "trade" => Some(Channel::Trade),
32 "fill" => Some(Channel::Fill),
33 "market_position" => Some(Channel::MarketPosition),
34 "market_lifecycle_v2" => Some(Channel::MarketLifecycleV2),
35 "event_lifecycle" => Some(Channel::EventLifecycle),
36 "multivariate" => Some(Channel::Multivariate),
37 "communications" => Some(Channel::Communications),
38 _ => None,
39 }
40}
41
42impl KalshiWebSocket {
43 pub async fn subscribe(
59 &mut self,
60 channels: Vec<Channel>,
61 market_ticker: Option<String>,
62 market_tickers: Option<Vec<String>>,
63 ) -> Result<Vec<SubscribeResponse>, KalshiError> {
64 if channels.is_empty() {
65 return Ok(vec![]);
66 }
67
68 let id = self.get_next_id();
69 let expected_responses = channels.len();
70
71 let tickers: Vec<String> = match (&market_ticker, &market_tickers) {
73 (Some(ticker), _) => vec![ticker.clone()],
74 (_, Some(tickers)) => tickers.clone(),
75 (None, None) => vec![],
76 };
77
78 let mut receivers = Vec::with_capacity(expected_responses);
80 for i in 0..expected_responses {
81 let response_id = id + i as i32;
82 let rx = self.register_pending_command(response_id);
83 receivers.push((response_id, rx));
84 }
85
86 let mut cmd = serde_json::json!({
88 "id": id,
89 "cmd": "subscribe",
90 "params": {
91 "channels": channels.iter().map(|c| c.to_string()).collect::<Vec<_>>()
92 }
93 });
94
95 if let Some(ticker) = market_ticker {
96 cmd["params"]["market_ticker"] = serde_json::Value::String(ticker);
97 }
98 if let Some(tickers_list) = market_tickers {
99 cmd["params"]["market_tickers"] = serde_json::Value::Array(
100 tickers_list
101 .into_iter()
102 .map(serde_json::Value::String)
103 .collect(),
104 );
105 }
106
107 self.send_command(cmd).await?;
108
109 let responses = self
111 .wait_for_responses(receivers, expected_responses)
112 .await?;
113
114 let mut result = Vec::with_capacity(responses.len());
116 for response in responses {
117 match response {
118 CommandResponse::Subscribed { sid, channel } => {
119 if let Some(channel_enum) = parse_channel(&channel) {
121 self.subscriptions.insert(
122 sid,
123 Subscription {
124 sid,
125 channel: channel_enum,
126 market_tickers: tickers.clone(),
127 },
128 );
129 }
130 result.push(SubscribeResponse { sid, channel });
131 }
132 CommandResponse::Error { code, msg } => {
133 return Err(KalshiError::InternalError(format!(
134 "Subscribe failed with code {}: {}",
135 code, msg
136 )));
137 }
138 CommandResponse::Ok { .. } => {
139 }
141 }
142 }
143
144 Ok(result)
145 }
146
147 pub async fn unsubscribe(&mut self, sids: Vec<i32>) -> Result<(), KalshiError> {
156 if sids.is_empty() {
157 return Ok(());
158 }
159
160 let id = self.get_next_id();
161
162 let rx = self.register_pending_command(id);
164
165 let cmd = serde_json::json!({
166 "id": id,
167 "cmd": "unsubscribe",
168 "params": {
169 "sids": sids
170 }
171 });
172
173 self.send_command(cmd).await?;
174
175 let response = self.wait_for_response(rx).await?;
177
178 match response {
179 CommandResponse::Ok { .. } => {
180 for sid in &sids {
182 self.subscriptions.remove(sid);
183 }
184 Ok(())
185 }
186 CommandResponse::Error { code, msg } => Err(KalshiError::InternalError(format!(
187 "Unsubscribe failed with code {}: {}",
188 code, msg
189 ))),
190 CommandResponse::Subscribed { .. } => {
191 for sid in &sids {
193 self.subscriptions.remove(sid);
194 }
195 Ok(())
196 }
197 }
198 }
199
200 pub fn list_subscriptions(&self) -> Vec<Subscription> {
205 self.subscriptions.values().cloned().collect()
206 }
207
208 pub fn get_subscription(&self, sid: i32) -> Option<&Subscription> {
210 self.subscriptions.get(&sid)
211 }
212
213 pub async fn update_subscription(
221 &mut self,
222 sids: Vec<i32>,
223 market_tickers: Vec<String>,
224 action: UpdateAction,
225 ) -> Result<(), KalshiError> {
226 if sids.is_empty() || market_tickers.is_empty() {
227 return Ok(());
228 }
229
230 let id = self.get_next_id();
231
232 let rx = self.register_pending_command(id);
234
235 let action_str = match action {
236 UpdateAction::AddMarkets => "add_markets",
237 UpdateAction::DeleteMarkets => "delete_markets",
238 };
239
240 let cmd = serde_json::json!({
241 "id": id,
242 "cmd": "update_subscription",
243 "params": {
244 "sids": sids,
245 "market_tickers": market_tickers,
246 "action": action_str
247 }
248 });
249
250 self.send_command(cmd).await?;
251
252 let response = self.wait_for_response(rx).await?;
254
255 match response {
256 CommandResponse::Ok { .. } => {
257 for sid in &sids {
259 if let Some(sub) = self.subscriptions.get_mut(sid) {
260 match action {
261 UpdateAction::AddMarkets => {
262 for ticker in &market_tickers {
263 if !sub.market_tickers.contains(ticker) {
264 sub.market_tickers.push(ticker.clone());
265 }
266 }
267 }
268 UpdateAction::DeleteMarkets => {
269 sub.market_tickers.retain(|t| !market_tickers.contains(t));
270 }
271 }
272 }
273 }
274 Ok(())
275 }
276 CommandResponse::Error { code, msg } => Err(KalshiError::InternalError(format!(
277 "Update subscription failed with code {}: {}",
278 code, msg
279 ))),
280 CommandResponse::Subscribed { .. } => {
281 Ok(())
283 }
284 }
285 }
286}