import argparse
import sys
from pathlib import Path
import torch
import torch.nn as nn
class ImageScaleInput(nn.Module):
def forward(self, x):
x = x / 255.0
mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1, 3, 1, 1)
return (x - mean) / std
class ImageScaleOutput(nn.Module):
def forward(self, x):
mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1, 3, 1, 1)
x = x * std + mean
return x.clamp(0, 1) * 255.0
def convert(model_type: str, output_path: str):
if not Path("deoldify").is_dir():
print(
"Error: This script must be run from the DeOldify repository root.",
file=sys.stderr,
)
print(" git clone https://github.com/jantic/DeOldify.git", file=sys.stderr)
print(" cd DeOldify", file=sys.stderr)
sys.exit(1)
from deoldify.generators import gen_inference_deep, gen_inference_wide
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_type == "artistic":
learn = gen_inference_deep(
root_folder=Path("."), weights_name="ColorizeArtistic_gen"
)
elif model_type == "stable":
learn = gen_inference_wide(
root_folder=Path("."), weights_name="ColorizeStable_gen"
)
else:
print(f"Unknown model type: {model_type}", file=sys.stderr)
sys.exit(1)
generator = learn.model.eval().to(device)
model = nn.Sequential(ImageScaleInput(), generator, ImageScaleOutput()).eval().to(device)
dummy = torch.randn(1, 3, 256, 256, device=device)
torch.onnx.export(
model,
dummy,
output_path,
opset_version=12,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {2: "height", 3: "width"},
"output": {2: "height", 3: "width"},
},
)
print(f"Exported {model_type} model to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert DeOldify models to ONNX format"
)
parser.add_argument(
"--model",
choices=["artistic", "stable"],
required=True,
help="Model variant to convert",
)
parser.add_argument(
"--output", required=True, help="Output path for the ONNX model"
)
args = parser.parse_args()
convert(args.model, args.output)