Skip to main content

origin_asset/
client.rs

1use reqwest::Method;
2use serde::Deserialize;
3use serde_json::{json, Map, Value};
4
5use crate::error::Result;
6use crate::transport::HttpTransport;
7use crate::types::{
8    AssetType, AudioOptions, GenerateRequest, GenerateResponse, ImageOptions, JobSummary,
9    Model3dOptions, MusicOptions, ProcessRequest, ProcessResponse, ProviderInfo, SpriteOptions,
10    TtsOptions, VideoOptions,
11};
12
13/// Default Asset Gateway endpoint used by [`AssetClient`].
14pub const DEFAULT_BASE_URL: &str = "https://asset.origingame.dev";
15
16/// Builder for configuring an [`AssetClient`].
17#[derive(Debug, Clone)]
18pub struct AssetClientBuilder {
19    api_key: String,
20    base_url: String,
21    client: Option<reqwest::Client>,
22}
23
24impl AssetClientBuilder {
25    pub fn new(api_key: impl Into<String>) -> Self {
26        Self {
27            api_key: api_key.into(),
28            base_url: DEFAULT_BASE_URL.to_string(),
29            client: None,
30        }
31    }
32
33    /// Override the Asset Gateway base URL.
34    pub fn base_url(mut self, url: impl Into<String>) -> Self {
35        self.base_url = url.into();
36        self
37    }
38
39    /// Use a custom `reqwest::Client` (e.g. custom timeout, proxy, TLS config).
40    pub fn http_client(mut self, client: reqwest::Client) -> Self {
41        self.client = Some(client);
42        self
43    }
44
45    pub fn build(self) -> AssetClient {
46        let transport = HttpTransport::new(self.base_url, self.api_key);
47        let transport = if let Some(client) = self.client {
48            transport.with_client(client)
49        } else {
50            transport
51        };
52
53        AssetClient { transport }
54    }
55}
56
57/// Client for the Origin Asset Gateway service.
58#[derive(Debug, Clone)]
59pub struct AssetClient {
60    pub(crate) transport: HttpTransport,
61}
62
63impl AssetClient {
64    /// Create a client with the default Asset Gateway base URL.
65    pub fn new(api_key: impl Into<String>) -> Self {
66        AssetClientBuilder::new(api_key).build()
67    }
68
69    /// Create a builder for fine-grained configuration.
70    pub fn builder(api_key: impl Into<String>) -> AssetClientBuilder {
71        AssetClientBuilder::new(api_key)
72    }
73
74    /// Return the configured Asset Gateway base URL.
75    pub fn base_url(&self) -> &str {
76        self.transport.base_url()
77    }
78
79    pub async fn generate(&self, request: &GenerateRequest) -> Result<GenerateResponse> {
80        let body = json!({
81            "asset_type": request.asset_type,
82            "prompt": request.prompt,
83            "model": request.model,
84            "input_file": request.input_file,
85            "provider": request.provider,
86            "size": request.size,
87            "transparent": request.transparent,
88            "reference_images": request.reference_images,
89            "edit_mode": request.edit_mode,
90            "session_id": request.session_id,
91            "params": request.params,
92        });
93
94        self.transport.post("/api/generate", &body).await
95    }
96
97    pub async fn generate_image(
98        &self,
99        prompt: impl Into<String>,
100        options: Option<ImageOptions>,
101    ) -> Result<GenerateResponse> {
102        let options = options.unwrap_or_default();
103        let request = GenerateRequest {
104            asset_type: AssetType::Image,
105            prompt: Some(prompt.into()),
106            model: options.model,
107            input_file: options.input,
108            provider: options.provider,
109            size: options.size,
110            transparent: options.transparent,
111            reference_images: options.reference_images,
112            edit_mode: options.edit_mode,
113            session_id: options.session_id,
114            params: options.params,
115        };
116
117        self.generate(&request).await
118    }
119
120    pub async fn generate_video(
121        &self,
122        prompt: impl Into<String>,
123        options: Option<VideoOptions>,
124    ) -> Result<GenerateResponse> {
125        let options = options.unwrap_or_default();
126        let request = GenerateRequest {
127            asset_type: AssetType::Video,
128            prompt: Some(prompt.into()),
129            model: options.model,
130            input_file: options.input,
131            provider: options.provider,
132            size: options.size,
133            transparent: None,
134            reference_images: Vec::new(),
135            edit_mode: None,
136            session_id: None,
137            params: options.params,
138        };
139
140        self.generate(&request).await
141    }
142
143    pub async fn generate_audio(
144        &self,
145        prompt: impl Into<String>,
146        options: Option<AudioOptions>,
147    ) -> Result<GenerateResponse> {
148        let options = options.unwrap_or_default();
149        let request = GenerateRequest {
150            asset_type: AssetType::Audio,
151            prompt: Some(prompt.into()),
152            model: options.model,
153            input_file: None,
154            provider: options.provider,
155            size: None,
156            transparent: None,
157            reference_images: Vec::new(),
158            edit_mode: None,
159            session_id: None,
160            params: merge_params(
161                options.params,
162                [
163                    option_entry("audio_type", options.audio_type),
164                    option_entry("duration_seconds", options.duration.map(Value::from)),
165                ],
166            ),
167        };
168
169        self.generate(&request).await
170    }
171
172    pub async fn generate_tts(
173        &self,
174        prompt: impl Into<String>,
175        options: Option<TtsOptions>,
176    ) -> Result<GenerateResponse> {
177        let options = options.unwrap_or_default();
178        let request = GenerateRequest {
179            asset_type: AssetType::Tts,
180            prompt: Some(prompt.into()),
181            model: options.model,
182            input_file: None,
183            provider: options.provider,
184            size: None,
185            transparent: None,
186            reference_images: Vec::new(),
187            edit_mode: None,
188            session_id: None,
189            params: merge_params(
190                options.params,
191                [
192                    option_entry("voice", options.voice),
193                    option_entry("voice_id", options.voice_id),
194                    option_entry("language_type", options.language),
195                    option_entry("instructions", options.instructions),
196                ],
197            ),
198        };
199
200        self.generate(&request).await
201    }
202
203    pub async fn generate_music(
204        &self,
205        prompt: impl Into<String>,
206        options: Option<MusicOptions>,
207    ) -> Result<GenerateResponse> {
208        let options = options.unwrap_or_default();
209        let request = GenerateRequest {
210            asset_type: AssetType::Music,
211            prompt: Some(prompt.into()),
212            model: options.model,
213            input_file: None,
214            provider: options.provider,
215            size: None,
216            transparent: None,
217            reference_images: Vec::new(),
218            edit_mode: None,
219            session_id: None,
220            params: merge_params(
221                options.params,
222                [
223                    option_entry("duration_seconds", options.duration.map(Value::from)),
224                    bool_entry("force_instrumental", options.force_instrumental),
225                    option_entry("output_format", options.output_format),
226                ],
227            ),
228        };
229
230        self.generate(&request).await
231    }
232
233    pub async fn generate_model3d(
234        &self,
235        prompt: impl Into<String>,
236        options: Option<Model3dOptions>,
237    ) -> Result<GenerateResponse> {
238        let options = options.unwrap_or_default();
239        let request = GenerateRequest {
240            asset_type: AssetType::Model3d,
241            prompt: Some(prompt.into()),
242            model: options.model,
243            input_file: options.input,
244            provider: options.provider,
245            size: None,
246            transparent: None,
247            reference_images: Vec::new(),
248            edit_mode: None,
249            session_id: None,
250            params: merge_params(
251                options.params,
252                [
253                    option_entry("model_version", options.model_version),
254                    option_entry("face_limit", options.face_limit.map(Value::from)),
255                    bool_entry("pbr", options.pbr),
256                    option_entry("texture_quality", options.texture_quality),
257                    bool_entry("auto_size", options.auto_size),
258                    option_entry("negative_prompt", options.negative_prompt),
259                    option_entry(
260                        "multiview",
261                        (!options.multiview.is_empty()).then(|| {
262                            Value::Array(options.multiview.into_iter().map(Value::from).collect())
263                        }),
264                    ),
265                    option_entry("style", options.style),
266                ],
267            ),
268        };
269
270        self.generate(&request).await
271    }
272
273    pub async fn generate_sprite(
274        &self,
275        prompt: impl Into<String>,
276        options: Option<SpriteOptions>,
277    ) -> Result<GenerateResponse> {
278        let options = options.unwrap_or_default();
279        let request = GenerateRequest {
280            asset_type: AssetType::Sprite,
281            prompt: Some(prompt.into()),
282            model: options.model,
283            input_file: options.input,
284            provider: options.provider,
285            size: None,
286            transparent: None,
287            reference_images: Vec::new(),
288            edit_mode: None,
289            session_id: None,
290            params: merge_params(
291                options.params,
292                [
293                    option_entry(
294                        "animation_type",
295                        Some(Value::String(
296                            options.animation_type.unwrap_or_else(|| "walk".to_string()),
297                        )),
298                    ),
299                    option_entry(
300                        "direction",
301                        Some(Value::String(
302                            options.direction.unwrap_or_else(|| "right".to_string()),
303                        )),
304                    ),
305                    option_entry("duration", Some(Value::from(options.duration.unwrap_or(2)))),
306                    option_entry(
307                        "output_format",
308                        Some(Value::String(
309                            options
310                                .output_format
311                                .unwrap_or_else(|| "spritesheet".to_string()),
312                        )),
313                    ),
314                    option_entry("fps", options.fps.map(Value::from)),
315                    option_entry("style", options.style),
316                ],
317            ),
318        };
319
320        self.generate(&request).await
321    }
322
323    pub async fn process(&self, request: &ProcessRequest) -> Result<ProcessResponse> {
324        let body = json!({
325            "input": request.input,
326            "inputs": request.inputs,
327            "operations": request.operations,
328        });
329
330        self.transport.post("/api/process", &body).await
331    }
332
333    pub async fn jobs(&self, status: Option<&str>, limit: Option<u32>) -> Result<Vec<JobSummary>> {
334        let mut path = String::from("/api/jobs");
335        let mut query = Vec::new();
336
337        if let Some(status) = status {
338            query.push(format!("status={status}"));
339        }
340        if let Some(limit) = limit {
341            query.push(format!("limit={limit}"));
342        }
343        if !query.is_empty() {
344            path.push('?');
345            path.push_str(&query.join("&"));
346        }
347
348        let response: JobListResponse = self.transport.get(&path).await?;
349        Ok(response.jobs)
350    }
351
352    pub async fn job_status(&self, job_id: &str) -> Result<JobSummary> {
353        self.transport.get(&format!("/api/jobs/{job_id}")).await
354    }
355
356    pub async fn providers(&self) -> Result<Vec<ProviderInfo>> {
357        let response: ProviderListResponse = self.transport.get("/api/providers").await?;
358        Ok(response.providers)
359    }
360
361    pub async fn health(&self) -> Result<bool> {
362        let response: HealthResponse = self.transport.request(Method::GET, "/health", None).await?;
363        Ok(response.status == "healthy")
364    }
365}
366
367#[derive(Debug, Deserialize)]
368struct JobListResponse {
369    jobs: Vec<JobSummary>,
370}
371
372#[derive(Debug, Deserialize)]
373struct ProviderListResponse {
374    providers: Vec<ProviderInfo>,
375}
376
377#[derive(Debug, Deserialize)]
378struct HealthResponse {
379    status: String,
380}
381
382fn merge_params<const N: usize>(base: Value, entries: [(String, Option<Value>); N]) -> Value {
383    let mut params = match base {
384        Value::Object(map) => map,
385        _ => Map::new(),
386    };
387
388    for (key, value) in entries {
389        if let Some(value) = value {
390            params.insert(key, value);
391        }
392    }
393
394    if params.is_empty() {
395        Value::Null
396    } else {
397        Value::Object(params)
398    }
399}
400
401fn option_entry<T>(key: &str, value: Option<T>) -> (String, Option<Value>)
402where
403    T: Into<Value>,
404{
405    (key.to_string(), value.map(Into::into))
406}
407
408fn bool_entry(key: &str, value: Option<bool>) -> (String, Option<Value>) {
409    (
410        key.to_string(),
411        value.and_then(|enabled| enabled.then_some(Value::Bool(true))),
412    )
413}