tch-plus 0.18.3

Rust wrappers for the PyTorch C++ api (libtorch).
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from safetensors.torch import save_file

def normalize_key(k):
    if k.startswith("backbone."):
        k = k[9:]
    if k.startswith("linear_head."):
        k = k[7:]
    return k

dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc', layers=1)
print(dinov2_vits14)
weights = dinov2_vits14.state_dict()
weights = {normalize_key(k): v for k, v in weights.items()}
save_file(weights, "dinov2_vits14.safetensors")