mod config;
mod http;
pub(crate) mod rate_limiter;
pub(crate) mod rest;
pub mod stream;
pub use config::{ClientBuilder, Credentials};
pub use http::HttpTransport;
pub use rest::{LiquidateBuilder, OrderBuilder, OrderUpdate, RiskBuilder, SymbolSearchParams};
use std::sync::atomic::{AtomicBool, Ordering};
use bytes::Bytes;
use hyper::Method;
use hyper::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
use serde::de::DeserializeOwned;
use crate::error::{Error, Result, parse_api_error};
use http::HttpClient;
use rate_limiter::RateLimiter;
#[derive(Clone)]
pub(crate) struct RequestHelper<H: HttpTransport> {
pub(crate) http: H,
pub(crate) base_url: String,
pub(crate) auth_headers: HeaderMap,
pub(crate) rate_limiter: Option<RateLimiter>,
}
impl<H: HttpTransport> RequestHelper<H> {
async fn send<T: DeserializeOwned>(
&self,
method: Method,
path: &str,
body: Option<Bytes>,
) -> Result<T> {
if let Some(limiter) = &self.rate_limiter {
limiter.wait().await;
}
let uri = format!("{}{path}", self.base_url).parse()?;
let (status, resp_body) = if body.is_some() {
let mut headers = self.auth_headers.clone();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
self.http.send(method, uri, body, &headers).await?
} else {
self.http
.send(method, uri, body, &self.auth_headers)
.await?
};
if !status.is_success() {
return Err(Error::Api {
status: status.as_u16(),
message: parse_api_error(&resp_body),
});
}
Ok(serde_json::from_slice(&resp_body)?)
}
pub(crate) async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
self.send(Method::GET, path, None).await
}
pub(crate) async fn post<B: serde::Serialize, T: DeserializeOwned>(
&self,
path: &str,
body: &B,
) -> Result<T> {
let bytes = Bytes::from(serde_json::to_vec(body)?);
self.send(Method::POST, path, Some(bytes)).await
}
pub(crate) async fn put<B: serde::Serialize, T: DeserializeOwned>(
&self,
path: &str,
body: &B,
) -> Result<T> {
let bytes = Bytes::from(serde_json::to_vec(body)?);
self.send(Method::PUT, path, Some(bytes)).await
}
pub(crate) async fn delete<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
self.send(Method::DELETE, path, None).await
}
pub(crate) async fn delete_with_body<B: serde::Serialize, T: DeserializeOwned>(
&self,
path: &str,
body: &B,
) -> Result<T> {
let bytes = Bytes::from(serde_json::to_vec(body)?);
self.send(Method::DELETE, path, Some(bytes)).await
}
}
pub struct Client<H: HttpTransport = HttpClient> {
pub(crate) request: RequestHelper<H>,
pub(crate) is_logged_out: AtomicBool,
}
impl<H: HttpTransport> std::fmt::Debug for Client<H> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("base_url", &self.request.base_url)
.field("auth_headers", &"[redacted]")
.field("is_logged_out", &self.is_logged_out)
.finish()
}
}
impl Client {
#[must_use]
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}
}
impl<H: HttpTransport> Client<H> {
pub(crate) async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
self.request.get(path).await
}
pub(crate) async fn delete<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
self.request.delete(path).await
}
pub(crate) async fn symbol_query<T: DeserializeOwned>(
&self,
path: &str,
symbols: &[&str],
) -> Result<T> {
if symbols.is_empty() {
return Err(Error::Other("symbols must not be empty".into()));
}
if symbols.len() > 10 {
return Err(Error::Other("symbols is limited to 10".into()));
}
let joined = symbols.join(",");
let encoded = urlencoding::encode(&joined);
self.get(&format!("{path}?symbols={encoded}")).await
}
pub(crate) async fn post<B: serde::Serialize, T: DeserializeOwned>(
&self,
path: &str,
body: &B,
) -> Result<T> {
self.request.post(path, body).await
}
pub(crate) async fn put<B: serde::Serialize, T: DeserializeOwned>(
&self,
path: &str,
body: &B,
) -> Result<T> {
self.request.put(path, body).await
}
pub(crate) async fn delete_with_body<B: serde::Serialize, T: DeserializeOwned>(
&self,
path: &str,
body: &B,
) -> Result<T> {
self.request.delete_with_body(path, body).await
}
}
impl<H: HttpTransport> Drop for Client<H> {
fn drop(&mut self) {
if self.is_logged_out.load(Ordering::Acquire) {
return;
}
let Ok(handle) = tokio::runtime::Handle::try_current() else {
return;
};
let req = self.request.clone();
handle.spawn(async move {
if let Err(e) = rest::auth::logout(&req).await {
tracing::warn!("logout on drop failed: {e}");
}
});
}
}
#[cfg(test)]
pub(crate) mod test_support {
use std::sync::atomic::AtomicBool;
use hyper::header::{AUTHORIZATION, HeaderMap, HeaderValue};
use super::http::mock::MockHttp;
use super::{Client, RequestHelper};
pub fn test_client(mock: MockHttp) -> Client<MockHttp> {
Client {
request: RequestHelper {
base_url: "http://test".into(),
auth_headers: HeaderMap::new(),
http: mock,
rate_limiter: None,
},
is_logged_out: AtomicBool::new(false),
}
}
pub fn test_client_with_auth(mock: MockHttp) -> Client<MockHttp> {
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer tok_test"));
Client {
request: RequestHelper {
base_url: "http://test".into(),
auth_headers: headers,
http: mock,
rate_limiter: None,
},
is_logged_out: AtomicBool::new(false),
}
}
const _: () = {
fn _assert_send_sync<T: Send + Sync>() {}
fn _check() {
_assert_send_sync::<super::Client>();
}
};
}
#[cfg(test)]
mod tests {
use super::http::mock::{MockHttp, MockResponse};
use super::rate_limiter::RateLimiter;
use super::test_support::test_client_with_auth;
use super::*;
#[test]
fn client_debug_redacts_auth() {
let mock = MockHttp::new(vec![]);
let client = test_client_with_auth(mock);
let debug = format!("{client:?}");
assert!(debug.contains("[redacted]"));
assert!(debug.contains("http://test"));
assert!(!debug.contains("tok_test"));
}
#[tokio::test]
async fn symbol_query_rejects_more_than_10() {
let mock = MockHttp::new(vec![]);
let client = test_client_with_auth(mock);
let syms: Vec<&str> = (0..11).map(|_| "SYM").collect();
let err = client
.symbol_query::<serde_json::Value>("/test", &syms)
.await
.unwrap_err();
assert!(matches!(err, Error::Other(msg) if msg.contains("10")));
}
#[tokio::test]
async fn post_sends_json_body() {
let mock = MockHttp::new(vec![MockResponse::ok(r#"{"status":"OK"}"#)]);
let client = test_client_with_auth(mock);
let body = serde_json::json!({"key": "val"});
let _: crate::types::SuccessResponse = client.post("/test", &body).await.unwrap();
let reqs = client.request.http.recorded_requests();
assert_eq!(reqs[0].method, hyper::Method::POST);
assert!(!reqs[0].body.is_empty());
let ct = reqs[0]
.headers
.get(hyper::header::CONTENT_TYPE)
.unwrap()
.to_str()
.unwrap();
assert_eq!(ct, "application/json");
}
#[tokio::test]
async fn put_sends_json_body() {
let mock = MockHttp::new(vec![MockResponse::ok(r#"{"status":"OK"}"#)]);
let client = test_client_with_auth(mock);
let body = serde_json::json!({"key": "val"});
let _: crate::types::SuccessResponse = client.put("/test", &body).await.unwrap();
let reqs = client.request.http.recorded_requests();
assert_eq!(reqs[0].method, hyper::Method::PUT);
}
#[tokio::test]
async fn delete_with_body_sends_json() {
let mock = MockHttp::new(vec![MockResponse::ok(r#"{"status":"OK"}"#)]);
let client = test_client_with_auth(mock);
let body = serde_json::json!({"id": 1});
let _: crate::types::SuccessResponse =
client.delete_with_body("/test", &body).await.unwrap();
let reqs = client.request.http.recorded_requests();
assert_eq!(reqs[0].method, hyper::Method::DELETE);
assert!(!reqs[0].body.is_empty());
}
#[tokio::test]
async fn drop_skips_logout_when_already_logged_out() {
let mock = MockHttp::new(vec![]);
let client = test_client_with_auth(mock.clone());
client
.is_logged_out
.store(true, std::sync::atomic::Ordering::Release);
drop(client);
assert!(mock.recorded_requests().is_empty());
}
#[tokio::test]
async fn request_helper_applies_rate_limiter() {
let mock = MockHttp::new(vec![
MockResponse::ok(r#"{"status":"OK"}"#),
MockResponse::ok(r#"{"status":"OK"}"#),
]);
let mut headers = hyper::header::HeaderMap::new();
headers.insert(
hyper::header::AUTHORIZATION,
hyper::header::HeaderValue::from_static("Bearer tok"),
);
let helper = RequestHelper {
http: mock,
base_url: "http://test".into(),
auth_headers: headers,
rate_limiter: Some(RateLimiter::new(10)),
};
let _: crate::types::SuccessResponse = helper.get("/test1").await.unwrap();
let _: crate::types::SuccessResponse = helper.get("/test2").await.unwrap();
let reqs = helper.http.recorded_requests();
assert_eq!(reqs.len(), 2);
}
}