import unittest
import mlx.core as mx
import mlx_tests
class TestDefaultDevice(unittest.TestCase):
def test_mlx_default_device(self):
device = mx.default_device()
if mx.is_available(mx.gpu):
self.assertEqual(device, mx.Device(mx.gpu))
self.assertEqual(str(device), "Device(gpu, 0)")
self.assertEqual(device, mx.gpu)
self.assertEqual(mx.gpu, device)
else:
self.assertEqual(device.type, mx.Device(mx.cpu))
with self.assertRaises(ValueError):
mx.set_default_device(mx.gpu)
class TestDevice(mlx_tests.MLXTestCase):
def test_device(self):
device = mx.default_device()
cpu = mx.Device(mx.cpu)
mx.set_default_device(cpu)
self.assertEqual(mx.default_device(), cpu)
self.assertEqual(str(cpu), "Device(cpu, 0)")
mx.set_default_device(mx.cpu)
self.assertEqual(mx.default_device(), mx.cpu)
self.assertEqual(cpu, mx.cpu)
self.assertEqual(mx.cpu, cpu)
mx.set_default_device(device)
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_device_context(self):
default = mx.default_device()
diff = mx.cpu if default == mx.gpu else mx.gpu
self.assertNotEqual(default, diff)
with mx.stream(diff):
a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2)))
mx.eval(a)
self.assertEqual(mx.default_device(), diff)
self.assertEqual(mx.default_device(), default)
def test_op_on_device(self):
x = mx.array(1.0)
y = mx.array(1.0)
a = mx.add(x, y, stream=None)
b = mx.add(x, y, stream=mx.default_device())
self.assertEqual(a.item(), b.item())
b = mx.add(x, y, stream=mx.cpu)
self.assertEqual(a.item(), b.item())
if mx.metal.is_available():
b = mx.add(x, y, stream=mx.gpu)
self.assertEqual(a.item(), b.item())
class TestStream(mlx_tests.MLXTestCase):
def test_stream(self):
s1 = mx.default_stream(mx.default_device())
self.assertEqual(s1.device, mx.default_device())
s2 = mx.new_stream(mx.default_device())
self.assertEqual(s2.device, mx.default_device())
self.assertNotEqual(s1, s2)
if mx.is_available(mx.gpu):
s_gpu = mx.default_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
with self.assertRaises(ValueError):
mx.default_stream(mx.gpu)
s_cpu = mx.default_stream(mx.cpu)
self.assertEqual(s_cpu.device, mx.cpu)
s_cpu = mx.new_stream(mx.cpu)
self.assertEqual(s_cpu.device, mx.cpu)
if mx.is_available(mx.gpu):
s_gpu = mx.new_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
with self.assertRaises(ValueError):
mx.new_stream(mx.gpu)
def test_op_on_stream(self):
x = mx.array(1.0)
y = mx.array(1.0)
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
if mx.is_available(mx.gpu):
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
self.assertEqual(a.item(), b.item())
s_gpu = mx.new_stream(mx.gpu)
b = mx.add(x, y, stream=s_gpu)
self.assertEqual(a.item(), b.item())
b = mx.add(x, y, stream=mx.default_stream(mx.cpu))
self.assertEqual(a.item(), b.item())
s_cpu = mx.new_stream(mx.cpu)
b = mx.add(x, y, stream=s_cpu)
self.assertEqual(a.item(), b.item())
class TestDeviceInfo(mlx_tests.MLXTestCase):
def test_device_count(self):
cpu_count = mx.device_count(mx.cpu)
self.assertIsInstance(cpu_count, int)
self.assertEqual(cpu_count, 1)
gpu_count = mx.device_count(mx.gpu)
self.assertIsInstance(gpu_count, int)
self.assertGreaterEqual(gpu_count, 0)
def test_device_info_cpu(self):
info = mx.device_info(mx.cpu)
self.assertIsInstance(info, dict)
self.assertIn("device_name", info)
self.assertTrue(len(info["device_name"]) > 0)
self.assertIn("architecture", info)
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_device_info_gpu(self):
gpu_count = mx.device_count(mx.gpu)
for i in range(gpu_count):
info = mx.device_info(mx.Device(mx.gpu, i))
self.assertIsInstance(info, dict)
self.assertIn("device_name", info)
self.assertTrue(len(info["device_name"]) > 0)
self.assertIn("architecture", info)
def test_device_info_default(self):
info = mx.device_info()
self.assertIsInstance(info, dict)
self.assertIn("device_name", info)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()