import logging
from dataclasses import dataclass
from typing import NamedTuple, Optional, Sequence, Union
import haiku as hk
import jax
import jax.numpy as jnp
logger = logging.getLogger(__name__)
class TrainingState(NamedTuple):
params: hk.Params
def ffn_size(emb_size, widening_factor):
_ffn_size = int(widening_factor * emb_size) * 2 // 3
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
return _ffn_size
def make_recsys_attn_mask(
seq_len: int,
candidate_start_offset: int,
dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype))
attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0)
candidate_indices = jnp.arange(candidate_start_offset, seq_len)
attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1)
return attn_mask
class MHAOutput(NamedTuple):
embeddings: jax.Array
class DecoderOutput(NamedTuple):
embeddings: jax.Array
class TransformerOutput(NamedTuple):
embeddings: jax.Array
@dataclass
class TransformerConfig:
emb_size: int
key_size: int
num_q_heads: int
num_kv_heads: int
num_layers: int
widening_factor: float = 4.0
attn_output_multiplier: float = 1.0
name: Optional[str] = None
def make(self) -> "Transformer":
return Transformer(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
widening_factor=self.widening_factor,
key_size=self.key_size,
attn_output_multiplier=self.attn_output_multiplier,
num_layers=self.num_layers,
)
def hk_rms_norm(
x: jax.Array,
fixed_scale=False,
) -> jax.Array:
ln = RMSNorm(axis=-1, create_scale=not fixed_scale)
return ln(x)
class Linear(hk.Linear):
def __init__(
self,
output_size: int,
with_bias: bool = True,
name: Optional[str] = None,
):
super().__init__(
output_size=output_size,
with_bias=with_bias,
name=name,
)
def __call__( self,
inputs: jax.Array,
) -> jax.Array:
fprop_dtype = inputs.dtype
if not inputs.shape:
raise ValueError("Input must not be scalar.")
input_size = inputs.shape[-1]
output_size = self.output_size
w = hk.get_parameter(
"w", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0)
)
out = jnp.dot(inputs, w.astype(fprop_dtype))
if self.with_bias:
b = hk.get_parameter(
"b", [self.output_size], jnp.float32, init=hk.initializers.Constant(0)
)
b = jnp.broadcast_to(b, out.shape)
out = out + b.astype(fprop_dtype)
return out
class RMSNorm(hk.RMSNorm):
def __init__(
self,
axis: Union[int, Sequence[int], slice],
eps: float = 1e-5,
name: Optional[str] = None,
create_scale: bool = True,
):
super().__init__(axis, eps, create_scale=create_scale, name=name)
def __call__(self, inputs: jax.Array):
fprop_dtype = inputs.dtype
param_shape = (inputs.shape[-1],)
if self.create_scale:
scale = hk.get_parameter(
"scale",
param_shape,
dtype=jnp.float32,
init=hk.initializers.Constant(0),
)
scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape)
else:
scale = 1.0
inputs = inputs.astype(jnp.float32)
scale = jnp.float32(scale)
mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
mean_squared = jnp.broadcast_to(mean_squared, inputs.shape)
normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps)
outputs = scale * normed_inputs
return outputs.astype(fprop_dtype)
def rotate_half(
x: jax.Array,
) -> jax.Array:
x1, x2 = jnp.split(x, 2, axis=-1)
return jnp.concatenate((-x2, x1), axis=-1)
class RotaryEmbedding(hk.Module):
def __init__(
self,
dim: int,
name: Optional[str] = None,
base_exponent: int = 10000,
):
super().__init__(name)
self.dim = dim
self.base_exponent = base_exponent
assert self.dim % 2 == 0
def __call__(
self,
x: jax.Array,
seq_dim: int,
offset: jax.Array,
const_position: Optional[int] = None,
t: Optional[jax.Array] = None,
) -> jax.Array:
fprop_dtype = x.dtype
exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
inv_freq = jnp.asarray(
1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32
)
if jnp.shape(offset) == ():
offset = jnp.expand_dims(offset, 0)
if const_position:
t = const_position * jnp.ones(
(
1,
x.shape[seq_dim],
),
dtype=jnp.float32,
)
elif t is None:
t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1)
phase = jnp.einsum("bi,j->bij", t, inv_freq)
phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :]
x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)
x = x.astype(fprop_dtype)
return x
class MultiHeadAttention(hk.Module):
def __init__(
self,
num_q_heads: int,
num_kv_heads: int,
key_size: int,
*,
with_bias: bool = True,
value_size: Optional[int] = None,
model_size: Optional[int] = None,
attn_output_multiplier: float = 1.0,
name: Optional[str] = None,
):
super().__init__(name=name)
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.key_size = key_size
self.value_size = value_size or key_size
self.model_size = model_size or key_size * num_q_heads
self.attn_output_multiplier = attn_output_multiplier
self.with_bias = with_bias
def __call__(
self,
query: jax.Array,
key: jax.Array,
value: jax.Array,
mask: jax.Array,
) -> MHAOutput:
projection = self._linear_projection
assert key.shape[:2] == value.shape[:2], f"key/value shape: {key.shape}/{value.shape}"
if mask is not None:
assert mask.ndim == 4
assert mask.shape[0] in {
1,
query.shape[0],
}, f"mask/query shape: {mask.shape}/{query.shape}"
assert key.shape[0] in {
1,
query.shape[0],
}, f"key/query shape: {key.shape}/{query.shape}"
assert mask.shape[1] == 1
assert mask.shape[2] in {
1,
query.shape[1],
}, f"mask/query shape: {mask.shape}/{query.shape}"
assert mask.shape[3] in {
1,
key.shape[1],
}, f"mask/query shape: {mask.shape}/{key.shape}"
assert self.num_q_heads % self.num_kv_heads == 0
query_heads = projection(query, self.key_size, self.num_q_heads, name="query")
key_heads = projection(key, self.key_size, self.num_kv_heads, name="key")
value_heads = projection(value, self.value_size, self.num_kv_heads, name="value")
rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))
key_heads = rotate(key_heads, seq_dim=1, offset=0)
query_heads = rotate(query_heads, seq_dim=1, offset=0)
b, t, h, d = query_heads.shape
_, _, kv_h, _ = key_heads.shape
assert h % kv_h == 0, f"query_heads {h} must be a multiple of kv_heads {kv_h}"
query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))
attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads).astype(
jnp.float32
)
attn_logits *= self.attn_output_multiplier
max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)
attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)
mask = mask[:, :, None, :, :]
if mask is not None:
if mask.ndim != attn_logits.ndim:
raise ValueError(
f"Mask dimensionality {mask.ndim} must match logits dimensionality "
f"{attn_logits.ndim} for {mask.shape}/{attn_logits.shape}."
)
attn_logits = jnp.where(mask, attn_logits, -1e30)
attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype)
attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads)
leading_dims = attn.shape[:2]
attn = jnp.reshape(attn, (*leading_dims, -1))
final_projection = Linear(self.model_size, with_bias=False)
return MHAOutput(final_projection(attn))
@hk.transparent
def _linear_projection(
self,
x: jax.Array,
head_size: int,
num_heads: int,
name: Optional[str] = None,
) -> jax.Array:
y = Linear(num_heads * head_size, with_bias=False, name=name)(x)
*leading_dims, _ = x.shape
return y.reshape((*leading_dims, num_heads, head_size))
@dataclass
class MHABlock(hk.Module):
num_q_heads: int
num_kv_heads: int
key_size: int
attn_output_multiplier: float = 1.0
@hk.transparent
def __call__(
self,
inputs: jax.Array, mask: jax.Array, ) -> MHAOutput:
_, _, model_size = inputs.shape
assert mask.ndim == 4, f"shape: {mask.shape}"
assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape)
assert mask.shape[3] in {1, inputs.shape[1]}, str(mask.shape)
side_input = inputs
def attn_block(query, key, value, mask) -> MHAOutput:
return MultiHeadAttention(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
model_size=model_size,
attn_output_multiplier=self.attn_output_multiplier,
)(query, key, value, mask)
attn_output = attn_block(inputs, side_input, side_input, mask)
h_attn = attn_output.embeddings
return MHAOutput(embeddings=h_attn)
@dataclass
class DenseBlock(hk.Module):
num_q_heads: int
num_kv_heads: int
key_size: int
widening_factor: float = 4.0
@hk.transparent
def __call__(
self,
inputs: jax.Array, ) -> jax.Array: _, _, model_size = inputs.shape
h_v = Linear(
ffn_size(model_size, self.widening_factor),
with_bias=False,
name="linear_v",
)(inputs)
h_w1 = jax.nn.gelu(
Linear(
ffn_size(model_size, self.widening_factor),
with_bias=False,
)(inputs)
)
h_dense = Linear(model_size, with_bias=False)(h_w1 * h_v)
return h_dense
@dataclass
class DecoderLayer(hk.Module):
num_q_heads: int
num_kv_heads: int
key_size: int
num_layers: int
layer_index: Optional[int] = None
widening_factor: float = 4.0
name: Optional[str] = None
attn_output_multiplier: float = 1.0
def __call__(
self,
inputs: jax.Array, mask: jax.Array, padding_mask: Optional[jax.Array],
) -> DecoderOutput:
del padding_mask
def layer_norm(x):
return hk_rms_norm(x)
h = inputs
attn_output = MHABlock(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
attn_output_multiplier=self.attn_output_multiplier,
)(layer_norm(h), mask)
h_attn = attn_output.embeddings
h_attn = layer_norm(h_attn)
h += h_attn
def base_dense_block(h):
h = DenseBlock(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
widening_factor=self.widening_factor,
)(h)
return h
h_dense = base_dense_block(layer_norm(h))
h_dense = layer_norm(h_dense)
h += h_dense
return DecoderOutput(
embeddings=h,
)
def layer_norm(x):
return hk_rms_norm(x)
@dataclass
class Transformer(hk.Module):
num_q_heads: int
num_kv_heads: int
key_size: int
widening_factor: float
attn_output_multiplier: float
num_layers: int
name: Optional[str] = None
def __call__(
self,
embeddings: jax.Array, mask: jax.Array, candidate_start_offset: Optional[int] = None,
) -> TransformerOutput:
fprop_dtype = embeddings.dtype
_, seq_len, _ = embeddings.shape
padding_mask = mask.copy()
mask = mask[:, None, None, :]
if candidate_start_offset is not None:
attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype)
mask = mask * attn_mask
else:
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(
fprop_dtype
) mask = mask * causal_mask
h = embeddings
def block(
h,
mask,
padding_mask,
layer_index: Optional[int] = None,
widening_factor: Optional[int] = None,
name: Optional[str] = None,
) -> DecoderOutput:
return DecoderLayer(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
widening_factor=widening_factor or self.widening_factor,
num_layers=self.num_layers,
attn_output_multiplier=self.attn_output_multiplier,
name=name,
layer_index=layer_index,
)(h, mask, padding_mask)
for i in range(self.num_layers):
decoder_output = block(
h,
mask,
padding_mask,
layer_index=i,
name=f"decoder_layer_{i}",
)
h = decoder_output.embeddings
return TransformerOutput(
embeddings=h,
)