Skip to main content

tripo_api/
client.rs

1//! `Client`: entry point for the library. Builds a configured `reqwest::Client`
2//! and carries the API key + base URL + retry policy.
3
4use std::time::Duration;
5
6use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderName, HeaderValue, USER_AGENT};
7use url::Url;
8
9use crate::error::{Error, Result};
10use crate::retry::RetryPolicy;
11
12/// Env var name for the API key.
13pub const API_KEY_ENV: &str = "TRIPO_API_KEY";
14
15/// Env var name for the region selector (`global` | `cn`).
16pub const REGION_ENV: &str = "TRIPO_REGION";
17
18/// Global v2 openapi base URL.
19pub const BASE_URL_GLOBAL: &str = "https://api.tripo3d.ai/v2/openapi";
20/// China mainland v2 openapi base URL.
21pub const BASE_URL_CN: &str = "https://api.tripo3d.com/v2/openapi";
22
23/// Region selector.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
25pub enum Region {
26    /// Global endpoint (default).
27    #[default]
28    Global,
29    /// China mainland endpoint. Adds `X-Tripo-Region: rg2` on GETs.
30    Cn,
31}
32
33impl Region {
34    /// Parse the `TRIPO_REGION` env form: `global` | `cn`.
35    #[must_use]
36    pub fn parse(s: &str) -> Option<Self> {
37        match s.trim().to_ascii_lowercase().as_str() {
38            "global" | "" => Some(Self::Global),
39            "cn" | "china" | "mainland" => Some(Self::Cn),
40            _ => None,
41        }
42    }
43
44    /// Default base URL for this region.
45    #[must_use]
46    pub fn default_base_url(self) -> Url {
47        match self {
48            Self::Global => BASE_URL_GLOBAL.parse().expect("valid const URL"),
49            Self::Cn => BASE_URL_CN.parse().expect("valid const URL"),
50        }
51    }
52}
53
54/// Async client for the Tripo 3D Generation API.
55#[derive(Clone)]
56pub struct Client {
57    pub(crate) http: reqwest::Client,
58    pub(crate) base_url: Url,
59    pub(crate) region: Region,
60    pub(crate) retry: RetryPolicy,
61}
62
63impl std::fmt::Debug for Client {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("Client")
66            .field("base_url", &self.base_url.as_str())
67            .field("region", &self.region)
68            .finish_non_exhaustive()
69    }
70}
71
72fn validate_key(key: &str) -> Result<()> {
73    if key.is_empty() {
74        return Err(Error::MissingApiKey);
75    }
76    if !key.starts_with("tsk_") {
77        return Err(Error::InvalidApiKey);
78    }
79    Ok(())
80}
81
82fn build_http(api_key: &str) -> Result<reqwest::Client> {
83    let mut headers = HeaderMap::new();
84    let mut auth =
85        HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|_| Error::InvalidApiKey)?;
86    auth.set_sensitive(true);
87    headers.insert(AUTHORIZATION, auth);
88    headers.insert(
89        USER_AGENT,
90        HeaderValue::from_static(concat!(
91            "tripo-rs/",
92            env!("CARGO_PKG_VERSION"),
93            " (+https://github.com/pavlov-net/tripo3d-cli)"
94        )),
95    );
96    reqwest::Client::builder()
97        .default_headers(headers)
98        .connect_timeout(Duration::from_secs(10))
99        .timeout(Duration::from_mins(1))
100        .build()
101        .map_err(Error::from)
102}
103
104impl Client {
105    /// Read `TRIPO_API_KEY` (and optionally `TRIPO_REGION`) from the environment.
106    pub fn new() -> Result<Self> {
107        let key = std::env::var(API_KEY_ENV).map_err(|_| Error::MissingApiKey)?;
108        let region = std::env::var(REGION_ENV)
109            .ok()
110            .and_then(|r| Region::parse(&r))
111            .unwrap_or_default();
112        Self::builder().api_key(key).region(region).build()
113    }
114
115    /// Start a [`ClientBuilder`].
116    #[must_use]
117    pub fn builder() -> ClientBuilder {
118        ClientBuilder::default()
119    }
120
121    /// Construct with an explicit key, using defaults for everything else.
122    pub fn with_api_key(key: impl Into<String>) -> Result<Self> {
123        Self::builder().api_key(key).build()
124    }
125
126    /// Override the base URL (testing or staging).
127    #[must_use]
128    pub fn with_base_url(mut self, url: Url) -> Self {
129        self.base_url = url;
130        self
131    }
132
133    /// Current base URL.
134    #[must_use]
135    pub fn base_url(&self) -> &Url {
136        &self.base_url
137    }
138
139    /// Current region.
140    #[must_use]
141    pub fn region(&self) -> Region {
142        self.region
143    }
144
145    /// Join `segments` onto the base URL, preserving the existing path.
146    pub(crate) fn url(&self, segments: &[&str]) -> Url {
147        let mut u = self.base_url.clone();
148        {
149            let mut seg = u.path_segments_mut().expect("http(s) base");
150            for s in segments {
151                seg.push(s);
152            }
153        }
154        u
155    }
156
157    /// Extra headers attached to every request. `X-Tripo-Region: rg2` for CN.
158    pub(crate) fn region_headers(&self) -> HeaderMap {
159        let mut h = HeaderMap::new();
160        if self.region == Region::Cn {
161            h.insert(
162                HeaderName::from_static("x-tripo-region"),
163                HeaderValue::from_static("rg2"),
164            );
165        }
166        h
167    }
168
169    /// `GET /user/balance` — current account balance.
170    #[tracing::instrument(skip(self))]
171    pub async fn get_balance(&self) -> Result<crate::types::Balance> {
172        let url = self.url(&["user", "balance"]);
173        let resp = self
174            .send_with_retry(|| self.http.get(url.clone()).headers(self.region_headers()))
175            .await?;
176        let status = resp.status();
177        let bytes = resp.bytes().await?;
178        if !status.is_success() {
179            return Err(crate::envelope::map_http_error(status, &bytes));
180        }
181        let env: crate::envelope::Envelope<crate::types::Balance> = serde_json::from_slice(&bytes)?;
182        env.into_result()
183    }
184
185    /// `GET /task/{id}` — current state of an existing task.
186    #[tracing::instrument(skip(self), fields(task_id = %id))]
187    pub async fn get_task(&self, id: &crate::types::TaskId) -> Result<crate::types::Task> {
188        let url = self.url(&["task", id.as_str()]);
189        let resp = self
190            .send_with_retry(|| self.http.get(url.clone()).headers(self.region_headers()))
191            .await?;
192        let status = resp.status();
193        let bytes = resp.bytes().await?;
194        if !status.is_success() {
195            return Err(crate::envelope::map_http_error(status, &bytes));
196        }
197        let env: crate::envelope::Envelope<crate::types::Task> = serde_json::from_slice(&bytes)?;
198        env.into_result()
199    }
200
201    /// `POST /task` — submit a task. Any `ImageInput::Path` in the request is
202    /// uploaded first and replaced with a `FileToken`.
203    #[tracing::instrument(skip(self, req))]
204    pub async fn create_task(
205        &self,
206        mut req: crate::tasks::TaskRequest,
207    ) -> Result<crate::types::TaskId> {
208        req.validate()?;
209        req.upload_images(self).await?;
210        self.create_task_raw(&serde_json::to_value(&req)?).await
211    }
212
213    /// Submit an already-built JSON body to `/task`. Used by `create_task` and
214    /// the CLI's `task create --json <FILE>` escape hatch.
215    pub async fn create_task_raw(&self, body: &serde_json::Value) -> Result<crate::types::TaskId> {
216        #[derive(serde::Deserialize)]
217        struct TaskIdBody {
218            task_id: String,
219        }
220        let url = self.url(&["task"]);
221        let body = body.clone();
222        let resp = self
223            .send_with_retry(|| {
224                self.http
225                    .post(url.clone())
226                    .headers(self.region_headers())
227                    .json(&body)
228            })
229            .await?;
230        let status = resp.status();
231        let bytes = resp.bytes().await?;
232        if !status.is_success() {
233            return Err(crate::envelope::map_http_error(status, &bytes));
234        }
235        let env: crate::envelope::Envelope<TaskIdBody> = serde_json::from_slice(&bytes)?;
236        Ok(crate::types::TaskId(env.into_result()?.task_id))
237    }
238
239    pub(crate) async fn send_with_retry<F>(&self, build: F) -> Result<reqwest::Response>
240    where
241        F: Fn() -> reqwest::RequestBuilder,
242    {
243        use crate::retry::{RetryDecision, parse_retry_after};
244
245        let mut attempt: u32 = 0;
246        loop {
247            let req = build();
248            match req.send().await {
249                Ok(resp) => {
250                    let status = resp.status();
251                    if status.is_success() || (status.is_client_error() && status.as_u16() != 429) {
252                        return Ok(resp);
253                    }
254                    let retry_after = resp
255                        .headers()
256                        .get(reqwest::header::RETRY_AFTER)
257                        .and_then(parse_retry_after);
258                    match self.retry.decide_status(attempt, status, retry_after) {
259                        RetryDecision::Stop => return Ok(resp),
260                        RetryDecision::Retry(d) => {
261                            tracing::debug!(?status, ?d, attempt, "retrying after status");
262                            tokio::time::sleep(d).await;
263                        }
264                    }
265                }
266                Err(err) => match self.retry.decide_transport(attempt, &err) {
267                    RetryDecision::Stop => return Err(Error::from(err)),
268                    RetryDecision::Retry(d) => {
269                        tracing::debug!(error = %err, ?d, attempt, "retrying after transport error");
270                        tokio::time::sleep(d).await;
271                    }
272                },
273            }
274            attempt += 1;
275        }
276    }
277}
278
279/// Builder for [`Client`].
280#[derive(Default)]
281pub struct ClientBuilder {
282    api_key: Option<String>,
283    base_url: Option<Url>,
284    region: Option<Region>,
285    retry: Option<RetryPolicy>,
286}
287
288impl ClientBuilder {
289    /// Set the API key.
290    #[must_use]
291    pub fn api_key(mut self, k: impl Into<String>) -> Self {
292        self.api_key = Some(k.into());
293        self
294    }
295    /// Set the region (determines default base URL and `X-Tripo-Region` header).
296    #[must_use]
297    pub fn region(mut self, r: Region) -> Self {
298        self.region = Some(r);
299        self
300    }
301    /// Override the base URL (ignores region's default).
302    #[must_use]
303    pub fn base_url(mut self, u: Url) -> Self {
304        self.base_url = Some(u);
305        self
306    }
307    /// Override the retry policy.
308    #[must_use]
309    pub fn retry(mut self, r: RetryPolicy) -> Self {
310        self.retry = Some(r);
311        self
312    }
313    /// Build, validating the API key.
314    pub fn build(self) -> Result<Client> {
315        let key = self.api_key.ok_or(Error::MissingApiKey)?;
316        validate_key(&key)?;
317        let region = self.region.unwrap_or_default();
318        let base_url = self.base_url.unwrap_or_else(|| region.default_base_url());
319        let http = build_http(&key)?;
320        Ok(Client {
321            http,
322            base_url,
323            region,
324            retry: self.retry.unwrap_or_default(),
325        })
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn rejects_missing_key() {
335        let err = Client::builder().build().unwrap_err();
336        assert!(matches!(err, Error::MissingApiKey));
337    }
338
339    #[test]
340    fn rejects_bad_prefix() {
341        let err = Client::builder()
342            .api_key("wrong_prefix")
343            .build()
344            .unwrap_err();
345        assert!(matches!(err, Error::InvalidApiKey));
346    }
347
348    #[test]
349    fn region_defaults_global() {
350        let c = Client::builder().api_key("tsk_abc").build().unwrap();
351        assert_eq!(c.region(), Region::Global);
352        assert_eq!(c.base_url().as_str(), "https://api.tripo3d.ai/v2/openapi");
353    }
354
355    #[test]
356    fn region_cn_switches_base_url() {
357        let c = Client::builder()
358            .api_key("tsk_abc")
359            .region(Region::Cn)
360            .build()
361            .unwrap();
362        assert_eq!(c.base_url().as_str(), "https://api.tripo3d.com/v2/openapi");
363        assert!(c.region_headers().contains_key("x-tripo-region"));
364    }
365
366    #[test]
367    fn url_joins_segments() {
368        let c = Client::builder().api_key("tsk_abc").build().unwrap();
369        let u = c.url(&["task", "abc123"]);
370        assert_eq!(u.as_str(), "https://api.tripo3d.ai/v2/openapi/task/abc123");
371    }
372}