from collections import OrderedDict
from .module import Module
class Sequential(Module):
def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
self.layer_keys = []
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
setattr(self, key, module)
self.layer_keys.append(key)
else:
for idx, module in enumerate(args):
setattr(self, str(idx), module)
self.layer_keys.append(str(idx))
def __getitem__(self, idx):
if isinstance(idx, slice):
return self.__class__(
OrderedDict(zip(self.layer_keys[idx], self.layer_values[idx]))
)
else:
return getattr(self, self.layer_keys[idx])
def __setitem__(self, idx, module):
key = self.layer_keys[idx]
return setattr(self, key, module)
def __delitem__(self, idx):
if isinstance(idx, slice):
for key in self.layer_keys[idx]:
delattr(self, key)
del self.layer_keys[idx]
else:
delattr(self, self.layer_keys[idx])
del self.layer_keys[idx]
def __len__(self):
return len(self.layer_keys)
def __iter__(self):
return iter(self.layer_values)
@property
def layer_values(self):
return [getattr(self, key) for key in self.layer_keys]
def forward(self, inp):
for layer in [getattr(self, key) for key in self.layer_keys]:
inp = layer(inp)
return inp