import os
from setuptools import setup
from setuptools.command.test import test as TestCommand
from setuptools.command.sdist import sdist as SdistCommand
import sys
try:
from setuptools_rust import RustExtension
except ImportError:
import subprocess
errno = subprocess.call([sys.executable, "-m", "pip", "install", "setuptools-rust"])
if errno:
print("Please install setuptools-rust package")
raise SystemExit(errno)
else:
from setuptools_rust import RustExtension
class CargoModifiedSdist(SdistCommand):
def make_release_tree(self, base_dir, files):
super().make_release_tree(base_dir, files)
import toml
cargo_loc = os.path.join(base_dir, "Cargo.toml")
assert os.path.exists(cargo_loc)
with open(cargo_loc, "r") as f:
cargo_toml = toml.load(f)
rel_pyo3_path = cargo_toml["dependencies"]["pyo3"]["path"]
base_path = os.path.dirname(__file__)
abs_pyo3_path = os.path.abspath(os.path.join(base_path, rel_pyo3_path))
cargo_toml["dependencies"]["pyo3"]["path"] = abs_pyo3_path
with open(cargo_loc, "w") as f:
toml.dump(cargo_toml, f)
class PyTest(TestCommand):
user_options = []
def run(self):
self.run_command("test_rust")
import subprocess
subprocess.check_call(["pytest", "tests"])
setup_requires = ["setuptools-rust>=0.10.1", "wheel"]
install_requires = ["torch>=1.1.0", "transformers==2.2.1"]
test_requires = install_requires + ["pytest", "pytest-benchmark"]
setup(
name="rust_transformers",
version="0.1.0",
packages=["rust_transformers"],
rust_extensions=[RustExtension("rust_transformers.rust_transformers", "Cargo.toml", debug=False)],
install_requires=install_requires,
setup_requires=setup_requires,
test_requires=test_requires,
include_package_data=True,
zip_safe=False,
cmdclass={"test": PyTest, "sdist": CargoModifiedSdist},
)