vidu-cli 0.1.0

client for vidu
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::path::Path;

fn set(items: &[&str]) -> HashSet<String> {
    items.iter().map(|s| s.to_string()).collect()
}

fn valid_task_types() -> HashSet<String> {
    set(&["text2image", "text2video", "img2video", "headtailimg2video", "reference2image", "character2video"])
}

fn valid_model_versions() -> HashSet<String> {
    set(&["3.0", "3.1", "3.2", "3.2_fast_m", "3.2_pro_m"])
}

fn resolution_support() -> HashMap<String, HashSet<String>> {
    let mut m = HashMap::new();
    m.insert("text2image".into(), set(&["1080p", "2k", "4k"]));
    m.insert("reference2image".into(), set(&["1080p", "2k", "4k"]));
    m.insert("text2video".into(), set(&["1080p"]));
    m.insert("img2video".into(), set(&["1080p"]));
    m.insert("headtailimg2video".into(), set(&["1080p"]));
    m.insert("character2video".into(), set(&["1080p"]));
    m
}

fn valid_aspect_ratios() -> HashSet<String> {
    set(&["16:9", "9:16", "1:1"])
}

fn duration_ranges() -> HashMap<String, HashMap<String, (i64, i64)>> {
    let mut m = HashMap::new();
    let mut tv = HashMap::new();
    tv.insert("3.0".into(), (5, 5));
    tv.insert("3.1".into(), (2, 8));
    tv.insert("3.2".into(), (1, 16));
    m.insert("text2video".into(), tv.clone());
    m.insert("img2video".into(), tv.clone());
    m.insert("headtailimg2video".into(), tv);
    let mut cv = HashMap::new();
    cv.insert("3.0".into(), (5, 5));
    cv.insert("3.1".into(), (2, 8));
    cv.insert("3.1_pro".into(), (-1, 8));
    cv.insert("3.2".into(), (1, 16));
    m.insert("character2video".into(), cv);
    let mut ri = HashMap::new();
    ri.insert("3.1".into(), (0, 0));
    ri.insert("3.2_fast_m".into(), (0, 0));
    ri.insert("3.2_pro_m".into(), (0, 0));
    m.insert("reference2image".into(), ri.clone());
    m.insert("text2image".into(), ri);
    m
}

fn model_support() -> HashMap<String, HashSet<String>> {
    let mut m = HashMap::new();
    m.insert("text2image".into(), set(&["3.1", "3.2_fast_m", "3.2_pro_m"]));
    m.insert("text2video".into(), set(&["3.0", "3.1", "3.2"]));
    m.insert("img2video".into(), set(&["3.0", "3.1", "3.2"]));
    m.insert("headtailimg2video".into(), set(&["3.0", "3.1", "3.2"]));
    m.insert("reference2image".into(), set(&["3.1", "3.2_fast_m", "3.2_pro_m"]));
    m.insert("character2video".into(), set(&["3.0", "3.1", "3.1_pro", "3.2"]));
    m
}

fn sorted_join(s: &HashSet<String>) -> String {
    let mut v: Vec<&str> = s.iter().map(|x| x.as_str()).collect();
    v.sort();
    v.join(", ")
}

pub fn validate_task_body(body: &Value) -> String {
    let task_type = body.get("type").and_then(|v| v.as_str()).unwrap_or("");
    if task_type.is_empty() {
        return "Missing required field: type".into();
    }
    if !valid_task_types().contains(task_type) {
        return format!("Invalid type '{}'. Valid: {}", task_type, sorted_join(&valid_task_types()));
    }

    let input_obj = body.get("input");
    let input = match input_obj {
        Some(Value::Object(_)) => input_obj.unwrap(),
        _ => return "input must be an object".into(),
    };
    let prompts = input.get("prompts");
    match prompts {
        Some(Value::Array(arr)) if !arr.is_empty() => {}
        Some(Value::Array(_)) => return "input.prompts is required and must not be empty".into(),
        _ => return "input.prompts is required and must not be empty".into(),
    }

    let settings = match body.get("settings") {
        Some(Value::Object(_)) => body.get("settings").unwrap(),
        _ => return "settings must be an object".into(),
    };

    let model_version = settings.get("model_version").and_then(|v| v.as_str()).unwrap_or("2.0");
    if !valid_model_versions().contains(model_version) {
        return format!("Invalid model_version '{}'. Valid: {}", model_version, sorted_join(&valid_model_versions()));
    }

    let ms = model_support();
    if let Some(supported) = ms.get(task_type) {
        if !supported.contains(model_version) {
            return format!("model_version {} does not support {}", model_version, task_type);
        }
    }

    let dr = duration_ranges();
    if let Some(type_ranges) = dr.get(task_type) {
        let duration = settings.get("duration").and_then(|v| v.as_i64()).unwrap_or(0);
        if let Some(&(min_d, max_d)) = type_ranges.get(model_version) {
            if min_d > 0 && (duration < min_d || duration > max_d) {
                return format!("duration {} out of range [{}, {}] for {} with {}", duration, min_d, max_d, task_type, model_version);
            }
        }
    }

    if settings.get("resolution").is_none() {
        return format!("resolution is required for {}", task_type);
    }
    let res = settings.get("resolution").and_then(|v| v.as_str()).unwrap_or("");
    let rs = resolution_support();
    let valid_res = rs.get(task_type).cloned().unwrap_or_else(|| set(&["1080p"]));
    if !valid_res.contains(res) {
        return format!("Invalid resolution '{}' for {}. Valid: {}", res, task_type, sorted_join(&valid_res));
    }

    let no_ar_types = ["img2video", "headtailimg2video", "text2image"];
    if !no_ar_types.contains(&task_type) {
        if let Some(ar_val) = settings.get("aspect_ratio").and_then(|v| v.as_str()) {
            if !valid_aspect_ratios().contains(ar_val) {
                return format!("Invalid aspect_ratio '{}'. Valid: {}", ar_val, sorted_join(&valid_aspect_ratios()));
            }
        }
    }

    if settings.get("transition").is_some() {
        let no_trans = ["reference2image", "character2video", "text2image"];
        if no_trans.contains(&task_type) {
            return format!("{} should not include transition", task_type);
        }
        if task_type == "text2video" && model_version != "3.2" {
            return format!("text2video with {} should not include transition (only 3.2 supports pro/speed)", model_version);
        }
        if task_type == "img2video" || task_type == "headtailimg2video" {
            let transition = settings.get("transition").and_then(|v| v.as_str()).unwrap_or("");
            let valid_trans = if model_version == "3.0" {
                set(&["creative", "stable"])
            } else {
                set(&["pro", "speed"])
            };
            if !valid_trans.contains(transition) {
                return format!("Invalid transition '{}' for {} {}. Valid: {}", transition, task_type, model_version, sorted_join(&valid_trans));
            }
        }
    }

    if input.get("enhance").is_none() {
        return "input.enhance is required (true or false)".into();
    }

    String::new()
}

pub fn validate_element_preprocess(body: &Value) -> String {
    if !body.is_object() {
        return "body must be an object".into();
    }
    if body.get("components").is_none() {
        return "Missing required field: components".into();
    }
    if body.get("name").is_none() {
        return "Missing required field: name".into();
    }
    if body.get("type").is_none() {
        return "Missing required field: type".into();
    }
    let components = match body.get("components") {
        Some(Value::Array(arr)) if !arr.is_empty() => arr,
        _ => return "components must be a non-empty array".into(),
    };
    if components.len() > 3 {
        return "components must have at most 3 items".into();
    }
    let main_count = components.iter().filter(|c| c.get("type").and_then(|v| v.as_str()) == Some("main")).count();
    if main_count != 1 {
        return "components must have exactly one item with type='main'".into();
    }
    String::new()
}

pub fn validate_element_create(body: &Value) -> String {
    if !body.is_object() {
        return "body must be an object".into();
    }
    if body.get("id").is_none() {
        return "Missing required field: id (from preprocess response)".into();
    }
    if body.get("name").is_none() {
        return "Missing required field: name".into();
    }
    if body.get("modality").is_none() {
        return "Missing required field: modality".into();
    }
    String::new()
}

pub fn validate_image_file(path: &str) -> String {
    let p = Path::new(path);
    if !p.is_file() {
        return format!("Image file not found: {}", path);
    }
    String::new()
}