use std::sync::Arc;
use std::time::Duration;
use tracing::warn;
use crate::error::{Error, Result};
use crate::jar::NameKeyedJar;
pub const DEFAULT_BASE_URL: &str = "https://localhost:5000/v1/api";
#[derive(Clone)]
pub struct Client {
inner: Arc<ClientInner>,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("base_url", &self.inner.base_url.as_str())
.field("gateway_root", &self.inner.gateway_root.as_str())
.finish_non_exhaustive()
}
}
struct ClientInner {
api: bezant_api::IbRestApiClient,
http: reqwest::Client,
base_url: url::Url,
gateway_root: url::Url,
cookie_jar: Arc<NameKeyedJar>,
}
impl Client {
pub fn new(base_url: impl AsRef<str>) -> Result<Self> {
ClientBuilder::new(base_url).build()
}
pub fn builder(base_url: impl AsRef<str>) -> ClientBuilder {
ClientBuilder::new(base_url)
}
#[must_use]
pub fn api(&self) -> &bezant_api::IbRestApiClient {
&self.inner.api
}
#[must_use]
pub fn http(&self) -> &reqwest::Client {
&self.inner.http
}
#[must_use]
pub fn base_url(&self) -> &url::Url {
&self.inner.base_url
}
#[must_use]
pub fn gateway_root_url(&self) -> &url::Url {
&self.inner.gateway_root
}
#[must_use]
pub fn cookie_jar(&self) -> Arc<NameKeyedJar> {
Arc::clone(&self.inner.cookie_jar)
}
}
#[must_use]
#[derive(Debug, Clone)]
pub struct ClientBuilder {
base_url: String,
accept_invalid_certs: bool,
timeout: Duration,
user_agent: String,
follow_redirects: bool,
http1_only: bool,
}
impl Default for ClientBuilder {
fn default() -> Self {
Self::new(DEFAULT_BASE_URL)
}
}
impl ClientBuilder {
pub fn new(base_url: impl AsRef<str>) -> Self {
Self {
base_url: base_url.as_ref().to_owned(),
accept_invalid_certs: true,
timeout: Duration::from_secs(30),
user_agent: format!("bezant/{}", env!("CARGO_PKG_VERSION")),
follow_redirects: true,
http1_only: true,
}
}
pub fn accept_invalid_certs(mut self, accept: bool) -> Self {
self.accept_invalid_certs = accept;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn user_agent(mut self, ua: impl Into<String>) -> Self {
self.user_agent = ua.into();
self
}
pub fn follow_redirects(mut self, follow: bool) -> Self {
self.follow_redirects = follow;
self
}
pub fn http1_only(mut self, http1_only: bool) -> Self {
self.http1_only = http1_only;
self
}
pub fn build(self) -> Result<Client> {
let redirect_policy = if self.follow_redirects {
reqwest::redirect::Policy::default()
} else {
reqwest::redirect::Policy::none()
};
let cookie_jar = Arc::new(NameKeyedJar::new());
let mut http_builder = reqwest::Client::builder()
.cookie_provider(Arc::clone(&cookie_jar))
.danger_accept_invalid_certs(self.accept_invalid_certs)
.timeout(self.timeout)
.connect_timeout(Duration::from_secs(5))
.pool_max_idle_per_host(32)
.pool_idle_timeout(Duration::from_secs(90))
.tcp_keepalive(Duration::from_secs(30))
.user_agent(&self.user_agent)
.redirect(redirect_policy);
if self.http1_only {
http_builder = http_builder.http1_only();
}
let http = http_builder.build().map_err(Error::Http)?;
if self.accept_invalid_certs {
warn!(
"bezant: accepting invalid TLS certs (Gateway default self-signed cert). \
Set ClientBuilder::accept_invalid_certs(false) once you install a trusted cert."
);
}
let api = bezant_api::IbRestApiClient::with_client(&self.base_url, http.clone())
.map_err(|e| Error::InvalidBaseUrl(e.to_string()))?;
let base_url: url::Url = self
.base_url
.parse()
.map_err(|e: url::ParseError| Error::InvalidBaseUrl(e.to_string()))?;
let gateway_root = derive_gateway_root(&base_url);
Ok(Client {
inner: Arc::new(ClientInner {
api,
http,
base_url,
gateway_root,
cookie_jar,
}),
})
}
}
fn derive_gateway_root(base_url: &url::Url) -> url::Url {
let mut root = base_url.clone();
if root.path().ends_with('/') {
let trimmed = root.path().trim_end_matches('/').to_owned();
root.set_path(&trimmed);
}
if root.path().ends_with("/v1/api") {
let new_path = root.path().strip_suffix("/v1/api").unwrap_or("").to_owned();
root.set_path(&new_path);
}
if !root.path().ends_with('/') {
let with_slash = format!("{}/", root.path());
root.set_path(&with_slash);
}
root.set_query(None);
root.set_fragment(None);
root
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gateway_root_strips_v1_api() {
let base: url::Url = "https://localhost:5000/v1/api".parse().unwrap();
assert_eq!(
derive_gateway_root(&base).as_str(),
"https://localhost:5000/"
);
}
#[test]
fn gateway_root_strips_trailing_slash() {
let base: url::Url = "https://localhost:5000/v1/api/".parse().unwrap();
assert_eq!(
derive_gateway_root(&base).as_str(),
"https://localhost:5000/"
);
}
#[test]
fn gateway_root_preserves_custom_prefix() {
let base: url::Url = "https://gw.example.com/ibkr/v1/api".parse().unwrap();
assert_eq!(
derive_gateway_root(&base).as_str(),
"https://gw.example.com/ibkr/"
);
}
}