Skip to main content

bezant/
ws.rs

1//! WebSocket streaming client for the IBKR Client Portal Web API.
2//!
3//! The CPAPI exposes a WebSocket endpoint at `/v1/api/ws` that multiplexes
4//! real-time market data, order updates, PnL snapshots, and more. Bezant
5//! wraps the raw socket with:
6//!
7//! - A single-call [`WsClient::connect`] that derives the WS URL from a
8//!   REST [`crate::Client`], reuses its session cookie, and returns a
9//!   duplex handle.
10//! - Typed subscribe/unsubscribe helpers for the common topics
11//!   ([`WsClient::subscribe_market_data`],
12//!   [`WsClient::subscribe_orders`], [`WsClient::subscribe_pnl`]).
13//! - A [`WsClient::raw_stream`] escape hatch returning every decoded JSON
14//!   message so you can handle message types we haven't modelled yet.
15//!
16//! # Topic format
17//!
18//! CPAPI's wire format is `TOPIC+{json}`. The first letter selects the
19//! action: `s` subscribe, `u` unsubscribe. Examples:
20//!
21//! ```text
22//! smd+265598+{"fields":["31","84","86"]}   // subscribe to AAPL L1 quote
23//! umd+265598+{}                             // unsubscribe
24//! sor+{}                                    // subscribe to order updates
25//! spl+{}                                    // subscribe to PnL updates
26//! ```
27//!
28//! See [IBKR's WebSocket lesson][ibkr-ws] for the full catalogue.
29//!
30//! [ibkr-ws]: https://www.interactivebrokers.com/campus/trading-lessons/websockets/
31
32use std::time::Duration;
33
34use futures_util::{SinkExt, Stream, StreamExt};
35use serde::Serialize;
36use serde_json::Value;
37use tokio::net::TcpStream;
38use tokio_tungstenite::tungstenite::{client::IntoClientRequest, Message};
39use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
40use tracing::{debug, warn};
41use url::Url;
42
43use crate::client::Client;
44use crate::error::{Error, Result};
45
46/// Raw WebSocket stream type the client multiplexes.
47pub type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
48
49/// A live Bezant WebSocket connection. Clone cheaply to share across tasks —
50/// actually no, WebSocket sinks aren't cheap to split arbitrarily. Keep one
51/// owner per connection and [`WsClient::split`] if you need a read/write
52/// halving.
53#[derive(Debug)]
54pub struct WsClient {
55    stream: WsStream,
56}
57
58/// Concrete name for the sink half of [`WsClient::split`] — what
59/// `futures_util` `SplitSink` resolves to over a TLS-wrapped
60/// tungstenite stream. Re-exposed so callers can store it in a
61/// struct field without naming an `impl Trait`-only type.
62pub type WsSink = futures_util::stream::SplitSink<WsStream, Message>;
63
64/// Concrete name for the stream half of [`WsClient::split`].
65pub type WsRecv = futures_util::stream::SplitStream<WsStream>;
66
67/// A decoded CPAPI frame. Most messages fall into one of the variants below,
68/// but the CPAPI occasionally emits payloads we haven't modelled — those end
69/// up in [`WsMessage::Other`].
70#[derive(Debug, Clone)]
71#[non_exhaustive]
72pub enum WsMessage {
73    /// Heartbeat pings sent by the server periodically.
74    Heartbeat,
75    /// System / session status messages (e.g. `"topic": "system"`).
76    System(Value),
77    /// Market data tick for a subscribed contract.
78    MarketData {
79        /// The contract id this tick is for.
80        conid: i64,
81        /// The full decoded payload (field codes are string-keyed).
82        payload: Value,
83    },
84    /// Order update (working → filled, cancellations, etc).
85    Order(Value),
86    /// PnL / account summary update.
87    Pnl(Value),
88    /// Any message whose `topic` we didn't recognise.
89    Other(Value),
90    /// The socket emitted a frame we couldn't decode — a text body that
91    /// wasn't valid JSON, or a binary frame converted lossily to UTF-8.
92    /// The decoder's error is captured alongside the original text so
93    /// callers can telemeter parse rates.
94    Malformed {
95        /// The raw (possibly lossy) text the socket delivered.
96        text: String,
97        /// Human-readable reason the decoder gave up — serde JSON error
98        /// for malformed payloads, a "non-UTF-8 binary frame" marker for
99        /// binary decoder losses.
100        error: String,
101    },
102}
103
104impl WsMessage {
105    /// Return a static label for the message variant. Useful for
106    /// `tracing::Span::record("topic", ...)` and metrics labels.
107    #[must_use]
108    pub fn topic(&self) -> &'static str {
109        match self {
110            Self::Heartbeat => "heartbeat",
111            Self::System(_) => "system",
112            Self::MarketData { .. } => "market_data",
113            Self::Order(_) => "order",
114            Self::Pnl(_) => "pnl",
115            Self::Other(_) => "other",
116            Self::Malformed { .. } => "malformed",
117        }
118    }
119
120    /// Borrow the inner [`Value`] for variants that carry one. `None`
121    /// for [`WsMessage::Heartbeat`] and [`WsMessage::Malformed`] (which
122    /// has no parsed value to lend).
123    #[must_use]
124    pub fn as_value(&self) -> Option<&Value> {
125        match self {
126            Self::System(v) | Self::Order(v) | Self::Pnl(v) | Self::Other(v) => Some(v),
127            Self::MarketData { payload, .. } => Some(payload),
128            Self::Heartbeat | Self::Malformed { .. } => None,
129        }
130    }
131}
132
133/// Handle to a single live WebSocket subscription. Returned by the
134/// `WsClient::subscribe_*` calls so callers can cancel an individual
135/// feed without remembering the (topic, conid) pair themselves.
136///
137/// `cancel` consumes the handle to prevent double-cancel. The handle
138/// is `Send + Sync + Clone` so it can be stashed in a registry struct
139/// shared between tasks; `Clone` exists to support that pattern but
140/// double-cancel is harmless beyond a redundant frame on the wire.
141#[derive(Debug, Clone)]
142pub struct Subscription {
143    /// The unsubscribe payload to send to cancel this feed.
144    cancel_payload: String,
145    /// Human-readable label for telemetry — `market_data:265598`,
146    /// `orders`, `pnl`.
147    pub name: String,
148}
149
150impl Subscription {
151    /// Cancel this subscription by sending the matching `umd`/`uor`/`upl`
152    /// frame on `ws`. Consumes the handle. Errors propagate as
153    /// [`Error::WsTransport`] — but if the socket is already closed
154    /// the upstream cancellation is implicit, callers can usually
155    /// ignore the error.
156    pub async fn cancel(self, ws: &mut WsClient) -> Result<()> {
157        ws.send_text(self.cancel_payload).await
158    }
159
160    /// Get the cancel payload bytes — exposed for callers that want
161    /// to send the cancellation through a different sink (e.g. the
162    /// returned half of [`WsClient::split`]) rather than the original
163    /// `WsClient`.
164    #[must_use]
165    pub fn cancel_payload(&self) -> &str {
166        &self.cancel_payload
167    }
168}
169
170/// Market-data field codes used when subscribing. See
171/// [`bezant_api::GetMdSnapshotRequestQuery`] for the documented set on the
172/// REST side — every code listed there works on the WebSocket too.
173///
174/// Kept as an opaque newtype so we can change the internal representation
175/// in a point release without breaking downstream callers.
176#[derive(Debug, Clone)]
177pub struct MarketDataFields(Vec<String>);
178
179impl MarketDataFields {
180    /// Reasonable default: last price, bid, ask, last size, bid size, ask size.
181    #[must_use]
182    pub fn default_l1() -> Self {
183        Self::from_codes(["31", "84", "86", "85", "88", "87"])
184    }
185
186    /// Build a new [`MarketDataFields`] from any iterator of code strings.
187    pub fn from_codes<I, S>(codes: I) -> Self
188    where
189        I: IntoIterator<Item = S>,
190        S: Into<String>,
191    {
192        Self(codes.into_iter().map(Into::into).collect())
193    }
194
195    /// Borrow the underlying field codes — handy when forwarding the same
196    /// set to multiple subscribes, or serialising for logging.
197    #[must_use]
198    pub fn as_slice(&self) -> &[String] {
199        &self.0
200    }
201}
202
203impl<S> FromIterator<S> for MarketDataFields
204where
205    S: Into<String>,
206{
207    fn from_iter<I: IntoIterator<Item = S>>(iter: I) -> Self {
208        Self::from_codes(iter)
209    }
210}
211
212impl WsClient {
213    /// Open a WebSocket connection to the Gateway that `client` is pointed at.
214    ///
215    /// Internally:
216    /// 1. Issues a `/tickle` HTTP call to mint a session cookie.
217    /// 2. Derives the `wss://…/ws` URL from the REST base URL.
218    /// 3. Attaches `Cookie: api={"session":"…"}` to the WS handshake.
219    /// 4. Returns a connected [`WsClient`].
220    ///
221    /// # Errors
222    /// Any tickle / handshake / TLS failure surfaces as [`Error`].
223    #[tracing::instrument(skip(client), level = "debug")]
224    pub async fn connect(client: &Client) -> Result<Self> {
225        let tickle = client.tickle().await?;
226        let session = tickle.session.ok_or(Error::NoSession)?;
227        let ws_url = ws_url_from_base(client.base_url())?;
228        let cookie = format!(r#"api={{"session":"{session}"}}"#);
229
230        debug!(%ws_url, "bezant: opening websocket");
231        let mut request = ws_url.as_str().into_client_request().map_err(|source| {
232            Error::WsHandshake {
233                url: ws_url.to_string(),
234                source,
235            }
236        })?;
237        request.headers_mut().insert(
238            "Cookie",
239            cookie.parse().map_err(|source| Error::Header {
240                name: "cookie",
241                source,
242            })?,
243        );
244        request.headers_mut().insert(
245            "User-Agent",
246            format!("bezant/{}", env!("CARGO_PKG_VERSION"))
247                .parse()
248                .map_err(|source| Error::Header {
249                    name: "user-agent",
250                    source,
251                })?,
252        );
253
254        let (stream, _) =
255            tokio_tungstenite::connect_async(request)
256                .await
257                .map_err(|source| Error::WsHandshake {
258                    url: ws_url.to_string(),
259                    source,
260                })?;
261
262        Ok(Self { stream })
263    }
264
265    /// Subscribe to level-1 market data for a single contract id. Use
266    /// [`MarketDataFields::default_l1`] if you just want the common fields.
267    ///
268    /// Returns a [`Subscription`] handle — call [`Subscription::cancel`]
269    /// when you're done with the feed instead of tracking the conid
270    /// yourself.
271    ///
272    /// # Errors
273    /// Any send failure surfaces as [`Error::WsTransport`] /
274    /// [`Error::WsProtocol`].
275    #[tracing::instrument(skip(self, fields), fields(conid = conid), level = "debug")]
276    pub async fn subscribe_market_data(
277        &mut self,
278        conid: i64,
279        fields: &MarketDataFields,
280    ) -> Result<Subscription> {
281        #[derive(Serialize)]
282        struct Body<'a> {
283            fields: &'a [String],
284        }
285        let body = Body {
286            fields: fields.as_slice(),
287        };
288        let payload = format!(
289            "smd+{conid}+{}",
290            serde_json::to_string(&body)
291                .map_err(|e| Error::WsProtocol(format!("serialise fields: {e}")))?
292        );
293        self.send_text(payload).await?;
294        Ok(Subscription {
295            cancel_payload: format!("umd+{conid}+{{}}"),
296            name: format!("market_data:{conid}"),
297        })
298    }
299
300    /// Unsubscribe from a previously-subscribed market data feed by
301    /// raw conid. Prefer [`Subscription::cancel`] on the handle returned
302    /// by [`Self::subscribe_market_data`] — this raw form remains for
303    /// callers that already track conids themselves.
304    ///
305    /// # Errors
306    /// Any send failure surfaces as [`Error::WsTransport`].
307    pub async fn unsubscribe_market_data(&mut self, conid: i64) -> Result<()> {
308        self.send_text(format!("umd+{conid}+{{}}")).await
309    }
310
311    /// Subscribe to order status updates. Returns a [`Subscription`]
312    /// you can cancel later.
313    ///
314    /// # Errors
315    /// Any send failure surfaces as [`Error::WsTransport`].
316    pub async fn subscribe_orders(&mut self) -> Result<Subscription> {
317        self.send_text("sor+{}".into()).await?;
318        Ok(Subscription {
319            cancel_payload: "uor+{}".into(),
320            name: "orders".into(),
321        })
322    }
323
324    /// Subscribe to PnL updates. Returns a [`Subscription`] you can
325    /// cancel later.
326    ///
327    /// # Errors
328    /// Any send failure surfaces as [`Error::WsTransport`].
329    pub async fn subscribe_pnl(&mut self) -> Result<Subscription> {
330        self.send_text("spl+{}".into()).await?;
331        Ok(Subscription {
332            cancel_payload: "upl+{}".into(),
333            name: "pnl".into(),
334        })
335    }
336
337    /// Send a raw text frame. Useful for subscribing to topics Bezant doesn't
338    /// yet model — follow the `topic+{json}` format.
339    ///
340    /// # Errors
341    /// Any send failure surfaces as [`Error::other`].
342    pub async fn send_text(&mut self, payload: String) -> Result<()> {
343        self.stream
344            .send(Message::text(payload))
345            .await
346            .map_err(|source| Error::WsTransport { source })
347    }
348
349    /// Pull the next decoded message. `None` means the socket closed.
350    ///
351    /// # Errors
352    /// Any read failure surfaces as [`Error::other`].
353    pub async fn next_message(&mut self) -> Result<Option<WsMessage>> {
354        while let Some(raw) = self.stream.next().await {
355            let frame = raw.map_err(|source| Error::WsTransport { source })?;
356            match frame {
357                Message::Text(text) => return Ok(Some(classify(text.as_str()))),
358                Message::Binary(bytes) => {
359                    // CPAPI occasionally sends binary frames for heartbeats.
360                    // Convert lossily — invalid UTF-8 becomes U+FFFD, which
361                    // will either parse (empty `{}` survives) or be reported
362                    // as [`WsMessage::Malformed`] with the JSON error. We
363                    // deliberately don't surface a separate "BinaryLost"
364                    // variant: the underlying socket is documented as text
365                    // JSON and any non-UTF-8 payload is upstream weirdness.
366                    let s = String::from_utf8_lossy(&bytes).to_string();
367                    return Ok(Some(classify(&s)));
368                }
369                Message::Ping(data) => {
370                    // Be a well-behaved client: echo the ping.
371                    if let Err(e) = self.stream.send(Message::Pong(data)).await {
372                        warn!(error = %e, "bezant: pong send failed");
373                    }
374                }
375                Message::Pong(_) => {}
376                Message::Frame(_) => {}
377                Message::Close(_) => return Ok(None),
378            }
379        }
380        Ok(None)
381    }
382
383    /// Return a [`Stream`] of [`WsMessage`]s that yields until the socket
384    /// closes. Consuming this yields exclusive access to the reader; use
385    /// [`WsClient::next_message`] on the client itself if you also need to
386    /// send frames on the same task.
387    pub fn raw_stream(self) -> impl Stream<Item = Result<WsMessage>> + Unpin {
388        Box::pin(futures_util::stream::unfold(
389            self.stream,
390            |mut s| async move {
391                loop {
392                    match s.next().await {
393                        None => return None,
394                        Some(Err(source)) => {
395                            return Some((Err(Error::WsTransport { source }), s))
396                        }
397                        Some(Ok(Message::Text(t))) => {
398                            return Some((Ok(classify(t.as_str())), s));
399                        }
400                        Some(Ok(Message::Binary(b))) => {
401                            let text = String::from_utf8_lossy(&b).to_string();
402                            return Some((Ok(classify(&text)), s));
403                        }
404                        Some(Ok(Message::Ping(p))) => {
405                            let _ = s.send(Message::Pong(p)).await;
406                        }
407                        Some(Ok(Message::Close(_))) => return None,
408                        Some(Ok(_)) => {}
409                    }
410                }
411            },
412        ))
413    }
414
415    /// Split the client into independent sink + stream halves so one task can
416    /// send and another can receive concurrently.
417    ///
418    /// Returns concrete `SplitSink`/`SplitStream` types from
419    /// `futures_util` so callers can name them in struct fields
420    /// without resorting to `Box<dyn …>` or `impl Trait`-only
421    /// associated types.
422    pub fn split(self) -> (WsSink, WsRecv) {
423        let (sink, stream) = self.stream.split();
424        (sink, stream)
425    }
426
427    /// How long to wait between application-level pings if you implement a
428    /// ticker task on top. Chosen to match CPAPI's 5-minute session timeout
429    /// with a safety margin.
430    #[must_use]
431    pub const fn recommended_keepalive() -> Duration {
432        Duration::from_secs(60)
433    }
434}
435
436/// Derive the WebSocket URL from a REST base URL.
437///
438/// `https://host:port/v1/api`       →  `wss://host:port/v1/api/ws`
439/// `http://host:port/v1/api`        →  `ws://host:port/v1/api/ws`
440fn ws_url_from_base(base: &Url) -> Result<Url> {
441    let mut ws = base.clone();
442    match ws.scheme() {
443        "https" => ws.set_scheme("wss").map_err(|()| Error::WsProtocol(
444            "failed to upgrade base URL scheme to wss".into(),
445        ))?,
446        "http" => ws
447            .set_scheme("ws")
448            .map_err(|()| Error::WsProtocol("failed to upgrade base URL scheme to ws".into()))?,
449        s => {
450            return Err(Error::BadRequest(format!(
451                "unsupported WebSocket base scheme '{s}' (expected http/https)"
452            )))
453        }
454    }
455    {
456        let mut segs = ws.path_segments_mut().map_err(|()| Error::UrlNotABase {
457            url: base.to_string(),
458        })?;
459        segs.push("ws");
460    }
461    Ok(ws)
462}
463
464/// Decode a text frame into a typed [`WsMessage`].
465fn classify(text: &str) -> WsMessage {
466    if text == "{}" || text.is_empty() {
467        return WsMessage::Heartbeat;
468    }
469    let value: Value = match serde_json::from_str(text) {
470        Ok(v) => v,
471        Err(e) => {
472            return WsMessage::Malformed {
473                text: text.to_owned(),
474                error: e.to_string(),
475            }
476        }
477    };
478
479    let topic = value
480        .get("topic")
481        .and_then(Value::as_str)
482        .unwrap_or_default();
483
484    // Market data topics are `smd+<conid>` (subscribe ack / tick).
485    if let Some(rest) = topic.strip_prefix("smd+") {
486        if let Ok(conid) = rest.parse::<i64>() {
487            return WsMessage::MarketData {
488                conid,
489                payload: value,
490            };
491        }
492    }
493
494    match topic {
495        "system" => WsMessage::System(value),
496        "sor" | "ortd" | "ord" => WsMessage::Order(value),
497        "spl" | "pnl" | "ssd" | "ssl" => WsMessage::Pnl(value),
498        _ => WsMessage::Other(value),
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn ws_url_flips_https_to_wss_and_appends_ws() {
508        let base = Url::parse("https://localhost:5000/v1/api").unwrap();
509        let ws = ws_url_from_base(&base).unwrap();
510        assert_eq!(ws.as_str(), "wss://localhost:5000/v1/api/ws");
511    }
512
513    #[test]
514    fn ws_url_flips_http_to_ws() {
515        let base = Url::parse("http://localhost:8080/v1/api").unwrap();
516        let ws = ws_url_from_base(&base).unwrap();
517        assert_eq!(ws.as_str(), "ws://localhost:8080/v1/api/ws");
518    }
519
520    #[test]
521    fn classify_identifies_market_data_by_topic() {
522        let raw = r#"{"topic":"smd+265598","31":"150.25","_updated":1700000000}"#;
523        match classify(raw) {
524            WsMessage::MarketData { conid, .. } => assert_eq!(conid, 265_598),
525            other => panic!("expected MarketData, got {other:?}"),
526        }
527    }
528
529    #[test]
530    fn classify_empty_brace_is_heartbeat() {
531        assert!(matches!(classify("{}"), WsMessage::Heartbeat));
532    }
533
534    #[test]
535    fn classify_system_topic() {
536        let raw = r#"{"topic":"system","msg":"ready"}"#;
537        assert!(matches!(classify(raw), WsMessage::System(_)));
538    }
539
540    #[test]
541    fn classify_malformed_text() {
542        assert!(matches!(classify("not-json"), WsMessage::Malformed { .. }));
543    }
544
545    #[test]
546    fn ws_message_topic_is_static_label() {
547        assert_eq!(WsMessage::Heartbeat.topic(), "heartbeat");
548        assert_eq!(
549            WsMessage::MarketData {
550                conid: 1,
551                payload: serde_json::json!({})
552            }
553            .topic(),
554            "market_data"
555        );
556        assert_eq!(WsMessage::Order(serde_json::json!({})).topic(), "order");
557        assert_eq!(WsMessage::Pnl(serde_json::json!({})).topic(), "pnl");
558        assert_eq!(WsMessage::System(serde_json::json!({})).topic(), "system");
559        assert_eq!(WsMessage::Other(serde_json::json!({})).topic(), "other");
560        assert_eq!(
561            WsMessage::Malformed {
562                text: "x".into(),
563                error: "y".into()
564            }
565            .topic(),
566            "malformed"
567        );
568    }
569
570    #[test]
571    fn ws_message_as_value_returns_payload_for_data_variants() {
572        let v = serde_json::json!({"hello": "world"});
573        assert_eq!(WsMessage::Order(v.clone()).as_value(), Some(&v));
574        assert_eq!(
575            WsMessage::MarketData {
576                conid: 1,
577                payload: v.clone()
578            }
579            .as_value(),
580            Some(&v)
581        );
582        assert_eq!(WsMessage::Heartbeat.as_value(), None);
583        assert_eq!(
584            WsMessage::Malformed {
585                text: "x".into(),
586                error: "y".into()
587            }
588            .as_value(),
589            None
590        );
591    }
592
593    #[test]
594    fn subscription_cancel_payload_round_trips_topic() {
595        // Construct a Subscription synthetically (the public API
596        // requires a live WsClient, but the cancel_payload field is
597        // pub(crate) and the accessor is public).
598        let sub = Subscription {
599            cancel_payload: "umd+265598+{}".into(),
600            name: "market_data:265598".into(),
601        };
602        assert_eq!(sub.cancel_payload(), "umd+265598+{}");
603        assert_eq!(sub.name, "market_data:265598");
604    }
605}