from typing import Optional
import numpy as np
from ..functional.nn import embedding as embedding_func
from ..tensor import Parameter
from . import init
from .module import Module
class Embedding(Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: Optional[float] = None,
initial_weight: Parameter = None,
freeze: bool = False,
**kwargs
):
super().__init__(**kwargs)
if padding_idx is not None:
raise ValueError("Not support padding index now.")
if max_norm is not None or norm_type is not None:
raise ValueError("Not support weight normalize now.")
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.freeze = freeze
if initial_weight is None:
self.weight = Parameter(
np.random.uniform(
size=(self.num_embeddings, self.embedding_dim)
).astype(np.float32)
)
self.reset_parameters()
else:
if initial_weight.numpy().shape != (num_embeddings, embedding_dim):
raise ValueError(
"The weight shape should match num_embeddings and embedding_dim"
)
self.weight = Parameter(initial_weight.numpy())
def reset_parameters(self) -> None:
init.normal_(self.weight)
def forward(self, inputs):
if self.freeze:
weight = self.weight.detach()
else:
weight = self.weight
return embedding_func(inputs, weight)
@classmethod
def from_pretrained(
cls,
embeddings: Parameter,
freeze: Optional[bool] = True,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: Optional[float] = None,
):
embeddings_shape = embeddings.shape
embeddings_dim = len(embeddings_shape)
if embeddings_dim != 2:
raise ValueError("Embeddings parameter is expected to be 2-dimensional")
rows = embeddings_shape[0]
cols = embeddings_shape[1]
embedding = cls(
num_embeddings=rows,
embedding_dim=cols,
initial_weight=embeddings,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
freeze=freeze,
)
return embedding