Skip to main content

bezant/
auth.rs

1//! Session keepalive + health helpers layered over the generated client.
2
3use std::time::Duration;
4
5use tokio::sync::oneshot;
6
7use crate::client::Client;
8use crate::error::{Error, Result};
9
10/// Simplified view of the Gateway's brokerage session status, projected from
11/// the generated [`bezant_api::BrokerageSessionStatus`] type.
12///
13/// `#[non_exhaustive]` so adding a field in a point release isn't a SemVer
14/// break — match with `AuthStatus { authenticated, connected, .. }` rather
15/// than positional destructuring.
16#[derive(Debug, Clone)]
17#[non_exhaustive]
18pub struct AuthStatus {
19    /// Whether the Gateway is authenticated.
20    pub authenticated: bool,
21    /// Whether the Gateway is connected to IBKR servers.
22    pub connected: bool,
23    /// Whether this Gateway session is the "competing" primary.
24    pub competing: bool,
25    /// Server-reported status message.
26    pub message: Option<String>,
27}
28
29impl From<bezant_api::BrokerageSessionStatus> for AuthStatus {
30    fn from(s: bezant_api::BrokerageSessionStatus) -> Self {
31        Self {
32            authenticated: s.authenticated.unwrap_or(false),
33            connected: s.connected.unwrap_or(false),
34            competing: s.competing.unwrap_or(false),
35            message: s.message,
36        }
37    }
38}
39
40/// Response from a tickle call, projected from the generated type.
41///
42/// `#[non_exhaustive]` so adding a field in a point release isn't a SemVer
43/// break.
44#[derive(Debug, Clone)]
45#[non_exhaustive]
46pub struct TickleResponse {
47    /// Session identifier the Gateway returned, if any.
48    pub session: Option<String>,
49    /// Raw response — exposed for callers that need the full tickle payload.
50    pub raw: bezant_api::TickleResponse,
51}
52
53impl Client {
54    /// Query current auth + connection status via
55    /// `POST /iserver/auth/status`.
56    ///
57    /// # Errors
58    /// Transport + decode errors; [`Error::Api`] for any underlying error.
59    #[tracing::instrument(skip(self), level = "debug")]
60    pub async fn auth_status(&self) -> Result<AuthStatus> {
61        // We drive a raw POST instead of the generated client so we
62        // can distinguish "no session" (Gateway redirects to login)
63        // from genuine protocol errors. CPAPI returns 302 for
64        // unauthenticated callers — not the 401 the OpenAPI spec
65        // implies — so the auto-generated `Unauthorized` branch never
66        // fires in practice.
67        let mut url = self.base_url().clone();
68        url.path_segments_mut()
69            .map_err(|()| Error::UrlNotABase {
70                url: self.base_url().to_string(),
71            })?
72            .push("iserver")
73            .push("auth")
74            .push("status");
75        url.set_query(None);
76        // Akamai (in front of CPAPI) refuses POSTs that reach it
77        // without a Content-Length header — even an empty `Vec<u8>`
78        // body isn't enough if reqwest decides to serialize it with
79        // Transfer-Encoding: chunked. Setting the header explicitly
80        // forces a `Content-Length: 0` wire representation.
81        //
82        // We also pin Origin/Referer to match the Gateway's own
83        // origin: CPGateway's CPAPI handler treats requests with
84        // missing/mismatched Origin as a CSRF attempt and 401s. Local
85        // browser flows succeed because the browser supplies an
86        // origin that matches the Gateway's expectations; from a
87        // typed server-side caller we need to mint one ourselves.
88        let gateway_origin = self
89            .gateway_root_url()
90            .as_str()
91            .trim_end_matches('/')
92            .to_owned();
93        let resp = self
94            .http()
95            .post(url.clone())
96            .header(reqwest::header::CONTENT_LENGTH, "0")
97            .header(reqwest::header::ORIGIN, &gateway_origin)
98            .header(reqwest::header::REFERER, format!("{gateway_origin}/"))
99            .send()
100            .await
101            .map_err(Error::Http)?;
102        let status = resp.status();
103        if status == reqwest::StatusCode::UNAUTHORIZED {
104            return Err(Error::NotAuthenticated);
105        }
106        if status.is_redirection() {
107            // Tighten: only treat redirects that actually aim at the SSO
108            // login flow as "not authenticated". Any other 3xx is a
109            // genuine protocol surprise we'd rather surface verbatim.
110            let location = resp
111                .headers()
112                .get(reqwest::header::LOCATION)
113                .and_then(|v| v.to_str().ok())
114                .unwrap_or_default()
115                .to_ascii_lowercase();
116            if location.contains("/sso/login") || location.contains("/sso/dispatcher") {
117                return Err(Error::NotAuthenticated);
118            }
119            return Err(Error::UpstreamStatus {
120                endpoint: "iserver/auth/status",
121                status: status.as_u16(),
122                body_preview: Some(format!("redirect to: {location}")),
123            });
124        }
125        if !status.is_success() {
126            return Err(Error::UpstreamStatus {
127                endpoint: "iserver/auth/status",
128                status: status.as_u16(),
129                body_preview: None,
130            });
131        }
132        let parsed: bezant_api::BrokerageSessionStatus = resp.json().await.map_err(|e| {
133            Error::Decode {
134                endpoint: format!("POST {}/iserver/auth/status", self.base_url()),
135                status: status.as_u16(),
136                message: e.to_string(),
137            }
138        })?;
139        Ok(AuthStatus::from(parsed))
140    }
141
142    /// Tickle the Gateway (POST `/tickle`) to keep the session alive.
143    ///
144    /// CPAPI sessions expire after ~5 minutes of inactivity; call this at
145    /// least once a minute from a background task, or use
146    /// [`Client::spawn_keepalive`].
147    ///
148    /// # Errors
149    /// Transport + decode errors.
150    #[tracing::instrument(skip(self), level = "debug")]
151    pub async fn tickle(&self) -> Result<TickleResponse> {
152        let resp = self
153            .api()
154            .get_session_token(bezant_api::GetSessionTokenRequest::default())
155            .await?;
156        match resp {
157            bezant_api::GetSessionTokenResponse::Ok(payload) => {
158                let session = match &payload {
159                    bezant_api::TickleResponse::Successful(s) => s.session.clone(),
160                    bezant_api::TickleResponse::Failed(_) => None,
161                };
162                Ok(TickleResponse {
163                    session,
164                    raw: payload,
165                })
166            }
167            bezant_api::GetSessionTokenResponse::Unauthorized => Err(Error::NotAuthenticated),
168            bezant_api::GetSessionTokenResponse::Unknown => Err(Error::Unknown {
169                endpoint: "iserver/auth/tickle",
170            }),
171        }
172    }
173
174    /// Convenience: return the status if authenticated, else a typed error
175    /// ([`Error::NotAuthenticated`] / [`Error::NoSession`]).
176    ///
177    /// # Errors
178    /// See variants.
179    #[tracing::instrument(skip(self), level = "debug")]
180    pub async fn health(&self) -> Result<AuthStatus> {
181        let status = self.auth_status().await?;
182        if !status.connected {
183            return Err(Error::NoSession);
184        }
185        if !status.authenticated {
186            return Err(Error::NotAuthenticated);
187        }
188        Ok(status)
189    }
190
191    /// Spawn a tokio task that calls [`Client::tickle`] on `interval` until
192    /// the returned [`KeepaliveHandle`] is dropped. The CPAPI session times
193    /// out around 5 minutes; 60s is a sensible default.
194    ///
195    /// Tickle failures are logged via `tracing::warn!` and don't abort the
196    /// task — a transient outage is recoverable once the Gateway is back.
197    #[must_use]
198    pub fn spawn_keepalive(&self, interval: Duration) -> KeepaliveHandle {
199        use tracing::Instrument;
200
201        let client = self.clone();
202        let (tx, mut rx) = oneshot::channel::<()>();
203        let span = tracing::info_span!("bezant_keepalive", interval_secs = interval.as_secs());
204        let join = tokio::spawn(
205            async move {
206                let mut ticker = tokio::time::interval(interval);
207                // First tick fires immediately; skip it so we don't hit the
208                // Gateway the microsecond after the caller created the client.
209                ticker.tick().await;
210                loop {
211                    tokio::select! {
212                        _ = &mut rx => break,
213                        _ = ticker.tick() => {
214                            if let Err(e) = client.tickle().await {
215                                tracing::warn!(error = %e, "bezant keepalive tickle failed");
216                            }
217                        }
218                    }
219                }
220            }
221            .instrument(span),
222        );
223        KeepaliveHandle {
224            shutdown: Some(tx),
225            join: Some(join),
226        }
227    }
228}
229
230/// Drop-to-stop handle for a background keepalive task.
231///
232/// Dropping the handle sends the shutdown signal via the [`Drop`] impl,
233/// so a forgotten `_handle` won't keep tickling forever. The spawned
234/// task observes the signal on its next tick — a pending tickle
235/// request can still be in flight when the runtime shuts down.
236/// Prefer [`KeepaliveHandle::stop`] when you need a clean, awaited
237/// exit (e.g. on SIGTERM in a long-running binary).
238#[derive(Debug)]
239pub struct KeepaliveHandle {
240    shutdown: Option<oneshot::Sender<()>>,
241    join: Option<tokio::task::JoinHandle<()>>,
242}
243
244impl KeepaliveHandle {
245    /// Stop the keepalive task and await its exit.
246    ///
247    /// # Errors
248    /// Returns [`Error::Other`] wrapping the JoinError if the task panicked.
249    pub async fn stop(mut self) -> Result<()> {
250        // Explicitly send the shutdown signal rather than relying on
251        // drop semantics — makes intent obvious and tolerates the
252        // receiver having already dropped (which is fine, nothing to
253        // do there).
254        if let Some(tx) = self.shutdown.take() {
255            let _ = tx.send(());
256        }
257        if let Some(join) = self.join.take() {
258            join.await
259                .map_err(|e| Error::Other(format!("keepalive task panicked: {e}")))?;
260        }
261        Ok(())
262    }
263}
264
265impl Drop for KeepaliveHandle {
266    /// Send the shutdown signal so a forgotten handle doesn't keep
267    /// tickling forever. The spawned task observes the signal on its
268    /// next tick — for a clean awaited exit use [`KeepaliveHandle::stop`].
269    /// We can't `.await` here (sync Drop), so any in-flight tickle
270    /// finishes detached.
271    fn drop(&mut self) {
272        if let Some(tx) = self.shutdown.take() {
273            let _ = tx.send(());
274        }
275        // Don't abort the join handle — we want any in-flight tickle
276        // to finish naturally rather than getting cut mid-write.
277    }
278}
279
280#[cfg(test)]
281mod keepalive_tests {
282    use super::*;
283    use std::time::Duration;
284
285    /// Build a `Client` against an unreachable URL — never actually
286    /// hits the network because we cancel the keepalive before its
287    /// first tick fires.
288    fn dummy_client() -> Client {
289        Client::builder("http://127.0.0.1:1/v1/api")
290            .build()
291            .expect("client")
292    }
293
294    #[tokio::test]
295    async fn stop_sends_shutdown_and_joins_cleanly() {
296        let client = dummy_client();
297        let handle = client.spawn_keepalive(Duration::from_secs(60));
298        // Give the spawned task a beat to enter its select loop.
299        tokio::time::sleep(Duration::from_millis(50)).await;
300        handle.stop().await.expect("clean stop");
301    }
302
303    #[tokio::test]
304    async fn drop_sends_shutdown_signal() {
305        let client = dummy_client();
306        // Capture the JoinHandle by stealing it out of the
307        // KeepaliveHandle's internals via the public API: we drop the
308        // handle and then verify the task observes the signal by
309        // checking that a fresh handle can be spawned without panicking
310        // (i.e. the runtime hasn't been clogged by a leaked task).
311        {
312            let _handle = client.spawn_keepalive(Duration::from_secs(60));
313            tokio::time::sleep(Duration::from_millis(50)).await;
314            // _handle drops here — Drop sends the shutdown signal.
315        }
316        // Drop happened. Spawn a fresh keepalive — if the previous
317        // task hadn't observed shutdown yet (it observes on next
318        // tick, which is 60s away), this still spawns fine. The real
319        // assertion is that no panic / deadlock occurs.
320        let _h2 = client.spawn_keepalive(Duration::from_secs(60));
321        // Give the runtime a moment; if Drop deadlocked we'd hang here.
322        tokio::time::sleep(Duration::from_millis(50)).await;
323    }
324}