decl_raw_opr('set_grad',
inputs=['src',
Doc('grad_getter', 'a function that takes the computing '
'graph for grad computing and return grad var',
'callable')],
inputs_cvt=[('grad_getter', '_mgb._SetGradCallbackPyWrapper')],
desc='output equals to input, but grad(input) would be replaced '
'by return value of given callback at runtime.')
decl_raw_opr('callback_injector',
inputs=['src',
Doc('cb', 'the callback function, taking a value of '
':class:`.CompGraphCallbackValueProxy`',
'callable')],
inputs_cvt=[('cb', '_mgb._CompGraphCallbackPyWrapperNoEager')],
desc='arrange a callback to be called whenever this operator is '
'executed')
decl_opr(
'MarkNoBroadcastElemwise',
desc='assert the output var would never be broadcasted when involved in '
'elementwise operations; this is useful to remove the reduce opr in grad '
'so graph optimizer can work under some conditions',
inputs=['src'],
params='Empty'
)
decl_opr(
'Identity',
desc='forward the input value; usually used for preventing the graph '
'optimizer from removing certain vars so gradients can be correctly '
'computed',
inputs=['src'],
params='Empty'
)
decl_opr(
'AssertEqual',
desc='assert that values of the input vars are equal; used for debug',
inputs=['expect', 'get'],
params='AssertEqual'
)
decl_raw_opr(
'timestamp',
desc='get a timestamp when this operator is executed; this is useful '
'for profiling a group of operators. The time measured in seconds '
'when ``input`` is executed would be written to ``dest_off`` in the '
'``dest`` array. Timestamp values are relative to a fixed point of '
'current function and computing node',
inputs=[Doc('input', 'an input var to be waited'),
Doc('dest', 'a numpy 1-dim array to receive the output result'),
Doc('dest_off', 'offset to write result in the dest array')],
body=[
'output = _mgb._Opr.timestamp(input, dest, dest_off, config)'
]
)
decl_raw_opr(
'virtual_dep',
desc='Make a virtual dependency opr, to make sure inputs\' operators '
'finished when executing virtual_dep opr. Forward input(0) to output',
inputs=[Doc('symvars', 'input symvars')],
body=[
'output = _mgb._Opr.virtual_dep(symvars, config)'
]
)
decl_raw_opr(
'virtual_loss',
desc='construct a loss var so that the gradients w.r.t. to ``ys[i]`` are '
'``y_grads[i]``',
inputs=['ys', 'y_grads'],
local_defs=['cvt_inp = lambda inp: '
'_helper.canonize_input_vars(inp, comp_graph=comp_graph, '
'config=config)'],
inputs_cvt=[('y_grads', 'cvt_inp')]
)
decl_opr(
'PersistentOutputStorage',
desc='copy input to an output var with persistent storage',
inputs=['inp'],
params='PersistentOutputStorage'
)
decl_opr(
'RequireInputDynamicStorage',
inputs=['inp'],
params='Empty'
)
decl_raw_opr(
'shape_hint',
desc='a special op providing shape hint only used in graph compilation',
inputs=[Doc('input', 'input var the shape hint was on'),
Doc('shape', 'given hint shape', 'list of int'),
Doc('is_const', 'whether treat given shape as constant', 'bool', 'False')],
body=[
'output = _mgb._Opr.shape_hint(input, shape, is_const, config)'
]
)
# vim: ft=python