kalshi_rust/websocket/
subscription.rs

1use super::{Channel, CommandResponse, KalshiWebSocket};
2use crate::kalshi_error::KalshiError;
3use serde::{Deserialize, Serialize};
4
5/// Represents an active subscription.
6#[derive(Debug, Clone)]
7pub struct Subscription {
8    pub sid: i32,
9    pub channel: Channel,
10    pub market_tickers: Vec<String>,
11}
12
13#[derive(Debug, Clone, Deserialize, Serialize)]
14pub struct SubscribeResponse {
15    pub sid: i32,
16    pub channel: String,
17}
18
19#[derive(Debug, Serialize)]
20#[serde(rename_all = "snake_case")]
21pub enum UpdateAction {
22    AddMarkets,
23    DeleteMarkets,
24}
25
26/// Parses a channel string back to a Channel enum.
27fn parse_channel(channel_str: &str) -> Option<Channel> {
28    match channel_str {
29        "orderbook_delta" => Some(Channel::OrderbookDelta),
30        "ticker" => Some(Channel::Ticker),
31        "trade" => Some(Channel::Trade),
32        "fill" => Some(Channel::Fill),
33        "market_position" => Some(Channel::MarketPosition),
34        "market_lifecycle_v2" => Some(Channel::MarketLifecycleV2),
35        "event_lifecycle" => Some(Channel::EventLifecycle),
36        "multivariate" => Some(Channel::Multivariate),
37        "communications" => Some(Channel::Communications),
38        _ => None,
39    }
40}
41
42impl KalshiWebSocket {
43    /// Subscribe to one or more channels for specified markets.
44    ///
45    /// This method sends a subscribe command and waits for the server to confirm
46    /// each channel subscription. The returned `SubscribeResponse` contains the
47    /// assigned subscription IDs (SIDs) which can be used for unsubscribing.
48    ///
49    /// # Arguments
50    ///
51    /// * `channels` - List of channels to subscribe to
52    /// * `market_ticker` - Optional single market ticker to subscribe to
53    /// * `market_tickers` - Optional list of market tickers to subscribe to
54    ///
55    /// # Returns
56    ///
57    /// A vector of `SubscribeResponse` containing the SID and channel name for each subscription.
58    pub async fn subscribe(
59        &mut self,
60        channels: Vec<Channel>,
61        market_ticker: Option<String>,
62        market_tickers: Option<Vec<String>>,
63    ) -> Result<Vec<SubscribeResponse>, KalshiError> {
64        if channels.is_empty() {
65            return Ok(vec![]);
66        }
67
68        let id = self.get_next_id();
69        let expected_responses = channels.len();
70
71        // Collect the market tickers for storing in subscriptions
72        let tickers: Vec<String> = match (&market_ticker, &market_tickers) {
73            (Some(ticker), _) => vec![ticker.clone()],
74            (_, Some(tickers)) => tickers.clone(),
75            (None, None) => vec![],
76        };
77
78        // Register pending commands for each expected response
79        let mut receivers = Vec::with_capacity(expected_responses);
80        for i in 0..expected_responses {
81            let response_id = id + i as i32;
82            let rx = self.register_pending_command(response_id);
83            receivers.push((response_id, rx));
84        }
85
86        // Build and send the subscribe command
87        let mut cmd = serde_json::json!({
88            "id": id,
89            "cmd": "subscribe",
90            "params": {
91                "channels": channels.iter().map(|c| c.to_string()).collect::<Vec<_>>()
92            }
93        });
94
95        if let Some(ticker) = market_ticker {
96            cmd["params"]["market_ticker"] = serde_json::Value::String(ticker);
97        }
98        if let Some(tickers_list) = market_tickers {
99            cmd["params"]["market_tickers"] = serde_json::Value::Array(
100                tickers_list
101                    .into_iter()
102                    .map(serde_json::Value::String)
103                    .collect(),
104            );
105        }
106
107        self.send_command(cmd).await?;
108
109        // Wait for all subscription confirmations
110        let responses = self
111            .wait_for_responses(receivers, expected_responses)
112            .await?;
113
114        // Process responses and build result
115        let mut result = Vec::with_capacity(responses.len());
116        for response in responses {
117            match response {
118                CommandResponse::Subscribed { sid, channel } => {
119                    // Store the subscription
120                    if let Some(channel_enum) = parse_channel(&channel) {
121                        self.subscriptions.insert(
122                            sid,
123                            Subscription {
124                                sid,
125                                channel: channel_enum,
126                                market_tickers: tickers.clone(),
127                            },
128                        );
129                    }
130                    result.push(SubscribeResponse { sid, channel });
131                }
132                CommandResponse::Error { code, msg } => {
133                    return Err(KalshiError::InternalError(format!(
134                        "Subscribe failed with code {}: {}",
135                        code, msg
136                    )));
137                }
138                CommandResponse::Ok { .. } => {
139                    // Unexpected Ok response for subscribe, but not an error
140                }
141            }
142        }
143
144        Ok(result)
145    }
146
147    /// Unsubscribe from one or more subscriptions.
148    ///
149    /// Removes the specified subscriptions and waits for server confirmation.
150    /// Also removes the subscriptions from the local tracking.
151    ///
152    /// # Arguments
153    ///
154    /// * `sids` - List of subscription IDs to unsubscribe from
155    pub async fn unsubscribe(&mut self, sids: Vec<i32>) -> Result<(), KalshiError> {
156        if sids.is_empty() {
157            return Ok(());
158        }
159
160        let id = self.get_next_id();
161
162        // Register for response
163        let rx = self.register_pending_command(id);
164
165        let cmd = serde_json::json!({
166            "id": id,
167            "cmd": "unsubscribe",
168            "params": {
169                "sids": sids
170            }
171        });
172
173        self.send_command(cmd).await?;
174
175        // Wait for confirmation
176        let response = self.wait_for_response(rx).await?;
177
178        match response {
179            CommandResponse::Ok { .. } => {
180                // Remove subscriptions from local tracking
181                for sid in &sids {
182                    self.subscriptions.remove(sid);
183                }
184                Ok(())
185            }
186            CommandResponse::Error { code, msg } => Err(KalshiError::InternalError(format!(
187                "Unsubscribe failed with code {}: {}",
188                code, msg
189            ))),
190            CommandResponse::Subscribed { .. } => {
191                // Unexpected subscribed response, but not an error
192                for sid in &sids {
193                    self.subscriptions.remove(sid);
194                }
195                Ok(())
196            }
197        }
198    }
199
200    /// List all active subscriptions.
201    ///
202    /// Returns the locally tracked subscriptions. Note that this returns
203    /// the subscriptions that have been tracked by this client instance.
204    pub fn list_subscriptions(&self) -> Vec<Subscription> {
205        self.subscriptions.values().cloned().collect()
206    }
207
208    /// Get a subscription by its SID.
209    pub fn get_subscription(&self, sid: i32) -> Option<&Subscription> {
210        self.subscriptions.get(&sid)
211    }
212
213    /// Update an existing subscription by adding or removing markets.
214    ///
215    /// # Arguments
216    ///
217    /// * `sids` - List of subscription IDs to update
218    /// * `market_tickers` - List of market tickers to add or remove
219    /// * `action` - Whether to add or remove the market tickers
220    pub async fn update_subscription(
221        &mut self,
222        sids: Vec<i32>,
223        market_tickers: Vec<String>,
224        action: UpdateAction,
225    ) -> Result<(), KalshiError> {
226        if sids.is_empty() || market_tickers.is_empty() {
227            return Ok(());
228        }
229
230        let id = self.get_next_id();
231
232        // Register for response
233        let rx = self.register_pending_command(id);
234
235        let action_str = match action {
236            UpdateAction::AddMarkets => "add_markets",
237            UpdateAction::DeleteMarkets => "delete_markets",
238        };
239
240        let cmd = serde_json::json!({
241            "id": id,
242            "cmd": "update_subscription",
243            "params": {
244                "sids": sids,
245                "market_tickers": market_tickers,
246                "action": action_str
247            }
248        });
249
250        self.send_command(cmd).await?;
251
252        // Wait for confirmation
253        let response = self.wait_for_response(rx).await?;
254
255        match response {
256            CommandResponse::Ok { .. } => {
257                // Update local subscription tracking
258                for sid in &sids {
259                    if let Some(sub) = self.subscriptions.get_mut(sid) {
260                        match action {
261                            UpdateAction::AddMarkets => {
262                                for ticker in &market_tickers {
263                                    if !sub.market_tickers.contains(ticker) {
264                                        sub.market_tickers.push(ticker.clone());
265                                    }
266                                }
267                            }
268                            UpdateAction::DeleteMarkets => {
269                                sub.market_tickers.retain(|t| !market_tickers.contains(t));
270                            }
271                        }
272                    }
273                }
274                Ok(())
275            }
276            CommandResponse::Error { code, msg } => Err(KalshiError::InternalError(format!(
277                "Update subscription failed with code {}: {}",
278                code, msg
279            ))),
280            CommandResponse::Subscribed { .. } => {
281                // Unexpected subscribed response, but not an error
282                Ok(())
283            }
284        }
285    }
286}