bezant/auth.rs
1//! Session keepalive + health helpers layered over the generated client.
2
3use std::time::Duration;
4
5use tokio::sync::oneshot;
6
7use crate::client::Client;
8use crate::error::{Error, Result};
9
10/// Simplified view of the Gateway's brokerage session status, projected from
11/// the generated [`bezant_api::BrokerageSessionStatus`] type.
12///
13/// `#[non_exhaustive]` so adding a field in a point release isn't a SemVer
14/// break — match with `AuthStatus { authenticated, connected, .. }` rather
15/// than positional destructuring.
16#[derive(Debug, Clone)]
17#[non_exhaustive]
18pub struct AuthStatus {
19 /// Whether the Gateway is authenticated.
20 pub authenticated: bool,
21 /// Whether the Gateway is connected to IBKR servers.
22 pub connected: bool,
23 /// Whether this Gateway session is the "competing" primary.
24 pub competing: bool,
25 /// Server-reported status message.
26 pub message: Option<String>,
27}
28
29impl From<bezant_api::BrokerageSessionStatus> for AuthStatus {
30 fn from(s: bezant_api::BrokerageSessionStatus) -> Self {
31 Self {
32 authenticated: s.authenticated.unwrap_or(false),
33 connected: s.connected.unwrap_or(false),
34 competing: s.competing.unwrap_or(false),
35 message: s.message,
36 }
37 }
38}
39
40/// Response from a tickle call, projected from the generated type.
41///
42/// `#[non_exhaustive]` so adding a field in a point release isn't a SemVer
43/// break.
44#[derive(Debug, Clone)]
45#[non_exhaustive]
46pub struct TickleResponse {
47 /// Session identifier the Gateway returned, if any.
48 pub session: Option<String>,
49 /// Raw response — exposed for callers that need the full tickle payload.
50 pub raw: bezant_api::TickleResponse,
51}
52
53impl Client {
54 /// Query current auth + connection status via
55 /// `POST /iserver/auth/status`.
56 ///
57 /// # Errors
58 /// Transport + decode errors; [`Error::Api`] for any underlying error.
59 #[tracing::instrument(skip(self), level = "debug")]
60 pub async fn auth_status(&self) -> Result<AuthStatus> {
61 // We drive a raw POST instead of the generated client so we
62 // can distinguish "no session" (Gateway redirects to login)
63 // from genuine protocol errors. CPAPI returns 302 for
64 // unauthenticated callers — not the 401 the OpenAPI spec
65 // implies — so the auto-generated `Unauthorized` branch never
66 // fires in practice.
67 let mut url = self.base_url().clone();
68 url.path_segments_mut()
69 .map_err(|()| Error::UrlNotABase {
70 url: self.base_url().to_string(),
71 })?
72 .push("iserver")
73 .push("auth")
74 .push("status");
75 url.set_query(None);
76 // Akamai (in front of CPAPI) refuses POSTs that reach it
77 // without a Content-Length header — even an empty `Vec<u8>`
78 // body isn't enough if reqwest decides to serialize it with
79 // Transfer-Encoding: chunked. Setting the header explicitly
80 // forces a `Content-Length: 0` wire representation.
81 //
82 // We also pin Origin/Referer to match the Gateway's own
83 // origin: CPGateway's CPAPI handler treats requests with
84 // missing/mismatched Origin as a CSRF attempt and 401s. Local
85 // browser flows succeed because the browser supplies an
86 // origin that matches the Gateway's expectations; from a
87 // typed server-side caller we need to mint one ourselves.
88 let gateway_origin = self
89 .gateway_root_url()
90 .as_str()
91 .trim_end_matches('/')
92 .to_owned();
93 let resp = self
94 .http()
95 .post(url.clone())
96 .header(reqwest::header::CONTENT_LENGTH, "0")
97 .header(reqwest::header::ORIGIN, &gateway_origin)
98 .header(reqwest::header::REFERER, format!("{gateway_origin}/"))
99 .send()
100 .await
101 .map_err(Error::Http)?;
102 let status = resp.status();
103 if status == reqwest::StatusCode::UNAUTHORIZED {
104 return Err(Error::NotAuthenticated);
105 }
106 if status.is_redirection() {
107 // Tighten: only treat redirects that actually aim at the SSO
108 // login flow as "not authenticated". Any other 3xx is a
109 // genuine protocol surprise we'd rather surface verbatim.
110 let location = resp
111 .headers()
112 .get(reqwest::header::LOCATION)
113 .and_then(|v| v.to_str().ok())
114 .unwrap_or_default()
115 .to_ascii_lowercase();
116 if location.contains("/sso/login") || location.contains("/sso/dispatcher") {
117 return Err(Error::NotAuthenticated);
118 }
119 return Err(Error::UpstreamStatus {
120 endpoint: "iserver/auth/status",
121 status: status.as_u16(),
122 body_preview: Some(format!("redirect to: {location}")),
123 });
124 }
125 if !status.is_success() {
126 return Err(Error::UpstreamStatus {
127 endpoint: "iserver/auth/status",
128 status: status.as_u16(),
129 body_preview: None,
130 });
131 }
132 let parsed: bezant_api::BrokerageSessionStatus = resp.json().await.map_err(|e| {
133 Error::Decode {
134 endpoint: format!("POST {}/iserver/auth/status", self.base_url()),
135 status: status.as_u16(),
136 message: e.to_string(),
137 }
138 })?;
139 Ok(AuthStatus::from(parsed))
140 }
141
142 /// Tickle the Gateway (POST `/tickle`) to keep the session alive.
143 ///
144 /// CPAPI sessions expire after ~5 minutes of inactivity; call this at
145 /// least once a minute from a background task, or use
146 /// [`Client::spawn_keepalive`].
147 ///
148 /// # Errors
149 /// Transport + decode errors.
150 #[tracing::instrument(skip(self), level = "debug")]
151 pub async fn tickle(&self) -> Result<TickleResponse> {
152 let resp = self
153 .api()
154 .get_session_token(bezant_api::GetSessionTokenRequest::default())
155 .await?;
156 match resp {
157 bezant_api::GetSessionTokenResponse::Ok(payload) => {
158 let session = match &payload {
159 bezant_api::TickleResponse::Successful(s) => s.session.clone(),
160 bezant_api::TickleResponse::Failed(_) => None,
161 };
162 Ok(TickleResponse {
163 session,
164 raw: payload,
165 })
166 }
167 bezant_api::GetSessionTokenResponse::Unauthorized => Err(Error::NotAuthenticated),
168 bezant_api::GetSessionTokenResponse::Unknown => Err(Error::Unknown {
169 endpoint: "iserver/auth/tickle",
170 }),
171 }
172 }
173
174 /// Convenience: return the status if authenticated, else a typed error
175 /// ([`Error::NotAuthenticated`] / [`Error::NoSession`]).
176 ///
177 /// # Errors
178 /// See variants.
179 #[tracing::instrument(skip(self), level = "debug")]
180 pub async fn health(&self) -> Result<AuthStatus> {
181 let status = self.auth_status().await?;
182 if !status.connected {
183 return Err(Error::NoSession);
184 }
185 if !status.authenticated {
186 return Err(Error::NotAuthenticated);
187 }
188 Ok(status)
189 }
190
191 /// Spawn a tokio task that calls [`Client::tickle`] on `interval` until
192 /// the returned [`KeepaliveHandle`] is dropped. The CPAPI session times
193 /// out around 5 minutes; 60s is a sensible default.
194 ///
195 /// Tickle failures are logged via `tracing::warn!` and don't abort the
196 /// task — a transient outage is recoverable once the Gateway is back.
197 #[must_use]
198 pub fn spawn_keepalive(&self, interval: Duration) -> KeepaliveHandle {
199 use tracing::Instrument;
200
201 let client = self.clone();
202 let (tx, mut rx) = oneshot::channel::<()>();
203 let span = tracing::info_span!("bezant_keepalive", interval_secs = interval.as_secs());
204 let join = tokio::spawn(
205 async move {
206 let mut ticker = tokio::time::interval(interval);
207 // First tick fires immediately; skip it so we don't hit the
208 // Gateway the microsecond after the caller created the client.
209 ticker.tick().await;
210 loop {
211 tokio::select! {
212 _ = &mut rx => break,
213 _ = ticker.tick() => {
214 if let Err(e) = client.tickle().await {
215 tracing::warn!(error = %e, "bezant keepalive tickle failed");
216 }
217 }
218 }
219 }
220 }
221 .instrument(span),
222 );
223 KeepaliveHandle {
224 shutdown: Some(tx),
225 join: Some(join),
226 }
227 }
228}
229
230/// Drop-to-stop handle for a background keepalive task.
231///
232/// Dropping the handle sends the shutdown signal via the [`Drop`] impl,
233/// so a forgotten `_handle` won't keep tickling forever. The spawned
234/// task observes the signal on its next tick — a pending tickle
235/// request can still be in flight when the runtime shuts down.
236/// Prefer [`KeepaliveHandle::stop`] when you need a clean, awaited
237/// exit (e.g. on SIGTERM in a long-running binary).
238#[derive(Debug)]
239pub struct KeepaliveHandle {
240 shutdown: Option<oneshot::Sender<()>>,
241 join: Option<tokio::task::JoinHandle<()>>,
242}
243
244impl KeepaliveHandle {
245 /// Stop the keepalive task and await its exit.
246 ///
247 /// # Errors
248 /// Returns [`Error::Other`] wrapping the JoinError if the task panicked.
249 pub async fn stop(mut self) -> Result<()> {
250 // Explicitly send the shutdown signal rather than relying on
251 // drop semantics — makes intent obvious and tolerates the
252 // receiver having already dropped (which is fine, nothing to
253 // do there).
254 if let Some(tx) = self.shutdown.take() {
255 let _ = tx.send(());
256 }
257 if let Some(join) = self.join.take() {
258 join.await
259 .map_err(|e| Error::Other(format!("keepalive task panicked: {e}")))?;
260 }
261 Ok(())
262 }
263}
264
265impl Drop for KeepaliveHandle {
266 /// Send the shutdown signal so a forgotten handle doesn't keep
267 /// tickling forever. The spawned task observes the signal on its
268 /// next tick — for a clean awaited exit use [`KeepaliveHandle::stop`].
269 /// We can't `.await` here (sync Drop), so any in-flight tickle
270 /// finishes detached.
271 fn drop(&mut self) {
272 if let Some(tx) = self.shutdown.take() {
273 let _ = tx.send(());
274 }
275 // Don't abort the join handle — we want any in-flight tickle
276 // to finish naturally rather than getting cut mid-write.
277 }
278}
279
280#[cfg(test)]
281mod keepalive_tests {
282 use super::*;
283 use std::time::Duration;
284
285 /// Build a `Client` against an unreachable URL — never actually
286 /// hits the network because we cancel the keepalive before its
287 /// first tick fires.
288 fn dummy_client() -> Client {
289 Client::builder("http://127.0.0.1:1/v1/api")
290 .build()
291 .expect("client")
292 }
293
294 #[tokio::test]
295 async fn stop_sends_shutdown_and_joins_cleanly() {
296 let client = dummy_client();
297 let handle = client.spawn_keepalive(Duration::from_secs(60));
298 // Give the spawned task a beat to enter its select loop.
299 tokio::time::sleep(Duration::from_millis(50)).await;
300 handle.stop().await.expect("clean stop");
301 }
302
303 #[tokio::test]
304 async fn drop_sends_shutdown_signal() {
305 let client = dummy_client();
306 // Capture the JoinHandle by stealing it out of the
307 // KeepaliveHandle's internals via the public API: we drop the
308 // handle and then verify the task observes the signal by
309 // checking that a fresh handle can be spawned without panicking
310 // (i.e. the runtime hasn't been clogged by a leaked task).
311 {
312 let _handle = client.spawn_keepalive(Duration::from_secs(60));
313 tokio::time::sleep(Duration::from_millis(50)).await;
314 // _handle drops here — Drop sends the shutdown signal.
315 }
316 // Drop happened. Spawn a fresh keepalive — if the previous
317 // task hadn't observed shutdown yet (it observes on next
318 // tick, which is 60s away), this still spawns fine. The real
319 // assertion is that no panic / deadlock occurs.
320 let _h2 = client.spawn_keepalive(Duration::from_secs(60));
321 // Give the runtime a moment; if Drop deadlocked we'd hang here.
322 tokio::time::sleep(Duration::from_millis(50)).await;
323 }
324}