use std::time::{Duration, Instant};
use http::Method;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use crate::error::{Error, Result};
use crate::session::tokens::{AuthTokens, OAuthPayload, SessionState};
use crate::session::{Credentials, SessionHandle, SessionInfo};
#[derive(Debug)]
pub struct SessionApi {
pub(crate) handle: SessionHandle,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct LoginRequest<'a> {
identifier: &'a str,
password: &'a str,
encrypted_password: bool,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct LoginResponseV3 {
account_id: String,
client_id: String,
timezone_offset: Option<i32>,
lightstreamer_endpoint: String,
currency_iso_code: Option<String>,
locale: Option<String>,
oauth_token: OAuthPayload,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct LoginResponseV2 {
#[serde(rename = "currentAccountId", alias = "accountId")]
account_id: String,
client_id: String,
timezone_offset: Option<i32>,
lightstreamer_endpoint: String,
currency_iso_code: Option<String>,
locale: Option<String>,
}
impl SessionApi {
#[instrument(skip_all)]
pub async fn login(&self) -> Result<SessionInfo> {
let creds = self
.handle
.credentials
.as_ref()
.ok_or_else(|| Error::Auth("no credentials configured on the client".into()))?;
match creds {
Credentials::Password { username, password } => {
self.login_v3(username, password, false).await
}
}
}
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
#[instrument(skip_all)]
pub async fn login_with_encryption(&self) -> Result<SessionInfo> {
let creds = self
.handle
.credentials
.as_ref()
.ok_or_else(|| Error::Auth("no credentials configured on the client".into()))?;
let Credentials::Password { username, password } = creds;
let key = self.encryption_key().await?;
let encrypted = crate::session::encryption::encrypt_password(
password,
&key.encryption_key,
key.time_stamp,
)?;
self.login_v3(username, &encrypted, true).await
}
#[instrument(skip_all)]
pub async fn login_v2(&self) -> Result<SessionInfo> {
let creds = self
.handle
.credentials
.as_ref()
.ok_or_else(|| Error::Auth("no credentials configured on the client".into()))?;
let Credentials::Password { username, password } = creds;
let body = LoginRequest {
identifier: username,
password,
encrypted_password: false,
};
let resp = self
.handle
.transport
.request_unauthenticated(Method::POST, "session", Some(2), Some(&body))
.await?;
let cst = resp
.headers
.get("CST")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| Error::Auth("missing CST header in login response".into()))?
.to_owned();
let xst = resp
.headers
.get("X-SECURITY-TOKEN")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| Error::Auth("missing X-SECURITY-TOKEN header".into()))?
.to_owned();
let body: LoginResponseV2 = serde_json::from_slice(&resp.body)?;
let new_state = SessionState {
tokens: Some(AuthTokens::Cst {
cst,
x_security_token: xst,
}),
account_id: Some(body.account_id.clone()),
client_id: Some(body.client_id.clone()),
lightstreamer_endpoint: Some(body.lightstreamer_endpoint.clone()),
};
self.handle.session.replace(new_state).await;
Ok(SessionInfo {
account_id: body.account_id,
client_id: body.client_id,
timezone_offset: body.timezone_offset,
lightstreamer_endpoint: body.lightstreamer_endpoint,
currency_iso_code: body.currency_iso_code,
locale: body.locale,
})
}
async fn login_v3(
&self,
username: &str,
password: &str,
encrypted_password: bool,
) -> Result<SessionInfo> {
let body = LoginRequest {
identifier: username,
password,
encrypted_password,
};
let resp = self
.handle
.transport
.request_unauthenticated(Method::POST, "session", Some(3), Some(&body))
.await?;
let body: LoginResponseV3 = serde_json::from_slice(&resp.body)?;
let expires_in = body.oauth_token.expires_in.parse::<u64>().map_err(|e| {
Error::Auth(format!(
"invalid expires_in '{}': {e}",
body.oauth_token.expires_in
))
})?;
let tokens = AuthTokens::OAuth {
access_token: body.oauth_token.access_token,
refresh_token: body.oauth_token.refresh_token,
token_type: body.oauth_token.token_type,
expires_at: Instant::now() + Duration::from_secs(expires_in),
};
let new_state = SessionState {
tokens: Some(tokens),
account_id: Some(body.account_id.clone()),
client_id: Some(body.client_id.clone()),
lightstreamer_endpoint: Some(body.lightstreamer_endpoint.clone()),
};
self.handle.session.replace(new_state).await;
Ok(SessionInfo {
account_id: body.account_id,
client_id: body.client_id,
timezone_offset: body.timezone_offset,
lightstreamer_endpoint: body.lightstreamer_endpoint,
currency_iso_code: body.currency_iso_code,
locale: body.locale,
})
}
#[instrument(skip_all)]
pub async fn refresh(&self) -> Result<()> {
let state = self.handle.session.snapshot().await;
let Some(AuthTokens::OAuth { refresh_token, .. }) = state.tokens else {
return Err(Error::Auth("no refresh token available".into()));
};
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
struct Req<'a> {
refresh_token: &'a str,
}
let resp = self
.handle
.transport
.request_unauthenticated(
Method::POST,
"session/refresh-token",
Some(1),
Some(&Req {
refresh_token: &refresh_token,
}),
)
.await?;
let payload: OAuthPayload = serde_json::from_slice(&resp.body)?;
let expires_in = payload
.expires_in
.parse::<u64>()
.map_err(|e| Error::Auth(format!("invalid expires_in: {e}")))?;
let new_tokens = AuthTokens::OAuth {
access_token: payload.access_token,
refresh_token: payload.refresh_token,
token_type: payload.token_type,
expires_at: Instant::now() + Duration::from_secs(expires_in),
};
self.handle
.session
.modify(|s| s.tokens = Some(new_tokens))
.await;
Ok(())
}
#[instrument(skip_all)]
pub async fn logout(&self) -> Result<()> {
let _ = self
.handle
.transport
.request::<(), serde_json::Value>(
Method::DELETE,
"session",
Some(1),
None::<&()>,
&self.handle.session,
)
.await;
self.handle.session.replace(SessionState::default()).await;
Ok(())
}
#[instrument(skip_all, fields(fetch_tokens = fetch_tokens))]
pub async fn read(&self, fetch_tokens: bool) -> Result<SessionDetails> {
let path = if fetch_tokens {
"session?fetchSessionTokens=true"
} else {
"session"
};
let raw = self
.handle
.transport
.request_authenticated_raw::<()>(
Method::GET,
path,
Some(1),
None::<&()>,
&self.handle.session,
)
.await?;
let details: SessionDetails = serde_json::from_slice(&raw.body)?;
if fetch_tokens {
let cst = raw
.headers
.get("CST")
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
let xst = raw
.headers
.get("X-SECURITY-TOKEN")
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
if let (Some(cst), Some(x_security_token)) = (cst, xst) {
self.handle
.session
.modify(|s| {
s.tokens = Some(AuthTokens::Cst {
cst,
x_security_token,
});
})
.await;
}
}
Ok(details)
}
#[instrument(skip_all, fields(account_id = %account_id))]
pub async fn switch_account(
&self,
account_id: &str,
default_account: bool,
) -> Result<SwitchAccountResponse> {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Req<'a> {
account_id: &'a str,
default_account: bool,
}
let resp: SwitchAccountResponse = self
.handle
.transport
.request(
Method::PUT,
"session",
Some(1),
Some(&Req {
account_id,
default_account,
}),
&self.handle.session,
)
.await?;
let new_id = account_id.to_owned();
self.handle
.session
.modify(|s| s.account_id = Some(new_id))
.await;
Ok(resp)
}
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
#[instrument(skip_all)]
pub async fn encryption_key(&self) -> Result<EncryptionKey> {
let resp = self
.handle
.transport
.request_unauthenticated::<()>(Method::GET, "session/encryptionKey", None, None)
.await?;
Ok(serde_json::from_slice(&resp.body)?)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionDetails {
pub account_id: String,
pub client_id: String,
pub account_type: Option<String>,
pub currency: Option<String>,
pub locale: Option<String>,
pub timezone_offset: Option<i32>,
pub lightstreamer_endpoint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[serde(rename_all = "camelCase", default)]
pub struct SwitchAccountResponse {
pub trailing_stops_enabled: bool,
pub dealing_enabled: bool,
pub has_active_demo_accounts: Option<bool>,
pub has_active_live_accounts: Option<bool>,
}
#[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct EncryptionKey {
pub encryption_key: String,
pub time_stamp: i64,
}