use axum::body::Body;
use axum::extract::{Path, Query, State};
use axum::http::{header, HeaderMap, HeaderValue, Response, StatusCode};
use axum::response::IntoResponse;
use axum::routing::{delete, get};
use axum::{Json, Router};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::AppError;
use crate::state::AppState;
pub fn router(state: AppState) -> Router {
Router::new()
.route("/health", get(health))
.route("/debug/jar", get(debug_jar))
.route("/debug/probe", get(debug_probe))
.route("/accounts", get(accounts))
.route("/accounts/{account_id}/summary", get(account_summary))
.route("/accounts/{account_id}/positions", get(account_positions))
.route("/accounts/{account_id}/ledger", get(account_ledger))
.route(
"/accounts/{account_id}/orders",
get(account_orders).post(submit_order),
)
.route(
"/accounts/{account_id}/orders/{order_id}",
delete(cancel_order),
)
.route("/contracts/search", get(contract_search))
.route("/market/snapshot", get(market_snapshot))
.fallback(passthrough_any)
.with_state(state)
}
const HOP_BY_HOP: &[&str] = &[
"host",
"content-length",
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
"cookie",
"authorization",
"x-forwarded-for",
"x-forwarded-host",
"x-forwarded-proto",
"x-real-ip",
"forwarded",
];
fn is_hop_by_hop(name: &str) -> bool {
HOP_BY_HOP.iter().any(|h| name.eq_ignore_ascii_case(h))
}
#[tracing::instrument(skip_all, fields(method, path))]
async fn passthrough_any(
State(state): State<AppState>,
req: axum::extract::Request,
) -> Result<Response<Body>, AppError> {
use axum::http::Method;
let method = req.method().clone();
let path_and_query = req
.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/")
.to_string();
let path_only = req.uri().path().to_string();
tracing::Span::current().record("method", tracing::field::display(&method));
tracing::Span::current().record("path", tracing::field::display(&path_only));
let gateway_root = state.client().gateway_root_url().as_str();
let target = format!("{}{}", gateway_root.trim_end_matches('/'), path_and_query);
let target_url: reqwest::Url = target
.parse()
.map_err(|e| bezant::Error::BadRequest(format!("target url: {e}")))?;
let headers = req.headers().clone();
let jar = state.client().cookie_jar();
let mut pairs: Vec<&str> = Vec::new();
for cookie_header in headers.get_all(axum::http::header::COOKIE) {
if let Ok(raw) = cookie_header.to_str() {
for pair in raw.split(';') {
let trimmed = pair.trim();
if trimmed.is_empty() {
continue;
}
let name = trimmed.split('=').next().unwrap_or("").trim();
if is_edge_auth_cookie(name) {
continue;
}
pairs.push(trimmed);
}
}
}
let injected = pairs.len();
if injected > 0 {
jar.set_pairs(&pairs);
}
tracing::debug!(
path = %path_only,
cookies = injected,
"passthrough cookie replay"
);
let body_bytes = axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024)
.await
.map_err(|e| bezant::Error::BadRequest(format!("read body: {e}")))?;
let method_reqwest = reqwest::Method::from_bytes(method.as_str().as_bytes())
.map_err(|e| bezant::Error::BadRequest(format!("method: {e}")))?;
let rewrite_origin = path_only.starts_with("/v1/api/") || path_only == "/v1/api";
let gateway_origin = if rewrite_origin {
let scheme = target_url.scheme();
target_url.host_str().map(|h| match target_url.port() {
Some(p) => format!("{scheme}://{h}:{p}"),
None => format!("{scheme}://{h}"),
})
} else {
None
};
let mut builder = state.client().http().request(method_reqwest, &target);
for (name, value) in headers.iter() {
if is_hop_by_hop(name.as_str()) {
continue;
}
let lower = name.as_str().to_ascii_lowercase();
if let Some(ref origin) = gateway_origin {
if lower == "origin" {
if let Ok(v) = reqwest::header::HeaderValue::from_str(origin) {
builder = builder.header(reqwest::header::ORIGIN, v);
}
continue;
}
if lower == "referer" {
if let Ok(orig) = value.to_str() {
let rewritten = rewrite_referer_origin(orig, origin);
if let Ok(v) = reqwest::header::HeaderValue::from_str(&rewritten) {
builder = builder.header(reqwest::header::REFERER, v);
}
}
continue;
}
}
if let Ok(v) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) {
if let Ok(name) = reqwest::header::HeaderName::from_bytes(name.as_str().as_bytes()) {
builder = builder.header(name, v);
}
}
}
if method != Method::GET && method != Method::HEAD {
let len = body_bytes.len();
builder = builder
.header(reqwest::header::CONTENT_LENGTH, len.to_string())
.body(body_bytes.to_vec());
}
let resp = builder.send().await.map_err(bezant::Error::Http)?;
forward(resp).await
}
#[derive(Serialize)]
struct HealthBody {
authenticated: bool,
connected: bool,
competing: bool,
message: Option<String>,
}
async fn debug_jar(
State(state): State<AppState>,
headers: HeaderMap,
Query(q): Query<HashMap<String, String>>,
) -> Response<Body> {
if let Err(resp) = debug_auth(&state, &headers, &q) {
return resp;
}
let jar = state.client().cookie_jar();
let entries: Vec<serde_json::Value> = jar
.snapshot()
.into_iter()
.map(|(name, value)| serde_json::json!({
"name": name,
"value_length": value.len(),
}))
.collect();
let body = serde_json::json!({
"gateway_root": state.client().gateway_root_url().as_str(),
"size": entries.len(),
"entries": entries,
});
Json(body).into_response()
}
#[allow(clippy::result_large_err)]
fn debug_auth(
state: &AppState,
headers: &HeaderMap,
query: &HashMap<String, String>,
) -> Result<(), Response<Body>> {
let Some(expected) = state.debug_token() else {
return Err(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap_or_default());
};
let presented = headers
.get("x-bezant-debug-token")
.and_then(|v| v.to_str().ok())
.or_else(|| query.get("token").map(String::as_str))
.unwrap_or("");
if constant_time_eq(presented.as_bytes(), expected.as_bytes()) {
return Ok(());
}
Err(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(
r#"{"code":"debug_unauthorized","message":"missing or invalid debug token"}"#,
))
.unwrap_or_default())
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
async fn debug_probe(
State(state): State<AppState>,
headers: HeaderMap,
Query(q): Query<HashMap<String, String>>,
) -> Response<Body> {
if let Err(resp) = debug_auth(&state, &headers, &q) {
return resp;
}
let client = state.client();
let started = std::time::Instant::now();
let jar_before = client.cookie_jar().snapshot().len();
let auth_status = probe_step(
client,
"auth_status",
reqwest::Method::POST,
&["iserver", "auth", "status"],
None,
)
.await;
let already_bridged = is_authenticated(&auth_status);
let ssodh_init = if already_bridged {
skipped_step(
"ssodh_init",
"POST",
"/v1/api/iserver/auth/ssodh/init",
"session already bridged (auth_status authenticated)",
)
} else {
probe_step(
client,
"ssodh_init",
reqwest::Method::POST,
&["iserver", "auth", "ssodh", "init"],
Some(serde_json::json!({ "publish": true, "compete": true })),
)
.await
};
let tickle = probe_step(
client,
"tickle",
reqwest::Method::POST,
&["tickle"],
None,
)
.await;
let accounts = probe_step(
client,
"accounts",
reqwest::Method::GET,
&["portfolio", "accounts"],
None,
)
.await;
let verdict = compute_verdict(&auth_status, &ssodh_init, &tickle, &accounts);
let jar_after = client.cookie_jar().snapshot().len();
let body = serde_json::json!({
"gateway_root": client.gateway_root_url().as_str(),
"elapsed_ms": started.elapsed().as_millis() as u64,
"jar_size_before": jar_before,
"jar_size_after": jar_after,
"verdict": verdict,
"steps": [auth_status, ssodh_init, tickle, accounts],
});
Json(body).into_response()
}
async fn probe_step(
client: &bezant::Client,
name: &'static str,
method: reqwest::Method,
path_segments: &[&str],
body: Option<serde_json::Value>,
) -> serde_json::Value {
let mut url = client.base_url().clone();
if let Ok(mut segs) = url.path_segments_mut() {
for seg in path_segments {
segs.push(seg);
}
}
let path_for_log = url.path().to_owned();
let gateway_origin = client
.gateway_root_url()
.as_str()
.trim_end_matches('/')
.to_owned();
let mut builder = client
.http()
.request(method.clone(), url.clone())
.header(reqwest::header::ORIGIN, &gateway_origin)
.header(reqwest::header::REFERER, format!("{gateway_origin}/"));
let body_bytes: Vec<u8> = match (&method, body) {
(m, _) if m == reqwest::Method::GET || m == reqwest::Method::HEAD => Vec::new(),
(_, Some(json)) => serde_json::to_vec(&json).unwrap_or_default(),
(_, None) => Vec::new(),
};
if method != reqwest::Method::GET && method != reqwest::Method::HEAD {
builder = builder
.header(reqwest::header::CONTENT_LENGTH, body_bytes.len().to_string())
.body(body_bytes.clone());
if !body_bytes.is_empty() {
builder = builder.header(reqwest::header::CONTENT_TYPE, "application/json");
}
}
let started = std::time::Instant::now();
let result = match tokio::time::timeout(std::time::Duration::from_secs(5), builder.send()).await
{
Ok(send_result) => send_result.map_err(|e| e.to_string()),
Err(_) => Err("step timed out after 5s".to_string()),
};
let latency_ms = started.elapsed().as_millis() as u64;
match result {
Ok(resp) => {
let status = resp.status().as_u16();
let set_cookie_names: Vec<String> = resp
.headers()
.get_all(reqwest::header::SET_COOKIE)
.iter()
.filter_map(|v| v.to_str().ok())
.map(|raw| {
raw.split(';')
.next()
.and_then(|s| s.split('=').next())
.map(|s| s.trim().to_owned())
.unwrap_or_default()
})
.filter(|s| !s.is_empty())
.collect();
let bytes = read_capped(resp, 1024 * 1024).await.unwrap_or_default();
let parsed_authenticated = serde_json::from_slice::<serde_json::Value>(&bytes)
.ok()
.and_then(|v| v["authenticated"].as_bool());
let preview_len = bytes.len().min(512);
let raw_preview = String::from_utf8_lossy(&bytes[..preview_len]).into_owned();
let body_preview = redact_tokens(&raw_preview);
serde_json::json!({
"name": name,
"method": method.as_str(),
"path": path_for_log,
"status": status,
"latency_ms": latency_ms,
"body_bytes": bytes.len(),
"body_preview": body_preview,
"set_cookie_names": set_cookie_names,
"error": serde_json::Value::Null,
"_authenticated": parsed_authenticated,
})
}
Err(e) => serde_json::json!({
"name": name,
"method": method.as_str(),
"path": path_for_log,
"status": serde_json::Value::Null,
"latency_ms": latency_ms,
"body_bytes": 0,
"body_preview": "",
"set_cookie_names": [],
"error": e,
}),
}
}
fn redact_tokens(preview: &str) -> String {
let Ok(mut value) = serde_json::from_str::<serde_json::Value>(preview) else {
return preview.to_owned();
};
redact_in_place(&mut value);
value.to_string()
}
fn redact_in_place(value: &mut serde_json::Value) {
match value {
serde_json::Value::Object(map) => {
for (k, v) in map.iter_mut() {
let lower = k.to_ascii_lowercase();
if lower == "session"
|| lower == "ssoconclusion"
|| lower.contains("token")
|| lower.contains("secret")
{
*v = serde_json::Value::String("<redacted>".to_owned());
} else {
redact_in_place(v);
}
}
}
serde_json::Value::Array(arr) => {
for v in arr.iter_mut() {
redact_in_place(v);
}
}
_ => {}
}
}
fn compute_verdict(
auth_status: &serde_json::Value,
ssodh_init: &serde_json::Value,
tickle: &serde_json::Value,
accounts: &serde_json::Value,
) -> &'static str {
if !is_2xx(auth_status) {
return "auth_status_failed";
}
let ssodh_ran_and_failed = !is_skipped(ssodh_init) && !is_2xx(ssodh_init);
if ssodh_ran_and_failed {
return "ssodh_failed";
}
if !is_authenticated(auth_status) {
return "needs_login";
}
if !is_2xx_or_skipped(tickle) {
return "tickle_failed";
}
if !is_2xx_or_skipped(accounts) {
return "accounts_failed";
}
"ok"
}
fn is_skipped(step: &serde_json::Value) -> bool {
step["skipped"].as_bool().unwrap_or(false)
}
fn is_2xx(step: &serde_json::Value) -> bool {
matches!(step["status"].as_u64(), Some(200..=299))
}
fn is_2xx_or_skipped(step: &serde_json::Value) -> bool {
is_2xx(step) || step["skipped"].as_bool().unwrap_or(false)
}
fn is_authenticated(step: &serde_json::Value) -> bool {
if !is_2xx(step) {
return false;
}
step["_authenticated"].as_bool().unwrap_or(false)
}
fn skipped_step(
name: &'static str,
method: &'static str,
path: &'static str,
reason: &'static str,
) -> serde_json::Value {
serde_json::json!({
"name": name,
"method": method,
"path": path,
"status": serde_json::Value::Null,
"latency_ms": 0,
"body_bytes": 0,
"body_preview": "",
"set_cookie_names": [],
"error": serde_json::Value::Null,
"skipped": true,
"skipped_reason": reason,
})
}
#[tracing::instrument(skip_all)]
async fn health(State(state): State<AppState>) -> Result<Json<HealthBody>, AppError> {
let status = state.client().auth_status().await?;
Ok(Json(HealthBody {
authenticated: status.authenticated,
connected: status.connected,
competing: status.competing,
message: status.message,
}))
}
#[tracing::instrument(skip_all)]
async fn accounts(State(state): State<AppState>) -> Result<Response<Body>, AppError> {
passthrough_get(&state, &["portfolio", "accounts"], &[]).await
}
#[tracing::instrument(skip(state), fields(account_id = %account_id))]
async fn account_summary(
State(state): State<AppState>,
Path(account_id): Path<String>,
) -> Result<Response<Body>, AppError> {
passthrough_get(&state, &["portfolio", account_id.as_str(), "summary"], &[]).await
}
#[derive(Deserialize, Debug)]
struct PositionsQuery {
#[serde(default)]
page: u32,
}
#[tracing::instrument(skip(state), fields(account_id = %account_id, page = q.page))]
async fn account_positions(
State(state): State<AppState>,
Path(account_id): Path<String>,
Query(q): Query<PositionsQuery>,
) -> Result<Response<Body>, AppError> {
let page = q.page.to_string();
passthrough_get(
&state,
&["portfolio", account_id.as_str(), "positions", page.as_str()],
&[],
)
.await
}
#[tracing::instrument(skip(state), fields(account_id = %account_id))]
async fn account_ledger(
State(state): State<AppState>,
Path(account_id): Path<String>,
) -> Result<Response<Body>, AppError> {
passthrough_get(&state, &["portfolio", account_id.as_str(), "ledger"], &[]).await
}
#[tracing::instrument(skip(state), fields(account_id = %account_id))]
async fn account_orders(
State(state): State<AppState>,
Path(account_id): Path<String>,
) -> Result<Response<Body>, AppError> {
passthrough_get(
&state,
&["iserver", "account", "orders"],
&[("accountId", account_id.as_str())],
)
.await
}
#[tracing::instrument(skip(state, body), fields(account_id = %account_id))]
async fn submit_order(
State(state): State<AppState>,
Path(account_id): Path<String>,
axum::extract::Json(body): axum::extract::Json<serde_json::Value>,
) -> Result<Response<Body>, AppError> {
let mut url = state.client().base_url().clone();
{
let mut segs = url
.path_segments_mut()
.map_err(|()| bezant::Error::UrlNotABase {
url: state.client().base_url().to_string(),
})?;
segs.push("iserver")
.push("account")
.push(account_id.as_str())
.push("orders");
}
let resp = state
.client()
.http()
.post(url)
.json(&body)
.send()
.await
.map_err(bezant::Error::Http)?;
forward(resp).await
}
#[tracing::instrument(skip(state), fields(account_id = %account_id, order_id = %order_id))]
async fn cancel_order(
State(state): State<AppState>,
Path((account_id, order_id)): Path<(String, String)>,
) -> Result<Response<Body>, AppError> {
let mut url = state.client().base_url().clone();
{
let mut segs = url
.path_segments_mut()
.map_err(|()| bezant::Error::UrlNotABase {
url: state.client().base_url().to_string(),
})?;
segs.push("iserver")
.push("account")
.push(account_id.as_str())
.push("order")
.push(order_id.as_str());
}
let resp = state
.client()
.http()
.delete(url)
.send()
.await
.map_err(bezant::Error::Http)?;
forward(resp).await
}
#[derive(Deserialize)]
struct ContractSearchQuery {
symbol: String,
#[serde(default)]
name: bool,
#[serde(rename = "secType", default = "default_sec_type")]
sec_type: String,
}
fn default_sec_type() -> String {
"STK".into()
}
async fn contract_search(
State(state): State<AppState>,
Query(q): Query<ContractSearchQuery>,
) -> Result<Response<Body>, AppError> {
let mut url = state.client().base_url().clone();
{
let mut segs = url
.path_segments_mut()
.map_err(|()| bezant::Error::UrlNotABase {
url: state.client().base_url().to_string(),
})?;
segs.push("iserver").push("secdef").push("search");
}
let body = serde_json::json!({
"symbol": q.symbol,
"name": q.name,
"secType": q.sec_type,
});
let resp = state
.client()
.http()
.post(url)
.json(&body)
.send()
.await
.map_err(bezant::Error::Http)?;
forward(resp).await
}
async fn market_snapshot(
State(state): State<AppState>,
Query(q): Query<HashMap<String, String>>,
) -> Result<Response<Body>, AppError> {
let conids = q
.get("conids")
.ok_or(bezant::Error::MissingQuery { name: "conids" })?;
let fields = q
.get("fields")
.cloned()
.unwrap_or_else(|| "31,84,86,87".into());
passthrough_get(
&state,
&["iserver", "marketdata", "snapshot"],
&[("conids", conids), ("fields", &fields)],
)
.await
}
async fn passthrough_get(
state: &AppState,
path_segments: &[&str],
query: &[(&str, &str)],
) -> Result<Response<Body>, AppError> {
let mut url = state.client().base_url().clone();
{
let mut segs = url
.path_segments_mut()
.map_err(|()| bezant::Error::UrlNotABase {
url: state.client().base_url().to_string(),
})?;
for seg in path_segments {
segs.push(seg);
}
}
if !query.is_empty() {
let mut q = url.query_pairs_mut();
for (k, v) in query {
q.append_pair(k, v);
}
}
let resp = state
.client()
.http()
.get(url)
.send()
.await
.map_err(bezant::Error::Http)?;
forward(resp).await
}
const MAX_UPSTREAM_BODY_BYTES: usize = 25 * 1024 * 1024;
#[tracing::instrument(skip_all, fields(upstream_status = %resp.status()))]
async fn forward(resp: reqwest::Response) -> Result<Response<Body>, AppError> {
let status = resp.status();
let headers_src = resp.headers().clone();
let body_must_be_empty =
matches!(status.as_u16(), 100..=199 | 204 | 304) || status.is_redirection();
let bytes: Vec<u8> = match read_capped(resp, MAX_UPSTREAM_BODY_BYTES).await {
Ok(b) => b,
Err(e) if body_must_be_empty => {
tracing::debug!(
%status,
error = %e,
"forward: empty-body fallback on no-body status"
);
Vec::new()
}
Err(e) => {
return Err(bezant::Error::UpstreamStatus {
endpoint: "passthrough",
status: status.as_u16(),
body_preview: Some(e),
}
.into())
}
};
let body_is_empty = bytes.is_empty();
let status = StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let mut headers = HeaderMap::new();
let mut had_content_type = false;
for (name, value) in headers_src.iter() {
let n = name.as_str().to_ascii_lowercase();
if is_hop_by_hop_response(&n) {
continue;
}
if n == "set-cookie" {
let value_bytes: Vec<u8> = match value.to_str() {
Ok(raw) => strip_cookie_domain(raw).into_bytes(),
Err(_) => value.as_bytes().to_vec(),
};
if let (Ok(name), Ok(value)) = (
header::HeaderName::from_bytes(name.as_str().as_bytes()),
HeaderValue::from_bytes(&value_bytes),
) {
headers.append(name, value);
}
continue;
}
if n == "content-type" {
let raw = value.to_str().unwrap_or("");
let rewrite = !body_is_empty
&& !body_must_be_empty
&& raw.eq_ignore_ascii_case("application/octet-stream");
let bytes_to_emit: &[u8] = if rewrite {
b"text/html; charset=UTF-8"
} else {
value.as_bytes()
};
if let (Ok(name), Ok(value)) = (
header::HeaderName::from_bytes(name.as_str().as_bytes()),
HeaderValue::from_bytes(bytes_to_emit),
) {
headers.insert(name, value);
had_content_type = true;
}
continue;
}
if let (Ok(name), Ok(value)) = (
header::HeaderName::from_bytes(name.as_str().as_bytes()),
HeaderValue::from_bytes(value.as_bytes()),
) {
headers.append(name, value);
}
}
if !had_content_type && !body_is_empty && !body_must_be_empty {
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/html; charset=UTF-8"),
);
}
let mut response = Response::builder()
.status(status)
.body(Body::from(bytes))
.map_err(|e| bezant::Error::ResponseBuild(e.to_string()))?;
*response.headers_mut() = headers;
Ok(response)
}
fn is_hop_by_hop_response(name: &str) -> bool {
matches!(
name,
"content-length"
| "connection"
| "keep-alive"
| "proxy-authenticate"
| "proxy-authorization"
| "te"
| "trailer"
| "transfer-encoding"
| "upgrade"
)
}
fn rewrite_referer_origin(original: &str, new_origin: &str) -> String {
match url::Url::parse(original) {
Ok(u) => {
let mut path_and_query = u.path().to_owned();
if let Some(q) = u.query() {
path_and_query.push('?');
path_and_query.push_str(q);
}
format!("{}{}", new_origin.trim_end_matches('/'), path_and_query)
}
Err(_) => new_origin.to_owned(),
}
}
fn is_edge_auth_cookie(name: &str) -> bool {
const BUILTIN_PREFIXES: &[&str] = &[
"CF_Authorization",
"CF_AppSession",
"AWSELBAuthSessionCookie",
"_oauth2_proxy",
"_vercel_jwt",
"_vercel_sso_nonce",
"_pomerium",
];
if BUILTIN_PREFIXES
.iter()
.any(|p| name.eq_ignore_ascii_case(p) || name.starts_with(p))
{
return true;
}
if let Ok(extra) = std::env::var("BEZANT_EDGE_COOKIE_PREFIXES") {
for prefix in extra.split(',').map(str::trim).filter(|p| !p.is_empty()) {
if name.starts_with(prefix) {
return true;
}
}
}
false
}
async fn read_capped(
resp: reqwest::Response,
max: usize,
) -> std::result::Result<Vec<u8>, String> {
use futures_util::StreamExt;
let mut bytes = Vec::new();
let mut stream = resp.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| format!("read chunk: {e}"))?;
if bytes.len() + chunk.len() > max {
return Err(format!(
"upstream body exceeded {max} byte cap (>{}B)",
bytes.len() + chunk.len()
));
}
bytes.extend_from_slice(&chunk);
}
Ok(bytes)
}
fn strip_cookie_domain(value: &str) -> String {
value
.split(';')
.filter(|part| !part.trim().to_ascii_lowercase().starts_with("domain="))
.collect::<Vec<_>>()
.join(";")
}
#[cfg(test)]
mod redact_tests {
use super::redact_tokens;
#[test]
fn redacts_session_field() {
let input = r#"{"session":"AAAA-real-token","other":1}"#;
let out = redact_tokens(input);
assert!(out.contains("<redacted>"), "got: {out}");
assert!(!out.contains("AAAA-real-token"), "raw token leaked: {out}");
assert!(out.contains("\"other\":1"), "got: {out}");
}
#[test]
fn redacts_token_substring_keys() {
let input = r#"{"accessToken":"x","refresh_token":"y","tokenExpiry":99}"#;
let out = redact_tokens(input);
assert!(!out.contains(r#""x""#), "accessToken leaked: {out}");
assert!(!out.contains(r#""y""#), "refresh_token leaked: {out}");
assert!(!out.contains("99"), "tokenExpiry value leaked: {out}");
}
#[test]
fn passes_non_json_through_verbatim() {
let input = "Not a JSON body, just text.";
let out = redact_tokens(input);
assert_eq!(out, input);
}
#[test]
fn redacts_inside_arrays_and_nested_objects() {
let input = r#"{"sessions":[{"session":"a"},{"session":"b"}]}"#;
let out = redact_tokens(input);
assert!(!out.contains(r#""a""#), "got: {out}");
assert!(!out.contains(r#""b""#), "got: {out}");
}
}
#[cfg(test)]
mod forward_tests {
use super::strip_cookie_domain;
#[test]
fn drops_ibkr_domain_attribute() {
let input = "SID=abc; Domain=.ibkr.com; Path=/; Secure";
assert_eq!(strip_cookie_domain(input), "SID=abc; Path=/; Secure");
}
#[test]
fn leaves_cookie_without_domain_untouched() {
let input = "SID=abc; Path=/; Secure";
assert_eq!(strip_cookie_domain(input), input);
}
#[test]
fn case_insensitive() {
let input = "SID=abc; DOMAIN=ibkr.com; Path=/";
assert_eq!(strip_cookie_domain(input), "SID=abc; Path=/");
}
}