tch 0.24.0

Rust wrappers for the PyTorch C++ api (libtorch).
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# This script exports pre-trained model weights in the safetensors format.
import numpy as np
import torch
import torchvision
from safetensors import torch as stt

m = torchvision.models.efficientnet_b0(pretrained=True)
stt.save_file(m.state_dict(), 'efficientnet-b0.safetensors')
m = torchvision.models.efficientnet_b1(pretrained=True)
stt.save_file(m.state_dict(), 'efficientnet-b1.safetensors')
m = torchvision.models.efficientnet_b2(pretrained=True)
stt.save_file(m.state_dict(), 'efficientnet-b2.safetensors')
m = torchvision.models.efficientnet_b3(pretrained=True)
stt.save_file(m.state_dict(), 'efficientnet-b3.safetensors')
m = torchvision.models.efficientnet_b4(pretrained=True)
stt.save_file(m.state_dict(), 'efficientnet-b4.safetensors')