import argparse
import torch
from scipy.io import wavfile
from models import model_dict
from utils.silk_features import load_inference_data
from utils import endoscopy
debug = False
if debug:
args = type('dummy', (object,),
{
'input' : 'testitems/all_0_orig.se',
'checkpoint' : 'testout/checkpoints/checkpoint_epoch_5.pth',
'output' : 'out.wav',
})()
else:
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str, help='path to folder with features and signals')
parser.add_argument('checkpoint', type=str, help='checkpoint file')
parser.add_argument('output', type=str, help='output file')
parser.add_argument('--debug', action='store_true', help='enables debug output')
args = parser.parse_args()
torch.set_num_threads(2)
input_folder = args.input
checkpoint_file = args.checkpoint
output_file = args.output
if not output_file.endswith('.wav'):
output_file += '.wav'
checkpoint = torch.load(checkpoint_file, map_location="cpu")
if not 'name' in checkpoint['setup']['model']:
print(f'warning: did not find model name entry in setup, using pitchpostfilter per default')
model_name = 'pitchpostfilter'
else:
model_name = checkpoint['setup']['model']['name']
model = model_dict[model_name](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
model.load_state_dict(checkpoint['state_dict'])
setup = checkpoint['setup']
signal, features, periods, numbits = load_inference_data(input_folder, **setup['data'])
if args.debug:
endoscopy.init()
output = model.process(signal, features, periods, numbits, debug=args.debug)
wavfile.write(output_file, 16000, output.cpu().numpy())
if args.debug:
endoscopy.close()