import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'target', 'debug'))
try:
import rstorch_python as rstorch
except ImportError:
print("Error: rstorch_python not built yet. Please run 'maturin develop' first.")
sys.exit(1)
def demonstrate_devices():
print("=" * 60)
print("DEVICE MANAGEMENT")
print("=" * 60)
cpu = rstorch.PyDevice("cpu")
print(f"CPU device: {cpu}")
print(f" Type: {cpu.type}")
print(f" Index: {cpu.index}")
print()
cuda0 = rstorch.PyDevice("cuda:0")
cuda1 = rstorch.PyDevice("cuda:1")
print(f"CUDA device 0: {cuda0}")
print(f" Type: {cuda0.type}")
print(f" Index: {cuda0.index}")
print()
print(f"CUDA device 1: {cuda1}")
print(f" Type: {cuda1.type}")
print(f" Index: {cuda1.index}")
print()
metal = rstorch.PyDevice("metal:0")
print(f"Metal device: {metal}")
print(f" Type: {metal.type}")
print(f" Index: {metal.index}")
print()
cpu2 = rstorch.PyDevice("cpu")
print(f"Device equality test:")
print(f" cpu == cpu2: {cpu == cpu2}")
print(f" cpu == cuda0: {cpu == cuda0}")
print()
print(f"Device constants:")
print(f" rstorch.cpu: {rstorch.cpu}")
print()
print(f"Device utility functions:")
print(f" device_count(): {rstorch.device_count()}")
print(f" is_available(): {rstorch.is_available()}")
print(f" cuda_is_available(): {rstorch.cuda_is_available()}")
print(f" mps_is_available(): {rstorch.mps_is_available()}")
print(f" get_device_name(cpu): {rstorch.get_device_name(cpu)}")
print()
def demonstrate_dtypes():
print("=" * 60)
print("DATA TYPE HANDLING")
print("=" * 60)
float32 = rstorch.PyDType("float32")
float64 = rstorch.PyDType("float64")
int32 = rstorch.PyDType("int32")
int64 = rstorch.PyDType("int64")
bool_type = rstorch.PyDType("bool")
print("Data types:")
print(f" float32: {float32} (size: {float32.itemsize} bytes)")
print(f" float64: {float64} (size: {float64.itemsize} bytes)")
print(f" int32: {int32} (size: {int32.itemsize} bytes)")
print(f" int64: {int64} (size: {int64.itemsize} bytes)")
print(f" bool: {bool_type} (size: {bool_type.itemsize} bytes)")
print()
print("Float32 properties:")
print(f" Name: {float32.name}")
print(f" Is floating point: {float32.is_floating_point}")
print(f" Is signed: {float32.is_signed}")
print()
print("Int32 properties:")
print(f" Name: {int32.name}")
print(f" Is floating point: {int32.is_floating_point}")
print(f" Is signed: {int32.is_signed}")
print()
print("Bool properties:")
print(f" Name: {bool_type.name}")
print(f" Is floating point: {bool_type.is_floating_point}")
print(f" Is signed: {bool_type.is_signed}")
print()
f32_alias = rstorch.PyDType("f32")
print(f"Type aliases:")
print(f" 'float32' == 'f32': {float32 == f32_alias}")
print()
print("Type constants:")
print(f" rstorch.float32: {rstorch.float32}")
print(f" rstorch.float64: {rstorch.float64}")
print(f" rstorch.int32: {rstorch.int32}")
print(f" rstorch.int64: {rstorch.int64}")
print(f" rstorch.bool: {rstorch.bool}")
print()
print("PyTorch-style aliases:")
print(f" rstorch.float: {rstorch.float}")
print(f" rstorch.double: {rstorch.double}")
print(f" rstorch.long: {rstorch.long}")
print(f" rstorch.int: {rstorch.int}")
print()
def demonstrate_error_handling():
print("=" * 60)
print("ERROR HANDLING")
print("=" * 60)
try:
error = rstorch.TorshError("This is a test error")
print(f"TorshError created: {error}")
print(f" Repr: {repr(error)}")
except Exception as e:
print(f"Error creating TorshError: {e}")
print()
try:
invalid_device = rstorch.PyDevice("invalid_device")
except ValueError as e:
print(f"✓ Caught ValueError for invalid device: {e}")
print()
try:
invalid_cuda = rstorch.PyDevice("cuda:abc")
except ValueError as e:
print(f"✓ Caught ValueError for invalid CUDA ID: {e}")
print()
try:
negative_device = rstorch.PyDevice(-1)
except ValueError as e:
print(f"✓ Caught ValueError for negative device ID: {e}")
print()
try:
invalid_dtype = rstorch.PyDType("invalid_dtype")
except ValueError as e:
print(f"✓ Caught ValueError for invalid dtype: {e}")
print()
try:
unsupported_dtype = rstorch.PyDType("uint16")
except ValueError as e:
print(f"✓ Caught ValueError for unsupported dtype: {e}")
print()
def demonstrate_version():
print("=" * 60)
print("VERSION INFORMATION")
print("=" * 60)
print(f"ToRSh Python version: {rstorch.__version__}")
print()
def main():
print("\n" + "=" * 60)
print("TORSH PYTHON BINDINGS - BASIC USAGE EXAMPLES")
print("=" * 60)
print()
demonstrate_version()
demonstrate_devices()
demonstrate_dtypes()
demonstrate_error_handling()
print("=" * 60)
print("All demonstrations completed successfully!")
print("=" * 60)
print()
print("Note: Tensor operations are currently disabled due to dependency issues.")
print(" Please check the TODO.md file for the roadmap to re-enable them.")
if __name__ == "__main__":
main()