decl_raw_opr(
'collective_comm',
inputs = [
Doc('input', 'Input var.'),
Doc('key', 'The key to NCCL cliques. Operators with same key belong '
'to the same NCCL operation.', 'str'),
Doc('nr_devices', 'Total number of devices involved in the NCCL '
'operation to which this operator belongs.', 'int'),
Doc('is_root', 'whether this node is root node', 'bool'),
Doc('rank', 'rank of this node, if is -1, generate one', 'int'),
Doc('server_addr', 'rpc server ip address'),
Doc('port', 'server rpc listening port'),
Doc('param', 'The only component of *param* is *mode*, which refers to '
'a specific NCCL operation type.',
':class:`~megbrain.opr_param_defs.CollectiveComm`'),
Doc('dtype', 'Data type of inputs and outputs. Currently this is '
'required by BROADCAST and optional to other operations. If '
'specified, it must be consistent with the *dtype* of inputs (if '
'any).', ':class:`~megbrain.opr_param_defs.DType`', 'None'),
Doc('backend', 'Backend for collective communication, nccl or ucx',
'str', '\'nccl\''),
Doc('local_grad', 'whether use local grad', 'bool', 'False'),
Doc('output_buffer', 'The external dev buffer reserving output result',
':class:`.SharedND`', 'None'),
Doc('disable', 'If true, the execution will return directly and the output '
'is a random value. All the disable should be same in one collective '
'communication group.', ':class:`.SharedScalar`', '_mgb.SharedScalar(0)')
],
body = [
'if isinstance(input, _mgb.SymbolVar):',
(' output = _mgb._Opr.collective_comm_with_input(input, key, '
'nr_devices, is_root, rank, local_grad, server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)'),
'else:',
' assert isinstance(input, _mgb.CompGraph)',
(' output = _mgb._Opr.collective_comm_without_input(input, key, '
'nr_devices, is_root, rank, local_grad, server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)')
],
desc = ('collective communication between multiple CompNodes on multiple '
'machines')
)
# vim: ft=python