from pathlib import Path
from moshi.models import loaders
import numpy as np
from pydantic import BaseModel
import math
import time
import julius
import torch
from torch import nn
from torch.nn import functional as F
import torchaudio.transforms
def normalize_loudness(
wav: torch.Tensor,
sample_rate: int,
loudness_headroom_db: float = 18,
energy_floor: float = 2e-3
):
wav = wav - wav.mean(dim=-1, keepdim=True)
energy = wav.std()
if energy < energy_floor:
return wav
transform = torchaudio.transforms.Loudness(sample_rate)
try:
input_loudness_db = transform(wav).item()
except RuntimeError:
return wav
delta_loudness = -loudness_headroom_db - input_loudness_db
gain = 10.0 ** (delta_loudness / 20.0)
output = gain * wav
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
return output
def sinc(t: torch.Tensor) -> torch.Tensor:
return torch.where(t == 0, torch.ones(1, device=t.device, dtype=t.dtype), torch.sin(t) / t)
def kernel_upsample2(zeros=56, device=None):
win = torch.hann_window(4 * zeros + 1, periodic=False, device=device)
winodd = win[1::2]
t = torch.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros, device=device)
t *= math.pi
kernel = (sinc(t) * winodd).view(1, 1, -1)
return kernel
def upsample2(x, zeros=56):
*other, time = x.shape
kernel = kernel_upsample2(zeros, x.device).to(x)
out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(*other, time)
y = torch.stack([x, out], dim=-1)
return y.view(*other, -1)
def kernel_downsample2(zeros=56, device=None):
win = torch.hann_window(4 * zeros + 1, periodic=False, device=device)
winodd = win[1::2]
t = torch.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros, device=device)
t.mul_(math.pi)
kernel = (sinc(t) * winodd).view(1, 1, -1)
return kernel
def downsample2(x, zeros=56):
if x.shape[-1] % 2 != 0:
x = F.pad(x, (0, 1))
xeven = x[..., ::2]
xodd = x[..., 1::2]
*other, time = xodd.shape
kernel = kernel_downsample2(zeros, x.device).to(x)
out = xeven + F.conv1d(xodd.view(-1, 1, time), kernel, padding=zeros)[..., :-1].view(
*other, time)
return out.view(*other, -1).mul(0.5)
class BLSTM(nn.Module):
def __init__(self, dim, layers=2, bi=True):
super().__init__()
klass = nn.LSTM
self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = None
if bi:
self.linear = nn.Linear(2 * dim, dim)
def forward(self, x, hidden=None):
x, hidden = self.lstm(x, hidden)
if self.linear:
x = self.linear(x)
return x, hidden
def rescale_conv(conv, reference):
std = conv.weight.std().detach()
scale = (std / reference)**0.5
conv.weight.data /= scale
if conv.bias is not None:
conv.bias.data /= scale
def rescale_module(module, reference):
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
rescale_conv(sub, reference)
class Demucs(nn.Module):
def __init__(self,
chin=1,
chout=1,
hidden=48,
depth=5,
kernel_size=8,
stride=4,
causal=True,
resample=4,
growth=2,
max_hidden=10_000,
normalize=True,
glu=True,
rescale=0.1,
floor=1e-3,
sample_rate=16_000):
super().__init__()
if resample not in [1, 2, 4]:
raise ValueError("Resample should be 1, 2 or 4.")
self.chin = chin
self.chout = chout
self.hidden = hidden
self.depth = depth
self.kernel_size = kernel_size
self.stride = stride
self.causal = causal
self.floor = floor
self.resample = resample
self.normalize = normalize
self.sample_rate = sample_rate
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
activation = nn.GLU(1) if glu else nn.ReLU()
ch_scale = 2 if glu else 1
for index in range(depth):
encode = []
encode += [
nn.Conv1d(chin, hidden, kernel_size, stride),
nn.ReLU(),
nn.Conv1d(hidden, hidden * ch_scale, 1), activation,
]
self.encoder.append(nn.Sequential(*encode))
decode = []
decode += [
nn.Conv1d(hidden, ch_scale * hidden, 1), activation,
nn.ConvTranspose1d(hidden, chout, kernel_size, stride),
]
if index > 0:
decode.append(nn.ReLU())
self.decoder.insert(0, nn.Sequential(*decode))
chout = hidden
chin = hidden
hidden = min(int(growth * hidden), max_hidden)
self.lstm = BLSTM(chin, bi=not causal)
if rescale:
rescale_module(self, reference=rescale)
def valid_length(self, length):
length = math.ceil(length * self.resample)
for idx in range(self.depth):
length = math.ceil((length - self.kernel_size) / self.stride) + 1
length = max(length, 1)
for idx in range(self.depth):
length = (length - 1) * self.stride + self.kernel_size
length = int(math.ceil(length / self.resample))
return int(length)
@property
def total_stride(self):
return self.stride ** self.depth // self.resample
def forward(self, mix):
if mix.dim() == 2:
mix = mix.unsqueeze(1)
if self.normalize:
mono = mix.mean(dim=1, keepdim=True)
std = mono.std(dim=-1, keepdim=True)
mix = mix / (self.floor + std)
else:
std = 1
length = mix.shape[-1]
x = mix
x = F.pad(x, (0, self.valid_length(length) - length))
if self.resample == 2:
x = upsample2(x)
elif self.resample == 4:
x = upsample2(x)
x = upsample2(x)
skips = []
for encode in self.encoder:
x = encode(x)
skips.append(x)
x = x.permute(2, 0, 1)
x, _ = self.lstm(x)
x = x.permute(1, 2, 0)
for decode in self.decoder:
skip = skips.pop(-1)
x = x + skip[..., :x.shape[-1]]
x = decode(x)
if self.resample == 2:
x = downsample2(x)
elif self.resample == 4:
x = downsample2(x)
x = downsample2(x)
x = x[..., :length]
return std * x
def fast_conv(conv, x):
batch, chin, length = x.shape
chout, chin, kernel = conv.weight.shape
assert batch == 1
if kernel == 1:
x = x.view(chin, length)
out = torch.addmm(
conv.bias.view(-1, 1),
conv.weight.view(chout, chin), x)
elif length == kernel:
x = x.view(chin * kernel, 1)
out = torch.addmm(
conv.bias.view(-1, 1),
conv.weight.view(chout, chin * kernel), x)
else:
out = conv(x)
return out.view(batch, chout, -1)
class DemucsStreamer:
def __init__(self, demucs,
dry=0,
num_frames=1,
resample_lookahead=64,
resample_buffer=256,
mean_decay_duration: float = 10.):
device = next(iter(demucs.parameters())).device
self.demucs = demucs
self.lstm_state = None
self.conv_state = None
self.dry = dry
self.resample_lookahead = resample_lookahead
resample_buffer = min(demucs.total_stride, resample_buffer)
self.resample_buffer = resample_buffer
self.frame_length = demucs.valid_length(1) + demucs.total_stride * (num_frames - 1)
self.total_length = self.frame_length + self.resample_lookahead
self.stride = demucs.total_stride * num_frames
self.resample_in = torch.zeros(demucs.chin, resample_buffer, device=device)
self.resample_out = torch.zeros(demucs.chin, resample_buffer, device=device)
self.frames = 0
self.total_time = 0
self.mean_variance = 0.
self.mean_total = 0.
mean_receptive_field_in_samples = mean_decay_duration * demucs.sample_rate
mean_receptive_field_in_frames = mean_receptive_field_in_samples / demucs.total_stride
self.mean_decay = 1 - 1 / mean_receptive_field_in_frames
self.pending = torch.zeros(demucs.chin, 0, device=device)
bias = demucs.decoder[0][2].bias
weight = demucs.decoder[0][2].weight
chin, chout, kernel = weight.shape
self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1)
self._weight = weight.permute(1, 2, 0).contiguous()
@property
def variance(self) -> float:
return self.mean_variance / self.mean_total
def reset_time_per_frame(self):
self.total_time = 0
self.frames = 0
@property
def time_per_frame(self):
return self.total_time / self.frames
def flush(self):
self.lstm_state = None
self.conv_state = None
pending_length = self.pending.shape[1]
padding = torch.zeros(self.demucs.chin, self.total_length, device=self.pending.device)
out = self.feed(padding)
return out[:, :pending_length]
def feed(self, wav):
begin = time.time()
demucs = self.demucs
resample_buffer = self.resample_buffer
stride = self.stride
resample = demucs.resample
if wav.dim() != 2:
raise ValueError("input wav should be two dimensional.")
chin, _ = wav.shape
if chin != demucs.chin:
raise ValueError(f"Expected {demucs.chin} channels, got {chin}")
self.pending = torch.cat([self.pending, wav], dim=1)
outs = []
while self.pending.shape[1] >= self.total_length:
self.frames += 1
frame = self.pending[:, :self.total_length]
dry_signal = frame[:, :stride]
if demucs.normalize:
mono = frame.mean(0)
variance = (mono**2).mean()
self.mean_variance = self.mean_variance * self.mean_decay + (1 - self.mean_decay) * variance
self.mean_total = self.mean_total * self.mean_decay + (1 - self.mean_decay)
frame = frame / (demucs.floor + torch.sqrt(self.variance))
padded_frame = torch.cat([self.resample_in, frame], dim=-1)
self.resample_in[:] = frame[:, stride - resample_buffer:stride]
frame = padded_frame
if resample == 4:
frame = upsample2(upsample2(frame))
elif resample == 2:
frame = upsample2(frame)
frame = frame[:, resample * resample_buffer:] frame = frame[:, :resample * self.frame_length]
out, extra = self._separate_frame(frame)
padded_out = torch.cat([self.resample_out, out, extra], 1)
self.resample_out[:] = out[:, -resample_buffer:]
if resample == 4:
out = downsample2(downsample2(padded_out))
elif resample == 2:
out = downsample2(padded_out)
else:
out = padded_out
out = out[:, resample_buffer // resample:]
out = out[:, :stride]
if demucs.normalize:
out *= torch.sqrt(self.variance)
out = self.dry * dry_signal + (1 - self.dry) * out
outs.append(out)
self.pending = self.pending[:, stride:]
self.total_time += time.time() - begin
if outs:
out = torch.cat(outs, 1)
else:
out = torch.zeros(chin, 0, device=wav.device)
return out
def _separate_frame(self, frame):
demucs = self.demucs
skips = []
next_state = []
first = self.conv_state is None
stride = self.stride * demucs.resample
x = frame[None]
for idx, encode in enumerate(demucs.encoder):
stride //= demucs.stride
length = x.shape[2]
if idx == demucs.depth - 1:
x = fast_conv(encode[0], x)
x = encode[1](x)
x = fast_conv(encode[2], x)
x = encode[3](x)
else:
if not first:
prev = self.conv_state.pop(0)
prev = prev[..., stride:]
tgt = (length - demucs.kernel_size) // demucs.stride + 1
missing = tgt - prev.shape[-1]
offset = length - demucs.kernel_size - demucs.stride * (missing - 1)
x = x[..., offset:]
x = encode[1](encode[0](x))
x = fast_conv(encode[2], x)
x = encode[3](x)
if not first:
x = torch.cat([prev, x], -1)
next_state.append(x)
skips.append(x)
x = x.permute(2, 0, 1)
x, self.lstm_state = demucs.lstm(x, self.lstm_state)
x = x.permute(1, 2, 0)
extra = None
for idx, decode in enumerate(demucs.decoder):
skip = skips.pop(-1)
x += skip[..., :x.shape[-1]]
x = fast_conv(decode[0], x)
x = decode[1](x)
if extra is not None:
skip = skip[..., x.shape[-1]:]
extra += skip[..., :extra.shape[-1]]
extra = decode[2](decode[1](decode[0](extra)))
x = decode[2](x)
next_state.append(x[..., -demucs.stride:] - decode[2].bias.view(-1, 1))
if extra is None:
extra = x[..., -demucs.stride:]
else:
extra[..., :demucs.stride] += next_state[-1]
x = x[..., :-demucs.stride]
if not first:
prev = self.conv_state.pop(0)
x[..., :demucs.stride] += prev
if idx != demucs.depth - 1:
x = decode[3](x)
extra = decode[3](extra)
self.conv_state = next_state
return x[0], extra[0]
def get_demucs():
model = Demucs(hidden=64)
url = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/dns64-a7761ff99a7d5bb6.th"
state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
model.load_state_dict(state_dict)
return model
class Config(BaseModel):
log_folder: Path = Path.home() / 'tmp/tts-service'
hf_repo: str = loaders.DEFAULT_REPO
mimi_weight: Path = Path.home() / 'models/moshi/moshi_e9d43d50@500/e9d43d50_500_mimi_voice.safetensors'
config_path: Path | None = None
device: str = "cpu"
dry_fraction: float = 0.02
num_cpu_threads: int = 8
class Processor:
def __init__(self, config_override: dict):
print(config_override)
config = Config(**config_override)
torch.set_num_threads(config.num_cpu_threads)
checkpoint_info = loaders.CheckpointInfo.from_hf_repo(
config.hf_repo, mimi_weights=config.mimi_weight, config_path=config.config_path,
)
self.dry_fraction = config.dry_fraction
loaders._quantizer_kwargs["n_q"] = 16
print("loading mimi")
self._mimi = checkpoint_info.get_mimi(device=config.device)
print("mimi loaded")
self._length = 24000 * 10
self._demucs = get_demucs()
self._lowpass = julius.lowpass.LowPassFilter(8 / 24)
self._downsample = julius.resample.ResampleFrac(24, 16)
self._upsample = julius.resample.ResampleFrac(16, 24)
@torch.no_grad()
def run_one(self, pcm: np.ndarray):
print(pcm.shape)
wav = torch.from_numpy(pcm[None, None, :self._length]).float()
assert wav.shape[-1] == self._length
low = self._lowpass(wav)
high = wav - low
low = self._downsample(low, full=True)
denoised = self._demucs(low)
denoised = (1 - self.dry_fraction) * denoised + self.dry_fraction * low
denoised = self._upsample(denoised, output_length=wav.shape[-1])
denoised = denoised + high
denoised = normalize_loudness(denoised, 24000)
latent = self._mimi.encode_to_latent(denoised, quantize=False)
latent = latent.cpu().numpy()
print(latent.shape)
return latent
def init(config: dict):
processor = Processor(config)
return processor