deoldify 0.1.0

A Rust tool that colorizes grayscale and black-and-white images using the DeOldify model and the ONNX Runtime.
Documentation
#!/usr/bin/env python3
"""Download DeOldify pretrained weights and convert to ONNX format.

This script must be run from within a clone of https://github.com/jantic/DeOldify
with the model weights (.pth files) placed in the ./models/ directory.

Setup:
    git clone https://github.com/jantic/DeOldify.git
    cd DeOldify
    pip install -r requirements.txt
    # Download weights into DeOldify/models/:
    #   - ColorizeArtistic_gen.pth from https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth
    #   - ColorizeStable_gen.pth (see DeOldify README for link)

Usage:
    python convert_model.py --model artistic --output ColorizeArtistic.onnx
    python convert_model.py --model stable --output ColorizeStable.onnx
"""

import argparse
import sys
from pathlib import Path

import torch
import torch.nn as nn


class ImageScaleInput(nn.Module):
    """Normalize input from 0-255 to ImageNet-normalized range."""

    def forward(self, x):
        x = x / 255.0
        mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1, 3, 1, 1)
        return (x - mean) / std


class ImageScaleOutput(nn.Module):
    """Denormalize output back to 0-255 range."""

    def forward(self, x):
        mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1, 3, 1, 1)
        x = x * std + mean
        return x.clamp(0, 1) * 255.0


def convert(model_type: str, output_path: str):
    # Ensure we're inside the DeOldify repo
    if not Path("deoldify").is_dir():
        print(
            "Error: This script must be run from the DeOldify repository root.",
            file=sys.stderr,
        )
        print("  git clone https://github.com/jantic/DeOldify.git", file=sys.stderr)
        print("  cd DeOldify", file=sys.stderr)
        sys.exit(1)

    from deoldify.generators import gen_inference_deep, gen_inference_wide

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if model_type == "artistic":
        learn = gen_inference_deep(
            root_folder=Path("."), weights_name="ColorizeArtistic_gen"
        )
    elif model_type == "stable":
        learn = gen_inference_wide(
            root_folder=Path("."), weights_name="ColorizeStable_gen"
        )
    else:
        print(f"Unknown model type: {model_type}", file=sys.stderr)
        sys.exit(1)

    generator = learn.model.eval().to(device)
    model = nn.Sequential(ImageScaleInput(), generator, ImageScaleOutput()).eval().to(device)

    dummy = torch.randn(1, 3, 256, 256, device=device)

    torch.onnx.export(
        model,
        dummy,
        output_path,
        opset_version=12,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {2: "height", 3: "width"},
            "output": {2: "height", 3: "width"},
        },
    )
    print(f"Exported {model_type} model to {output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Convert DeOldify models to ONNX format"
    )
    parser.add_argument(
        "--model",
        choices=["artistic", "stable"],
        required=True,
        help="Model variant to convert",
    )
    parser.add_argument(
        "--output", required=True, help="Output path for the ONNX model"
    )
    args = parser.parse_args()
    convert(args.model, args.output)