Skip to main content

trading_ig/session/
auth.rs

1//! Login / refresh / switch / logout flows.
2
3use std::time::{Duration, Instant};
4
5use http::Method;
6use serde::{Deserialize, Serialize};
7use tracing::instrument;
8
9use crate::error::{Error, Result};
10use crate::session::tokens::{AuthTokens, OAuthPayload, SessionState};
11use crate::session::{Credentials, SessionHandle, SessionInfo};
12
13/// Public entry point for session management. Obtain via
14/// [`crate::IgClient::session`].
15#[derive(Debug)]
16pub struct SessionApi {
17    pub(crate) handle: SessionHandle,
18}
19
20#[derive(Debug, Serialize)]
21#[serde(rename_all = "camelCase")]
22struct LoginRequest<'a> {
23    identifier: &'a str,
24    password: &'a str,
25    encrypted_password: bool,
26}
27
28#[derive(Debug, Deserialize)]
29#[serde(rename_all = "camelCase")]
30struct LoginResponseV3 {
31    account_id: String,
32    client_id: String,
33    timezone_offset: Option<i32>,
34    lightstreamer_endpoint: String,
35    currency_iso_code: Option<String>,
36    locale: Option<String>,
37    oauth_token: OAuthPayload,
38}
39
40#[derive(Debug, Deserialize)]
41#[serde(rename_all = "camelCase")]
42struct LoginResponseV2 {
43    /// IG returns this as `currentAccountId` in real responses; the
44    /// `accountId` alias preserves compatibility with older fixtures and
45    /// any non-IG mocks that follow the v3 naming.
46    #[serde(rename = "currentAccountId", alias = "accountId")]
47    account_id: String,
48    client_id: String,
49    timezone_offset: Option<i32>,
50    lightstreamer_endpoint: String,
51    currency_iso_code: Option<String>,
52    locale: Option<String>,
53}
54
55impl SessionApi {
56    /// Log in using the canonical v3 flow (OAuth bearer tokens).
57    #[instrument(skip_all)]
58    pub async fn login(&self) -> Result<SessionInfo> {
59        let creds = self
60            .handle
61            .credentials
62            .as_ref()
63            .ok_or_else(|| Error::Auth("no credentials configured on the client".into()))?;
64
65        match creds {
66            Credentials::Password { username, password } => {
67                self.login_v3(username, password, false).await
68            }
69        }
70    }
71
72    /// Log in v3 with an **RSA-encrypted password** instead of plaintext.
73    ///
74    /// **Recommended for accounts that hold real funds** (live or funded
75    /// demo). The password is encrypted client-side with IG's published
76    /// RSA public key (PKCS#1 v1.5) before being sent over the wire, so
77    /// it never appears in plaintext in any intermediate proxy or
78    /// server-side log.
79    ///
80    /// Workflow (handled internally):
81    /// 1. `GET /session/encryptionKey` to fetch the public key + timestamp.
82    /// 2. `encrypt_password(password, key, timestamp)` (RSA PKCS#1v15).
83    /// 3. `POST /session` v3 with `encryptedPassword=true`.
84    ///
85    /// Behind the optional `encryption` cargo feature.
86    ///
87    /// # Errors
88    ///
89    /// - `Error::Api` if either the key fetch or the login itself returns
90    ///   a non-2xx response.
91    /// - `Error::Auth` if the encryption step fails (malformed key, bad
92    ///   key/timestamp combination, etc.).
93    #[cfg(feature = "encryption")]
94    #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
95    #[instrument(skip_all)]
96    pub async fn login_with_encryption(&self) -> Result<SessionInfo> {
97        let creds = self
98            .handle
99            .credentials
100            .as_ref()
101            .ok_or_else(|| Error::Auth("no credentials configured on the client".into()))?;
102        let Credentials::Password { username, password } = creds;
103
104        let key = self.encryption_key().await?;
105        let encrypted = crate::session::encryption::encrypt_password(
106            password,
107            &key.encryption_key,
108            key.time_stamp,
109        )?;
110        self.login_v3(username, &encrypted, true).await
111    }
112
113    /// Log in using the legacy v2 flow (CST + X-SECURITY-TOKEN response headers).
114    /// Mainly used by the streaming client which still wants CST/XST.
115    #[instrument(skip_all)]
116    pub async fn login_v2(&self) -> Result<SessionInfo> {
117        let creds = self
118            .handle
119            .credentials
120            .as_ref()
121            .ok_or_else(|| Error::Auth("no credentials configured on the client".into()))?;
122
123        let Credentials::Password { username, password } = creds;
124        let body = LoginRequest {
125            identifier: username,
126            password,
127            encrypted_password: false,
128        };
129
130        let resp = self
131            .handle
132            .transport
133            .request_unauthenticated(Method::POST, "session", Some(2), Some(&body))
134            .await?;
135
136        let cst = resp
137            .headers
138            .get("CST")
139            .and_then(|v| v.to_str().ok())
140            .ok_or_else(|| Error::Auth("missing CST header in login response".into()))?
141            .to_owned();
142        let xst = resp
143            .headers
144            .get("X-SECURITY-TOKEN")
145            .and_then(|v| v.to_str().ok())
146            .ok_or_else(|| Error::Auth("missing X-SECURITY-TOKEN header".into()))?
147            .to_owned();
148
149        let body: LoginResponseV2 = serde_json::from_slice(&resp.body)?;
150        let new_state = SessionState {
151            tokens: Some(AuthTokens::Cst {
152                cst,
153                x_security_token: xst,
154            }),
155            account_id: Some(body.account_id.clone()),
156            client_id: Some(body.client_id.clone()),
157            lightstreamer_endpoint: Some(body.lightstreamer_endpoint.clone()),
158        };
159        self.handle.session.replace(new_state).await;
160
161        Ok(SessionInfo {
162            account_id: body.account_id,
163            client_id: body.client_id,
164            timezone_offset: body.timezone_offset,
165            lightstreamer_endpoint: body.lightstreamer_endpoint,
166            currency_iso_code: body.currency_iso_code,
167            locale: body.locale,
168        })
169    }
170
171    async fn login_v3(
172        &self,
173        username: &str,
174        password: &str,
175        encrypted_password: bool,
176    ) -> Result<SessionInfo> {
177        let body = LoginRequest {
178            identifier: username,
179            password,
180            encrypted_password,
181        };
182
183        let resp = self
184            .handle
185            .transport
186            .request_unauthenticated(Method::POST, "session", Some(3), Some(&body))
187            .await?;
188
189        let body: LoginResponseV3 = serde_json::from_slice(&resp.body)?;
190
191        let expires_in = body.oauth_token.expires_in.parse::<u64>().map_err(|e| {
192            Error::Auth(format!(
193                "invalid expires_in '{}': {e}",
194                body.oauth_token.expires_in
195            ))
196        })?;
197
198        let tokens = AuthTokens::OAuth {
199            access_token: body.oauth_token.access_token,
200            refresh_token: body.oauth_token.refresh_token,
201            token_type: body.oauth_token.token_type,
202            expires_at: Instant::now() + Duration::from_secs(expires_in),
203        };
204        let new_state = SessionState {
205            tokens: Some(tokens),
206            account_id: Some(body.account_id.clone()),
207            client_id: Some(body.client_id.clone()),
208            lightstreamer_endpoint: Some(body.lightstreamer_endpoint.clone()),
209        };
210        self.handle.session.replace(new_state).await;
211
212        Ok(SessionInfo {
213            account_id: body.account_id,
214            client_id: body.client_id,
215            timezone_offset: body.timezone_offset,
216            lightstreamer_endpoint: body.lightstreamer_endpoint,
217            currency_iso_code: body.currency_iso_code,
218            locale: body.locale,
219        })
220    }
221
222    /// Refresh the v3 access token using the stored refresh token.
223    #[instrument(skip_all)]
224    pub async fn refresh(&self) -> Result<()> {
225        let state = self.handle.session.snapshot().await;
226        let Some(AuthTokens::OAuth { refresh_token, .. }) = state.tokens else {
227            return Err(Error::Auth("no refresh token available".into()));
228        };
229
230        #[derive(Serialize)]
231        #[serde(rename_all = "snake_case")]
232        struct Req<'a> {
233            refresh_token: &'a str,
234        }
235
236        let resp = self
237            .handle
238            .transport
239            .request_unauthenticated(
240                Method::POST,
241                "session/refresh-token",
242                Some(1),
243                Some(&Req {
244                    refresh_token: &refresh_token,
245                }),
246            )
247            .await?;
248
249        let payload: OAuthPayload = serde_json::from_slice(&resp.body)?;
250        let expires_in = payload
251            .expires_in
252            .parse::<u64>()
253            .map_err(|e| Error::Auth(format!("invalid expires_in: {e}")))?;
254
255        let new_tokens = AuthTokens::OAuth {
256            access_token: payload.access_token,
257            refresh_token: payload.refresh_token,
258            token_type: payload.token_type,
259            expires_at: Instant::now() + Duration::from_secs(expires_in),
260        };
261        self.handle
262            .session
263            .modify(|s| s.tokens = Some(new_tokens))
264            .await;
265        Ok(())
266    }
267
268    /// Tear down the current session on the server side and locally.
269    #[instrument(skip_all)]
270    pub async fn logout(&self) -> Result<()> {
271        // Best-effort: even if the server call fails (e.g. tokens already
272        // expired) we still clear local state.
273        let _ = self
274            .handle
275            .transport
276            .request::<(), serde_json::Value>(
277                Method::DELETE,
278                "session",
279                Some(1),
280                None::<&()>,
281                &self.handle.session,
282            )
283            .await;
284        self.handle.session.replace(SessionState::default()).await;
285        Ok(())
286    }
287
288    /// Read details about the current session.
289    ///
290    /// When `fetch_tokens` is `true`, the server responds with `CST` and
291    /// `X-SECURITY-TOKEN` headers. These are written into the local session
292    /// state — necessary when an OAuth (v3) session needs CST/XST tokens
293    /// for the Lightstreamer streaming endpoint.
294    #[instrument(skip_all, fields(fetch_tokens = fetch_tokens))]
295    pub async fn read(&self, fetch_tokens: bool) -> Result<SessionDetails> {
296        let path = if fetch_tokens {
297            "session?fetchSessionTokens=true"
298        } else {
299            "session"
300        };
301        let raw = self
302            .handle
303            .transport
304            .request_authenticated_raw::<()>(
305                Method::GET,
306                path,
307                Some(1),
308                None::<&()>,
309                &self.handle.session,
310            )
311            .await?;
312
313        let details: SessionDetails = serde_json::from_slice(&raw.body)?;
314
315        if fetch_tokens {
316            let cst = raw
317                .headers
318                .get("CST")
319                .and_then(|v| v.to_str().ok())
320                .map(str::to_owned);
321            let xst = raw
322                .headers
323                .get("X-SECURITY-TOKEN")
324                .and_then(|v| v.to_str().ok())
325                .map(str::to_owned);
326            if let (Some(cst), Some(x_security_token)) = (cst, xst) {
327                self.handle
328                    .session
329                    .modify(|s| {
330                        // For v3 sessions we keep the OAuth tokens too — but
331                        // some callers (notably Lightstreamer) need CST/XST.
332                        // We replace the token bag entirely; OAuth holders
333                        // who still want refresh capability should call
334                        // `read(false)` only after they're done streaming.
335                        s.tokens = Some(AuthTokens::Cst {
336                            cst,
337                            x_security_token,
338                        });
339                    })
340                    .await;
341            }
342        }
343
344        Ok(details)
345    }
346
347    /// Switch the active trading account.
348    ///
349    /// Updates the local session state so that subsequent v3 requests carry
350    /// the new `IG-ACCOUNT-ID` header.
351    #[instrument(skip_all, fields(account_id = %account_id))]
352    pub async fn switch_account(
353        &self,
354        account_id: &str,
355        default_account: bool,
356    ) -> Result<SwitchAccountResponse> {
357        #[derive(Serialize)]
358        #[serde(rename_all = "camelCase")]
359        struct Req<'a> {
360            account_id: &'a str,
361            default_account: bool,
362        }
363
364        let resp: SwitchAccountResponse = self
365            .handle
366            .transport
367            .request(
368                Method::PUT,
369                "session",
370                Some(1),
371                Some(&Req {
372                    account_id,
373                    default_account,
374                }),
375                &self.handle.session,
376            )
377            .await?;
378
379        let new_id = account_id.to_owned();
380        self.handle
381            .session
382            .modify(|s| s.account_id = Some(new_id))
383            .await;
384        Ok(resp)
385    }
386
387    /// Fetch the encryption key + timestamp used for encrypted-password login.
388    ///
389    /// Combine with [`crate::session::encryption::encrypt_password`] (behind
390    /// the `encryption` feature) to build the `password` field expected by
391    /// `POST /session` when `encryptedPassword=true`.
392    #[cfg(feature = "encryption")]
393    #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
394    #[instrument(skip_all)]
395    pub async fn encryption_key(&self) -> Result<EncryptionKey> {
396        // No Version header for this endpoint.
397        let resp = self
398            .handle
399            .transport
400            .request_unauthenticated::<()>(Method::GET, "session/encryptionKey", None, None)
401            .await?;
402        Ok(serde_json::from_slice(&resp.body)?)
403    }
404}
405
406/// Details returned by `GET /session`.
407#[derive(Debug, Clone, Deserialize, Serialize)]
408#[serde(rename_all = "camelCase")]
409pub struct SessionDetails {
410    pub account_id: String,
411    pub client_id: String,
412    pub account_type: Option<String>,
413    pub currency: Option<String>,
414    pub locale: Option<String>,
415    pub timezone_offset: Option<i32>,
416    pub lightstreamer_endpoint: Option<String>,
417}
418
419/// Body returned by `PUT /session` (switch account).
420///
421/// Most useful field is `dealing_enabled`. `has_active_demo_accounts` and
422/// `has_active_live_accounts` are present in some IG responses; modelled
423/// as `Option` for forward compatibility.
424#[derive(Debug, Clone, Deserialize, Serialize, Default)]
425#[serde(rename_all = "camelCase", default)]
426pub struct SwitchAccountResponse {
427    pub trailing_stops_enabled: bool,
428    pub dealing_enabled: bool,
429    pub has_active_demo_accounts: Option<bool>,
430    pub has_active_live_accounts: Option<bool>,
431}
432
433/// Wire-level response of `GET /session/encryptionKey`.
434#[cfg(feature = "encryption")]
435#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
436#[derive(Debug, Clone, Deserialize, Serialize)]
437#[serde(rename_all = "camelCase")]
438pub struct EncryptionKey {
439    /// Base64-encoded RSA public key (DER-encoded SPKI).
440    pub encryption_key: String,
441    /// Server-supplied timestamp in milliseconds. Concatenate it to the
442    /// password before encryption: `format!("{password}|{time_stamp}")`.
443    pub time_stamp: i64,
444}