use std::{borrow::Cow, collections::HashMap, str::FromStr, sync::Arc, time::Duration};
use nautilus_core::collections::into_ustr_vec;
use nautilus_cryptography::providers::install_cryptographic_provider;
use reqwest::{
Method, Response, Url,
header::{HeaderMap, HeaderName, HeaderValue},
};
use ustr::Ustr;
use super::{HttpClientError, HttpResponse, HttpStatus};
use crate::ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota};
const DEFAULT_POOL_MAX_IDLE_PER_HOST: usize = 32;
const DEFAULT_POOL_IDLE_TIMEOUT_SECS: u64 = 60;
const DEFAULT_HTTP2_KEEP_ALIVE_SECS: u64 = 30;
#[derive(Clone, Debug)]
#[cfg_attr(
feature = "python",
pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network", from_py_object)
)]
#[cfg_attr(
feature = "python",
pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
)]
pub struct HttpClient {
pub(crate) client: InnerHttpClient,
pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
}
impl HttpClient {
pub fn new(
headers: HashMap<String, String>,
header_keys: Vec<String>,
keyed_quotas: Vec<(String, Quota)>,
default_quota: Option<Quota>,
timeout_secs: Option<u64>,
proxy_url: Option<String>,
) -> Result<Self, HttpClientError> {
install_cryptographic_provider();
let mut header_map = HeaderMap::new();
for (key, value) in headers {
let header_name = HeaderName::from_str(&key)
.map_err(|e| HttpClientError::Error(format!("Invalid header name '{key}': {e}")))?;
let header_value = HeaderValue::from_str(&value).map_err(|e| {
HttpClientError::Error(format!("Invalid header value '{value}': {e}"))
})?;
header_map.insert(header_name, header_value);
}
let mut client_builder = reqwest::Client::builder()
.default_headers(header_map)
.tcp_nodelay(true)
.pool_max_idle_per_host(DEFAULT_POOL_MAX_IDLE_PER_HOST)
.pool_idle_timeout(Duration::from_secs(DEFAULT_POOL_IDLE_TIMEOUT_SECS))
.http2_keep_alive_interval(Duration::from_secs(DEFAULT_HTTP2_KEEP_ALIVE_SECS))
.http2_keep_alive_while_idle(true)
.http2_adaptive_window(true);
if let Some(timeout_secs) = timeout_secs {
client_builder = client_builder.timeout(Duration::from_secs(timeout_secs));
}
if let Some(proxy_url) = proxy_url {
let proxy = reqwest::Proxy::all(&proxy_url)
.map_err(|e| HttpClientError::InvalidProxy(format!("{proxy_url}: {e}")))?;
client_builder = client_builder.proxy(proxy);
}
let client = client_builder
.build()
.map_err(|e| HttpClientError::ClientBuildError(e.to_string()))?;
let (valid_keys, header_names): (Vec<String>, Vec<HeaderName>) = header_keys
.into_iter()
.filter_map(|k| HeaderName::from_str(&k).ok().map(|name| (k, name)))
.unzip();
let client = InnerHttpClient {
client,
header_keys: Arc::new(valid_keys),
header_names: Arc::new(header_names),
};
let keyed_quotas = keyed_quotas
.into_iter()
.map(|(key, quota)| (Ustr::from(&key), quota))
.collect();
let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
Ok(Self {
client,
rate_limiter,
})
}
#[allow(clippy::too_many_arguments)]
pub async fn request(
&self,
method: Method,
url: String,
params: Option<&HashMap<String, Vec<String>>>,
headers: Option<HashMap<String, String>>,
body: Option<Vec<u8>>,
timeout_secs: Option<u64>,
keys: Option<Vec<String>>,
) -> Result<HttpResponse, HttpClientError> {
let keys = keys.map(into_ustr_vec);
self.request_with_ustr_keys(method, url, params, headers, body, timeout_secs, keys)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn request_with_params<P: serde::Serialize>(
&self,
method: Method,
url: String,
params: Option<&P>,
headers: Option<HashMap<String, String>>,
body: Option<Vec<u8>>,
timeout_secs: Option<u64>,
keys: Option<Vec<String>>,
) -> Result<HttpResponse, HttpClientError> {
let keys = keys.map(into_ustr_vec);
let rate_limiter = self.rate_limiter.clone();
rate_limiter.await_keys_ready(keys.as_deref()).await;
self.client
.send_request_with_query(method, url, params, headers, body, timeout_secs)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn request_with_ustr_keys(
&self,
method: Method,
url: String,
params: Option<&HashMap<String, Vec<String>>>,
headers: Option<HashMap<String, String>>,
body: Option<Vec<u8>>,
timeout_secs: Option<u64>,
keys: Option<Vec<Ustr>>,
) -> Result<HttpResponse, HttpClientError> {
let rate_limiter = self.rate_limiter.clone();
rate_limiter.await_keys_ready(keys.as_deref()).await;
self.client
.send_request(method, url, params, headers, body, timeout_secs)
.await
}
pub async fn get(
&self,
url: String,
params: Option<&HashMap<String, Vec<String>>>,
headers: Option<HashMap<String, String>>,
timeout_secs: Option<u64>,
keys: Option<Vec<String>>,
) -> Result<HttpResponse, HttpClientError> {
self.request(Method::GET, url, params, headers, None, timeout_secs, keys)
.await
}
pub async fn post(
&self,
url: String,
params: Option<&HashMap<String, Vec<String>>>,
headers: Option<HashMap<String, String>>,
body: Option<Vec<u8>>,
timeout_secs: Option<u64>,
keys: Option<Vec<String>>,
) -> Result<HttpResponse, HttpClientError> {
self.request(Method::POST, url, params, headers, body, timeout_secs, keys)
.await
}
pub async fn patch(
&self,
url: String,
params: Option<&HashMap<String, Vec<String>>>,
headers: Option<HashMap<String, String>>,
body: Option<Vec<u8>>,
timeout_secs: Option<u64>,
keys: Option<Vec<String>>,
) -> Result<HttpResponse, HttpClientError> {
self.request(
Method::PATCH,
url,
params,
headers,
body,
timeout_secs,
keys,
)
.await
}
pub async fn delete(
&self,
url: String,
params: Option<&HashMap<String, Vec<String>>>,
headers: Option<HashMap<String, String>>,
timeout_secs: Option<u64>,
keys: Option<Vec<String>>,
) -> Result<HttpResponse, HttpClientError> {
self.request(
Method::DELETE,
url,
params,
headers,
None,
timeout_secs,
keys,
)
.await
}
}
#[derive(Clone, Debug)]
pub struct InnerHttpClient {
pub(crate) client: reqwest::Client,
pub(crate) header_keys: Arc<Vec<String>>,
pub(crate) header_names: Arc<Vec<HeaderName>>,
}
impl InnerHttpClient {
pub async fn send_request(
&self,
method: Method,
url: String,
params: Option<&HashMap<String, Vec<String>>>,
headers: Option<HashMap<String, String>>,
body: Option<Vec<u8>>,
timeout_secs: Option<u64>,
) -> Result<HttpResponse, HttpClientError> {
let full_url = encode_url_params(&url, params)?;
self.send_request_internal(
method,
full_url.as_ref(),
None::<&()>,
headers,
body,
timeout_secs,
)
.await
}
pub async fn send_request_with_query<Q: serde::Serialize>(
&self,
method: Method,
url: String,
query: Option<&Q>,
headers: Option<HashMap<String, String>>,
body: Option<Vec<u8>>,
timeout_secs: Option<u64>,
) -> Result<HttpResponse, HttpClientError> {
self.send_request_internal(method, &url, query, headers, body, timeout_secs)
.await
}
async fn send_request_internal<Q: serde::Serialize>(
&self,
method: Method,
url: &str,
query: Option<&Q>,
headers: Option<HashMap<String, String>>,
body: Option<Vec<u8>>,
timeout_secs: Option<u64>,
) -> Result<HttpResponse, HttpClientError> {
let reqwest_url =
Url::parse(url).map_err(|e| HttpClientError::from(format!("URL parse error: {e}")))?;
let mut request_builder = self.client.request(method, reqwest_url);
if let Some(headers) = headers {
let mut header_map = HeaderMap::with_capacity(headers.len());
for (header_key, header_value) in &headers {
let key = HeaderName::from_bytes(header_key.as_bytes())
.map_err(|e| HttpClientError::from(format!("Invalid header name: {e}")))?;
if let Some(old_value) = header_map.insert(
key.clone(),
header_value
.parse()
.map_err(|e| HttpClientError::from(format!("Invalid header value: {e}")))?,
) {
log::trace!("Replaced header '{key}': old={old_value:?}, new={header_value}");
}
}
request_builder = request_builder.headers(header_map);
}
if let Some(q) = query {
request_builder = request_builder.query(q);
}
if let Some(timeout_secs) = timeout_secs {
request_builder = request_builder.timeout(Duration::new(timeout_secs, 0));
}
let request = match body {
Some(b) => request_builder
.body(b)
.build()
.map_err(HttpClientError::from)?,
None => request_builder.build().map_err(HttpClientError::from)?,
};
log::trace!("{} {}", request.method(), request.url());
let response = self
.client
.execute(request)
.await
.map_err(HttpClientError::from)?;
self.to_response(response).await
}
pub async fn to_response(&self, response: Response) -> Result<HttpResponse, HttpClientError> {
log::trace!("{response:?}");
let resp_headers = response.headers();
let mut headers =
HashMap::with_capacity(std::cmp::min(self.header_names.len(), resp_headers.len()));
for (name, key_str) in self.header_names.iter().zip(self.header_keys.iter()) {
if let Some(val) = resp_headers.get(name)
&& let Ok(v) = val.to_str()
{
headers.insert(key_str.clone(), v.to_owned());
}
}
let status = HttpStatus::new(response.status());
let body = response.bytes().await.map_err(HttpClientError::from)?;
Ok(HttpResponse {
status,
headers,
body,
})
}
}
impl Default for InnerHttpClient {
fn default() -> Self {
install_cryptographic_provider();
let client = reqwest::Client::new();
Self {
client,
header_keys: Arc::default(),
header_names: Arc::default(),
}
}
}
fn encode_url_params<'a>(
url: &'a str,
params: Option<&HashMap<String, Vec<String>>>,
) -> Result<Cow<'a, str>, HttpClientError> {
let Some(params) = params else {
return Ok(Cow::Borrowed(url));
};
let pairs: Vec<(&str, &str)> = params
.iter()
.flat_map(|(key, values)| {
values
.iter()
.map(move |value| (key.as_str(), value.as_str()))
})
.collect();
if pairs.is_empty() {
return Ok(Cow::Borrowed(url));
}
let query_string = serde_urlencoded::to_string(pairs)
.map_err(|e| HttpClientError::Error(format!("Failed to encode params: {e}")))?;
let separator = if url.contains('?') { '&' } else { '?' };
Ok(Cow::Owned(format!("{url}{separator}{query_string}")))
}
#[cfg(test)]
#[cfg(target_os = "linux")] mod tests {
use std::net::SocketAddr;
use axum::{
Router,
routing::{delete, get, patch, post},
serve,
};
use http::status::StatusCode;
use rstest::rstest;
use super::*;
fn create_router() -> Router {
Router::new()
.route("/get", get(|| async { "hello-world!" }))
.route("/post", post(|| async { StatusCode::OK }))
.route("/patch", patch(|| async { StatusCode::OK }))
.route("/delete", delete(|| async { StatusCode::OK }))
.route("/notfound", get(|| async { StatusCode::NOT_FOUND }))
.route(
"/slow",
get(|| async {
tokio::time::sleep(Duration::from_secs(2)).await;
"Eventually responded"
}),
)
}
async fn start_test_server() -> Result<SocketAddr, Box<dyn std::error::Error + Send + Sync>> {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
serve(listener, create_router()).await.unwrap();
});
Ok(addr)
}
#[tokio::test]
async fn test_get() {
let addr = start_test_server().await.unwrap();
let url = format!("http://{addr}");
let client = InnerHttpClient::default();
let response = client
.send_request(
reqwest::Method::GET,
format!("{url}/get"),
None,
None,
None,
None,
)
.await
.unwrap();
assert!(response.status.is_success());
assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
}
#[tokio::test]
async fn test_post() {
let addr = start_test_server().await.unwrap();
let url = format!("http://{addr}");
let client = InnerHttpClient::default();
let response = client
.send_request(
reqwest::Method::POST,
format!("{url}/post"),
None,
None,
None,
None,
)
.await
.unwrap();
assert!(response.status.is_success());
}
#[tokio::test]
async fn test_post_with_body() {
let addr = start_test_server().await.unwrap();
let url = format!("http://{addr}");
let client = InnerHttpClient::default();
let mut body = HashMap::new();
body.insert(
"key1".to_string(),
serde_json::Value::String("value1".to_string()),
);
body.insert(
"key2".to_string(),
serde_json::Value::String("value2".to_string()),
);
let body_string = serde_json::to_string(&body).unwrap();
let body_bytes = body_string.into_bytes();
let response = client
.send_request(
reqwest::Method::POST,
format!("{url}/post"),
None,
None,
Some(body_bytes),
None,
)
.await
.unwrap();
assert!(response.status.is_success());
}
#[tokio::test]
async fn test_patch() {
let addr = start_test_server().await.unwrap();
let url = format!("http://{addr}");
let client = InnerHttpClient::default();
let response = client
.send_request(
reqwest::Method::PATCH,
format!("{url}/patch"),
None,
None,
None,
None,
)
.await
.unwrap();
assert!(response.status.is_success());
}
#[tokio::test]
async fn test_delete() {
let addr = start_test_server().await.unwrap();
let url = format!("http://{addr}");
let client = InnerHttpClient::default();
let response = client
.send_request(
reqwest::Method::DELETE,
format!("{url}/delete"),
None,
None,
None,
None,
)
.await
.unwrap();
assert!(response.status.is_success());
}
#[tokio::test]
async fn test_not_found() {
let addr = start_test_server().await.unwrap();
let url = format!("http://{addr}/notfound");
let client = InnerHttpClient::default();
let response = client
.send_request(reqwest::Method::GET, url, None, None, None, None)
.await
.unwrap();
assert!(response.status.is_client_error());
assert_eq!(response.status.as_u16(), 404);
}
#[tokio::test]
async fn test_timeout() {
let addr = start_test_server().await.unwrap();
let url = format!("http://{addr}/slow");
let client = InnerHttpClient::default();
let result = client
.send_request(reqwest::Method::GET, url, None, None, None, Some(1))
.await;
match result {
Err(HttpClientError::TimeoutError(msg)) => {
println!("Got expected timeout error: {msg}");
}
Err(e) => panic!("Expected a timeout error, was: {e:?}"),
Ok(resp) => panic!("Expected a timeout error, but was a successful response: {resp:?}"),
}
}
#[rstest]
fn test_http_client_without_proxy() {
let result = HttpClient::new(
HashMap::new(),
vec![],
vec![],
None,
None,
None, );
assert!(result.is_ok());
}
#[rstest]
fn test_http_client_with_valid_proxy() {
let result = HttpClient::new(
HashMap::new(),
vec![],
vec![],
None,
None,
Some("http://proxy.example.com:8080".to_string()),
);
assert!(result.is_ok());
}
#[rstest]
fn test_http_client_with_socks5_proxy() {
let result = HttpClient::new(
HashMap::new(),
vec![],
vec![],
None,
None,
Some("socks5://127.0.0.1:1080".to_string()),
);
assert!(result.is_ok());
}
#[rstest]
fn test_http_client_with_malformed_proxy() {
let result = HttpClient::new(
HashMap::new(),
vec![],
vec![],
None,
None,
Some("://invalid".to_string()),
);
assert!(result.is_err());
assert!(matches!(result, Err(HttpClientError::InvalidProxy(_))));
}
#[rstest]
fn test_http_client_with_empty_proxy_string() {
let result = HttpClient::new(
HashMap::new(),
vec![],
vec![],
None,
None,
Some(String::new()),
);
assert!(result.is_err());
assert!(matches!(result, Err(HttpClientError::InvalidProxy(_))));
}
#[tokio::test]
async fn test_http_client_get() {
let addr = start_test_server().await.unwrap();
let url = format!("http://{addr}/get");
let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
let response = client.get(url, None, None, None, None).await.unwrap();
assert!(response.status.is_success());
assert_eq!(String::from_utf8_lossy(&response.body), "hello-world!");
}
#[tokio::test]
async fn test_http_client_post() {
let addr = start_test_server().await.unwrap();
let url = format!("http://{addr}/post");
let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
let response = client
.post(url, None, None, None, None, None)
.await
.unwrap();
assert!(response.status.is_success());
}
#[tokio::test]
async fn test_http_client_patch() {
let addr = start_test_server().await.unwrap();
let url = format!("http://{addr}/patch");
let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
let response = client
.patch(url, None, None, None, None, None)
.await
.unwrap();
assert!(response.status.is_success());
}
#[tokio::test]
async fn test_http_client_delete() {
let addr = start_test_server().await.unwrap();
let url = format!("http://{addr}/delete");
let client = HttpClient::new(HashMap::new(), vec![], vec![], None, None, None).unwrap();
let response = client.delete(url, None, None, None, None).await.unwrap();
assert!(response.status.is_success());
}
}