use serde::{Deserialize, Serialize};
use crate::client::Client;
use crate::enums::GeometryQuality;
use crate::error::{Error, Result};
use crate::image::ImageInput;
use crate::versions;
pub mod check_riggable;
pub mod convert_model;
pub mod image_to_model;
pub mod mesh_completion;
pub mod mesh_segmentation;
pub mod multiview_to_model;
pub mod refine_model;
pub mod retarget_animation;
pub mod rig_model;
pub mod smart_lowpoly;
pub mod stylize_model;
pub mod text_to_model;
pub mod texture_model;
pub use check_riggable::CheckRiggableRequest;
pub use convert_model::ConvertModelRequest;
pub use image_to_model::ImageToModelRequest;
pub use mesh_completion::MeshCompletionRequest;
pub use mesh_segmentation::MeshSegmentationRequest;
pub use multiview_to_model::MultiviewToModelRequest;
pub use refine_model::RefineModelRequest;
pub use retarget_animation::{AnimationInput, RetargetAnimationRequest};
pub use rig_model::RigModelRequest;
pub use smart_lowpoly::SmartLowpolyRequest;
pub use stylize_model::StylizeModelRequest;
pub use text_to_model::TextToModelRequest;
pub use texture_model::{TextureModelRequest, TexturePrompt};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[serde(tag = "type")]
#[allow(
clippy::unsafe_derive_deserialize,
reason = "transitive lint via nested types; this enum itself has no unsafe methods"
)]
pub enum TaskRequest {
#[serde(rename = "text_to_model")]
TextToModel(TextToModelRequest),
#[serde(rename = "image_to_model")]
ImageToModel(ImageToModelRequest),
#[serde(rename = "multiview_to_model")]
MultiviewToModel(MultiviewToModelRequest),
#[serde(rename = "convert_model")]
ConvertModel(ConvertModelRequest),
#[serde(rename = "stylize_model")]
Stylize(StylizeModelRequest),
#[serde(rename = "texture_model")]
TextureModel(TextureModelRequest),
#[serde(rename = "refine_model")]
Refine(RefineModelRequest),
#[serde(rename = "animate_prerigcheck")]
CheckRiggable(CheckRiggableRequest),
#[serde(rename = "animate_rig")]
Rig(RigModelRequest),
#[serde(rename = "animate_retarget")]
Retarget(RetargetAnimationRequest),
#[serde(rename = "mesh_segmentation")]
MeshSegmentation(MeshSegmentationRequest),
#[serde(rename = "mesh_completion")]
MeshCompletion(MeshCompletionRequest),
#[serde(rename = "highpoly_to_lowpoly")]
SmartLowpoly(SmartLowpolyRequest),
}
impl TaskRequest {
pub fn validate(&self) -> Result<()> {
match self {
Self::Rig(r) => r.validate(),
Self::TextToModel(r) => r.validate(),
Self::ImageToModel(r) => r.validate(),
Self::MultiviewToModel(r) => r.validate(),
_ => Ok(()),
}
}
pub async fn upload_images(&mut self, client: &Client) -> Result<()> {
match self {
Self::ImageToModel(r) => upload_image_if_path(client, &mut r.image).await,
Self::MultiviewToModel(r) => {
let futs = r
.images
.iter_mut()
.flatten()
.map(|img| upload_image_if_path(client, img));
futures::future::try_join_all(futs).await?;
Ok(())
}
Self::TextureModel(r) => {
let image = &mut r.texture_prompt.image;
let style = &mut r.texture_prompt.style_image;
match (image.as_mut(), style.as_mut()) {
(Some(a), Some(b)) => {
tokio::try_join!(
upload_image_if_path(client, a),
upload_image_if_path(client, b)
)?;
}
(Some(a), None) => upload_image_if_path(client, a).await?,
(None, Some(b)) => upload_image_if_path(client, b).await?,
(None, None) => {}
}
Ok(())
}
Self::TextToModel(_)
| Self::ConvertModel(_)
| Self::Stylize(_)
| Self::Refine(_)
| Self::CheckRiggable(_)
| Self::Rig(_)
| Self::Retarget(_)
| Self::MeshSegmentation(_)
| Self::MeshCompletion(_)
| Self::SmartLowpoly(_) => Ok(()),
}
}
}
pub(crate) fn validate_p1_params(
model_version: Option<&str>,
quad: Option<bool>,
smart_low_poly: Option<bool>,
generate_parts: Option<bool>,
geometry_quality: Option<&GeometryQuality>,
) -> Result<()> {
if model_version != Some(versions::text_image::P1) {
return Ok(());
}
let mut unsupported: Vec<&str> = Vec::new();
if quad == Some(true) {
unsupported.push("quad");
}
if smart_low_poly == Some(true) {
unsupported.push("smart_low_poly");
}
if generate_parts == Some(true) {
unsupported.push("generate_parts");
}
if geometry_quality.is_some() {
unsupported.push("geometry_quality");
}
if unsupported.is_empty() {
Ok(())
} else {
Err(Error::InvalidRequest(format!(
"model_version {} does not support: {}",
versions::text_image::P1,
unsupported.join(", "),
)))
}
}
pub(crate) async fn upload_image_if_path(client: &Client, img: &mut ImageInput) -> Result<()> {
if let ImageInput::Path(p) = img {
let up = client.upload_file(&*p).await?;
*img = ImageInput::FileToken(up.file_token);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn non_p1_version_skips_p1_checks() {
validate_p1_params(
None,
Some(true),
Some(true),
Some(true),
Some(&GeometryQuality::Detailed),
)
.unwrap();
validate_p1_params(
Some(versions::text_image::V3_1),
Some(true),
Some(true),
Some(true),
Some(&GeometryQuality::Detailed),
)
.unwrap();
}
#[test]
fn p1_with_no_unsupported_fields_ok() {
validate_p1_params(Some(versions::text_image::P1), None, None, None, None).unwrap();
validate_p1_params(
Some(versions::text_image::P1),
Some(false),
Some(false),
Some(false),
None,
)
.unwrap();
}
#[test]
fn p1_rejects_quad() {
let err = validate_p1_params(Some(versions::text_image::P1), Some(true), None, None, None)
.unwrap_err();
assert!(matches!(err, Error::InvalidRequest(ref m) if m.contains("quad")));
}
#[test]
fn p1_rejects_all_unsupported_together() {
let err = validate_p1_params(
Some(versions::text_image::P1),
Some(true),
Some(true),
Some(true),
Some(&GeometryQuality::Detailed),
)
.unwrap_err();
let Error::InvalidRequest(msg) = err else {
panic!("wrong variant");
};
for field in [
"quad",
"smart_low_poly",
"generate_parts",
"geometry_quality",
] {
assert!(msg.contains(field), "missing {field} in {msg}");
}
}
}