1use serde::{Deserialize, Serialize};
4
5use crate::client::Client;
6use crate::enums::GeometryQuality;
7use crate::error::{Error, Result};
8use crate::image::ImageInput;
9use crate::versions;
10
11pub mod check_riggable;
12pub mod convert_model;
13pub mod image_to_model;
14pub mod mesh_completion;
15pub mod mesh_segmentation;
16pub mod multiview_to_model;
17pub mod refine_model;
18pub mod retarget_animation;
19pub mod rig_model;
20pub mod smart_lowpoly;
21pub mod stylize_model;
22pub mod text_to_model;
23pub mod texture_model;
24
25pub use check_riggable::CheckRiggableRequest;
26pub use convert_model::ConvertModelRequest;
27pub use image_to_model::ImageToModelRequest;
28pub use mesh_completion::MeshCompletionRequest;
29pub use mesh_segmentation::MeshSegmentationRequest;
30pub use multiview_to_model::MultiviewToModelRequest;
31pub use refine_model::RefineModelRequest;
32pub use retarget_animation::{AnimationInput, RetargetAnimationRequest};
33pub use rig_model::RigModelRequest;
34pub use smart_lowpoly::SmartLowpolyRequest;
35pub use stylize_model::StylizeModelRequest;
36pub use text_to_model::TextToModelRequest;
37pub use texture_model::{TextureModelRequest, TexturePrompt};
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
45#[serde(tag = "type")]
46#[allow(
47 clippy::unsafe_derive_deserialize,
48 reason = "transitive lint via nested types; this enum itself has no unsafe methods"
49)]
50pub enum TaskRequest {
51 #[serde(rename = "text_to_model")]
53 TextToModel(TextToModelRequest),
54 #[serde(rename = "image_to_model")]
56 ImageToModel(ImageToModelRequest),
57 #[serde(rename = "multiview_to_model")]
59 MultiviewToModel(MultiviewToModelRequest),
60 #[serde(rename = "convert_model")]
62 ConvertModel(ConvertModelRequest),
63 #[serde(rename = "stylize_model")]
65 Stylize(StylizeModelRequest),
66 #[serde(rename = "texture_model")]
68 TextureModel(TextureModelRequest),
69 #[serde(rename = "refine_model")]
71 Refine(RefineModelRequest),
72 #[serde(rename = "animate_prerigcheck")]
74 CheckRiggable(CheckRiggableRequest),
75 #[serde(rename = "animate_rig")]
77 Rig(RigModelRequest),
78 #[serde(rename = "animate_retarget")]
80 Retarget(RetargetAnimationRequest),
81 #[serde(rename = "mesh_segmentation")]
83 MeshSegmentation(MeshSegmentationRequest),
84 #[serde(rename = "mesh_completion")]
86 MeshCompletion(MeshCompletionRequest),
87 #[serde(rename = "highpoly_to_lowpoly")]
89 SmartLowpoly(SmartLowpolyRequest),
90}
91
92impl TaskRequest {
93 pub fn validate(&self) -> Result<()> {
97 match self {
98 Self::Rig(r) => r.validate(),
99 Self::TextToModel(r) => r.validate(),
100 Self::ImageToModel(r) => r.validate(),
101 Self::MultiviewToModel(r) => r.validate(),
102 _ => Ok(()),
103 }
104 }
105
106 pub async fn upload_images(&mut self, client: &Client) -> Result<()> {
109 match self {
110 Self::ImageToModel(r) => upload_image_if_path(client, &mut r.image).await,
111 Self::MultiviewToModel(r) => {
112 let futs = r
113 .images
114 .iter_mut()
115 .flatten()
116 .map(|img| upload_image_if_path(client, img));
117 futures::future::try_join_all(futs).await?;
118 Ok(())
119 }
120 Self::TextureModel(r) => {
121 let image = &mut r.texture_prompt.image;
122 let style = &mut r.texture_prompt.style_image;
123 match (image.as_mut(), style.as_mut()) {
124 (Some(a), Some(b)) => {
125 tokio::try_join!(
126 upload_image_if_path(client, a),
127 upload_image_if_path(client, b)
128 )?;
129 }
130 (Some(a), None) => upload_image_if_path(client, a).await?,
131 (None, Some(b)) => upload_image_if_path(client, b).await?,
132 (None, None) => {}
133 }
134 Ok(())
135 }
136 Self::TextToModel(_)
137 | Self::ConvertModel(_)
138 | Self::Stylize(_)
139 | Self::Refine(_)
140 | Self::CheckRiggable(_)
141 | Self::Rig(_)
142 | Self::Retarget(_)
143 | Self::MeshSegmentation(_)
144 | Self::MeshCompletion(_)
145 | Self::SmartLowpoly(_) => Ok(()),
146 }
147 }
148}
149
150pub(crate) fn validate_p1_params(
155 model_version: Option<&str>,
156 quad: Option<bool>,
157 smart_low_poly: Option<bool>,
158 generate_parts: Option<bool>,
159 geometry_quality: Option<&GeometryQuality>,
160) -> Result<()> {
161 if model_version != Some(versions::text_image::P1) {
162 return Ok(());
163 }
164 let mut unsupported: Vec<&str> = Vec::new();
165 if quad == Some(true) {
166 unsupported.push("quad");
167 }
168 if smart_low_poly == Some(true) {
169 unsupported.push("smart_low_poly");
170 }
171 if generate_parts == Some(true) {
172 unsupported.push("generate_parts");
173 }
174 if geometry_quality.is_some() {
175 unsupported.push("geometry_quality");
176 }
177 if unsupported.is_empty() {
178 Ok(())
179 } else {
180 Err(Error::InvalidRequest(format!(
181 "model_version {} does not support: {}",
182 versions::text_image::P1,
183 unsupported.join(", "),
184 )))
185 }
186}
187
188pub(crate) async fn upload_image_if_path(client: &Client, img: &mut ImageInput) -> Result<()> {
191 if let ImageInput::Path(p) = img {
192 let up = client.upload_file(&*p).await?;
193 *img = ImageInput::FileToken(up.file_token);
194 }
195 Ok(())
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn non_p1_version_skips_p1_checks() {
204 validate_p1_params(
205 None,
206 Some(true),
207 Some(true),
208 Some(true),
209 Some(&GeometryQuality::Detailed),
210 )
211 .unwrap();
212 validate_p1_params(
213 Some(versions::text_image::V3_1),
214 Some(true),
215 Some(true),
216 Some(true),
217 Some(&GeometryQuality::Detailed),
218 )
219 .unwrap();
220 }
221
222 #[test]
223 fn p1_with_no_unsupported_fields_ok() {
224 validate_p1_params(Some(versions::text_image::P1), None, None, None, None).unwrap();
225 validate_p1_params(
226 Some(versions::text_image::P1),
227 Some(false),
228 Some(false),
229 Some(false),
230 None,
231 )
232 .unwrap();
233 }
234
235 #[test]
236 fn p1_rejects_quad() {
237 let err = validate_p1_params(Some(versions::text_image::P1), Some(true), None, None, None)
238 .unwrap_err();
239 assert!(matches!(err, Error::InvalidRequest(ref m) if m.contains("quad")));
240 }
241
242 #[test]
243 fn p1_rejects_all_unsupported_together() {
244 let err = validate_p1_params(
245 Some(versions::text_image::P1),
246 Some(true),
247 Some(true),
248 Some(true),
249 Some(&GeometryQuality::Detailed),
250 )
251 .unwrap_err();
252 let Error::InvalidRequest(msg) = err else {
253 panic!("wrong variant");
254 };
255 for field in [
256 "quad",
257 "smart_low_poly",
258 "generate_parts",
259 "geometry_quality",
260 ] {
261 assert!(msg.contains(field), "missing {field} in {msg}");
262 }
263 }
264}