tch 0.24.0

Rust wrappers for the PyTorch C++ api (libtorch).
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from safetensors.torch import save_file

def normalize_key(k):
    if k.startswith("backbone."):
        k = k[9:]
    if k.startswith("linear_head."):
        k = k[7:]
    return k

for model_size in ["small", "base", "large", "giant"]:
    letter = model_size[0]
    dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{letter}14_lc', layers=1)
    weights = dinov2_vits14.state_dict()
    weights = {normalize_key(k): v for k, v in weights.items()}
    save_file(weights, f"dinov2_vit{letter}14.safetensors")