use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use bytes::Bytes;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Method, StatusCode};
use serde_json::{json, Value};
use url::Url;
use crate::context::EffectiveConfig;
use crate::error::CliError;
use crate::http::{print_host_autofix_banner, ClientUi};
use crate::multi_base::{self, BaseCandidate};
#[derive(Debug)]
pub(super) struct TrpcClient {
bases: Vec<BaseCandidate>,
active_base: AtomicUsize,
warned_autofix: AtomicBool,
retries: u32,
dry_run: bool,
client: reqwest::Client,
cookie: Option<String>,
ui: ClientUi,
}
impl TrpcClient {
pub(super) fn new(
base_url: &str,
timeout: Duration,
retries: u32,
dry_run: bool,
ui: ClientUi,
) -> Result<Self, CliError> {
let bases = multi_base::build_base_candidates(base_url)?;
let client = reqwest::Client::builder().timeout(timeout).build()?;
Ok(Self {
bases,
active_base: AtomicUsize::new(0),
warned_autofix: AtomicBool::new(false),
retries,
dry_run,
client,
cookie: None,
ui,
})
}
pub(super) fn with_cookie(mut self, cookie: Option<String>) -> Self {
self.cookie = cookie;
self
}
pub(super) async fn query(&self, procedure: &str, input: Value) -> Result<Value, CliError> {
let path = format!("api/trpc/{}", procedure.trim());
let mut headers = HeaderMap::new();
headers.insert("accept", HeaderValue::from_static("application/json"));
if let Some(ref cookie) = self.cookie {
headers.insert(
reqwest::header::COOKIE,
HeaderValue::from_str(cookie).map_err(|_| {
CliError::InvalidArgument("cookie contains invalid characters".to_string())
})?,
);
}
let input_param = if input.is_null() {
None
} else {
Some(serde_json::to_string(&json!({ "json": input }))?)
};
if self.dry_run {
let base_idx = self.active_base.load(Ordering::Relaxed);
let mut url = self.build_url_for_base(base_idx, &path)?;
if let Some(input) = input_param.as_deref() {
url.query_pairs_mut().append_pair("input", input);
}
print_dry_run_no_body(&Method::GET, &url, &headers);
return Err(CliError::DryRunPrinted);
}
multi_base::try_with_base_fallback(
&self.bases,
&self.active_base,
&path,
false,
should_try_host_autofix,
|mut url| {
if let Some(ref input) = input_param {
url.query_pairs_mut().append_pair("input", input);
}
self.query_with_url(url, &headers)
},
|idx| self.maybe_warn_host_autofix(idx),
)
.await
}
pub(super) async fn mutation(&self, procedure: &str, input: Value) -> Result<Value, CliError> {
let path = format!("api/trpc/{}?batch=1", procedure.trim());
let body = json!({ "0": { "json": input } });
let body_bytes = Bytes::from(serde_json::to_vec(&body)?);
let mut headers = HeaderMap::new();
headers.insert("accept", HeaderValue::from_static("application/json"));
headers.insert("content-type", HeaderValue::from_static("application/json"));
if let Some(ref cookie) = self.cookie {
headers.insert(
reqwest::header::COOKIE,
HeaderValue::from_str(cookie).map_err(|_| {
CliError::InvalidArgument("cookie contains invalid characters".to_string())
})?,
);
}
if self.dry_run {
let base_idx = self.active_base.load(Ordering::Relaxed);
let url = self.build_url_for_base(base_idx, &path)?;
print_dry_run(&Method::POST, &url, &headers, &body);
return Err(CliError::DryRunPrinted);
}
multi_base::try_with_base_fallback(
&self.bases,
&self.active_base,
&path,
false,
should_try_host_autofix,
|url| self.call_with_url(url, &headers, body_bytes.clone()),
|idx| self.maybe_warn_host_autofix(idx),
)
.await
}
pub(super) async fn call(&self, procedure: &str, input: Value) -> Result<Value, CliError> {
self.mutation(procedure, input).await
}
fn build_url_for_base(&self, base_idx: usize, path: &str) -> Result<Url, CliError> {
multi_base::build_url_for_base(&self.bases, base_idx, path, false)
}
fn maybe_warn_host_autofix(&self, active_idx: usize) {
multi_base::maybe_warn_host_autofix(
self.ui.quiet,
&self.warned_autofix,
&self.bases,
active_idx,
|configured, using| print_host_autofix_banner(&self.ui, configured, using),
);
}
async fn call_with_url(
&self,
url: Url,
headers: &HeaderMap,
body_bytes: Bytes,
) -> Result<Value, CliError> {
let mut backoff = Duration::from_millis(200);
for attempt in 0..=self.retries {
let request = self
.client
.request(Method::POST, url.clone())
.headers(headers.clone())
.body(body_bytes.clone());
match request.send().await {
Ok(resp) => {
let status = resp.status();
let retry_after = resp
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.map(Duration::from_secs);
let bytes = resp.bytes().await?;
if should_retry_status(status) && attempt < self.retries {
if status == StatusCode::TOO_MANY_REQUESTS {
tokio::time::sleep(retry_after.unwrap_or(backoff)).await;
} else {
tokio::time::sleep(backoff).await;
}
backoff = (backoff * 2).min(Duration::from_secs(5));
continue;
}
return parse_trpc_http_response(status, bytes.as_ref());
}
Err(err) => {
if attempt < self.retries && should_retry_error(&err) {
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(Duration::from_secs(5));
continue;
}
return Err(CliError::Request(err));
}
}
}
Err(CliError::RateLimited)
}
async fn query_with_url(&self, url: Url, headers: &HeaderMap) -> Result<Value, CliError> {
let mut backoff = Duration::from_millis(200);
for attempt in 0..=self.retries {
let request = self
.client
.request(Method::GET, url.clone())
.headers(headers.clone());
match request.send().await {
Ok(resp) => {
let status = resp.status();
let retry_after = resp
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.map(Duration::from_secs);
let bytes = resp.bytes().await?;
if should_retry_status(status) && attempt < self.retries {
if status == StatusCode::TOO_MANY_REQUESTS {
tokio::time::sleep(retry_after.unwrap_or(backoff)).await;
} else {
tokio::time::sleep(backoff).await;
}
backoff = (backoff * 2).min(Duration::from_secs(5));
continue;
}
return parse_trpc_http_response(status, bytes.as_ref());
}
Err(err) => {
if attempt < self.retries && should_retry_error(&err) {
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(Duration::from_secs(5));
continue;
}
return Err(CliError::Request(err));
}
}
}
Err(CliError::RateLimited)
}
}
fn should_try_host_autofix(err: &CliError) -> bool {
if multi_base::should_try_host_autofix_basic(err) {
return true;
}
matches!(err, CliError::HttpStatus { message, .. } if message == "invalid json response")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn trpc_join_preserves_base_path_prefix() {
let client = TrpcClient::new(
"https://example.com/api",
Duration::from_secs(1),
0,
true,
ClientUi::default(),
)
.unwrap();
let url = client.build_url_for_base(0, "api/trpc/foo?batch=1").unwrap();
assert_eq!(url.as_str(), "https://example.com/api/api/trpc/foo?batch=1");
}
}
pub(super) fn cookie_from_effective(effective: &EffectiveConfig) -> Option<String> {
let session = effective.session_cookie.as_deref()?.trim();
if session.is_empty() {
return None;
}
let mut parts = vec![
format!("next-auth.session-token={session}"),
format!("__Secure-next-auth.session-token={session}"),
];
if let Some(device) = effective.device_cookie.as_deref() {
let device = device.trim();
if !device.is_empty() {
parts.push(format!("next-auth.did-token={device}"));
}
}
Some(parts.join("; "))
}
pub(super) fn require_cookie_from_effective(effective: &EffectiveConfig) -> Result<String, CliError> {
cookie_from_effective(effective).ok_or(CliError::SessionRequired)
}
fn parse_trpc_http_response(status: StatusCode, bytes: &[u8]) -> Result<Value, CliError> {
if status == StatusCode::UNAUTHORIZED {
return Err(CliError::SessionRequired);
}
let parsed = serde_json::from_slice::<Value>(bytes);
let value = match parsed {
Ok(v) => v,
Err(_) => {
let body = String::from_utf8_lossy(bytes).to_string();
return Err(CliError::HttpStatus {
status,
message: "invalid json response".to_string(),
body: Some(body),
});
}
};
parse_trpc_envelope(status, value)
}
fn parse_trpc_envelope(http_status: StatusCode, value: Value) -> Result<Value, CliError> {
let item = match value {
Value::Array(mut items) => items
.drain(..)
.next()
.ok_or_else(|| CliError::HttpStatus {
status: http_status,
message: "empty tRPC response".to_string(),
body: None,
})?,
other => other,
};
let Some(obj) = item.as_object() else {
return Ok(item);
};
if let Some(err) = obj.get("error") {
let message = err
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("tRPC error")
.to_string();
let code = err
.get("data")
.and_then(|d| d.get("code"))
.and_then(|v| v.as_str())
.unwrap_or("");
let http_status = err
.get("data")
.and_then(|d| d.get("httpStatus"))
.and_then(|v| v.as_u64())
.and_then(|n| StatusCode::from_u16(n as u16).ok())
.unwrap_or(http_status);
if code == "UNAUTHORIZED" || http_status == StatusCode::UNAUTHORIZED {
return Err(CliError::SessionRequired);
}
return Err(CliError::HttpStatus {
status: http_status,
message,
body: Some(err.to_string()),
});
}
let Some(result) = obj.get("result") else {
return Ok(Value::Object(obj.clone()));
};
let data = result.get("data").unwrap_or(&Value::Null);
if let Some(json) = data.get("json") {
return Ok(json.clone());
}
Ok(data.clone())
}
fn should_retry_status(status: StatusCode) -> bool {
status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()
}
fn should_retry_error(err: &reqwest::Error) -> bool {
err.is_timeout() || err.is_connect() || err.is_request()
}
fn print_dry_run(method: &Method, url: &Url, headers: &HeaderMap, body: &Value) {
println!("{method} {url}");
for (name, value) in headers.iter() {
if name.as_str().eq_ignore_ascii_case("cookie") {
println!("{name}: REDACTED");
continue;
}
if let Ok(value) = value.to_str() {
println!("{name}: {value}");
}
}
if let Ok(pretty) = serde_json::to_string_pretty(body) {
println!();
println!("{pretty}");
}
}
fn print_dry_run_no_body(method: &Method, url: &Url, headers: &HeaderMap) {
println!("{method} {url}");
for (name, value) in headers.iter() {
if name.as_str().eq_ignore_ascii_case("cookie") {
println!("{name}: REDACTED");
continue;
}
if let Ok(value) = value.to_str() {
println!("{name}: {value}");
}
}
}