use async_trait::async_trait;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::core::types::{
AccountType, ExchangeError, ExchangeId, ExchangeResult, Kline, OrderBook, Price,
Ticker, SymbolInput,
};
use crate::core::traits::{ExchangeIdentity, MarketData};
use super::auth::IBAuth;
use super::endpoints::{IBEndpoint, IBEndpoints};
use super::parser::{IBAccountSummary, IBParser, IBPosition};
pub struct IBConnector {
client: Client,
auth: IBAuth,
endpoints: IBEndpoints,
testnet: bool,
symbol_cache: Arc<RwLock<HashMap<String, i64>>>,
}
impl IBConnector {
pub async fn from_gateway(
base_url: impl Into<String>,
account_id: impl Into<String>,
) -> ExchangeResult<Self> {
let base_url = base_url.into();
let auth = IBAuth::new(account_id);
let client = Client::builder()
.danger_accept_invalid_certs(true) .timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| ExchangeError::Network(format!("Failed to create HTTP client: {}", e)))?;
let connector = Self {
client,
auth,
endpoints: IBEndpoints::custom(base_url, None::<String>),
testnet: false,
symbol_cache: Arc::new(RwLock::new(HashMap::new())),
};
connector.check_auth().await?;
Ok(connector)
}
pub async fn paper(
account_id: impl Into<String>,
base_url: Option<impl Into<String>>,
) -> ExchangeResult<Self> {
let url = base_url
.map(|u| u.into())
.unwrap_or_else(|| "https://localhost:4004/v1/api".to_string());
let auth = IBAuth::new(account_id);
let client = Client::builder()
.danger_accept_invalid_certs(true) .timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| ExchangeError::Network(format!("Failed to create HTTP client: {}", e)))?;
let connector = Self {
client,
auth,
endpoints: IBEndpoints::custom(url, None::<String>),
testnet: true,
symbol_cache: Arc::new(RwLock::new(HashMap::new())),
};
connector.check_auth().await?;
Ok(connector)
}
pub fn with_testnet(mut self, testnet: bool) -> Self {
self.testnet = testnet;
self
}
pub async fn from_oauth(_account_id: impl Into<String>) -> ExchangeResult<Self> {
Err(ExchangeError::UnsupportedOperation(
"OAuth 2.0 authentication not yet implemented. Use from_gateway() instead.".to_string(),
))
}
async fn check_auth(&self) -> ExchangeResult<()> {
let url = format!("{}{}", self.endpoints.rest_base, IBEndpoint::AuthStatus.path());
let response = self
.client
.get(&url)
.send()
.await
.map_err(|e| ExchangeError::Network(format!("Auth check failed: {}", e)))?;
if !response.status().is_success() {
return Err(ExchangeError::Auth(format!(
"Authentication check failed: HTTP {}",
response.status()
)));
}
let auth_status: serde_json::Value = response
.json()
.await
.map_err(|e| ExchangeError::Parse(format!("Failed to parse auth status: {}", e)))?;
let authenticated = auth_status
.get("authenticated")
.and_then(|v| v.as_bool())
.unwrap_or(false);
if !authenticated {
return Err(ExchangeError::Auth(
"Not authenticated. Please login via browser to Gateway.".to_string(),
));
}
Ok(())
}
async fn resolve_symbol(&self, symbol: &str) -> ExchangeResult<i64> {
let symbol_key = symbol.to_string();
{
let cache = self.symbol_cache.read().await;
if let Some(&conid) = cache.get(&symbol_key) {
return Ok(conid);
}
}
let contracts = self.search_contract(symbol, "STK").await?;
if contracts.is_empty() {
return Err(ExchangeError::NotFound(format!(
"Symbol {} not found",
symbol_key
)));
}
let conid = contracts[0].0;
{
let mut cache = self.symbol_cache.write().await;
cache.insert(symbol_key, conid);
}
Ok(conid)
}
async fn search_contract(
&self,
symbol: &str,
sec_type: &str,
) -> ExchangeResult<Vec<(i64, String, String)>> {
let url = format!("{}{}", self.endpoints.rest_base, IBEndpoint::ContractSearch.path());
let body = serde_json::json!({
"symbol": symbol,
"name": false,
"secType": sec_type
});
let response = self
.client
.post(&url)
.json(&body)
.send()
.await
.map_err(|e| ExchangeError::Network(format!("Contract search failed: {}", e)))?;
if !response.status().is_success() {
return Err(ExchangeError::Http(format!(
"Contract search failed: HTTP {}",
response.status()
)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| ExchangeError::Parse(format!("Failed to parse search results: {}", e)))?;
IBParser::parse_contract_search(&json)
}
async fn get_market_data_snapshot(
&self,
conid: i64,
fields: &[&str],
) -> ExchangeResult<serde_json::Value> {
let url = format!(
"{}{}",
self.endpoints.rest_base,
IBEndpoint::MarketDataSnapshot.path()
);
let mut params = HashMap::new();
params.insert("conids".to_string(), conid.to_string());
params.insert("fields".to_string(), fields.join(","));
let response = self
.client
.get(&url)
.query(¶ms)
.send()
.await
.map_err(|e| ExchangeError::Network(format!("Market data request failed: {}", e)))?;
if !response.status().is_success() {
return Err(ExchangeError::Http(format!(
"Market data request failed: HTTP {}",
response.status()
)));
}
response
.json()
.await
.map_err(|e| ExchangeError::Parse(format!("Failed to parse market data: {}", e)))
}
async fn get_historical_data(
&self,
conid: i64,
period: &str,
bar_size: &str,
) -> ExchangeResult<serde_json::Value> {
let url = format!(
"{}{}",
self.endpoints.rest_base,
IBEndpoint::MarketDataHistory.path()
);
let mut params = HashMap::new();
params.insert("conid".to_string(), conid.to_string());
params.insert("period".to_string(), period.to_string());
params.insert("bar".to_string(), bar_size.to_string());
let response = self
.client
.get(&url)
.query(¶ms)
.send()
.await
.map_err(|e| ExchangeError::Network(format!("Historical data request failed: {}", e)))?;
if !response.status().is_success() {
return Err(ExchangeError::Http(format!(
"Historical data request failed: HTTP {}",
response.status()
)));
}
response
.json()
.await
.map_err(|e| ExchangeError::Parse(format!("Failed to parse historical data: {}", e)))
}
pub async fn get_positions(&self) -> ExchangeResult<Vec<IBPosition>> {
let url = format!(
"{}{}",
self.endpoints.rest_base,
IBEndpoint::PortfolioPositions {
account_id: self.auth.account_id().to_string(),
page: 0
}
.path()
);
let response = self
.client
.get(&url)
.send()
.await
.map_err(|e| ExchangeError::Network(format!("Positions request failed: {}", e)))?;
if !response.status().is_success() {
return Err(ExchangeError::Http(format!(
"Positions request failed: HTTP {}",
response.status()
)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| ExchangeError::Parse(format!("Failed to parse positions: {}", e)))?;
IBParser::parse_positions(&json)
}
pub async fn get_account_summary(&self) -> ExchangeResult<IBAccountSummary> {
let url = format!(
"{}{}",
self.endpoints.rest_base,
IBEndpoint::PortfolioSummary {
account_id: self.auth.account_id().to_string()
}
.path()
);
let response = self
.client
.get(&url)
.send()
.await
.map_err(|e| ExchangeError::Network(format!("Account summary request failed: {}", e)))?;
if !response.status().is_success() {
return Err(ExchangeError::Http(format!(
"Account summary request failed: HTTP {}",
response.status()
)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| ExchangeError::Parse(format!("Failed to parse account summary: {}", e)))?;
IBParser::parse_account_summary(&json)
}
}
impl ExchangeIdentity for IBConnector {
fn exchange_name(&self) -> &'static str {
"Interactive Brokers"
}
fn exchange_id(&self) -> ExchangeId {
ExchangeId::Ib
}
fn is_testnet(&self) -> bool {
self.testnet
}
fn supported_account_types(&self) -> Vec<AccountType> {
vec![AccountType::Spot] }
}
#[async_trait]
impl MarketData for IBConnector {
async fn get_price(&self, symbol: SymbolInput<'_>, _account_type: AccountType) -> ExchangeResult<Price> {
let sym_str: String = match symbol { SymbolInput::Raw(s) => s.to_string(), SymbolInput::Canonical(c) => c.to_concat() };
let conid = self.resolve_symbol(&sym_str).await?;
let snapshot = self.get_market_data_snapshot(conid, &["31"]).await?;
IBParser::parse_price(&snapshot)
}
async fn get_orderbook(
&self,
_symbol: SymbolInput<'_>,
_depth: Option<u16>,
_account_type: AccountType,
) -> ExchangeResult<OrderBook> {
Err(ExchangeError::UnsupportedOperation(
"IB does not provide orderbook via REST API. Use TWS API for Level 2 data.".to_string(),
))
}
async fn get_klines(
&self,
symbol: SymbolInput<'_>,
interval: &str,
limit: Option<u16>,
_account_type: AccountType,
_end_time: Option<i64>,
) -> ExchangeResult<Vec<Kline>> {
let sym_str: String = match symbol { SymbolInput::Raw(s) => s.to_string(), SymbolInput::Canonical(c) => c.to_concat() };
let conid = self.resolve_symbol(&sym_str).await?;
let (period, bar_size) = self.map_interval(interval, limit)?;
let historical = self.get_historical_data(conid, &period, &bar_size).await?;
IBParser::parse_klines(&historical)
}
async fn get_ticker(
&self,
symbol: SymbolInput<'_>,
_account_type: AccountType,
) -> ExchangeResult<Ticker> {
let sym_str: String = match symbol { SymbolInput::Raw(s) => s.to_string(), SymbolInput::Canonical(c) => c.to_concat() };
let conid = self.resolve_symbol(&sym_str).await?;
let fields = &["31", "84", "86", "70", "71", "87", "7219"];
let snapshot = self.get_market_data_snapshot(conid, fields).await?;
IBParser::parse_ticker(&snapshot, &sym_str)
}
async fn ping(&self) -> ExchangeResult<()> {
let url = format!("{}{}", self.endpoints.rest_base, IBEndpoint::Tickle.path());
let response = self
.client
.get(&url)
.send()
.await
.map_err(|e| ExchangeError::Network(format!("Ping failed: {}", e)))?;
if response.status().is_success() {
Ok(())
} else {
Err(ExchangeError::Network(format!(
"Ping failed: HTTP {}",
response.status()
)))
}
}
}
impl IBConnector {
fn map_interval(&self, interval: &str, limit: Option<u16>) -> ExchangeResult<(String, String)> {
let limit = limit.unwrap_or(100);
let (bar_size, bar_duration_mins) = match interval {
"1m" => ("1min", 1),
"5m" => ("5min", 5),
"15m" => ("15min", 15),
"30m" => ("30min", 30),
"1h" => ("1h", 60),
"2h" => ("2h", 120),
"4h" => ("4h", 240),
"1d" => ("1d", 1440),
"1w" => ("1w", 10080),
_ => return Err(ExchangeError::InvalidRequest(format!("Unsupported interval: {}", interval))),
};
let total_mins = limit as u64 * bar_duration_mins;
let period = if total_mins < 1440 {
format!("{}d", 1)
} else if total_mins < 10080 {
let days = (total_mins / 1440).max(1);
format!("{}d", days)
} else if total_mins < 43200 {
let weeks = (total_mins / 10080).max(1);
format!("{}w", weeks)
} else {
let months = (total_mins / 43200).clamp(1, 12);
format!("{}m", months)
};
Ok((period, bar_size.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_map_interval() {
let connector = create_test_connector();
let (_period, bar) = connector.map_interval("1m", Some(100)).unwrap();
assert_eq!(bar, "1min");
let (_period, bar) = connector.map_interval("1h", Some(24)).unwrap();
assert_eq!(bar, "1h");
let (_period, bar) = connector.map_interval("1d", Some(30)).unwrap();
assert_eq!(bar, "1d");
}
fn create_test_connector() -> IBConnector {
IBConnector {
client: Client::new(),
auth: IBAuth::new("TEST"),
endpoints: IBEndpoints::default(),
testnet: false,
symbol_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
}