Skip to main content

mold_core/
validation.rs

1use crate::{
2    GenerateRequest, KeyframeCondition, LoraWeight, Ltx2PipelineMode, OutputFormat, UpscaleRequest,
3};
4
5/// Maximum total pixels allowed (~1.8 megapixels). Qwen-Image trains at ~1.6MP
6/// (1328x1328), other models at ≤1MP. Headroom for non-square aspect ratios.
7pub const MAX_PIXELS: u64 = 1_800_000;
8pub const MAX_INLINE_AUDIO_BYTES: usize = 64 * 1024 * 1024;
9pub const MAX_INLINE_SOURCE_VIDEO_BYTES: usize = 64 * 1024 * 1024;
10
11fn megapixel_limit_label() -> String {
12    format!("{:.1}MP", MAX_PIXELS as f64 / 1_000_000.0)
13}
14
15fn mib_label(bytes: usize) -> String {
16    format!("{:.0} MiB", bytes as f64 / (1024.0 * 1024.0))
17}
18
19/// Clamp dimensions to fit within the megapixel limit, preserving aspect ratio.
20/// Both dimensions are rounded down to multiples of 16.
21/// Returns the original dimensions unchanged if already within limits.
22pub fn clamp_to_megapixel_limit(w: u32, h: u32) -> (u32, u32) {
23    let pixels = w as u64 * h as u64;
24    if pixels <= MAX_PIXELS {
25        return (w, h);
26    }
27    let scale = (MAX_PIXELS as f64 / pixels as f64).sqrt();
28    let new_w = ((w as f64 * scale) as u32 / 16) * 16;
29    let new_h = ((h as f64 * scale) as u32 / 16) * 16;
30    // Ensure we don't produce zero dimensions
31    (new_w.max(16), new_h.max(16))
32}
33
34/// Fit source image dimensions into a model's native resolution bounding box,
35/// preserving aspect ratio.
36///
37/// The model's default width/height define the bounding box. The source image's
38/// aspect ratio is preserved:
39/// - If the source is wider than the model bounds, width is set to `model_w` and
40///   height is scaled proportionally.
41/// - If the source is taller, height is set to `model_h` and width is scaled.
42/// - If the source fits entirely within model bounds (same aspect ratio as the
43///   model), the model's native dimensions are used as the output. For sources
44///   with a different aspect ratio, the output fills the limiting axis at model
45///   scale while keeping the other axis within bounds.
46///
47/// Output is rounded to 16px alignment and clamped to the megapixel limit.
48pub fn fit_to_model_dimensions(src_w: u32, src_h: u32, model_w: u32, model_h: u32) -> (u32, u32) {
49    let src_ratio = src_w as f64 / src_h as f64;
50    let model_ratio = model_w as f64 / model_h as f64;
51
52    let (w, h) = if src_ratio > model_ratio {
53        // Source is wider: width-limited
54        (model_w as f64, model_w as f64 / src_ratio)
55    } else {
56        // Source is taller or same: height-limited
57        (model_h as f64 * src_ratio, model_h as f64)
58    };
59
60    let w = ((w as u32) / 16 * 16).max(16);
61    let h = ((h as u32) / 16 * 16).max(16);
62    clamp_to_megapixel_limit(w, h)
63}
64
65/// Resize dimensions toward a target pixel area while preserving aspect ratio.
66///
67/// The result is rounded to the requested alignment and clamped to the shared
68/// megapixel safety limit.
69pub fn fit_to_target_area(src_w: u32, src_h: u32, target_area: u32, align: u32) -> (u32, u32) {
70    let src_w = src_w.max(1);
71    let src_h = src_h.max(1);
72    let align = align.max(1);
73    let scale = (f64::from(target_area) / (f64::from(src_w) * f64::from(src_h))).sqrt();
74    let width = ((f64::from(src_w) * scale) / f64::from(align)).round() as u32 * align;
75    let height = ((f64::from(src_h) * scale) / f64::from(align)).round() as u32 * align;
76    clamp_to_megapixel_limit(width.max(align), height.max(align))
77}
78
79/// Check whether `data` starts with a recognized image format magic bytes (PNG or JPEG).
80fn is_valid_image_format(data: &[u8]) -> bool {
81    let is_png = data.len() >= 4 && data[..4] == [0x89, 0x50, 0x4E, 0x47];
82    let is_jpeg = data.len() >= 2 && data[..2] == [0xFF, 0xD8];
83    is_png || is_jpeg
84}
85
86fn model_family(model_name: &str) -> Option<&str> {
87    crate::manifest::find_manifest(model_name)
88        .map(|m| m.family.as_str())
89        .or_else(|| {
90            if model_name.starts_with("qwen-image-edit") {
91                Some("qwen-image-edit")
92            } else if model_name.starts_with("qwen-image") {
93                Some("qwen-image")
94            } else {
95                None
96            }
97        })
98}
99
100fn validate_lora_weight(lora: &LoraWeight, field_name: &str) -> Result<(), String> {
101    if lora.scale < 0.0 || lora.scale > 2.0 {
102        return Err(format!(
103            "{field_name} scale ({}) must be in range [0.0, 2.0]",
104            lora.scale
105        ));
106    }
107    if !lora.path.ends_with(".safetensors") && !lora.path.starts_with("camera-control:") {
108        return Err(format!(
109            "{field_name} file must be a .safetensors file or camera-control preset"
110        ));
111    }
112    Ok(())
113}
114
115fn validate_keyframes(
116    keyframes: &[KeyframeCondition],
117    frames: Option<u32>,
118    family: Option<&str>,
119) -> Result<(), String> {
120    match family {
121        Some("ltx2") => {}
122        None => {
123            return Err(
124                "unknown model family; keyframes are only supported for LTX-2 / LTX-2.3 models"
125                    .to_string(),
126            );
127        }
128        _ => {
129            return Err("keyframes are only supported for LTX-2 / LTX-2.3 models".to_string());
130        }
131    }
132    if keyframes.is_empty() {
133        return Err("keyframes must not be empty".to_string());
134    }
135
136    let mut seen = std::collections::BTreeSet::new();
137    for keyframe in keyframes {
138        if !is_valid_image_format(&keyframe.image) {
139            return Err("keyframes must contain only PNG or JPEG images".to_string());
140        }
141        if let Some(total_frames) = frames {
142            if keyframe.frame >= total_frames {
143                return Err(format!(
144                    "keyframe frame ({}) must be less than frames ({total_frames})",
145                    keyframe.frame
146                ));
147            }
148        }
149        if !seen.insert(keyframe.frame) {
150            return Err(format!("duplicate keyframe frame: {}", keyframe.frame));
151        }
152    }
153
154    Ok(())
155}
156
157fn require_ltx2_family(family: Option<&str>, feature_name: &str) -> Result<(), String> {
158    match family {
159        Some("ltx2") => Ok(()),
160        None => Err(format!(
161            "unknown model family; {feature_name} is only supported for LTX-2 / LTX-2.3 models"
162        )),
163        _ => Err(format!(
164            "{feature_name} is only supported for LTX-2 / LTX-2.3 models"
165        )),
166    }
167}
168
169fn validate_inline_media_size(
170    bytes: &[u8],
171    field_name: &str,
172    max_bytes: usize,
173) -> Result<(), String> {
174    if bytes.len() > max_bytes {
175        return Err(format!(
176            "{field_name} exceeds the {} inline request limit (got {:.1} MiB)",
177            mib_label(max_bytes),
178            bytes.len() as f64 / (1024.0 * 1024.0)
179        ));
180    }
181    Ok(())
182}
183
184/// Validate a generate request. Returns `Ok(())` if valid, or an error message.
185/// Shared between the HTTP server and local CLI inference paths.
186pub fn validate_generate_request(req: &GenerateRequest) -> Result<(), String> {
187    let family = model_family(&req.model);
188
189    if req.prompt.trim().is_empty() {
190        return Err("prompt must not be empty".to_string());
191    }
192    if req.width == 0 || req.height == 0 {
193        return Err("width and height must be > 0".to_string());
194    }
195    if !req.width.is_multiple_of(16) || !req.height.is_multiple_of(16) {
196        return Err(format!(
197            "width ({}) and height ({}) must be multiples of 16 (FLUX patchification requirement)",
198            req.width, req.height
199        ));
200    }
201    // Cap by total pixel count rather than per-dimension to allow portrait/landscape.
202    // 896x1152 = 1.03M, 1024x1024 = 1.05M, 1280x768 = 0.98M — all fine.
203    // 1408x1408 = 1.98M — too large, OOMs on VAE decode.
204    let pixels = req.width as u64 * req.height as u64;
205    if pixels > MAX_PIXELS {
206        return Err(format!(
207            "{}x{} = {:.2} megapixels exceeds the {} limit (VAE VRAM constraint)",
208            req.width,
209            req.height,
210            pixels as f64 / 1_000_000.0,
211            megapixel_limit_label()
212        ));
213    }
214    if req.steps == 0 {
215        return Err("steps must be >= 1".to_string());
216    }
217    if req.steps > 100 {
218        return Err(format!("steps ({}) must be <= 100", req.steps));
219    }
220    if req.batch_size == 0 {
221        return Err("batch_size must be >= 1".to_string());
222    }
223    // No upper limit on batch_size — users can batch as many as they want.
224    if req.guidance < 0.0 {
225        return Err(format!("guidance ({}) must be >= 0.0", req.guidance));
226    }
227    if req.guidance > 100.0 {
228        return Err(format!("guidance ({}) must be <= 100.0", req.guidance));
229    }
230    if req.prompt.len() > 77_000 {
231        return Err(format!(
232            "prompt length ({} bytes) exceeds the 77,000-byte limit",
233            req.prompt.len()
234        ));
235    }
236    if let Some(ref neg) = req.negative_prompt {
237        if neg.len() > 77_000 {
238            return Err(format!(
239                "negative_prompt length ({} bytes) exceeds the 77,000-byte limit",
240                neg.len()
241            ));
242        }
243    }
244    if family == Some("qwen-image-edit") {
245        if req.edit_images.as_ref().is_none_or(Vec::is_empty) {
246            return Err("qwen-image-edit requires edit_images to be provided".to_string());
247        }
248        if req.batch_size != 1 {
249            return Err("qwen-image-edit only supports batch_size = 1".to_string());
250        }
251        if req.source_image.is_some() {
252            return Err("qwen-image-edit uses edit_images instead of source_image".to_string());
253        }
254        if req.mask_image.is_some() {
255            return Err("qwen-image-edit does not support mask_image".to_string());
256        }
257        if req.control_image.is_some() || req.control_model.is_some() {
258            return Err("qwen-image-edit does not support ControlNet inputs".to_string());
259        }
260        if let Some(ref images) = req.edit_images {
261            for image in images {
262                if !is_valid_image_format(image) {
263                    return Err("edit_images must contain only PNG or JPEG images".to_string());
264                }
265            }
266        }
267    } else if req.edit_images.is_some() {
268        return Err("edit_images are only supported for qwen-image-edit models".to_string());
269    }
270    // img2img validation
271    if let Some(ref img) = req.source_image {
272        if req.strength < 0.0 || req.strength > 1.0 {
273            return Err(format!(
274                "strength ({}) must be in range [0.0, 1.0] when source_image is provided",
275                req.strength
276            ));
277        }
278        if !is_valid_image_format(img) {
279            return Err("source_image must be a PNG or JPEG image".to_string());
280        }
281    }
282    // ControlNet validation
283    if let Some(ref ctrl) = req.control_image {
284        if req.control_model.is_none() {
285            return Err("control_image requires control_model to also be provided".to_string());
286        }
287        if !is_valid_image_format(ctrl) {
288            return Err("control_image must be a PNG or JPEG image".to_string());
289        }
290        if req.control_scale < 0.0 {
291            return Err(format!(
292                "control_scale ({}) must be >= 0.0",
293                req.control_scale
294            ));
295        }
296    }
297    if req.control_model.is_some() && req.control_image.is_none() {
298        return Err("control_model requires control_image to also be provided".to_string());
299    }
300    // Inpainting validation
301    if let Some(ref mask) = req.mask_image {
302        if req.source_image.is_none() {
303            return Err("mask_image requires source_image to also be provided".to_string());
304        }
305        if !is_valid_image_format(mask) {
306            return Err("mask_image must be a PNG or JPEG image".to_string());
307        }
308    }
309    // LoRA validation (format checks only — path existence is checked at the
310    // inference layer, since in remote mode the path refers to the server filesystem).
311    if let Some(ref lora) = req.lora {
312        validate_lora_weight(lora, "lora")?;
313    }
314    if let Some(ref loras) = req.loras {
315        if loras.is_empty() {
316            return Err("loras must not be empty when provided".to_string());
317        }
318        for lora in loras {
319            validate_lora_weight(lora, "loras")?;
320        }
321    }
322    // Video frame validation
323    if let Some(frames) = req.frames {
324        if frames == 0 {
325            return Err("frames must be >= 1".to_string());
326        }
327        if matches!(family, Some("ltx-video" | "ltx2")) && frames > 1 && (frames - 1) % 8 != 0 {
328            return Err(format!(
329                "frames ({frames}) must be 8n+1 for current LTX-Video / LTX-2 models (e.g. 9, 17, 25, 33, 41, 49, …)"
330            ));
331        }
332        if frames > 257 {
333            return Err(format!("frames ({frames}) must be <= 257"));
334        }
335    }
336    if let Some(fps) = req.fps {
337        if fps == 0 {
338            return Err("fps must be >= 1".to_string());
339        }
340        if fps > 120 {
341            return Err(format!("fps ({fps}) must be <= 120"));
342        }
343    }
344    if let Some(keyframes) = &req.keyframes {
345        validate_keyframes(keyframes, req.frames, family)?;
346    }
347    if let Some(audio) = &req.audio_file {
348        require_ltx2_family(family, "audio_file")?;
349        if audio.is_empty() {
350            return Err("audio_file must not be empty".to_string());
351        }
352        validate_inline_media_size(audio, "audio_file", MAX_INLINE_AUDIO_BYTES)?;
353    }
354    if let Some(video) = &req.source_video {
355        require_ltx2_family(family, "source_video")?;
356        if video.is_empty() {
357            return Err("source_video must not be empty".to_string());
358        }
359        validate_inline_media_size(video, "source_video", MAX_INLINE_SOURCE_VIDEO_BYTES)?;
360    }
361    if req.enable_audio.is_some() {
362        require_ltx2_family(family, "enable_audio")?;
363    }
364    if req.retake_range.is_some() {
365        require_ltx2_family(family, "retake_range")?;
366    }
367    if req.spatial_upscale.is_some() {
368        require_ltx2_family(family, "spatial_upscale")?;
369    }
370    if req.temporal_upscale.is_some() {
371        require_ltx2_family(family, "temporal_upscale")?;
372    }
373    if req.pipeline.is_some() {
374        require_ltx2_family(family, "pipeline")?;
375    }
376
377    if family == Some("ltx2") {
378        match req.output_format {
379            OutputFormat::Gif | OutputFormat::Apng | OutputFormat::Webp | OutputFormat::Mp4 => {}
380            _ => return Err("LTX-2 outputs must use mp4, gif, apng, or webp".to_string()),
381        }
382
383        if req.enable_audio == Some(true) && req.output_format != OutputFormat::Mp4 {
384            return Err("audio-enabled LTX-2 outputs must use mp4 format".to_string());
385        }
386
387        if req.retake_range.is_some() && req.source_video.is_none() {
388            return Err("retake_range requires source_video to also be provided".to_string());
389        }
390
391        if let Some(range) = &req.retake_range {
392            if !(range.start_seconds.is_finite() && range.end_seconds.is_finite()) {
393                return Err("retake_range values must be finite numbers".to_string());
394            }
395            if range.start_seconds < 0.0 {
396                return Err("retake_range start_seconds must be >= 0.0".to_string());
397            }
398            if range.end_seconds <= range.start_seconds {
399                return Err(
400                    "retake_range end_seconds must be greater than start_seconds".to_string(),
401                );
402            }
403        }
404
405        if let Some(pipeline) = req.pipeline {
406            match pipeline {
407                Ltx2PipelineMode::A2Vid => {
408                    if req.audio_file.is_none() {
409                        return Err("pipeline=a2vid requires audio_file".to_string());
410                    }
411                }
412                Ltx2PipelineMode::Retake => {
413                    if req.source_video.is_none() {
414                        return Err("pipeline=retake requires source_video".to_string());
415                    }
416                    if req.retake_range.is_none() {
417                        return Err("pipeline=retake requires retake_range".to_string());
418                    }
419                }
420                Ltx2PipelineMode::Keyframe => {
421                    let keyframe_count = req.keyframes.as_ref().map_or(0, Vec::len);
422                    if keyframe_count < 2 {
423                        return Err("pipeline=keyframe requires at least 2 keyframes".to_string());
424                    }
425                }
426                Ltx2PipelineMode::IcLora => {
427                    if req.source_video.is_none() {
428                        return Err("pipeline=ic-lora requires source_video".to_string());
429                    }
430                    if req.lora.is_none() && req.loras.as_ref().is_none_or(Vec::is_empty) {
431                        return Err("pipeline=ic-lora requires at least one LoRA".to_string());
432                    }
433                }
434                Ltx2PipelineMode::OneStage
435                | Ltx2PipelineMode::TwoStage
436                | Ltx2PipelineMode::TwoStageHq
437                | Ltx2PipelineMode::Distilled => {}
438            }
439        }
440    }
441
442    Ok(())
443}
444
445/// Validate an upscale request. Returns `Ok(())` if valid, or an error message.
446pub fn validate_upscale_request(req: &UpscaleRequest) -> Result<(), String> {
447    if req.model.trim().is_empty() {
448        return Err("upscale model must not be empty".to_string());
449    }
450    if req.image.is_empty() {
451        return Err("upscale image must not be empty".to_string());
452    }
453    if !is_valid_image_format(&req.image) {
454        return Err("upscale image must be a PNG or JPEG image".to_string());
455    }
456    if let Some(tile_size) = req.tile_size {
457        if tile_size != 0 && tile_size < 64 {
458            return Err(format!(
459                "tile_size ({tile_size}) must be 0 (disabled) or >= 64"
460            ));
461        }
462    }
463    Ok(())
464}
465
466// ── Dimension recommendations ───────────────────────────────────────────────
467
468/// Recommended (width, height) pairs for SD1.5 models (native 512x512).
469const SD15_DIMS: &[(u32, u32)] = &[(512, 512), (512, 768), (768, 512), (384, 512), (512, 384)];
470
471/// Official SDXL training buckets from Stability AI (native 1024x1024).
472const SDXL_DIMS: &[(u32, u32)] = &[
473    (1024, 1024),
474    (1152, 896),
475    (896, 1152),
476    (1216, 832),
477    (832, 1216),
478    (1344, 768),
479    (768, 1344),
480    (1536, 640),
481    (640, 1536),
482];
483
484/// Recommended dimensions for SD3.5 models (native 1024x1024).
485const SD3_DIMS: &[(u32, u32)] = &[
486    (1024, 1024),
487    (1152, 896),
488    (896, 1152),
489    (1216, 832),
490    (832, 1216),
491    (1344, 768),
492    (768, 1344),
493];
494
495/// Recommended dimensions for FLUX models (native 1024x1024).
496const FLUX_DIMS: &[(u32, u32)] = &[
497    (1024, 1024),
498    (1024, 768),
499    (768, 1024),
500    (1024, 576),
501    (576, 1024),
502    (768, 768),
503];
504
505/// Recommended dimensions for Z-Image models (native 1024x1024).
506const ZIMAGE_DIMS: &[(u32, u32)] = &[(1024, 1024), (1024, 768), (768, 1024)];
507
508/// Recommended dimensions for Qwen-Image models (native 1328x1328, ~1.76MP max).
509/// Supports dynamic resolution — any dims divisible by 16 within the megapixel budget work,
510/// but these are the standard aspect-ratio buckets.
511const QWEN_IMAGE_DIMS: &[(u32, u32)] = &[
512    (1328, 1328), // 1:1 (native)
513    (1024, 1024), // 1:1
514    (1152, 896),  // 9:7
515    (896, 1152),  // 7:9
516    (1216, 832),  // 19:13
517    (832, 1216),  // 13:19
518    (1344, 768),  // 7:4
519    (768, 1344),  // 4:7
520    (1664, 928),  // ~16:9
521    (928, 1664),  // ~9:16
522    (768, 768),   // 1:1 (small)
523    (512, 512),   // 1:1 (small, fast)
524];
525
526/// Recommended dimensions for Wuerstchen models (native 1024x1024).
527const WUERSTCHEN_DIMS: &[(u32, u32)] = &[(1024, 1024)];
528
529/// Recommended dimensions for LTX Video models (native 768x512).
530/// LTX Video requires dimensions divisible by 32 (patchification).
531const LTX_VIDEO_DIMS: &[(u32, u32)] = &[
532    (768, 512),  // 3:2 (native)
533    (512, 512),  // 1:1
534    (1024, 576), // 16:9
535    (576, 1024), // 9:16
536    (768, 768),  // 1:1
537    (512, 768),  // 2:3
538];
539
540/// Return the list of recommended (width, height) pairs for a model family.
541///
542/// Returns an empty slice for unknown families, utility models (e.g. `qwen3-expand`),
543/// and conditioning models (e.g. ControlNet).
544pub fn recommended_dimensions(family: &str) -> &'static [(u32, u32)] {
545    match family {
546        "sd15" => SD15_DIMS,
547        "sdxl" => SDXL_DIMS,
548        "sd3" => SD3_DIMS,
549        "flux" => FLUX_DIMS,
550        "flux2" => FLUX_DIMS,
551        "z-image" => ZIMAGE_DIMS,
552        "qwen-image" => QWEN_IMAGE_DIMS,
553        "qwen-image-edit" => QWEN_IMAGE_DIMS,
554        "wuerstchen" => WUERSTCHEN_DIMS,
555        "ltx-video" => LTX_VIDEO_DIMS,
556        _ => &[],
557    }
558}
559
560/// Check if the requested dimensions match any recommended resolution for the model family.
561///
562/// Returns `None` if the dimensions are recommended or the family has no recommendation list.
563/// Returns `Some(warning_message)` with suggested alternatives otherwise.
564pub fn dimension_warning(width: u32, height: u32, family: &str) -> Option<String> {
565    let dims = recommended_dimensions(family);
566    if dims.is_empty() {
567        return None;
568    }
569    if dims.contains(&(width, height)) {
570        return None;
571    }
572    // Build a compact list of suggested alternatives (show up to 4)
573    let suggestions: Vec<String> = dims
574        .iter()
575        .take(4)
576        .map(|(w, h)| format!("{w}x{h}"))
577        .collect();
578    let more = if dims.len() > 4 {
579        format!(", ... ({} total)", dims.len())
580    } else {
581        String::new()
582    };
583    Some(format!(
584        "{width}x{height} is not a recommended resolution for {family} models. \
585         Suggested: {}{}",
586        suggestions.join(", "),
587        more,
588    ))
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594    use crate::OutputFormat;
595
596    fn valid_req() -> GenerateRequest {
597        GenerateRequest {
598            prompt: "a red apple".to_string(),
599            negative_prompt: None,
600            model: "test-model".to_string(),
601            width: 1024,
602            height: 1024,
603            steps: 4,
604            guidance: 0.0,
605            seed: Some(42),
606            batch_size: 1,
607            output_format: OutputFormat::Png,
608            embed_metadata: None,
609            scheduler: None,
610            source_image: None,
611            edit_images: None,
612            strength: 0.75,
613            mask_image: None,
614            control_image: None,
615            control_model: None,
616            control_scale: 1.0,
617            expand: None,
618            original_prompt: None,
619            lora: None,
620            frames: None,
621            fps: None,
622            upscale_model: None,
623            gif_preview: false,
624            enable_audio: None,
625            audio_file: None,
626            source_video: None,
627            keyframes: None,
628            pipeline: None,
629            loras: None,
630            retake_range: None,
631            spatial_upscale: None,
632            temporal_upscale: None,
633        }
634    }
635
636    /// Minimal valid PNG header bytes for testing.
637    fn png_bytes() -> Vec<u8> {
638        vec![0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]
639    }
640
641    /// Minimal valid JPEG header bytes for testing.
642    fn jpeg_bytes() -> Vec<u8> {
643        vec![0xFF, 0xD8, 0xFF, 0xE0]
644    }
645
646    // ── clamp_to_megapixel_limit tests ──────────────────────────────────────
647
648    #[test]
649    fn clamp_noop_within_limit() {
650        assert_eq!(super::clamp_to_megapixel_limit(1024, 1024), (1024, 1024));
651    }
652
653    #[test]
654    fn clamp_noop_qwen_image_native_resolution() {
655        // Qwen-Image trains at 1328x1328 (~1.76MP), must fit within MAX_PIXELS
656        assert_eq!(super::clamp_to_megapixel_limit(1328, 1328), (1328, 1328));
657    }
658
659    #[test]
660    fn clamp_noop_qwen_image_landscape() {
661        // Qwen-Image 16:9 training resolution (1664x928 = ~1.54MP)
662        assert_eq!(super::clamp_to_megapixel_limit(1664, 928), (1664, 928));
663    }
664
665    #[test]
666    fn clamp_downscales_oversized() {
667        let (w, h) = super::clamp_to_megapixel_limit(1888, 1168);
668        assert!(w % 16 == 0 && h % 16 == 0, "must be multiples of 16");
669        let pixels = w as u64 * h as u64;
670        assert!(
671            pixels <= super::MAX_PIXELS,
672            "must be within limit: {pixels}"
673        );
674        // Aspect ratio roughly preserved
675        let orig_ratio = 1888.0 / 1168.0;
676        let new_ratio = w as f64 / h as f64;
677        assert!(
678            (orig_ratio - new_ratio).abs() < 0.05,
679            "aspect ratio drift too large"
680        );
681    }
682
683    #[test]
684    fn clamp_large_square() {
685        let (w, h) = super::clamp_to_megapixel_limit(2048, 2048);
686        assert!(w % 16 == 0 && h % 16 == 0);
687        assert!(w as u64 * h as u64 <= super::MAX_PIXELS);
688    }
689
690    #[test]
691    fn clamp_extreme_aspect_ratio() {
692        let (w, h) = super::clamp_to_megapixel_limit(4096, 256);
693        assert!(w % 16 == 0 && h % 16 == 0);
694        assert!(w as u64 * h as u64 <= super::MAX_PIXELS);
695        assert!(w > h, "should remain landscape");
696    }
697
698    // ── validate_generate_request tests ──────────────────────────────────────
699
700    #[test]
701    fn valid_request_passes() {
702        assert!(validate_generate_request(&valid_req()).is_ok());
703    }
704
705    #[test]
706    fn ltx2_audio_requires_mp4() {
707        let mut req = valid_req();
708        req.model = "ltx-2-19b-distilled:fp8".to_string();
709        req.output_format = OutputFormat::Gif;
710        req.enable_audio = Some(true);
711        assert!(validate_generate_request(&req).unwrap_err().contains("mp4"));
712    }
713
714    #[test]
715    fn ltx2_retake_requires_source_video() {
716        let mut req = valid_req();
717        req.model = "ltx-2-19b-distilled:fp8".to_string();
718        req.output_format = OutputFormat::Mp4;
719        req.retake_range = Some(crate::TimeRange {
720            start_seconds: 0.0,
721            end_seconds: 1.0,
722        });
723        assert!(validate_generate_request(&req)
724            .unwrap_err()
725            .contains("source_video"));
726    }
727
728    #[test]
729    fn ltx2_audio_file_rejects_inline_payloads_above_limit() {
730        let mut req = valid_req();
731        req.model = "ltx-2-19b-distilled:fp8".to_string();
732        req.output_format = OutputFormat::Mp4;
733        req.audio_file = Some(vec![0; MAX_INLINE_AUDIO_BYTES + 1]);
734        let err = validate_generate_request(&req).unwrap_err();
735        assert!(err.contains("audio_file exceeds"), "got: {err}");
736        assert!(err.contains("64 MiB"), "got: {err}");
737    }
738
739    #[test]
740    fn ltx2_source_video_rejects_inline_payloads_above_limit() {
741        let mut req = valid_req();
742        req.model = "ltx-2-19b-distilled:fp8".to_string();
743        req.output_format = OutputFormat::Mp4;
744        req.source_video = Some(vec![0; MAX_INLINE_SOURCE_VIDEO_BYTES + 1]);
745        let err = validate_generate_request(&req).unwrap_err();
746        assert!(err.contains("source_video exceeds"), "got: {err}");
747        assert!(err.contains("64 MiB"), "got: {err}");
748    }
749
750    #[test]
751    fn ltx2_keyframe_pipeline_requires_multiple_keyframes() {
752        let mut req = valid_req();
753        req.model = "ltx-2-19b-distilled:fp8".to_string();
754        req.output_format = OutputFormat::Mp4;
755        req.pipeline = Some(crate::Ltx2PipelineMode::Keyframe);
756        req.frames = Some(17);
757        req.keyframes = Some(vec![crate::KeyframeCondition {
758            frame: 0,
759            image: png_bytes(),
760        }]);
761        assert!(validate_generate_request(&req)
762            .unwrap_err()
763            .contains("at least 2 keyframes"));
764    }
765
766    #[test]
767    fn keyframes_on_unknown_family_report_unknown_model_family() {
768        let mut req = valid_req();
769        req.model = "private-ltx2-style-model".to_string();
770        req.frames = Some(17);
771        req.keyframes = Some(vec![
772            crate::KeyframeCondition {
773                frame: 0,
774                image: png_bytes(),
775            },
776            crate::KeyframeCondition {
777                frame: 16,
778                image: png_bytes(),
779            },
780        ]);
781        let err = validate_generate_request(&req).unwrap_err();
782        assert!(err.contains("unknown model family"), "got: {err}");
783    }
784
785    #[test]
786    fn ltx2_allows_temporal_upscale_request() {
787        let mut req = valid_req();
788        req.model = "ltx-2-19b-distilled:fp8".to_string();
789        req.output_format = OutputFormat::Mp4;
790        req.temporal_upscale = Some(crate::Ltx2TemporalUpscale::X2);
791        validate_generate_request(&req).unwrap();
792    }
793
794    #[test]
795    fn ltx2_allows_x1_5_spatial_upscale_request() {
796        let mut req = valid_req();
797        req.model = "ltx-2.3-22b-distilled:fp8".to_string();
798        req.output_format = OutputFormat::Mp4;
799        req.spatial_upscale = Some(crate::Ltx2SpatialUpscale::X1_5);
800        validate_generate_request(&req).unwrap();
801    }
802
803    #[test]
804    fn empty_prompt_rejected() {
805        let mut req = valid_req();
806        req.prompt = "   ".to_string();
807        assert!(validate_generate_request(&req)
808            .unwrap_err()
809            .contains("prompt"));
810    }
811
812    #[test]
813    fn zero_dimensions_rejected() {
814        let mut req = valid_req();
815        req.width = 0;
816        assert!(validate_generate_request(&req).is_err());
817        req.width = 1024;
818        req.height = 0;
819        assert!(validate_generate_request(&req).is_err());
820    }
821
822    #[test]
823    fn dimensions_must_be_multiple_of_16() {
824        let mut req = valid_req();
825        req.width = 513; // not multiple of 16
826        assert!(validate_generate_request(&req)
827            .unwrap_err()
828            .contains("multiples of 16"));
829    }
830
831    #[test]
832    fn valid_non_square_dimensions() {
833        let mut req = valid_req();
834        req.width = 512;
835        req.height = 768;
836        assert!(validate_generate_request(&req).is_ok());
837    }
838
839    #[test]
840    fn oversized_image_rejected() {
841        let mut req = valid_req();
842        req.width = 1408;
843        req.height = 1408; // ~1.98MP > 1.8MP limit
844        assert!(validate_generate_request(&req)
845            .unwrap_err()
846            .contains("megapixels"));
847    }
848
849    #[test]
850    fn oversized_image_error_reports_current_megapixel_limit() {
851        let mut req = valid_req();
852        req.width = 1408;
853        req.height = 1408;
854        let err = validate_generate_request(&req).unwrap_err();
855        assert!(err.contains("1.8MP"), "got: {err}");
856    }
857
858    #[test]
859    fn zero_steps_rejected() {
860        let mut req = valid_req();
861        req.steps = 0;
862        assert!(validate_generate_request(&req).is_err());
863    }
864
865    #[test]
866    fn excessive_steps_rejected() {
867        let mut req = valid_req();
868        req.steps = 101;
869        assert!(validate_generate_request(&req).is_err());
870    }
871
872    #[test]
873    fn valid_step_counts() {
874        for steps in [1, 4, 20, 28, 50, 100] {
875            let mut req = valid_req();
876            req.steps = steps;
877            assert!(
878                validate_generate_request(&req).is_ok(),
879                "steps={steps} should be valid"
880            );
881        }
882    }
883
884    #[test]
885    fn ltx2_frames_must_still_follow_8n_plus_1() {
886        let mut req = valid_req();
887        req.model = "ltx-2-19b-distilled:fp8".to_string();
888        req.output_format = OutputFormat::Mp4;
889        req.frames = Some(10);
890        let err = validate_generate_request(&req).unwrap_err();
891        assert!(err.contains("8n+1"), "got: {err}");
892        assert!(err.contains("LTX-Video / LTX-2"), "got: {err}");
893    }
894
895    #[test]
896    fn non_ltx_models_do_not_apply_the_ltx_frame_grid_rule() {
897        let mut req = valid_req();
898        req.frames = Some(10);
899        assert!(validate_generate_request(&req).is_ok());
900    }
901
902    #[test]
903    fn zero_batch_rejected() {
904        let mut req = valid_req();
905        req.batch_size = 0;
906        assert!(validate_generate_request(&req).is_err());
907    }
908
909    #[test]
910    fn large_batch_accepted() {
911        let mut req = valid_req();
912        req.batch_size = 100;
913        assert!(validate_generate_request(&req).is_ok());
914    }
915
916    #[test]
917    fn negative_guidance_rejected() {
918        let mut req = valid_req();
919        req.guidance = -1.0;
920        assert!(validate_generate_request(&req).is_err());
921    }
922
923    #[test]
924    fn zero_guidance_valid() {
925        let mut req = valid_req();
926        req.guidance = 0.0;
927        assert!(validate_generate_request(&req).is_ok());
928    }
929
930    #[test]
931    fn high_guidance_valid() {
932        let mut req = valid_req();
933        req.guidance = 20.0;
934        assert!(validate_generate_request(&req).is_ok());
935    }
936
937    #[test]
938    fn guidance_over_100_rejected() {
939        let mut req = valid_req();
940        req.guidance = 100.1;
941        assert!(validate_generate_request(&req)
942            .unwrap_err()
943            .contains("guidance"));
944    }
945
946    #[test]
947    fn guidance_at_100_valid() {
948        let mut req = valid_req();
949        req.guidance = 100.0;
950        assert!(validate_generate_request(&req).is_ok());
951    }
952
953    #[test]
954    fn prompt_too_long_rejected() {
955        let mut req = valid_req();
956        req.prompt = "x".repeat(77_001);
957        assert!(validate_generate_request(&req)
958            .unwrap_err()
959            .contains("77,000"));
960    }
961
962    #[test]
963    fn prompt_at_limit_valid() {
964        let mut req = valid_req();
965        req.prompt = "x".repeat(77_000);
966        assert!(validate_generate_request(&req).is_ok());
967    }
968
969    #[test]
970    fn negative_prompt_too_long_rejected() {
971        let mut req = valid_req();
972        req.negative_prompt = Some("x".repeat(77_001));
973        assert!(validate_generate_request(&req)
974            .unwrap_err()
975            .contains("negative_prompt"));
976    }
977
978    #[test]
979    fn negative_prompt_at_limit_valid() {
980        let mut req = valid_req();
981        req.negative_prompt = Some("x".repeat(77_000));
982        assert!(validate_generate_request(&req).is_ok());
983    }
984
985    #[test]
986    fn negative_prompt_none_valid() {
987        let req = valid_req();
988        assert!(req.negative_prompt.is_none());
989        assert!(validate_generate_request(&req).is_ok());
990    }
991
992    #[test]
993    fn negative_prompt_empty_valid() {
994        let mut req = valid_req();
995        req.negative_prompt = Some(String::new());
996        assert!(validate_generate_request(&req).is_ok());
997    }
998
999    #[test]
1000    fn seed_is_optional() {
1001        let mut req = valid_req();
1002        req.seed = None;
1003        assert!(validate_generate_request(&req).is_ok());
1004    }
1005
1006    // ── img2img validation tests ────────────────────────────────────────────
1007
1008    #[test]
1009    fn img2img_strength_zero_accepted() {
1010        let mut req = valid_req();
1011        req.source_image = Some(png_bytes());
1012        req.strength = 0.0;
1013        assert!(validate_generate_request(&req).is_ok());
1014    }
1015
1016    #[test]
1017    fn img2img_strength_negative_rejected() {
1018        let mut req = valid_req();
1019        req.source_image = Some(png_bytes());
1020        req.strength = -0.1;
1021        assert!(validate_generate_request(&req)
1022            .unwrap_err()
1023            .contains("strength"));
1024    }
1025
1026    #[test]
1027    fn img2img_strength_one_accepted() {
1028        let mut req = valid_req();
1029        req.source_image = Some(png_bytes());
1030        req.strength = 1.0;
1031        assert!(validate_generate_request(&req).is_ok());
1032    }
1033
1034    #[test]
1035    fn img2img_strength_half_accepted() {
1036        let mut req = valid_req();
1037        req.source_image = Some(png_bytes());
1038        req.strength = 0.5;
1039        assert!(validate_generate_request(&req).is_ok());
1040    }
1041
1042    #[test]
1043    fn img2img_invalid_magic_bytes_rejected() {
1044        let mut req = valid_req();
1045        req.source_image = Some(vec![0x00, 0x01, 0x02, 0x03]);
1046        req.strength = 0.75;
1047        assert!(validate_generate_request(&req)
1048            .unwrap_err()
1049            .contains("PNG or JPEG"));
1050    }
1051
1052    #[test]
1053    fn img2img_jpeg_accepted() {
1054        let mut req = valid_req();
1055        req.source_image = Some(jpeg_bytes());
1056        req.strength = 0.75;
1057        assert!(validate_generate_request(&req).is_ok());
1058    }
1059
1060    #[test]
1061    fn img2img_no_source_image_skips_strength_check() {
1062        let mut req = valid_req();
1063        req.source_image = None;
1064        req.strength = 0.0; // Would fail if source_image present, but should pass without
1065        assert!(validate_generate_request(&req).is_ok());
1066    }
1067
1068    #[test]
1069    fn qwen_image_edit_requires_edit_images() {
1070        let mut req = valid_req();
1071        req.model = "qwen-image-edit:q4".to_string();
1072        let err = validate_generate_request(&req).unwrap_err();
1073        assert!(err.contains("requires edit_images"), "got: {err}");
1074    }
1075
1076    #[test]
1077    fn qwen_image_edit_rejects_batch_size_above_one() {
1078        let mut req = valid_req();
1079        req.model = "qwen-image-edit:q4".to_string();
1080        req.edit_images = Some(vec![png_bytes()]);
1081        req.batch_size = 2;
1082        let err = validate_generate_request(&req).unwrap_err();
1083        assert!(err.contains("batch_size = 1"), "got: {err}");
1084    }
1085
1086    #[test]
1087    fn qwen_image_edit_accepts_edit_images() {
1088        let mut req = valid_req();
1089        req.model = "qwen-image-edit:q4".to_string();
1090        req.edit_images = Some(vec![png_bytes()]);
1091        req.guidance = 4.0;
1092        assert!(validate_generate_request(&req).is_ok());
1093    }
1094
1095    #[test]
1096    fn qwen_image_edit_rejects_source_image_field() {
1097        let mut req = valid_req();
1098        req.model = "qwen-image-edit:q4".to_string();
1099        req.edit_images = Some(vec![png_bytes()]);
1100        req.source_image = Some(png_bytes());
1101        let err = validate_generate_request(&req).unwrap_err();
1102        assert!(
1103            err.contains("edit_images instead of source_image"),
1104            "got: {err}"
1105        );
1106    }
1107
1108    #[test]
1109    fn non_edit_models_reject_edit_images() {
1110        let mut req = valid_req();
1111        req.model = "flux-schnell:q8".to_string();
1112        req.edit_images = Some(vec![png_bytes()]);
1113        let err = validate_generate_request(&req).unwrap_err();
1114        assert!(
1115            err.contains("only supported for qwen-image-edit"),
1116            "got: {err}"
1117        );
1118    }
1119
1120    #[test]
1121    fn non_edit_models_reject_edit_images_before_format_validation() {
1122        let mut req = valid_req();
1123        req.model = "flux-schnell:q8".to_string();
1124        req.edit_images = Some(vec![b"not-an-image".to_vec()]);
1125        let err = validate_generate_request(&req).unwrap_err();
1126        assert!(
1127            err.contains("only supported for qwen-image-edit"),
1128            "got: {err}"
1129        );
1130    }
1131
1132    // ── ControlNet validation tests ────────────────────────────────────────
1133
1134    #[test]
1135    fn controlnet_valid_request() {
1136        let mut req = valid_req();
1137        req.control_image = Some(png_bytes());
1138        req.control_model = Some("controlnet-canny-sd15".to_string());
1139        req.control_scale = 0.8;
1140        assert!(validate_generate_request(&req).is_ok());
1141    }
1142
1143    #[test]
1144    fn controlnet_image_without_model_rejected() {
1145        let mut req = valid_req();
1146        req.control_image = Some(png_bytes());
1147        req.control_model = None;
1148        assert!(validate_generate_request(&req)
1149            .unwrap_err()
1150            .contains("control_model"));
1151    }
1152
1153    #[test]
1154    fn controlnet_model_without_image_rejected() {
1155        let mut req = valid_req();
1156        req.control_image = None;
1157        req.control_model = Some("controlnet-canny-sd15".to_string());
1158        assert!(validate_generate_request(&req)
1159            .unwrap_err()
1160            .contains("control_image"));
1161    }
1162
1163    #[test]
1164    fn controlnet_invalid_image_rejected() {
1165        let mut req = valid_req();
1166        req.control_image = Some(vec![0x00, 0x01, 0x02, 0x03]);
1167        req.control_model = Some("controlnet-canny-sd15".to_string());
1168        assert!(validate_generate_request(&req)
1169            .unwrap_err()
1170            .contains("PNG or JPEG"));
1171    }
1172
1173    #[test]
1174    fn controlnet_negative_scale_rejected() {
1175        let mut req = valid_req();
1176        req.control_image = Some(png_bytes());
1177        req.control_model = Some("controlnet-canny-sd15".to_string());
1178        req.control_scale = -0.1;
1179        assert!(validate_generate_request(&req)
1180            .unwrap_err()
1181            .contains("control_scale"));
1182    }
1183
1184    #[test]
1185    fn controlnet_zero_scale_accepted() {
1186        let mut req = valid_req();
1187        req.control_image = Some(png_bytes());
1188        req.control_model = Some("controlnet-canny-sd15".to_string());
1189        req.control_scale = 0.0;
1190        assert!(validate_generate_request(&req).is_ok());
1191    }
1192
1193    #[test]
1194    fn controlnet_high_scale_accepted() {
1195        let mut req = valid_req();
1196        req.control_image = Some(png_bytes());
1197        req.control_model = Some("controlnet-canny-sd15".to_string());
1198        req.control_scale = 2.0;
1199        assert!(validate_generate_request(&req).is_ok());
1200    }
1201
1202    #[test]
1203    fn controlnet_jpeg_accepted() {
1204        let mut req = valid_req();
1205        req.control_image = Some(jpeg_bytes());
1206        req.control_model = Some("controlnet-canny-sd15".to_string());
1207        assert!(validate_generate_request(&req).is_ok());
1208    }
1209    // ── Inpainting validation tests ───────────────────────────────────────
1210
1211    #[test]
1212    fn mask_without_source_image_rejected() {
1213        let mut req = valid_req();
1214        req.mask_image = Some(png_bytes());
1215        assert!(validate_generate_request(&req)
1216            .unwrap_err()
1217            .contains("mask_image requires source_image"));
1218    }
1219
1220    #[test]
1221    fn mask_with_source_image_accepted() {
1222        let mut req = valid_req();
1223        req.source_image = Some(png_bytes());
1224        req.mask_image = Some(png_bytes());
1225        assert!(validate_generate_request(&req).is_ok());
1226    }
1227
1228    #[test]
1229    fn mask_jpeg_accepted() {
1230        let mut req = valid_req();
1231        req.source_image = Some(png_bytes());
1232        req.mask_image = Some(jpeg_bytes());
1233        assert!(validate_generate_request(&req).is_ok());
1234    }
1235
1236    #[test]
1237    fn mask_invalid_bytes_rejected() {
1238        let mut req = valid_req();
1239        req.source_image = Some(png_bytes());
1240        req.mask_image = Some(vec![0x00, 0x01, 0x02, 0x03]);
1241        assert!(validate_generate_request(&req)
1242            .unwrap_err()
1243            .contains("mask_image must be a PNG or JPEG"));
1244    }
1245
1246    #[test]
1247    fn no_mask_no_source_passes() {
1248        let req = valid_req();
1249        assert!(validate_generate_request(&req).is_ok());
1250    }
1251
1252    // ── fit_to_model_dimensions tests ────────────────────────────────────
1253
1254    #[test]
1255    fn fit_same_aspect_downscale() {
1256        // 1024x1024 source -> 512x512 SD1.5 model
1257        assert_eq!(fit_to_model_dimensions(1024, 1024, 512, 512), (512, 512));
1258    }
1259
1260    #[test]
1261    fn fit_wide_source_downscale() {
1262        // 1920x1080 source -> 512x512 SD1.5 model
1263        // width-limited: w=512, h=512/1.778=287.9 -> 288 (16px aligned)
1264        assert_eq!(fit_to_model_dimensions(1920, 1080, 512, 512), (512, 288));
1265    }
1266
1267    #[test]
1268    fn fit_small_source_upscale_to_model_native() {
1269        // 512x512 source -> 1024x1024 FLUX model (upscale to native)
1270        assert_eq!(fit_to_model_dimensions(512, 512, 1024, 1024), (1024, 1024));
1271    }
1272
1273    #[test]
1274    fn fit_portrait_source() {
1275        // 768x1024 source -> 512x512 model
1276        // height-limited: h=512, w=512*0.75=384
1277        assert_eq!(fit_to_model_dimensions(768, 1024, 512, 512), (384, 512));
1278    }
1279
1280    #[test]
1281    fn fit_identity() {
1282        assert_eq!(
1283            fit_to_model_dimensions(1024, 1024, 1024, 1024),
1284            (1024, 1024)
1285        );
1286    }
1287
1288    #[test]
1289    fn fit_extreme_landscape() {
1290        // 3840x720 -> 1024x1024 model
1291        // width-limited: w=1024, h=1024/5.333=192
1292        assert_eq!(fit_to_model_dimensions(3840, 720, 1024, 1024), (1024, 192));
1293    }
1294
1295    #[test]
1296    fn fit_non_square_model_bounds() {
1297        // 1920x1080 -> 1024x768 model
1298        // src_ratio=1.778, model_ratio=1.333, width-limited: w=1024, h=1024/1.778=575.8 -> 576
1299        assert_eq!(fit_to_model_dimensions(1920, 1080, 1024, 768), (1024, 576));
1300    }
1301
1302    #[test]
1303    fn fit_dimensions_are_16px_aligned() {
1304        let (w, h) = fit_to_model_dimensions(1000, 600, 512, 512);
1305        assert!(w % 16 == 0, "width {w} must be 16px aligned");
1306        assert!(h % 16 == 0, "height {h} must be 16px aligned");
1307    }
1308
1309    #[test]
1310    fn fit_within_megapixel_limit() {
1311        let (w, h) = fit_to_model_dimensions(4096, 4096, 2048, 2048);
1312        let pixels = w as u64 * h as u64;
1313        assert!(
1314            pixels <= MAX_PIXELS,
1315            "{}x{} = {} pixels exceeds limit",
1316            w,
1317            h,
1318            pixels
1319        );
1320    }
1321
1322    #[test]
1323    fn fit_tiny_source_gets_model_native() {
1324        // 64x64 source -> 1024x1024 model
1325        assert_eq!(fit_to_model_dimensions(64, 64, 1024, 1024), (1024, 1024));
1326    }
1327
1328    #[test]
1329    fn fit_to_target_area_preserves_ratio_and_alignment() {
1330        let (w, h) = fit_to_target_area(1600, 900, 1024 * 1024, 16);
1331        assert_eq!((w, h), (1360, 768));
1332    }
1333
1334    // ── LoRA validation tests ──────────────────────────────────────────────
1335
1336    #[test]
1337    fn lora_none_valid() {
1338        let req = valid_req();
1339        assert!(req.lora.is_none());
1340        assert!(validate_generate_request(&req).is_ok());
1341    }
1342
1343    #[test]
1344    fn lora_scale_too_low_rejected() {
1345        let mut req = valid_req();
1346        req.lora = Some(crate::LoraWeight {
1347            path: "adapter.safetensors".to_string(),
1348            scale: -0.1,
1349        });
1350        let err = validate_generate_request(&req).unwrap_err();
1351        assert!(
1352            err.contains("lora scale"),
1353            "expected lora scale error: {err}"
1354        );
1355    }
1356
1357    #[test]
1358    fn lora_scale_too_high_rejected() {
1359        let mut req = valid_req();
1360        req.lora = Some(crate::LoraWeight {
1361            path: "adapter.safetensors".to_string(),
1362            scale: 2.1,
1363        });
1364        let err = validate_generate_request(&req).unwrap_err();
1365        assert!(
1366            err.contains("lora scale"),
1367            "expected lora scale error: {err}"
1368        );
1369    }
1370
1371    #[test]
1372    fn lora_scale_boundary_valid() {
1373        for scale in [0.0, 1.0, 2.0] {
1374            let mut req = valid_req();
1375            req.lora = Some(crate::LoraWeight {
1376                path: "adapter.safetensors".to_string(),
1377                scale,
1378            });
1379            assert!(
1380                validate_generate_request(&req).is_ok(),
1381                "scale={scale} should be valid"
1382            );
1383        }
1384    }
1385
1386    #[test]
1387    fn lora_path_not_found_passes_validation() {
1388        // Path existence is checked at the inference layer, not validation,
1389        // so remote LoRA paths (server-side files) work correctly.
1390        let mut req = valid_req();
1391        req.lora = Some(crate::LoraWeight {
1392            path: "/nonexistent/path/adapter.safetensors".to_string(),
1393            scale: 1.0,
1394        });
1395        assert!(validate_generate_request(&req).is_ok());
1396    }
1397
1398    #[test]
1399    fn lora_wrong_extension_rejected() {
1400        let mut req = valid_req();
1401        req.lora = Some(crate::LoraWeight {
1402            path: "/some/path/adapter.bin".to_string(),
1403            scale: 1.0,
1404        });
1405        let err = validate_generate_request(&req).unwrap_err();
1406        assert!(
1407            err.contains("safetensors"),
1408            "expected safetensors error: {err}"
1409        );
1410    }
1411
1412    // ── dimension_warning tests ────────────────────────────────────────────
1413
1414    #[test]
1415    fn dimension_warning_matching_returns_none() {
1416        assert!(dimension_warning(1024, 1024, "flux").is_none());
1417        assert!(dimension_warning(512, 512, "sd15").is_none());
1418        assert!(dimension_warning(1024, 1024, "sdxl").is_none());
1419        assert!(dimension_warning(1024, 1024, "wuerstchen").is_none());
1420    }
1421
1422    #[test]
1423    fn dimension_warning_non_matching_returns_some() {
1424        let warning = dimension_warning(256, 256, "flux");
1425        assert!(warning.is_some());
1426        let msg = warning.unwrap();
1427        assert!(msg.contains("256x256"), "should mention requested dims");
1428        assert!(msg.contains("flux"), "should mention model family");
1429        assert!(msg.contains("Suggested"), "should include suggestions");
1430    }
1431
1432    #[test]
1433    fn dimension_warning_unknown_family_returns_none() {
1434        assert!(dimension_warning(256, 256, "unknown-model").is_none());
1435    }
1436
1437    #[test]
1438    fn dimension_warning_empty_family_returns_none() {
1439        assert!(dimension_warning(512, 512, "").is_none());
1440    }
1441
1442    #[test]
1443    fn dimension_warning_sd15_at_1024_warns() {
1444        let warning = dimension_warning(1024, 1024, "sd15");
1445        assert!(warning.is_some(), "SD1.5 at 1024x1024 should warn");
1446        assert!(warning.unwrap().contains("512x512"));
1447    }
1448
1449    #[test]
1450    fn dimension_warning_sdxl_buckets_accepted() {
1451        for (w, h) in recommended_dimensions("sdxl") {
1452            assert!(
1453                dimension_warning(*w, *h, "sdxl").is_none(),
1454                "SDXL bucket {w}x{h} should not warn"
1455            );
1456        }
1457    }
1458
1459    #[test]
1460    fn dimension_warning_qwen_image_has_native_resolution() {
1461        let dims = recommended_dimensions("qwen-image");
1462        assert!(
1463            dims.contains(&(1328, 1328)),
1464            "must include native 1328x1328"
1465        );
1466        assert!(dims.contains(&(512, 512)), "must include 512x512");
1467        assert!(dims.contains(&(1024, 1024)), "must include 1024x1024");
1468        assert_eq!(dimension_warning(1328, 1328, "qwen-image"), None);
1469        assert_eq!(dimension_warning(512, 512, "qwen-image"), None);
1470    }
1471
1472    #[test]
1473    fn dimension_warning_qwen_image_edit_reuses_qwen_dimensions() {
1474        assert_eq!(
1475            recommended_dimensions("qwen-image-edit"),
1476            recommended_dimensions("qwen-image")
1477        );
1478        assert_eq!(dimension_warning(1024, 1024, "qwen-image-edit"), None);
1479    }
1480
1481    #[test]
1482    fn dimension_warning_flux2_uses_flux_dims() {
1483        assert_eq!(
1484            recommended_dimensions("flux2"),
1485            recommended_dimensions("flux"),
1486            "flux2 should share FLUX dimensions"
1487        );
1488    }
1489
1490    #[test]
1491    fn every_family_native_in_recommendations() {
1492        // Each family's native resolution (from ManifestDefaults) should appear
1493        // in its recommended list.
1494        let families = &[
1495            ("sd15", 512, 512),
1496            ("sdxl", 1024, 1024),
1497            ("sd3", 1024, 1024),
1498            ("flux", 1024, 1024),
1499            ("flux2", 1024, 1024),
1500            ("z-image", 1024, 1024),
1501            ("qwen-image", 1024, 1024),
1502            ("qwen-image-edit", 1024, 1024),
1503            ("wuerstchen", 1024, 1024),
1504            ("ltx-video", 768, 512),
1505        ];
1506        for (family, w, h) in families {
1507            let dims = recommended_dimensions(family);
1508            assert!(
1509                dims.contains(&(*w, *h)),
1510                "{family} native {w}x{h} missing from recommended list"
1511            );
1512        }
1513    }
1514
1515    #[test]
1516    fn dimension_warning_message_format() {
1517        let msg = dimension_warning(800, 600, "sd15").unwrap();
1518        assert!(msg.contains("800x600"));
1519        assert!(msg.contains("sd15"));
1520        assert!(msg.contains("Suggested:"));
1521        // Should list known alternatives
1522        assert!(msg.contains("512x512"));
1523    }
1524
1525    #[test]
1526    fn dimension_warning_truncates_long_lists() {
1527        // SDXL has 9 buckets but warning should show at most 4 + "N total"
1528        let msg = dimension_warning(800, 600, "sdxl").unwrap();
1529        assert!(msg.contains("total"), "long lists should show total count");
1530    }
1531
1532    // ── validate_upscale_request tests ────────────────────────────────────
1533
1534    fn valid_upscale_req() -> crate::UpscaleRequest {
1535        crate::UpscaleRequest {
1536            model: "real-esrgan-x4plus:fp16".to_string(),
1537            image: png_bytes(),
1538            output_format: crate::OutputFormat::Png,
1539            tile_size: None,
1540        }
1541    }
1542
1543    #[test]
1544    fn upscale_valid_request_passes() {
1545        assert!(validate_upscale_request(&valid_upscale_req()).is_ok());
1546    }
1547
1548    #[test]
1549    fn upscale_empty_model_rejected() {
1550        let mut req = valid_upscale_req();
1551        req.model = "  ".to_string();
1552        assert!(validate_upscale_request(&req)
1553            .unwrap_err()
1554            .contains("model"));
1555    }
1556
1557    #[test]
1558    fn upscale_empty_image_rejected() {
1559        let mut req = valid_upscale_req();
1560        req.image = vec![];
1561        assert!(validate_upscale_request(&req)
1562            .unwrap_err()
1563            .contains("empty"));
1564    }
1565
1566    #[test]
1567    fn upscale_invalid_image_format_rejected() {
1568        let mut req = valid_upscale_req();
1569        req.image = vec![0x00, 0x01, 0x02, 0x03];
1570        assert!(validate_upscale_request(&req)
1571            .unwrap_err()
1572            .contains("PNG or JPEG"));
1573    }
1574
1575    #[test]
1576    fn upscale_jpeg_accepted() {
1577        let mut req = valid_upscale_req();
1578        req.image = jpeg_bytes();
1579        assert!(validate_upscale_request(&req).is_ok());
1580    }
1581
1582    #[test]
1583    fn upscale_tile_size_too_small_rejected() {
1584        let mut req = valid_upscale_req();
1585        req.tile_size = Some(32);
1586        assert!(validate_upscale_request(&req)
1587            .unwrap_err()
1588            .contains("tile_size"));
1589    }
1590
1591    #[test]
1592    fn upscale_tile_size_zero_accepted() {
1593        let mut req = valid_upscale_req();
1594        req.tile_size = Some(0);
1595        assert!(validate_upscale_request(&req).is_ok());
1596    }
1597
1598    #[test]
1599    fn upscale_tile_size_64_accepted() {
1600        let mut req = valid_upscale_req();
1601        req.tile_size = Some(64);
1602        assert!(validate_upscale_request(&req).is_ok());
1603    }
1604
1605    #[test]
1606    fn upscale_tile_size_none_accepted() {
1607        let req = valid_upscale_req();
1608        assert!(validate_upscale_request(&req).is_ok());
1609    }
1610}