import numpy as np
import pytest
from megengine import tensor
from megengine.module import Conv1d, Conv2d, Conv3d, Linear
from megengine.module.init import calculate_fan_in_and_fan_out, fill_
def test_fill_():
x = tensor(np.zeros((2, 3, 4)), dtype=np.float32)
fill_(x, 5.0)
np.testing.assert_array_equal(
x.numpy(), np.full(shape=(2, 3, 4), fill_value=5.0, dtype=np.float32)
)
def test_calculate_fan_in_and_fan_out():
l = Linear(in_features=3, out_features=8)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 3
assert fanout == 8
with pytest.raises(ValueError):
calculate_fan_in_and_fan_out(l.bias)
l = Conv1d(in_channels=2, out_channels=3, kernel_size=5)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 * 5
assert fanout == 3 * 5
l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7))
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 * 5 * 7
assert fanout == 3 * 5 * 7
l = Conv2d(in_channels=2, out_channels=4, kernel_size=(5, 7), groups=2)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 // 2 * 5 * 7
assert fanout == 4 // 2 * 5 * 7
l = Conv3d(in_channels=2, out_channels=4, kernel_size=(5, 7, 9), groups=2)
fanin, fanout = calculate_fan_in_and_fan_out(l.weight)
assert fanin == 2 // 2 * 5 * 7 * 9
assert fanout == 4 // 2 * 5 * 7 * 9