kalshi_rust/websocket/connection.rs
1use crate::kalshi_error::KalshiError;
2use crate::TradingEnvironment;
3use futures_util::{stream::SplitSink, stream::SplitStream, SinkExt, StreamExt};
4use openssl::pkey::{PKey, Private};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::net::TcpStream;
9use tokio::sync::{oneshot, Mutex};
10use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
11
12type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
13type WsSink = SplitSink<WsStream, Message>;
14type WsReader = SplitStream<WsStream>;
15
16/// Response from a WebSocket command.
17///
18/// When you send commands to the WebSocket server (subscribe, unsubscribe, etc.),
19/// the server responds with one of these message types to confirm or reject the action.
20///
21/// # Variants
22///
23/// - `Ok`: Command was successful
24/// - `Error`: Command failed (includes error code and message)
25/// - `Subscribed`: Subscription confirmed (includes subscription ID and channel name)
26#[derive(Debug, Clone)]
27pub enum CommandResponse {
28 /// Successful acknowledgment from the server.
29 ///
30 /// # Fields
31 /// - `id`: The command ID that was acknowledged
32 Ok { id: i32 },
33
34 /// Error response from the server.
35 ///
36 /// # Fields
37 /// - `code`: Numeric error code
38 /// - `msg`: Human-readable error message
39 Error { code: i32, msg: String },
40
41 /// Subscription confirmation with assigned subscription ID.
42 ///
43 /// # Fields
44 /// - `sid`: Subscription ID assigned by the server
45 /// - `channel`: The channel name that was subscribed to
46 Subscribed { sid: i32, channel: String },
47}
48
49/// Default timeout for waiting on command responses (in seconds).
50const DEFAULT_COMMAND_TIMEOUT_SECS: u64 = 10;
51
52/// WebSocket client for real-time Kalshi market data and trading events.
53///
54/// `KalshiWebSocket` provides a persistent, authenticated connection to the Kalshi
55/// WebSocket API for streaming market data and portfolio updates. The client handles
56/// authentication, subscription management, and message routing automatically.
57///
58/// # Features
59///
60/// - **Automatic authentication** using RSA-PSS signing
61/// - **Subscription management** with support for multiple simultaneous channels
62/// - **Async streaming** interface compatible with Tokio and futures
63/// - **Connection lifecycle** management (connect, disconnect, reconnect)
64/// - **Type-safe messages** via the [`WebSocketMessage`](super::WebSocketMessage) enum
65///
66/// # Creating a Client
67///
68/// The WebSocket client is typically created from an existing [`Kalshi`](crate::Kalshi)
69/// instance using the [`websocket()`](crate::Kalshi::websocket) method, which automatically
70/// transfers the authentication credentials.
71///
72/// ```rust,ignore
73/// use kalshi::{Kalshi, TradingEnvironment};
74///
75/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
76/// let kalshi = Kalshi::new(
77/// TradingEnvironment::DemoMode,
78/// "your-key-id",
79/// "path/to/private.pem"
80/// ).await?;
81///
82/// let mut ws = kalshi.websocket();
83/// # Ok(())
84/// # }
85/// ```
86///
87/// # Connection Flow
88///
89/// 1. **Create** the client (does not connect automatically)
90/// 2. **Connect** with [`connect()`](KalshiWebSocket::connect)
91/// 3. **Subscribe** to channels using subscription methods
92/// 4. **Stream** messages using the [`messages()`](KalshiWebSocket::messages) stream
93/// 5. **Disconnect** with [`disconnect()`](KalshiWebSocket::disconnect) when done
94///
95/// # Example Usage
96///
97/// ```rust,ignore
98/// use kalshi::{Kalshi, TradingEnvironment, WebSocketMessage};
99/// use futures_util::StreamExt;
100///
101/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
102/// let kalshi = Kalshi::new(TradingEnvironment::DemoMode, "key", "key.pem").await?;
103/// let mut ws = kalshi.websocket();
104///
105/// // Connect to WebSocket
106/// ws.connect().await?;
107///
108/// // Subscribe to channels
109/// ws.subscribe_to_ticker("HIGHNY-24JAN15-T50").await?;
110/// ws.subscribe_to_fills().await?;
111///
112/// // Process messages
113/// let mut stream = ws.messages();
114/// while let Some(msg) = stream.next().await {
115/// match msg {
116/// WebSocketMessage::Ticker(ticker) => {
117/// println!("Ticker update: {} @ {}", ticker.ticker, ticker.last_price);
118/// }
119/// WebSocketMessage::Fill(fill) => {
120/// println!("Fill: {} contracts on {}", fill.count, fill.ticker);
121/// }
122/// _ => {}
123/// }
124/// }
125///
126/// // Clean disconnect
127/// ws.disconnect().await?;
128/// # Ok(())
129/// # }
130/// ```
131///
132/// # Thread Safety
133///
134/// The WebSocket client is not `Send` or `Sync` and must be used from a single async task.
135/// The internal writer is wrapped in an `Arc<Mutex<>>` to allow sharing across message
136/// processing, but the overall client should not be shared across threads.
137pub struct KalshiWebSocket {
138 url: String,
139 key_id: String,
140 private_key: PKey<Private>,
141 writer: Option<Arc<Mutex<WsSink>>>,
142 reader: Option<WsReader>,
143 next_id: i32,
144 pub(crate) subscriptions: HashMap<i32, super::Subscription>,
145 /// Pending command response channels, keyed by command ID.
146 pending_commands: HashMap<i32, oneshot::Sender<CommandResponse>>,
147}
148
149impl KalshiWebSocket {
150 /// Creates a new WebSocket client without establishing a connection.
151 ///
152 /// This method initializes the WebSocket client with the necessary credentials
153 /// but does not open a network connection. Call [`connect()`](KalshiWebSocket::connect)
154 /// to establish the connection.
155 ///
156 /// # Arguments
157 ///
158 /// * `trading_env` - The trading environment (DemoMode or ProdMode)
159 /// * `key_id` - Your Kalshi API key ID
160 /// * `private_key` - Your RSA private key for signing authentication requests
161 ///
162 /// # Returns
163 ///
164 /// A new `KalshiWebSocket` instance ready to connect.
165 ///
166 /// # Example
167 ///
168 /// ```rust,ignore
169 /// use kalshi::{TradingEnvironment, KalshiWebSocket};
170 /// use openssl::pkey::PKey;
171 /// use std::fs;
172 ///
173 /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
174 /// let pem = fs::read("path/to/private.pem")?;
175 /// let private_key = PKey::private_key_from_pem(&pem)?;
176 ///
177 /// let ws = KalshiWebSocket::new(
178 /// TradingEnvironment::DemoMode,
179 /// "your-key-id",
180 /// private_key
181 /// );
182 /// # Ok(())
183 /// # }
184 /// ```
185 ///
186 /// # Note
187 ///
188 /// Most users should create the WebSocket client via [`Kalshi::websocket()`](crate::Kalshi::websocket)
189 /// which handles credential transfer automatically.
190 pub fn new(trading_env: TradingEnvironment, key_id: &str, private_key: PKey<Private>) -> Self {
191 let url = match trading_env {
192 TradingEnvironment::DemoMode => "wss://demo-api.kalshi.co/trade-api/ws/v2",
193 TradingEnvironment::ProdMode => "wss://api.elections.kalshi.com/trade-api/ws/v2",
194 };
195
196 Self {
197 url: url.to_string(),
198 key_id: key_id.to_string(),
199 private_key,
200 writer: None,
201 reader: None,
202 next_id: 1,
203 subscriptions: HashMap::new(),
204 pending_commands: HashMap::new(),
205 }
206 }
207
208 /// Connects to the WebSocket server with automatic authentication.
209 ///
210 /// This method establishes a WebSocket connection to the Kalshi exchange and
211 /// performs RSA-PSS authentication using the provided credentials. The connection
212 /// is authenticated at connection time via query parameters.
213 ///
214 /// # Returns
215 ///
216 /// - `Ok(())`: Connection established successfully
217 /// - `Err(KalshiError)`: Connection or authentication failed
218 ///
219 /// # Errors
220 ///
221 /// This method can return errors for:
222 /// - Network connectivity issues
223 /// - Invalid credentials (authentication failure)
224 /// - Server unavailability
225 /// - SSL/TLS errors
226 ///
227 /// # Example
228 ///
229 /// ```rust,ignore
230 /// # use kalshi::KalshiWebSocket;
231 /// # async fn example(mut ws: KalshiWebSocket) -> Result<(), Box<dyn std::error::Error>> {
232 /// ws.connect().await?;
233 /// println!("Connected to WebSocket!");
234 /// # Ok(())
235 /// # }
236 /// ```
237 ///
238 /// # Connection Process
239 ///
240 /// 1. Generates a timestamp and authentication signature
241 /// 2. Constructs the WebSocket URL with authentication parameters
242 /// 3. Establishes the WebSocket connection
243 /// 4. Splits the connection into reader and writer halves for async processing
244 pub async fn connect(&mut self) -> Result<(), KalshiError> {
245 let timestamp = chrono::Utc::now().timestamp_millis();
246 let method = "GET";
247 let path = "/trade-api/ws/v2";
248
249 let message = format!("{}{}{}", timestamp, method, path);
250 let signature = self.sign_message(&message)?;
251
252 // Build URL with properly encoded query parameters
253 let mut url = reqwest::Url::parse(&self.url)
254 .map_err(|e| KalshiError::InternalError(format!("Invalid WebSocket URL: {}", e)))?;
255 url.query_pairs_mut()
256 .append_pair("api-key", &self.key_id)
257 .append_pair("timestamp", ×tamp.to_string())
258 .append_pair("signature", &signature);
259
260 let auth_url = url.to_string();
261
262 let (ws_stream, _response) = connect_async(&auth_url)
263 .await
264 .map_err(|e| KalshiError::InternalError(format!("WebSocket connect failed: {}", e)))?;
265
266 let (write, read) = ws_stream.split();
267 self.writer = Some(Arc::new(Mutex::new(write)));
268 self.reader = Some(read);
269
270 Ok(())
271 }
272
273 /// Disconnects from the WebSocket server gracefully.
274 ///
275 /// This method closes the WebSocket connection, clears all subscriptions,
276 /// and resets the client state. After disconnecting, you can call
277 /// [`connect()`](KalshiWebSocket::connect) again to re-establish the connection.
278 ///
279 /// # Returns
280 ///
281 /// - `Ok(())`: Disconnected successfully
282 /// - `Err(KalshiError)`: Error during disconnection
283 ///
284 /// # Example
285 ///
286 /// ```rust,ignore
287 /// # use kalshi::KalshiWebSocket;
288 /// # async fn example(mut ws: KalshiWebSocket) -> Result<(), Box<dyn std::error::Error>> {
289 /// // Use the connection...
290 /// ws.connect().await?;
291 /// // Do work...
292 ///
293 /// // Clean disconnect when done
294 /// ws.disconnect().await?;
295 /// # Ok(())
296 /// # }
297 /// ```
298 ///
299 /// # Note
300 ///
301 /// All active subscriptions are removed when disconnecting. You will need to
302 /// re-subscribe after reconnecting.
303 pub async fn disconnect(&mut self) -> Result<(), KalshiError> {
304 if let Some(writer) = &self.writer {
305 let mut w = writer.lock().await;
306 w.close()
307 .await
308 .map_err(|e| KalshiError::InternalError(format!("Close failed: {}", e)))?;
309 }
310 self.writer = None;
311 self.reader = None;
312 self.subscriptions.clear();
313 self.pending_commands.clear();
314 Ok(())
315 }
316
317 /// Returns `true` if the WebSocket connection is currently active.
318 ///
319 /// This checks whether the internal writer stream is initialized, which
320 /// indicates an active connection.
321 ///
322 /// # Returns
323 ///
324 /// - `true`: Connected to the WebSocket server
325 /// - `false`: Not connected (either never connected or disconnected)
326 ///
327 /// # Example
328 ///
329 /// ```rust,ignore
330 /// # use kalshi::KalshiWebSocket;
331 /// # async fn example(mut ws: KalshiWebSocket) -> Result<(), Box<dyn std::error::Error>> {
332 /// assert!(!ws.is_connected());
333 ///
334 /// ws.connect().await?;
335 /// assert!(ws.is_connected());
336 ///
337 /// ws.disconnect().await?;
338 /// assert!(!ws.is_connected());
339 /// # Ok(())
340 /// # }
341 /// ```
342 pub fn is_connected(&self) -> bool {
343 self.writer.is_some()
344 }
345
346 fn sign_message(&self, message: &str) -> Result<String, KalshiError> {
347 use openssl::hash::MessageDigest;
348 use openssl::rsa::Padding;
349 use openssl::sign::Signer;
350
351 let mut signer = Signer::new(MessageDigest::sha256(), &self.private_key)?;
352 signer.set_rsa_padding(Padding::PKCS1_PSS)?;
353 signer.set_rsa_pss_saltlen(openssl::sign::RsaPssSaltlen::DIGEST_LENGTH)?;
354 signer.update(message.as_bytes())?;
355 let signature = signer.sign_to_vec()?;
356 Ok(base64::Engine::encode(
357 &base64::engine::general_purpose::STANDARD,
358 &signature,
359 ))
360 }
361
362 pub(crate) fn get_next_id(&mut self) -> i32 {
363 let id = self.next_id;
364 self.next_id += 1;
365 id
366 }
367
368 /// Sends a command to the WebSocket server.
369 pub(crate) async fn send_command(&mut self, cmd: serde_json::Value) -> Result<(), KalshiError> {
370 let writer = self
371 .writer
372 .as_ref()
373 .ok_or_else(|| KalshiError::InternalError("Not connected".to_string()))?;
374
375 let msg = Message::Text(serde_json::to_string(&cmd)?);
376 let mut w = writer.lock().await;
377 w.send(msg)
378 .await
379 .map_err(|e| KalshiError::InternalError(format!("Send failed: {}", e)))?;
380 Ok(())
381 }
382
383 /// Registers a pending command to receive its response.
384 pub(crate) fn register_pending_command(
385 &mut self,
386 id: i32,
387 ) -> oneshot::Receiver<CommandResponse> {
388 let (tx, rx) = oneshot::channel();
389 self.pending_commands.insert(id, tx);
390 rx
391 }
392
393 /// Routes a command response to the appropriate pending command.
394 /// Returns true if the response was routed, false if no pending command was found.
395 pub(crate) fn route_response(&mut self, id: i32, response: CommandResponse) -> bool {
396 if let Some(sender) = self.pending_commands.remove(&id) {
397 // Ignore send error - receiver may have been dropped
398 let _ = sender.send(response);
399 true
400 } else {
401 false
402 }
403 }
404
405 /// Waits for a single command response with timeout.
406 pub(crate) async fn wait_for_response(
407 &mut self,
408 rx: oneshot::Receiver<CommandResponse>,
409 ) -> Result<CommandResponse, KalshiError> {
410 match tokio::time::timeout(Duration::from_secs(DEFAULT_COMMAND_TIMEOUT_SECS), rx).await {
411 Ok(Ok(response)) => Ok(response),
412 Ok(Err(_)) => Err(KalshiError::InternalError(
413 "Response channel closed unexpectedly".to_string(),
414 )),
415 Err(_) => Err(KalshiError::InternalError(
416 "Timeout waiting for command response".to_string(),
417 )),
418 }
419 }
420
421 /// Waits for multiple command responses (e.g., multiple `subscribed` messages).
422 /// Returns responses in the order they are received.
423 pub(crate) async fn wait_for_responses(
424 &mut self,
425 mut receivers: Vec<(i32, oneshot::Receiver<CommandResponse>)>,
426 expected_count: usize,
427 ) -> Result<Vec<CommandResponse>, KalshiError> {
428 let mut responses = Vec::with_capacity(expected_count);
429 let deadline =
430 tokio::time::Instant::now() + Duration::from_secs(DEFAULT_COMMAND_TIMEOUT_SECS);
431
432 while responses.len() < expected_count && !receivers.is_empty() {
433 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
434 if remaining.is_zero() {
435 return Err(KalshiError::InternalError(
436 "Timeout waiting for all command responses".to_string(),
437 ));
438 }
439
440 // Try to read more messages to route responses
441 if let Some(reader) = self.reader.as_mut() {
442 match tokio::time::timeout(Duration::from_millis(100), reader.next()).await {
443 Ok(Some(Ok(Message::Text(text)))) => {
444 if let Ok(msg) = super::WebSocketMessage::parse(&text) {
445 self.handle_control_message(&msg);
446 }
447 }
448 Ok(Some(Ok(_))) => {
449 // Non-text message, ignore
450 }
451 Ok(Some(Err(_))) | Ok(None) => {
452 return Err(KalshiError::InternalError(
453 "WebSocket connection closed".to_string(),
454 ));
455 }
456 Err(_) => {
457 // Timeout on read, continue checking receivers
458 }
459 }
460 }
461
462 // Check which receivers have responses ready
463 let mut i = 0;
464 while i < receivers.len() {
465 match receivers[i].1.try_recv() {
466 Ok(response) => {
467 responses.push(response);
468 receivers.remove(i);
469 }
470 Err(oneshot::error::TryRecvError::Empty) => {
471 i += 1;
472 }
473 Err(oneshot::error::TryRecvError::Closed) => {
474 // Channel closed without response
475 receivers.remove(i);
476 }
477 }
478 }
479 }
480
481 if responses.len() < expected_count {
482 return Err(KalshiError::InternalError(format!(
483 "Expected {} responses, got {}",
484 expected_count,
485 responses.len()
486 )));
487 }
488
489 Ok(responses)
490 }
491
492 /// Handles control messages (subscribed, ok, error) and routes them to pending commands.
493 pub(crate) fn handle_control_message(&mut self, msg: &super::WebSocketMessage) {
494 match msg {
495 super::WebSocketMessage::Subscribed(sub_msg) => {
496 // For subscribed messages, we need to find the pending command by iterating
497 // since the server response doesn't include the original command ID directly.
498 // Instead, we route based on channel matching for the most recently registered command.
499 // Note: This is a simplification. In practice, we track by command ID.
500 let response = CommandResponse::Subscribed {
501 sid: sub_msg.sid,
502 channel: sub_msg.channel.clone(),
503 };
504 // Try to route to any pending command (they should be waiting for subscribed responses)
505 if let Some((&id, _)) = self.pending_commands.iter().next() {
506 self.route_response(id, response);
507 }
508 }
509 super::WebSocketMessage::Ok(ok_msg) => {
510 let response = CommandResponse::Ok { id: ok_msg.sid };
511 self.route_response(ok_msg.sid, response);
512 }
513 super::WebSocketMessage::Error(err_msg) => {
514 let response = CommandResponse::Error {
515 code: err_msg.code,
516 msg: err_msg.msg.clone(),
517 };
518 // Route to the first pending command since errors don't have command IDs
519 if let Some((&id, _)) = self.pending_commands.iter().next() {
520 self.route_response(id, response);
521 }
522 }
523 _ => {
524 // Non-control message, ignore
525 }
526 }
527 }
528}
529
530// Stream interface (Task 4.7)
531use futures_util::Stream;
532use std::pin::Pin;
533use std::task::{Context, Poll};
534
535impl KalshiWebSocket {
536 /// Returns an asynchronous stream of WebSocket messages.
537 ///
538 /// This method provides a [`Stream`](futures_util::Stream) interface for receiving
539 /// messages from the WebSocket connection. The stream yields
540 /// [`WebSocketMessage`](super::WebSocketMessage) items as they arrive.
541 ///
542 /// # Returns
543 ///
544 /// A stream that yields `WebSocketMessage` items. The stream ends when the
545 /// connection is closed.
546 ///
547 /// # Example
548 ///
549 /// ```rust,ignore
550 /// use kalshi::{KalshiWebSocket, WebSocketMessage};
551 /// use futures_util::StreamExt;
552 ///
553 /// # async fn example(mut ws: KalshiWebSocket) -> Result<(), Box<dyn std::error::Error>> {
554 /// ws.connect().await?;
555 /// ws.subscribe_to_ticker("HIGHNY-24JAN15-T50").await?;
556 ///
557 /// let mut stream = ws.messages();
558 /// while let Some(msg) = stream.next().await {
559 /// match msg {
560 /// WebSocketMessage::Ticker(ticker) => {
561 /// println!("Price update: {}", ticker.last_price);
562 /// }
563 /// WebSocketMessage::Heartbeat(_) => {
564 /// println!("Keepalive heartbeat");
565 /// }
566 /// _ => {}
567 /// }
568 /// }
569 /// # Ok(())
570 /// # }
571 /// ```
572 ///
573 /// # Message Types
574 ///
575 /// The stream can yield any of these message types:
576 /// - `OrderbookDelta` - Incremental orderbook updates
577 /// - `OrderbookSnapshot` - Full orderbook snapshots
578 /// - `Ticker` - Best bid/ask and last price updates
579 /// - `Trade` / `Trades` - Trade executions
580 /// - `Fill` - Your order fills (authenticated)
581 /// - `Order` - Your order updates (authenticated)
582 /// - `Heartbeat` - Keepalive messages
583 /// - `Subscribed` / `Ok` / `Error` - Control messages
584 ///
585 /// # Performance
586 ///
587 /// The stream processes messages as they arrive. Control messages (subscribed, ok, error)
588 /// are automatically routed to pending command handlers and also yielded to the stream.
589 pub fn messages(&mut self) -> impl Stream<Item = super::WebSocketMessage> + '_ {
590 MessageStream { ws: self }
591 }
592}
593
594struct MessageStream<'a> {
595 ws: &'a mut KalshiWebSocket,
596}
597
598impl<'a> Stream for MessageStream<'a> {
599 type Item = super::WebSocketMessage;
600
601 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
602 let reader = match self.ws.reader.as_mut() {
603 Some(r) => r,
604 None => return Poll::Ready(None),
605 };
606
607 match Pin::new(reader).poll_next(cx) {
608 Poll::Ready(Some(Ok(Message::Text(text)))) => {
609 match super::WebSocketMessage::parse(&text) {
610 Ok(msg) => {
611 // Route control messages to pending commands
612 self.ws.handle_control_message(&msg);
613 Poll::Ready(Some(msg))
614 }
615 Err(_) => {
616 cx.waker().wake_by_ref();
617 Poll::Pending
618 }
619 }
620 }
621 Poll::Ready(Some(Ok(Message::Ping(_)))) => {
622 cx.waker().wake_by_ref();
623 Poll::Pending
624 }
625 Poll::Ready(Some(Ok(_))) => {
626 cx.waker().wake_by_ref();
627 Poll::Pending
628 }
629 Poll::Ready(Some(Err(_))) => Poll::Ready(None),
630 Poll::Ready(None) => Poll::Ready(None),
631 Poll::Pending => Poll::Pending,
632 }
633 }
634}