use crate::common::*;
use crate::rest_models::{
AuthenticationPostRequest, AuthenticationPostResponseV3, ValidateRequest, ValidateResponse,
};
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::StatusCode;
use serde::Serialize;
use serde_json::{json, Value};
use std::error::Error;
const DEFAULT_SESSION_VERSION: usize = 2;
const DEFAULT_AUTO_LOGIN: bool = true;
#[derive(Clone, Debug)]
pub struct RestClient {
pub auth_headers: Option<HeaderMap>,
pub auto_login: bool,
pub base_url: String,
pub client: reqwest::Client,
pub common_headers: HeaderMap,
pub config: ApiConfig,
pub lightstreamer_endpoint: String,
pub refresh_token: Option<String>,
pub session_version: usize,
}
impl RestClient {
pub async fn delete(
&self,
method: String,
api_version: Option<usize>,
body: &Option<impl Serialize + ValidateRequest>,
) -> Result<(HeaderMap, Value), Box<dyn Error>> {
let version = api_version.unwrap_or(1).to_string();
if let Some(body) = body {
body.validate()?;
}
let body = serde_json::to_value(body)?;
let response = self
.client
.post(&format!("{}/{}", &self.base_url, method))
.json(&body)
.headers(self.auth_headers.clone().unwrap_or(HeaderMap::new()))
.headers(self.common_headers.clone())
.header("Version", version)
.header("_method", "DELETE".to_string())
.send()
.await?;
match response.status() {
StatusCode::NO_CONTENT => Ok((response.headers().clone(), json!({}))),
StatusCode::OK => Ok((response.headers().clone(), response.json().await?)),
_ => Err(Box::new(ApiError {
message: format!(
"DELETE operation using method '{}' failed with status code: {:?} - {:?}",
method,
response.status(),
response.text().await?
),
})),
}
}
pub async fn new(config: ApiConfig) -> Result<Self, Box<dyn Error>> {
let base_url = match config.execution_environment {
ExecutionEnvironment::Demo => config.base_url_demo.clone(),
ExecutionEnvironment::Live => config.base_url_live.clone(),
};
let session_version = config.session_version.unwrap_or(DEFAULT_SESSION_VERSION);
let auto_login = config.auto_login.unwrap_or(DEFAULT_AUTO_LOGIN);
let mut common_headers = HeaderMap::new();
common_headers.insert("Accept", "application/json; charset=UTF-8".parse()?);
common_headers.insert("Content-Type", "application/json; charset=UTF-8".parse()?);
common_headers.insert("X-IG-API-KEY", config.api_key.as_str().parse()?);
let mut rest_client = Self {
auth_headers: None,
auto_login,
base_url,
client: reqwest::Client::new(),
common_headers,
config,
lightstreamer_endpoint: "".to_string(),
refresh_token: None,
session_version,
};
if auto_login {
let _ = rest_client.login().await?;
};
Ok(rest_client)
}
pub async fn get(
&self,
method: String,
api_version: Option<usize>,
params: &Option<impl Serialize + ValidateRequest>,
) -> Result<(HeaderMap, Value), Box<dyn Error>> {
let api_version = api_version.unwrap_or(1).to_string();
if let Some(params) = params {
params.validate()?;
}
let query_string = params_to_query_string(params)?;
let url = if query_string.is_empty() {
format!("{}/{}", &self.base_url, method)
} else {
format!("{}/{}?{}", &self.base_url, method, query_string)
};
let response = self
.client
.get(&url)
.headers(self.auth_headers.clone().unwrap_or(HeaderMap::new()))
.headers(self.common_headers.clone())
.header("Version", api_version)
.send()
.await?;
match response.status() {
StatusCode::OK => Ok((response.headers().clone(), response.json().await?)),
_ => Err(Box::new(ApiError {
message: format!(
"GET operation to url '{}' and query_string '{}' failed with status code: {:?} - {:?}",
url,
query_string,
response.status(),
response.text().await?
),
})),
}
}
pub async fn login(&mut self) -> Result<Value, Box<dyn Error>> {
match self.session_version {
1 | 2 => Ok(self.login_v2().await?),
3 => Ok(self.login_v3().await?),
_ => Err(Box::new(ApiError {
message: format!("Invalid session version: {}", self.session_version),
})),
}
}
pub async fn login_v2(&mut self) -> Result<Value, Box<dyn Error>> {
let login_request_body = AuthenticationPostRequest {
identifier: self.config.username.clone(),
password: self.config.password.clone(),
};
login_request_body.validate()?;
let response = self
.client
.post(&format!("{}/session", &self.base_url))
.json(&login_request_body)
.headers(self.common_headers.clone())
.header("Version", "2")
.send()
.await?;
match response.status() {
StatusCode::OK => {
let mut auth_headers = HeaderMap::new();
if let Some(cst_header) = response.headers().get("cst") {
auth_headers.insert("cst", HeaderValue::from_str(cst_header.to_str()?)?);
}
if let Some(security_token_header) = response.headers().get("x-security-token") {
auth_headers.insert(
"x-security-token",
HeaderValue::from_str(security_token_header.to_str()?).unwrap(),
);
}
if auth_headers.get("cst").is_none()
|| auth_headers.get("x-security-token").is_none()
{
return Err(Box::new(ApiError {
message:
"Any of the cst / x-security-token headers not found in login response."
.to_string(),
}));
}
self.auth_headers = Some(auth_headers);
let response_json: Value = response.json().await?;
self.lightstreamer_endpoint = match response_json.get("lightstreamerEndpoint") {
Some(endpoint) => match endpoint.as_str() {
Some(s) => s.to_string(),
None => {
return Err(Box::new(ApiError {
message: "Lightstreamer endpoint is not a string.".to_string(),
}))
}
},
None => {
return Err(Box::new(ApiError {
message: "Lightstreamer endpoint not found in login response.".to_string(),
}))
}
};
Ok(response_json)
}
_ => Err(Box::new(ApiError {
message: format!(
"Login failed with status code: {:?} - {:?}",
response.status(),
response.text().await?
),
})),
}
}
pub async fn login_v3(&mut self) -> Result<Value, Box<dyn Error>> {
let login_request_body = AuthenticationPostRequest {
identifier: self.config.username.clone(),
password: self.config.password.clone(),
};
login_request_body.validate()?;
let response = self
.client
.post(&format!("{}/session", &self.base_url))
.json(&login_request_body)
.headers(self.common_headers.clone())
.header("Version", "3")
.send()
.await?;
match response.status() {
StatusCode::OK => {
let response_body = response.json().await?;
let login_response = AuthenticationPostResponseV3::from_value(&response_body)?;
let mut auth_headers = HeaderMap::new();
auth_headers.insert(
"Authorization",
HeaderValue::from_str(&format!(
"Bearer {}",
login_response.oauth_token.access_token
))?,
);
let account_number = match self.config.execution_environment {
ExecutionEnvironment::Demo => self.config.account_number_demo.clone(),
ExecutionEnvironment::Live => self.config.account_number_live.clone(),
};
auth_headers.insert("IG-ACCOUNT-ID", HeaderValue::from_str(&account_number)?);
self.auth_headers = Some(auth_headers);
self.refresh_token = Some(login_response.oauth_token.refresh_token);
self.lightstreamer_endpoint = login_response.lightstreamer_endpoint;
Ok(response_body)
}
_ => Err(Box::new(ApiError {
message: format!(
"Login failed with status code: {:?} - {:?}",
response.status(),
response.text().await?,
),
})),
}
}
pub async fn post(
&self,
method: String,
api_version: Option<usize>,
body: &(impl Serialize + ValidateRequest),
) -> Result<(HeaderMap, Value), Box<dyn Error>> {
let version = api_version.unwrap_or(1).to_string();
body.validate()?;
let body = serde_json::to_value(body)?;
let response = self
.client
.post(&format!("{}/{}", &self.base_url, method))
.json(&body)
.headers(self.auth_headers.clone().unwrap_or(HeaderMap::new()))
.headers(self.common_headers.clone())
.header("Version", version.clone())
.send()
.await?;
match response.status() {
StatusCode::OK => Ok((response.headers().clone(), response.json().await?)),
_ => Err(Box::new(ApiError {
message: format!(
"POST operation using method '{}', version '{}' and body '{:?}' failed with status code: {:?} - {:?}",
method,
version,
body,
response.status(),
response.text().await?
),
})),
}
}
pub async fn put(
&self,
method: String,
version: Option<usize>,
body: &(impl Serialize + ValidateRequest),
) -> Result<(HeaderMap, Value), Box<dyn Error>> {
let version = version.unwrap_or(1).to_string();
body.validate()?;
let response = self
.client
.put(&format!("{}/{}", &self.base_url, method))
.json(&body)
.headers(self.auth_headers.clone().unwrap_or(HeaderMap::new()))
.headers(self.common_headers.clone())
.header("Version", version.clone())
.send()
.await?;
match response.status() {
StatusCode::OK => Ok((response.headers().clone(), response.json().await?)),
_ => Err(Box::new(ApiError {
message: format!(
"PUT operation using method '{}', version '{}' and body '{:?}' failed with status code: {:?} - {:?}",
method,
version,
serde_json::to_string(&body)?,
response.status(),
response.text().await?
),
})),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::{ApiConfig, ExecutionEnvironment};
#[tokio::test]
async fn new_rest_client_works() {
let config = ApiConfig {
account_number_demo: "test_account_number_demo".to_string(),
account_number_live: "test_account_number_live".to_string(),
account_number_test: None,
api_key: "test_api_key".to_string(),
auto_login: Some(false),
execution_environment: ExecutionEnvironment::Demo,
base_url_demo: "https://demo.example.com".to_string(),
base_url_live: "https://live.example.com".to_string(),
session_version: Some(2),
streaming_api_max_connection_attempts: None,
password: "test_password".to_string(),
username: "test_username".to_string(),
};
let rest_client = RestClient::new(config).await.unwrap();
assert_eq!(rest_client.auth_headers, None);
assert_eq!(rest_client.auto_login, false);
assert_eq!(rest_client.base_url, "https://demo.example.com");
assert_eq!(
rest_client.common_headers.get("X-IG-API-KEY").unwrap(),
"test_api_key"
);
assert_eq!(
rest_client.config.account_number_demo,
"test_account_number_demo"
);
assert_eq!(
rest_client.config.account_number_live,
"test_account_number_live"
);
assert_eq!(rest_client.config.account_number_test, None);
assert_eq!(rest_client.config.api_key, "test_api_key");
assert_eq!(rest_client.config.auto_login, Some(false));
assert_eq!(
rest_client.config.execution_environment,
ExecutionEnvironment::Demo
);
assert_eq!(rest_client.config.base_url_demo, "https://demo.example.com");
assert_eq!(rest_client.config.base_url_live, "https://live.example.com");
assert_eq!(rest_client.config.session_version, Some(2));
assert_eq!(rest_client.config.password, "test_password");
assert_eq!(rest_client.config.username, "test_username");
assert_eq!(rest_client.session_version, 2);
}
}