import torch
from torch.nn.utils import remove_weight_norm
def count_parameters(model, verbose=False):
total = 0
for name, p in model.named_parameters():
count = torch.ones_like(p).sum().item()
if verbose:
print(f"{name}: {count} parameters")
total += count
return total
def count_nonzero_parameters(model, verbose=False):
total = 0
for name, p in model.named_parameters():
count = torch.count_nonzero(p).item()
if verbose:
print(f"{name}: {count} non-zero parameters")
total += count
return total
def retain_grads(module):
for p in module.parameters():
if p.requires_grad:
p.retain_grad()
def get_grad_norm(module, p=2):
norm = 0
for param in module.parameters():
if param.requires_grad:
norm = norm + (torch.abs(param.grad) ** p).sum()
return norm ** (1/p)
def create_weights(s_real, s_gen, alpha):
weights = []
with torch.no_grad():
for sr, sg in zip(s_real, s_gen):
weight = torch.exp(alpha * (sr[-1] - sg[-1]))
weights.append(weight)
return weights
def _get_candidates(module: torch.nn.Module):
candidates = []
for key in module.__dict__.keys():
if hasattr(module, key + '_v'):
candidates.append(key)
return candidates
def remove_all_weight_norm(model : torch.nn.Module, verbose=False):
for name, m in model.named_modules():
candidates = _get_candidates(m)
for candidate in candidates:
try:
remove_weight_norm(m, name=candidate)
if verbose: print(f'removed weight norm on weight {name}.{candidate}')
except:
pass