use std::time::Duration;
use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderName, HeaderValue, USER_AGENT};
use url::Url;
use crate::error::{Error, Result};
use crate::retry::RetryPolicy;
pub const API_KEY_ENV: &str = "TRIPO_API_KEY";
pub const REGION_ENV: &str = "TRIPO_REGION";
pub const BASE_URL_GLOBAL: &str = "https://api.tripo3d.ai/v2/openapi";
pub const BASE_URL_CN: &str = "https://api.tripo3d.com/v2/openapi";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Region {
#[default]
Global,
Cn,
}
impl Region {
#[must_use]
pub fn parse(s: &str) -> Option<Self> {
match s.trim().to_ascii_lowercase().as_str() {
"global" | "" => Some(Self::Global),
"cn" | "china" | "mainland" => Some(Self::Cn),
_ => None,
}
}
#[must_use]
pub fn default_base_url(self) -> Url {
match self {
Self::Global => BASE_URL_GLOBAL.parse().expect("valid const URL"),
Self::Cn => BASE_URL_CN.parse().expect("valid const URL"),
}
}
}
#[derive(Clone)]
pub struct Client {
pub(crate) http: reqwest::Client,
pub(crate) base_url: Url,
pub(crate) region: Region,
pub(crate) retry: RetryPolicy,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("base_url", &self.base_url.as_str())
.field("region", &self.region)
.finish_non_exhaustive()
}
}
fn validate_key(key: &str) -> Result<()> {
if key.is_empty() {
return Err(Error::MissingApiKey);
}
if !key.starts_with("tsk_") {
return Err(Error::InvalidApiKey);
}
Ok(())
}
fn build_http(api_key: &str) -> Result<reqwest::Client> {
let mut headers = HeaderMap::new();
let mut auth =
HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|_| Error::InvalidApiKey)?;
auth.set_sensitive(true);
headers.insert(AUTHORIZATION, auth);
headers.insert(
USER_AGENT,
HeaderValue::from_static(concat!(
"tripo-rs/",
env!("CARGO_PKG_VERSION"),
" (+https://github.com/pavlov-net/tripo3d-cli)"
)),
);
reqwest::Client::builder()
.default_headers(headers)
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_mins(1))
.build()
.map_err(Error::from)
}
impl Client {
pub fn new() -> Result<Self> {
let key = std::env::var(API_KEY_ENV).map_err(|_| Error::MissingApiKey)?;
let region = std::env::var(REGION_ENV)
.ok()
.and_then(|r| Region::parse(&r))
.unwrap_or_default();
Self::builder().api_key(key).region(region).build()
}
#[must_use]
pub fn builder() -> ClientBuilder {
ClientBuilder::default()
}
pub fn with_api_key(key: impl Into<String>) -> Result<Self> {
Self::builder().api_key(key).build()
}
#[must_use]
pub fn with_base_url(mut self, url: Url) -> Self {
self.base_url = url;
self
}
#[must_use]
pub fn base_url(&self) -> &Url {
&self.base_url
}
#[must_use]
pub fn region(&self) -> Region {
self.region
}
pub(crate) fn url(&self, segments: &[&str]) -> Url {
let mut u = self.base_url.clone();
{
let mut seg = u.path_segments_mut().expect("http(s) base");
for s in segments {
seg.push(s);
}
}
u
}
pub(crate) fn region_headers(&self) -> HeaderMap {
let mut h = HeaderMap::new();
if self.region == Region::Cn {
h.insert(
HeaderName::from_static("x-tripo-region"),
HeaderValue::from_static("rg2"),
);
}
h
}
#[tracing::instrument(skip(self))]
pub async fn get_balance(&self) -> Result<crate::types::Balance> {
let url = self.url(&["user", "balance"]);
let resp = self
.send_with_retry(|| self.http.get(url.clone()).headers(self.region_headers()))
.await?;
let status = resp.status();
let bytes = resp.bytes().await?;
if !status.is_success() {
return Err(crate::envelope::map_http_error(status, &bytes));
}
let env: crate::envelope::Envelope<crate::types::Balance> = serde_json::from_slice(&bytes)?;
env.into_result()
}
#[tracing::instrument(skip(self), fields(task_id = %id))]
pub async fn get_task(&self, id: &crate::types::TaskId) -> Result<crate::types::Task> {
let url = self.url(&["task", id.as_str()]);
let resp = self
.send_with_retry(|| self.http.get(url.clone()).headers(self.region_headers()))
.await?;
let status = resp.status();
let bytes = resp.bytes().await?;
if !status.is_success() {
return Err(crate::envelope::map_http_error(status, &bytes));
}
let env: crate::envelope::Envelope<crate::types::Task> = serde_json::from_slice(&bytes)?;
env.into_result()
}
#[tracing::instrument(skip(self, req))]
pub async fn create_task(
&self,
mut req: crate::tasks::TaskRequest,
) -> Result<crate::types::TaskId> {
req.validate()?;
req.upload_images(self).await?;
self.create_task_raw(&serde_json::to_value(&req)?).await
}
pub async fn create_task_raw(&self, body: &serde_json::Value) -> Result<crate::types::TaskId> {
#[derive(serde::Deserialize)]
struct TaskIdBody {
task_id: String,
}
let url = self.url(&["task"]);
let body = body.clone();
let resp = self
.send_with_retry(|| {
self.http
.post(url.clone())
.headers(self.region_headers())
.json(&body)
})
.await?;
let status = resp.status();
let bytes = resp.bytes().await?;
if !status.is_success() {
return Err(crate::envelope::map_http_error(status, &bytes));
}
let env: crate::envelope::Envelope<TaskIdBody> = serde_json::from_slice(&bytes)?;
Ok(crate::types::TaskId(env.into_result()?.task_id))
}
pub(crate) async fn send_with_retry<F>(&self, build: F) -> Result<reqwest::Response>
where
F: Fn() -> reqwest::RequestBuilder,
{
use crate::retry::{RetryDecision, parse_retry_after};
let mut attempt: u32 = 0;
loop {
let req = build();
match req.send().await {
Ok(resp) => {
let status = resp.status();
if status.is_success() || (status.is_client_error() && status.as_u16() != 429) {
return Ok(resp);
}
let retry_after = resp
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(parse_retry_after);
match self.retry.decide_status(attempt, status, retry_after) {
RetryDecision::Stop => return Ok(resp),
RetryDecision::Retry(d) => {
tracing::debug!(?status, ?d, attempt, "retrying after status");
tokio::time::sleep(d).await;
}
}
}
Err(err) => match self.retry.decide_transport(attempt, &err) {
RetryDecision::Stop => return Err(Error::from(err)),
RetryDecision::Retry(d) => {
tracing::debug!(error = %err, ?d, attempt, "retrying after transport error");
tokio::time::sleep(d).await;
}
},
}
attempt += 1;
}
}
}
#[derive(Default)]
pub struct ClientBuilder {
api_key: Option<String>,
base_url: Option<Url>,
region: Option<Region>,
retry: Option<RetryPolicy>,
}
impl ClientBuilder {
#[must_use]
pub fn api_key(mut self, k: impl Into<String>) -> Self {
self.api_key = Some(k.into());
self
}
#[must_use]
pub fn region(mut self, r: Region) -> Self {
self.region = Some(r);
self
}
#[must_use]
pub fn base_url(mut self, u: Url) -> Self {
self.base_url = Some(u);
self
}
#[must_use]
pub fn retry(mut self, r: RetryPolicy) -> Self {
self.retry = Some(r);
self
}
pub fn build(self) -> Result<Client> {
let key = self.api_key.ok_or(Error::MissingApiKey)?;
validate_key(&key)?;
let region = self.region.unwrap_or_default();
let base_url = self.base_url.unwrap_or_else(|| region.default_base_url());
let http = build_http(&key)?;
Ok(Client {
http,
base_url,
region,
retry: self.retry.unwrap_or_default(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_missing_key() {
let err = Client::builder().build().unwrap_err();
assert!(matches!(err, Error::MissingApiKey));
}
#[test]
fn rejects_bad_prefix() {
let err = Client::builder()
.api_key("wrong_prefix")
.build()
.unwrap_err();
assert!(matches!(err, Error::InvalidApiKey));
}
#[test]
fn region_defaults_global() {
let c = Client::builder().api_key("tsk_abc").build().unwrap();
assert_eq!(c.region(), Region::Global);
assert_eq!(c.base_url().as_str(), "https://api.tripo3d.ai/v2/openapi");
}
#[test]
fn region_cn_switches_base_url() {
let c = Client::builder()
.api_key("tsk_abc")
.region(Region::Cn)
.build()
.unwrap();
assert_eq!(c.base_url().as_str(), "https://api.tripo3d.com/v2/openapi");
assert!(c.region_headers().contains_key("x-tripo-region"));
}
#[test]
fn url_joins_segments() {
let c = Client::builder().api_key("tsk_abc").build().unwrap();
let u = c.url(&["task", "abc123"]);
assert_eq!(u.as_str(), "https://api.tripo3d.ai/v2/openapi/task/abc123");
}
}