use crate::common::{
deserialize, params_to_json, params_to_query_string, ApiConfig, ApiError, ExecutionEnvironment,
};
use crate::rest_models::{LoginRequest, LoginResponseV3};
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::StatusCode;
use serde::Serialize;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::error::Error;
const DEFAULT_SESSION_VERSION: usize = 2;
const DEFAULT_AUTO_LOGIN: bool = true;
#[derive(Clone, Debug)]
pub struct RestClient {
pub auth_headers: Option<HeaderMap>,
pub auto_login: bool,
pub base_url: String,
pub client: reqwest::Client,
pub common_headers: HeaderMap,
pub config: ApiConfig,
pub session_version: usize,
}
impl RestClient {
pub async fn delete(&self, method: String) -> Result<(HeaderMap, ()), Box<dyn Error>> {
let api_version: usize = 1;
let response = self
.client
.delete(&format!("{}/{}", &self.base_url, method))
.headers(self.auth_headers.clone().unwrap_or(HeaderMap::new()))
.headers(self.common_headers.clone())
.header("Version", api_version)
.send()
.await?;
match response.status() {
StatusCode::NO_CONTENT => Ok((response.headers().clone(), ())),
_ => Err(Box::new(ApiError {
message: format!(
"DELETE operation using method '{}' failed with status code: {}",
method,
response.status()
),
})),
}
}
pub async fn new(config: ApiConfig) -> Result<Self, Box<dyn Error>> {
let base_url = match config.execution_environment {
ExecutionEnvironment::Demo => config.base_url_demo.clone(),
ExecutionEnvironment::Live => config.base_url_live.clone(),
};
let session_version = config.session_version.unwrap_or(DEFAULT_SESSION_VERSION);
let auto_login = config.auto_login.unwrap_or(DEFAULT_AUTO_LOGIN);
let mut common_headers = HeaderMap::new();
common_headers.insert("Accept", "application/json; charset=UTF-8".parse()?);
common_headers.insert("Content-Type", "application/json; charset=UTF-8".parse()?);
common_headers.insert("X-IG-API-KEY", config.api_key.as_str().parse()?);
let mut rest_client = Self {
auth_headers: None,
auto_login,
base_url,
client: reqwest::Client::new(),
common_headers,
config,
session_version,
};
if auto_login {
rest_client.login().await?;
};
Ok(rest_client)
}
pub async fn get(
&self,
method: String,
api_version: Option<usize>,
params: Option<HashMap<String, String>>,
) -> Result<(HeaderMap, Value), Box<dyn Error>> {
let api_version = api_version.unwrap_or(1).to_string();
let query_string = params_to_query_string(params);
let response = self
.client
.get(&format!("{}/{}?{}", &self.base_url, method, query_string))
.headers(self.auth_headers.clone().unwrap_or(HeaderMap::new()))
.headers(self.common_headers.clone())
.header("Version", api_version)
.send()
.await?;
match response.status() {
StatusCode::OK => Ok((response.headers().clone(), response.json().await?)),
_ => Err(Box::new(ApiError {
message: format!(
"GET operation using method '{}' and query_string '{}' failed with status code: {}",
method,
query_string,
response.status()
),
})),
}
}
pub async fn login(&mut self) -> Result<Value, Box<dyn Error>> {
println!("Logging in with session version: {}", self.session_version);
match self.session_version {
1 | 2 => Ok(self.login_v2().await?),
3 => Ok(self.login_v3().await?),
_ => Err(Box::new(ApiError {
message: format!("Invalid session version: {}", self.session_version),
})),
}
}
pub async fn login_v2(&mut self) -> Result<Value, Box<dyn Error>> {
let login_request_body = LoginRequest {
identifier: self.config.username.clone(),
password: self.config.password.clone(),
};
let response = self
.client
.post(&format!("{}/session", &self.base_url))
.json(&login_request_body)
.headers(self.common_headers.clone())
.header("Version", "2")
.send()
.await?;
match response.status() {
StatusCode::OK => {
let mut auth_headers = HeaderMap::new();
if let Some(cst_header) = response.headers().get("cst") {
auth_headers.insert("cst", HeaderValue::from_str(cst_header.to_str()?)?);
}
if let Some(security_token_header) = response.headers().get("x-security-token") {
auth_headers.insert(
"x-security-token",
HeaderValue::from_str(security_token_header.to_str()?).unwrap(),
);
}
if auth_headers.get("cst").is_none()
|| auth_headers.get("x-security-token").is_none()
{
return Err(Box::new(ApiError {
message:
"Any of the cst / x-security-token headers not found in login response."
.to_string(),
}));
}
self.auth_headers = Some(auth_headers);
Ok(response.json().await?)
}
_ => Err(Box::new(ApiError {
message: format!("Login failed with status code: {}", response.status()),
})),
}
}
pub async fn login_v3(&mut self) -> Result<Value, Box<dyn Error>> {
let login_request_body = LoginRequest {
identifier: self.config.username.clone(),
password: self.config.password.clone(),
};
let response = self
.client
.post(&format!("{}/session", &self.base_url))
.json(&login_request_body)
.headers(self.common_headers.clone())
.header("Version", "3")
.send()
.await?;
match response.status() {
StatusCode::OK => {
let response_body = response.json().await?;
let login_response: LoginResponseV3 = deserialize(&response_body)?;
let mut auth_headers = HeaderMap::new();
auth_headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {}", login_response.oauth_token.access_token))?,
);
let account_number = match self.config.execution_environment {
ExecutionEnvironment::Demo => self.config.account_number_demo.clone(),
ExecutionEnvironment::Live => self.config.account_number_live.clone(),
};
auth_headers.insert(
"IG-ACCOUNT-ID",
HeaderValue::from_str(&account_number)?,
);
self.auth_headers = Some(auth_headers);
Ok(response_body)
}
_ => Err(Box::new(ApiError {
message: format!("Login failed with status code: {}", response.status()),
})),
}
}
pub async fn post(
&self,
method: String,
version: Option<usize>,
params: Option<HashMap<String, String>>,
) -> Result<(HeaderMap, Value), Box<dyn Error>> {
let version = version.unwrap_or(1).to_string();
let body = params_to_json(params);
let response = self
.client
.post(&format!("{}/{}", &self.base_url, method))
.headers(self.auth_headers.clone().unwrap_or(HeaderMap::new()))
.headers(self.common_headers.clone())
.header("Version", version.clone())
.send()
.await?;
match response.status() {
StatusCode::OK => Ok((response.headers().clone(), response.json().await?)),
_ => Err(Box::new(ApiError {
message: format!(
"POST operation using method '{}', version '{}' and body '{:?}' failed with status code: {}",
method,
version,
body,
response.status()
),
})),
}
}
pub async fn put(
&self,
method: String,
body: impl Serialize,
version: Option<usize>,
) -> Result<(HeaderMap, Value), Box<dyn Error>> {
let version = version.unwrap_or(1).to_string();
let response = self
.client
.put(&format!("{}/{}", &self.base_url, method))
.json(&body)
.headers(self.auth_headers.clone().unwrap_or(HeaderMap::new()))
.headers(self.common_headers.clone())
.header("Version", version.clone())
.send()
.await?;
match response.status() {
StatusCode::OK => Ok((response.headers().clone(), response.json().await?)),
_ => Err(Box::new(ApiError {
message: format!(
"PUT operation using method '{}', version '{}' and body '{:?}' failed with status code: {}",
method,
version,
serde_json::to_string(&body)?,
response.status()
),
})),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::{ApiConfig, ExecutionEnvironment};
#[tokio::test]
async fn new_rest_client_works() {
let config = ApiConfig {
account_number_demo: "test_account_number_demo".to_string(),
account_number_live: "test_account_number_live".to_string(),
api_key: "test_api_key".to_string(),
auto_login: Some(false),
execution_environment: ExecutionEnvironment::Demo,
base_url_demo: "https://demo.example.com".to_string(),
base_url_live: "https://live.example.com".to_string(),
session_version: Some(2),
password: "test_password".to_string(),
username: "test_username".to_string(),
};
let rest_client = RestClient::new(config).await.unwrap();
assert_eq!(rest_client.auth_headers, None);
assert_eq!(rest_client.auto_login, false);
assert_eq!(rest_client.base_url, "https://demo.example.com");
assert_eq!(
rest_client.common_headers.get("X-IG-API-KEY").unwrap(),
"test_api_key"
);
assert_eq!(rest_client.config.account_number_demo, "test_account_number_demo");
assert_eq!(rest_client.config.account_number_live, "test_account_number_live");
assert_eq!(rest_client.config.api_key, "test_api_key");
assert_eq!(rest_client.config.auto_login, Some(false));
assert_eq!(
rest_client.config.execution_environment,
ExecutionEnvironment::Demo
);
assert_eq!(rest_client.config.base_url_demo, "https://demo.example.com");
assert_eq!(rest_client.config.base_url_live, "https://live.example.com");
assert_eq!(rest_client.config.session_version, Some(2));
assert_eq!(rest_client.config.password, "test_password");
assert_eq!(rest_client.config.username, "test_username");
assert_eq!(rest_client.session_version, 2);
}
}