tch 0.23.0

Rust wrappers for the PyTorch C++ api (libtorch).
Documentation
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
#
# When uploading crates to the registry Cargo will automatically
# "normalize" Cargo.toml files for maximal compatibility
# with all versions of Cargo and also rewrite `path` dependencies
# to registry (e.g., crates.io) dependencies.
#
# If you are reading this file be aware that the original Cargo.toml
# will likely look very different (and much more reasonable).
# See Cargo.toml.orig for the original contents.

[package]
edition = "2021"
name = "tch"
version = "0.23.0"
authors = ["Laurent Mazare <lmazare@gmail.com>"]
build = "build.rs"
exclude = ["examples/stable-diffusion/media/*"]
autolib = false
autobins = false
autoexamples = false
autotests = false
autobenches = false
description = "Rust wrappers for the PyTorch C++ api (libtorch)."
readme = "README.md"
keywords = [
    "pytorch",
    "deep-learning",
    "machine-learning",
]
categories = ["science"]
license = "MIT/Apache-2.0"
repository = "https://github.com/LaurentMazare/tch-rs"

[package.metadata.docs.rs]
features = ["doc-only"]

[features]
cuda-tests = []
doc-only = ["torch-sys/doc-only"]
download-libtorch = ["torch-sys/download-libtorch"]
python-extension = ["torch-sys/python-extension"]
rl-python = ["cpython"]

[lib]
name = "tch"
path = "src/lib.rs"

[[example]]
name = "basics"
path = "examples/basics.rs"

[[example]]
name = "char-rnn"
path = "examples/char-rnn/main.rs"

[[example]]
name = "cifar"
path = "examples/cifar/main.rs"

[[example]]
name = "custom-optimizer"
path = "examples/custom-optimizer/main.rs"

[[example]]
name = "gan"
path = "examples/gan/main.rs"

[[example]]
name = "jit"
path = "examples/jit/main.rs"

[[example]]
name = "jit-quantized"
path = "examples/jit-quantized/main.rs"

[[example]]
name = "jit-trace"
path = "examples/jit-trace/main.rs"

[[example]]
name = "jit-train"
path = "examples/jit-train/main.rs"

[[example]]
name = "llama"
path = "examples/llama/main.rs"
required-features = [
    "regex",
    "clap",
    "serde_json",
    "memmap2",
]

[[example]]
name = "memory_test"
path = "examples/memory_test.rs"

[[example]]
name = "min-gpt"
path = "examples/min-gpt/main.rs"

[[example]]
name = "mnist"
path = "examples/mnist/main.rs"

[[example]]
name = "neural-style-transfer"
path = "examples/neural-style-transfer/main.rs"

[[example]]
name = "pretrained-models"
path = "examples/pretrained-models/main.rs"

[[example]]
name = "reinforcement-learning"
path = "examples/reinforcement-learning/main.rs"
required-features = ["rl-python"]

[[example]]
name = "stable-diffusion"
path = "examples/stable-diffusion/main.rs"
required-features = ["regex"]

[[example]]
name = "tensor-tools"
path = "examples/tensor-tools.rs"

[[example]]
name = "transfer-learning"
path = "examples/transfer-learning/main.rs"

[[example]]
name = "translation"
path = "examples/translation/main.rs"

[[example]]
name = "vae"
path = "examples/vae/main.rs"

[[example]]
name = "yolo"
path = "examples/yolo/main.rs"

[[test]]
name = "autocast"
path = "tests/autocast.rs"

[[test]]
name = "data_tests"
path = "tests/data_tests.rs"

[[test]]
name = "device_tests"
path = "tests/device_tests.rs"

[[test]]
name = "display_tests"
path = "tests/display_tests.rs"

[[test]]
name = "jit_tests"
path = "tests/jit_tests.rs"

[[test]]
name = "nn_tests"
path = "tests/nn_tests.rs"

[[test]]
name = "serialization_tests"
path = "tests/serialization_tests.rs"

[[test]]
name = "tensor_indexing"
path = "tests/tensor_indexing.rs"

[[test]]
name = "tensor_tests"
path = "tests/tensor_tests.rs"

[[test]]
name = "test_utils"
path = "tests/test_utils.rs"

[[test]]
name = "var_store"
path = "tests/var_store.rs"

[[test]]
name = "vision_tests"
path = "tests/vision_tests.rs"

[dependencies.clap]
version = "4.2.4"
features = ["derive"]
optional = true

[dependencies.cpython]
version = "0.7.1"
optional = true

[dependencies.half]
version = "2"

[dependencies.image]
version = "0.24.5"
optional = true

[dependencies.lazy_static]
version = "1.3.0"

[dependencies.libc]
version = "0.2.0"

[dependencies.memmap2]
version = "0.6.1"
optional = true

[dependencies.ndarray]
version = "0.16.1"

[dependencies.rand]
version = "0.8"

[dependencies.regex]
version = "1.6.0"
optional = true

[dependencies.safetensors]
version = "0.3.0"

[dependencies.serde_json]
version = "1.0.96"
optional = true

[dependencies.thiserror]
version = "1"

[dependencies.torch-sys]
version = "0.23.0"

[dependencies.zip]
version = "0.6"

[dev-dependencies.anyhow]
version = "^1.0.60"