kalshi_rust/websocket/
connection.rs

1use crate::kalshi_error::KalshiError;
2use crate::TradingEnvironment;
3use futures_util::{stream::SplitSink, stream::SplitStream, SinkExt, StreamExt};
4use openssl::pkey::{PKey, Private};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::net::TcpStream;
9use tokio::sync::{oneshot, Mutex};
10use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
11
12type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
13type WsSink = SplitSink<WsStream, Message>;
14type WsReader = SplitStream<WsStream>;
15
16/// Response from a WebSocket command.
17///
18/// When you send commands to the WebSocket server (subscribe, unsubscribe, etc.),
19/// the server responds with one of these message types to confirm or reject the action.
20///
21/// # Variants
22///
23/// - `Ok`: Command was successful
24/// - `Error`: Command failed (includes error code and message)
25/// - `Subscribed`: Subscription confirmed (includes subscription ID and channel name)
26#[derive(Debug, Clone)]
27pub enum CommandResponse {
28    /// Successful acknowledgment from the server.
29    ///
30    /// # Fields
31    /// - `id`: The command ID that was acknowledged
32    Ok { id: i32 },
33
34    /// Error response from the server.
35    ///
36    /// # Fields
37    /// - `code`: Numeric error code
38    /// - `msg`: Human-readable error message
39    Error { code: i32, msg: String },
40
41    /// Subscription confirmation with assigned subscription ID.
42    ///
43    /// # Fields
44    /// - `sid`: Subscription ID assigned by the server
45    /// - `channel`: The channel name that was subscribed to
46    Subscribed { sid: i32, channel: String },
47}
48
49/// Default timeout for waiting on command responses (in seconds).
50const DEFAULT_COMMAND_TIMEOUT_SECS: u64 = 10;
51
52/// WebSocket client for real-time Kalshi market data and trading events.
53///
54/// `KalshiWebSocket` provides a persistent, authenticated connection to the Kalshi
55/// WebSocket API for streaming market data and portfolio updates. The client handles
56/// authentication, subscription management, and message routing automatically.
57///
58/// # Features
59///
60/// - **Automatic authentication** using RSA-PSS signing
61/// - **Subscription management** with support for multiple simultaneous channels
62/// - **Async streaming** interface compatible with Tokio and futures
63/// - **Connection lifecycle** management (connect, disconnect, reconnect)
64/// - **Type-safe messages** via the [`WebSocketMessage`](super::WebSocketMessage) enum
65///
66/// # Creating a Client
67///
68/// The WebSocket client is typically created from an existing [`Kalshi`](crate::Kalshi)
69/// instance using the [`websocket()`](crate::Kalshi::websocket) method, which automatically
70/// transfers the authentication credentials.
71///
72/// ```rust,ignore
73/// use kalshi::{Kalshi, TradingEnvironment};
74///
75/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
76/// let kalshi = Kalshi::new(
77///     TradingEnvironment::DemoMode,
78///     "your-key-id",
79///     "path/to/private.pem"
80/// ).await?;
81///
82/// let mut ws = kalshi.websocket();
83/// # Ok(())
84/// # }
85/// ```
86///
87/// # Connection Flow
88///
89/// 1. **Create** the client (does not connect automatically)
90/// 2. **Connect** with [`connect()`](KalshiWebSocket::connect)
91/// 3. **Subscribe** to channels using subscription methods
92/// 4. **Stream** messages using the [`messages()`](KalshiWebSocket::messages) stream
93/// 5. **Disconnect** with [`disconnect()`](KalshiWebSocket::disconnect) when done
94///
95/// # Example Usage
96///
97/// ```rust,ignore
98/// use kalshi::{Kalshi, TradingEnvironment, WebSocketMessage};
99/// use futures_util::StreamExt;
100///
101/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
102/// let kalshi = Kalshi::new(TradingEnvironment::DemoMode, "key", "key.pem").await?;
103/// let mut ws = kalshi.websocket();
104///
105/// // Connect to WebSocket
106/// ws.connect().await?;
107///
108/// // Subscribe to channels
109/// ws.subscribe_to_ticker("HIGHNY-24JAN15-T50").await?;
110/// ws.subscribe_to_fills().await?;
111///
112/// // Process messages
113/// let mut stream = ws.messages();
114/// while let Some(msg) = stream.next().await {
115///     match msg {
116///         WebSocketMessage::Ticker(ticker) => {
117///             println!("Ticker update: {} @ {}", ticker.ticker, ticker.last_price);
118///         }
119///         WebSocketMessage::Fill(fill) => {
120///             println!("Fill: {} contracts on {}", fill.count, fill.ticker);
121///         }
122///         _ => {}
123///     }
124/// }
125///
126/// // Clean disconnect
127/// ws.disconnect().await?;
128/// # Ok(())
129/// # }
130/// ```
131///
132/// # Thread Safety
133///
134/// The WebSocket client is not `Send` or `Sync` and must be used from a single async task.
135/// The internal writer is wrapped in an `Arc<Mutex<>>` to allow sharing across message
136/// processing, but the overall client should not be shared across threads.
137pub struct KalshiWebSocket {
138    url: String,
139    key_id: String,
140    private_key: PKey<Private>,
141    writer: Option<Arc<Mutex<WsSink>>>,
142    reader: Option<WsReader>,
143    next_id: i32,
144    pub(crate) subscriptions: HashMap<i32, super::Subscription>,
145    /// Pending command response channels, keyed by command ID.
146    pending_commands: HashMap<i32, oneshot::Sender<CommandResponse>>,
147}
148
149impl KalshiWebSocket {
150    /// Creates a new WebSocket client without establishing a connection.
151    ///
152    /// This method initializes the WebSocket client with the necessary credentials
153    /// but does not open a network connection. Call [`connect()`](KalshiWebSocket::connect)
154    /// to establish the connection.
155    ///
156    /// # Arguments
157    ///
158    /// * `trading_env` - The trading environment (DemoMode or ProdMode)
159    /// * `key_id` - Your Kalshi API key ID
160    /// * `private_key` - Your RSA private key for signing authentication requests
161    ///
162    /// # Returns
163    ///
164    /// A new `KalshiWebSocket` instance ready to connect.
165    ///
166    /// # Example
167    ///
168    /// ```rust,ignore
169    /// use kalshi::{TradingEnvironment, KalshiWebSocket};
170    /// use openssl::pkey::PKey;
171    /// use std::fs;
172    ///
173    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
174    /// let pem = fs::read("path/to/private.pem")?;
175    /// let private_key = PKey::private_key_from_pem(&pem)?;
176    ///
177    /// let ws = KalshiWebSocket::new(
178    ///     TradingEnvironment::DemoMode,
179    ///     "your-key-id",
180    ///     private_key
181    /// );
182    /// # Ok(())
183    /// # }
184    /// ```
185    ///
186    /// # Note
187    ///
188    /// Most users should create the WebSocket client via [`Kalshi::websocket()`](crate::Kalshi::websocket)
189    /// which handles credential transfer automatically.
190    pub fn new(trading_env: TradingEnvironment, key_id: &str, private_key: PKey<Private>) -> Self {
191        let url = match trading_env {
192            TradingEnvironment::DemoMode => "wss://demo-api.kalshi.co/trade-api/ws/v2",
193            TradingEnvironment::ProdMode => "wss://api.elections.kalshi.com/trade-api/ws/v2",
194        };
195
196        Self {
197            url: url.to_string(),
198            key_id: key_id.to_string(),
199            private_key,
200            writer: None,
201            reader: None,
202            next_id: 1,
203            subscriptions: HashMap::new(),
204            pending_commands: HashMap::new(),
205        }
206    }
207
208    /// Connects to the WebSocket server with automatic authentication.
209    ///
210    /// This method establishes a WebSocket connection to the Kalshi exchange and
211    /// performs RSA-PSS authentication using the provided credentials. The connection
212    /// is authenticated at connection time via query parameters.
213    ///
214    /// # Returns
215    ///
216    /// - `Ok(())`: Connection established successfully
217    /// - `Err(KalshiError)`: Connection or authentication failed
218    ///
219    /// # Errors
220    ///
221    /// This method can return errors for:
222    /// - Network connectivity issues
223    /// - Invalid credentials (authentication failure)
224    /// - Server unavailability
225    /// - SSL/TLS errors
226    ///
227    /// # Example
228    ///
229    /// ```rust,ignore
230    /// # use kalshi::KalshiWebSocket;
231    /// # async fn example(mut ws: KalshiWebSocket) -> Result<(), Box<dyn std::error::Error>> {
232    /// ws.connect().await?;
233    /// println!("Connected to WebSocket!");
234    /// # Ok(())
235    /// # }
236    /// ```
237    ///
238    /// # Connection Process
239    ///
240    /// 1. Generates a timestamp and authentication signature
241    /// 2. Constructs the WebSocket URL with authentication parameters
242    /// 3. Establishes the WebSocket connection
243    /// 4. Splits the connection into reader and writer halves for async processing
244    pub async fn connect(&mut self) -> Result<(), KalshiError> {
245        let timestamp = chrono::Utc::now().timestamp_millis();
246        let method = "GET";
247        let path = "/trade-api/ws/v2";
248
249        let message = format!("{}{}{}", timestamp, method, path);
250        let signature = self.sign_message(&message)?;
251
252        // Build URL with properly encoded query parameters
253        let mut url = reqwest::Url::parse(&self.url)
254            .map_err(|e| KalshiError::InternalError(format!("Invalid WebSocket URL: {}", e)))?;
255        url.query_pairs_mut()
256            .append_pair("api-key", &self.key_id)
257            .append_pair("timestamp", &timestamp.to_string())
258            .append_pair("signature", &signature);
259
260        let auth_url = url.to_string();
261
262        let (ws_stream, _response) = connect_async(&auth_url)
263            .await
264            .map_err(|e| KalshiError::InternalError(format!("WebSocket connect failed: {}", e)))?;
265
266        let (write, read) = ws_stream.split();
267        self.writer = Some(Arc::new(Mutex::new(write)));
268        self.reader = Some(read);
269
270        Ok(())
271    }
272
273    /// Disconnects from the WebSocket server gracefully.
274    ///
275    /// This method closes the WebSocket connection, clears all subscriptions,
276    /// and resets the client state. After disconnecting, you can call
277    /// [`connect()`](KalshiWebSocket::connect) again to re-establish the connection.
278    ///
279    /// # Returns
280    ///
281    /// - `Ok(())`: Disconnected successfully
282    /// - `Err(KalshiError)`: Error during disconnection
283    ///
284    /// # Example
285    ///
286    /// ```rust,ignore
287    /// # use kalshi::KalshiWebSocket;
288    /// # async fn example(mut ws: KalshiWebSocket) -> Result<(), Box<dyn std::error::Error>> {
289    /// // Use the connection...
290    /// ws.connect().await?;
291    /// // Do work...
292    ///
293    /// // Clean disconnect when done
294    /// ws.disconnect().await?;
295    /// # Ok(())
296    /// # }
297    /// ```
298    ///
299    /// # Note
300    ///
301    /// All active subscriptions are removed when disconnecting. You will need to
302    /// re-subscribe after reconnecting.
303    pub async fn disconnect(&mut self) -> Result<(), KalshiError> {
304        if let Some(writer) = &self.writer {
305            let mut w = writer.lock().await;
306            w.close()
307                .await
308                .map_err(|e| KalshiError::InternalError(format!("Close failed: {}", e)))?;
309        }
310        self.writer = None;
311        self.reader = None;
312        self.subscriptions.clear();
313        self.pending_commands.clear();
314        Ok(())
315    }
316
317    /// Returns `true` if the WebSocket connection is currently active.
318    ///
319    /// This checks whether the internal writer stream is initialized, which
320    /// indicates an active connection.
321    ///
322    /// # Returns
323    ///
324    /// - `true`: Connected to the WebSocket server
325    /// - `false`: Not connected (either never connected or disconnected)
326    ///
327    /// # Example
328    ///
329    /// ```rust,ignore
330    /// # use kalshi::KalshiWebSocket;
331    /// # async fn example(mut ws: KalshiWebSocket) -> Result<(), Box<dyn std::error::Error>> {
332    /// assert!(!ws.is_connected());
333    ///
334    /// ws.connect().await?;
335    /// assert!(ws.is_connected());
336    ///
337    /// ws.disconnect().await?;
338    /// assert!(!ws.is_connected());
339    /// # Ok(())
340    /// # }
341    /// ```
342    pub fn is_connected(&self) -> bool {
343        self.writer.is_some()
344    }
345
346    fn sign_message(&self, message: &str) -> Result<String, KalshiError> {
347        use openssl::hash::MessageDigest;
348        use openssl::rsa::Padding;
349        use openssl::sign::Signer;
350
351        let mut signer = Signer::new(MessageDigest::sha256(), &self.private_key)?;
352        signer.set_rsa_padding(Padding::PKCS1_PSS)?;
353        signer.set_rsa_pss_saltlen(openssl::sign::RsaPssSaltlen::DIGEST_LENGTH)?;
354        signer.update(message.as_bytes())?;
355        let signature = signer.sign_to_vec()?;
356        Ok(base64::Engine::encode(
357            &base64::engine::general_purpose::STANDARD,
358            &signature,
359        ))
360    }
361
362    pub(crate) fn get_next_id(&mut self) -> i32 {
363        let id = self.next_id;
364        self.next_id += 1;
365        id
366    }
367
368    /// Sends a command to the WebSocket server.
369    pub(crate) async fn send_command(&mut self, cmd: serde_json::Value) -> Result<(), KalshiError> {
370        let writer = self
371            .writer
372            .as_ref()
373            .ok_or_else(|| KalshiError::InternalError("Not connected".to_string()))?;
374
375        let msg = Message::Text(serde_json::to_string(&cmd)?);
376        let mut w = writer.lock().await;
377        w.send(msg)
378            .await
379            .map_err(|e| KalshiError::InternalError(format!("Send failed: {}", e)))?;
380        Ok(())
381    }
382
383    /// Registers a pending command to receive its response.
384    pub(crate) fn register_pending_command(
385        &mut self,
386        id: i32,
387    ) -> oneshot::Receiver<CommandResponse> {
388        let (tx, rx) = oneshot::channel();
389        self.pending_commands.insert(id, tx);
390        rx
391    }
392
393    /// Routes a command response to the appropriate pending command.
394    /// Returns true if the response was routed, false if no pending command was found.
395    pub(crate) fn route_response(&mut self, id: i32, response: CommandResponse) -> bool {
396        if let Some(sender) = self.pending_commands.remove(&id) {
397            // Ignore send error - receiver may have been dropped
398            let _ = sender.send(response);
399            true
400        } else {
401            false
402        }
403    }
404
405    /// Waits for a single command response with timeout.
406    pub(crate) async fn wait_for_response(
407        &mut self,
408        rx: oneshot::Receiver<CommandResponse>,
409    ) -> Result<CommandResponse, KalshiError> {
410        match tokio::time::timeout(Duration::from_secs(DEFAULT_COMMAND_TIMEOUT_SECS), rx).await {
411            Ok(Ok(response)) => Ok(response),
412            Ok(Err(_)) => Err(KalshiError::InternalError(
413                "Response channel closed unexpectedly".to_string(),
414            )),
415            Err(_) => Err(KalshiError::InternalError(
416                "Timeout waiting for command response".to_string(),
417            )),
418        }
419    }
420
421    /// Waits for multiple command responses (e.g., multiple `subscribed` messages).
422    /// Returns responses in the order they are received.
423    pub(crate) async fn wait_for_responses(
424        &mut self,
425        mut receivers: Vec<(i32, oneshot::Receiver<CommandResponse>)>,
426        expected_count: usize,
427    ) -> Result<Vec<CommandResponse>, KalshiError> {
428        let mut responses = Vec::with_capacity(expected_count);
429        let deadline =
430            tokio::time::Instant::now() + Duration::from_secs(DEFAULT_COMMAND_TIMEOUT_SECS);
431
432        while responses.len() < expected_count && !receivers.is_empty() {
433            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
434            if remaining.is_zero() {
435                return Err(KalshiError::InternalError(
436                    "Timeout waiting for all command responses".to_string(),
437                ));
438            }
439
440            // Try to read more messages to route responses
441            if let Some(reader) = self.reader.as_mut() {
442                match tokio::time::timeout(Duration::from_millis(100), reader.next()).await {
443                    Ok(Some(Ok(Message::Text(text)))) => {
444                        if let Ok(msg) = super::WebSocketMessage::parse(&text) {
445                            self.handle_control_message(&msg);
446                        }
447                    }
448                    Ok(Some(Ok(_))) => {
449                        // Non-text message, ignore
450                    }
451                    Ok(Some(Err(_))) | Ok(None) => {
452                        return Err(KalshiError::InternalError(
453                            "WebSocket connection closed".to_string(),
454                        ));
455                    }
456                    Err(_) => {
457                        // Timeout on read, continue checking receivers
458                    }
459                }
460            }
461
462            // Check which receivers have responses ready
463            let mut i = 0;
464            while i < receivers.len() {
465                match receivers[i].1.try_recv() {
466                    Ok(response) => {
467                        responses.push(response);
468                        receivers.remove(i);
469                    }
470                    Err(oneshot::error::TryRecvError::Empty) => {
471                        i += 1;
472                    }
473                    Err(oneshot::error::TryRecvError::Closed) => {
474                        // Channel closed without response
475                        receivers.remove(i);
476                    }
477                }
478            }
479        }
480
481        if responses.len() < expected_count {
482            return Err(KalshiError::InternalError(format!(
483                "Expected {} responses, got {}",
484                expected_count,
485                responses.len()
486            )));
487        }
488
489        Ok(responses)
490    }
491
492    /// Handles control messages (subscribed, ok, error) and routes them to pending commands.
493    pub(crate) fn handle_control_message(&mut self, msg: &super::WebSocketMessage) {
494        match msg {
495            super::WebSocketMessage::Subscribed(sub_msg) => {
496                // For subscribed messages, we need to find the pending command by iterating
497                // since the server response doesn't include the original command ID directly.
498                // Instead, we route based on channel matching for the most recently registered command.
499                // Note: This is a simplification. In practice, we track by command ID.
500                let response = CommandResponse::Subscribed {
501                    sid: sub_msg.sid,
502                    channel: sub_msg.channel.clone(),
503                };
504                // Try to route to any pending command (they should be waiting for subscribed responses)
505                if let Some((&id, _)) = self.pending_commands.iter().next() {
506                    self.route_response(id, response);
507                }
508            }
509            super::WebSocketMessage::Ok(ok_msg) => {
510                let response = CommandResponse::Ok { id: ok_msg.sid };
511                self.route_response(ok_msg.sid, response);
512            }
513            super::WebSocketMessage::Error(err_msg) => {
514                let response = CommandResponse::Error {
515                    code: err_msg.code,
516                    msg: err_msg.msg.clone(),
517                };
518                // Route to the first pending command since errors don't have command IDs
519                if let Some((&id, _)) = self.pending_commands.iter().next() {
520                    self.route_response(id, response);
521                }
522            }
523            _ => {
524                // Non-control message, ignore
525            }
526        }
527    }
528}
529
530// Stream interface (Task 4.7)
531use futures_util::Stream;
532use std::pin::Pin;
533use std::task::{Context, Poll};
534
535impl KalshiWebSocket {
536    /// Returns an asynchronous stream of WebSocket messages.
537    ///
538    /// This method provides a [`Stream`](futures_util::Stream) interface for receiving
539    /// messages from the WebSocket connection. The stream yields
540    /// [`WebSocketMessage`](super::WebSocketMessage) items as they arrive.
541    ///
542    /// # Returns
543    ///
544    /// A stream that yields `WebSocketMessage` items. The stream ends when the
545    /// connection is closed.
546    ///
547    /// # Example
548    ///
549    /// ```rust,ignore
550    /// use kalshi::{KalshiWebSocket, WebSocketMessage};
551    /// use futures_util::StreamExt;
552    ///
553    /// # async fn example(mut ws: KalshiWebSocket) -> Result<(), Box<dyn std::error::Error>> {
554    /// ws.connect().await?;
555    /// ws.subscribe_to_ticker("HIGHNY-24JAN15-T50").await?;
556    ///
557    /// let mut stream = ws.messages();
558    /// while let Some(msg) = stream.next().await {
559    ///     match msg {
560    ///         WebSocketMessage::Ticker(ticker) => {
561    ///             println!("Price update: {}", ticker.last_price);
562    ///         }
563    ///         WebSocketMessage::Heartbeat(_) => {
564    ///             println!("Keepalive heartbeat");
565    ///         }
566    ///         _ => {}
567    ///     }
568    /// }
569    /// # Ok(())
570    /// # }
571    /// ```
572    ///
573    /// # Message Types
574    ///
575    /// The stream can yield any of these message types:
576    /// - `OrderbookDelta` - Incremental orderbook updates
577    /// - `OrderbookSnapshot` - Full orderbook snapshots
578    /// - `Ticker` - Best bid/ask and last price updates
579    /// - `Trade` / `Trades` - Trade executions
580    /// - `Fill` - Your order fills (authenticated)
581    /// - `Order` - Your order updates (authenticated)
582    /// - `Heartbeat` - Keepalive messages
583    /// - `Subscribed` / `Ok` / `Error` - Control messages
584    ///
585    /// # Performance
586    ///
587    /// The stream processes messages as they arrive. Control messages (subscribed, ok, error)
588    /// are automatically routed to pending command handlers and also yielded to the stream.
589    pub fn messages(&mut self) -> impl Stream<Item = super::WebSocketMessage> + '_ {
590        MessageStream { ws: self }
591    }
592}
593
594struct MessageStream<'a> {
595    ws: &'a mut KalshiWebSocket,
596}
597
598impl<'a> Stream for MessageStream<'a> {
599    type Item = super::WebSocketMessage;
600
601    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
602        let reader = match self.ws.reader.as_mut() {
603            Some(r) => r,
604            None => return Poll::Ready(None),
605        };
606
607        match Pin::new(reader).poll_next(cx) {
608            Poll::Ready(Some(Ok(Message::Text(text)))) => {
609                match super::WebSocketMessage::parse(&text) {
610                    Ok(msg) => {
611                        // Route control messages to pending commands
612                        self.ws.handle_control_message(&msg);
613                        Poll::Ready(Some(msg))
614                    }
615                    Err(_) => {
616                        cx.waker().wake_by_ref();
617                        Poll::Pending
618                    }
619                }
620            }
621            Poll::Ready(Some(Ok(Message::Ping(_)))) => {
622                cx.waker().wake_by_ref();
623                Poll::Pending
624            }
625            Poll::Ready(Some(Ok(_))) => {
626                cx.waker().wake_by_ref();
627                Poll::Pending
628            }
629            Poll::Ready(Some(Err(_))) => Poll::Ready(None),
630            Poll::Ready(None) => Poll::Ready(None),
631            Poll::Pending => Poll::Pending,
632        }
633    }
634}