import mlx.core as mx
import mlx.optimizers as optim
import mlx_distributed_tests
import mlx_tests
from mlx.nn.utils import average_gradients, fsdp_apply_gradients
class TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
@classmethod
def setUpClass(cls):
_ = mx.distributed.init(strict=True, backend="nccl")
cls.atol = 1e-4
cls.rtol = 1e-4
def test_sum_scatter(self):
world = mx.distributed.init()
dtypes = [
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
]
sizes = [
(8,),
(64,),
(1024,),
(1024, 1024),
]
key = mx.random.key(world.rank())
for dt, rtol in dtypes:
for sh in sizes:
x = (mx.random.uniform(shape=sh, key=key) * 10).astype(dt)
y = mx.distributed.sum_scatter(x) z = mx.distributed.all_sum(x) chunk = sh[0] // world.size()
start = world.rank() * chunk
stop = start + chunk
z_ref = z[start:stop]
maxrelerror = (y - z_ref).abs()
if rtol > 0:
maxrelerror /= z_ref.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
def test_groups(self):
world = mx.distributed.init()
self.assertEqual(world.size(), 8)
self.assertTrue(0 <= world.rank() < 8)
world2 = mx.distributed.init()
self.assertEqual(world.size(), world2.size())
self.assertEqual(world.rank(), world2.rank())
sub = world.split(world.rank() % 2)
self.assertEqual(sub.size(), 4)
self.assertEqual(sub.rank(), world.rank() // 2)
sub = world.split(world.rank() // 2)
self.assertEqual(sub.size(), 2)
def test_all_reduce_split(self):
world = mx.distributed.init()
dtypes = [
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
group = world.split(world.rank() % 2)
for dt, rtol in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(group.size(),) + sh, key=key) * 10
).astype(dt)
y = mx.distributed.all_sum(x[group.rank()], group=group)
z = x.sum(0)
maxrelerror = (y - z).abs()
if rtol > 0:
maxrelerror /= z.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
y = mx.distributed.all_max(x[group.rank()], group=group)
z = x.max(0)
self.assertTrue(mx.all(y == z))
y = mx.distributed.all_min(x[group.rank()], group=group)
z = x.min(0)
self.assertTrue(mx.all(y == z))
def test_all_gather_split(self):
world = mx.distributed.init()
dtypes = [mx.float32, mx.float16, mx.bfloat16]
sub = world.split(world.rank() % 2)
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_gather(x, group=sub)
self.assertEqual(y.shape, (sub.size() * 2, 2, 4))
self.assertTrue(mx.all(y == 1))
def test_fsdp_apply_gradients(self):
world = mx.distributed.init()
N = world.size()
params = {
"w1": mx.ones((N * 10, 8)),
"w2": mx.ones((N * 20,)),
}
grads = {
"w1": mx.ones((N * 10, 8)) * 0.1,
"w2": mx.ones((N * 20,)) * 0.1,
}
optimizer = optim.SGD(learning_rate=0.1)
updated_params_fsdp = fsdp_apply_gradients(grads, params, optimizer)
mx.eval(updated_params_fsdp)
self.assertEqual(updated_params_fsdp["w1"].shape, (N * 10, 8))
self.assertEqual(updated_params_fsdp["w2"].shape, (N * 20,))
self.assertTrue(
mx.allclose(
updated_params_fsdp["w1"], mx.ones((N * 10, 8)) * 0.99, atol=1e-6
)
)
self.assertTrue(
mx.allclose(updated_params_fsdp["w2"], mx.ones((N * 20,)) * 0.99, atol=1e-6)
)
grads = {
"w1": mx.ones((N * 10, 8)) * 10.0,
"w2": mx.ones((N * 20,)) * 10.0,
}
new_params_clipped, grad_norm = fsdp_apply_gradients(
grads, params, optimizer, max_norm=1.0
)
mx.eval(new_params_clipped, grad_norm)
self.assertIsNotNone(grad_norm)
expected_norm = mx.sqrt((N * 10 * 8 + N * 20) * 100.0)
self.assertTrue(mx.allclose(grad_norm, expected_norm, atol=1e-4, rtol=1e-4))
self.assertEqual(new_params_clipped["w1"].shape, (N * 10, 8))
self.assertEqual(new_params_clipped["w2"].shape, (N * 20,))
scale = 1.0 / expected_norm
expected_update = 1.0 - 0.1 * 10.0 * scale
self.assertTrue(
mx.allclose(
new_params_clipped["w1"],
mx.ones((N * 10, 8)) * expected_update,
atol=1e-4,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
new_params_clipped["w2"],
mx.ones((N * 20,)) * expected_update,
atol=1e-4,
rtol=1e-4,
)
)
params = {"w": mx.ones((N * 4,))}
grads = {"w": mx.ones((N * 4,)) * 0.5}
optimizer_fsdp = optim.SGD(learning_rate=0.1)
updated_params_fsdp = fsdp_apply_gradients(grads, params, optimizer_fsdp)
optimizer_ddp = optim.SGD(learning_rate=0.1)
avg_grads = average_gradients(grads)
updated_params_ddp = optimizer_ddp.apply_gradients(avg_grads, params)
mx.eval(updated_params_ddp, updated_params_fsdp)
self.assertTrue(
mx.allclose(
updated_params_fsdp["w"], updated_params_ddp["w"], atol=1e-6, rtol=1e-4
),
)
def test_fsdp_ddp_apply_gradients(self):
world = mx.distributed.init()
N = world.size()
S = 4
fsdp_group = world.split(world.rank() // S)
dp_group = world.split(world.rank() % S)
self.assertEqual(fsdp_group.size(), S)
self.assertEqual(dp_group.size(), N // S)
params = {
"w1": mx.ones((S * 10, 8)),
"w2": mx.ones((S * 20,)),
}
grads = {
"w1": mx.ones((S * 10, 8)) * 0.1,
"w2": mx.ones((S * 20,)) * 0.1,
}
optimizer = optim.SGD(learning_rate=0.1)
updated = fsdp_apply_gradients(
grads,
params,
optimizer,
fsdp_group=fsdp_group,
dp_group=dp_group,
)
mx.eval(updated)
self.assertEqual(updated["w1"].shape, (S * 10, 8))
self.assertEqual(updated["w2"].shape, (S * 20,))
self.assertTrue(
mx.allclose(updated["w1"], mx.ones((S * 10, 8)) * 0.99, atol=1e-6)
)
self.assertTrue(
mx.allclose(updated["w2"], mx.ones((S * 20,)) * 0.99, atol=1e-6)
)
grads_big = {
"w1": mx.ones((S * 10, 8)) * 10.0,
"w2": mx.ones((S * 20,)) * 10.0,
}
optimizer2 = optim.SGD(learning_rate=0.1)
clipped, grad_norm = fsdp_apply_gradients(
grads_big,
params,
optimizer2,
fsdp_group=fsdp_group,
dp_group=dp_group,
max_norm=1.0,
)
mx.eval(clipped, grad_norm)
self.assertIsNotNone(grad_norm)
expected_norm = mx.sqrt((S * 10 * 8 + S * 20) * 100.0)
self.assertTrue(mx.allclose(grad_norm, expected_norm, atol=1e-4, rtol=1e-4))
self.assertEqual(clipped["w1"].shape, (S * 10, 8))
self.assertEqual(clipped["w2"].shape, (S * 20,))
scale = 1.0 / expected_norm
expected_update = 1.0 - 0.1 * 10.0 * scale
self.assertTrue(
mx.allclose(
clipped["w1"],
mx.ones((S * 10, 8)) * expected_update,
atol=1e-4,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
clipped["w2"],
mx.ones((S * 20,)) * expected_update,
atol=1e-4,
rtol=1e-4,
)
)
params_eq = {"w": mx.ones((S * 4,))}
grads_eq = {"w": mx.ones((S * 4,)) * 0.5}
optimizer_hybrid = optim.SGD(learning_rate=0.1)
updated_hybrid = fsdp_apply_gradients(
grads_eq,
params_eq,
optimizer_hybrid,
fsdp_group=fsdp_group,
dp_group=dp_group,
)
optimizer_ddp = optim.SGD(learning_rate=0.1)
avg_grads = average_gradients(grads_eq)
updated_ddp = optimizer_ddp.apply_gradients(avg_grads, params_eq)
mx.eval(updated_hybrid, updated_ddp)
self.assertTrue(
mx.allclose(updated_hybrid["w"], updated_ddp["w"], atol=1e-6, rtol=1e-4),
)
def test_fsdp_peak_memory(self):
world = mx.distributed.init()
N = world.size()
mx.random.seed(42)
params = {
"w1": mx.random.normal((N * 1024, 1024)),
"w2": mx.random.normal((N * 2048, 512)),
}
grads = {
"w1": mx.random.normal((N * 1024, 1024)),
"w2": mx.random.normal((N * 2048, 512)),
}
mx.eval(params, grads)
optimizer_ddp = optim.Adam(learning_rate=0.01)
optimizer_fsdp = optim.Adam(learning_rate=0.01)
def pseudo_step_ddp(grads, params, optimizer):
grads = average_gradients(grads)
grads, grad_norm = optim.clip_grad_norm(grads, max_norm=1.0)
params = optimizer.apply_gradients(grads, params)
return grad_norm, params
def pseudo_step_fsdp(grads, params, optimizer):
params, grad_norm = fsdp_apply_gradients(
grads, params, optimizer, max_norm=1.0
)
return grad_norm, params
mx.reset_peak_memory()
for i in range(10):
grad_norm, params = pseudo_step_ddp(grads, params, optimizer_ddp)
mx.eval(grad_norm, params)
ddp_peak_memory = mx.get_peak_memory()
mx.reset_peak_memory()
for i in range(10):
grad_norm, params = pseudo_step_fsdp(grads, params, optimizer_fsdp)
mx.eval(grad_norm, params)
fsdp_peak_memory = mx.get_peak_memory()
self.assertTrue(fsdp_peak_memory < ddp_peak_memory)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()