#include "models.h"
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
GGML_ASSERT(idx < (int) x->ne[2]);
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
}
llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params),
model(model),
n_embd_head(model.hparams.n_embd_head_k()),
n_embd_altup(model.hparams.n_embd_altup),
n_altup(model.hparams.n_altup),
i_altup_act(model.hparams.i_altup_act) {
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
cb(inpL, "inp_scaled", -1);
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_per_layer = build_inp_per_layer();
ggml_build_forward_expand(gf, inp_per_layer);
inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer);
{
ggml_tensor * target_magnitude = calc_magnitude(inpL);
ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
ggml_tensor * altup_added =
ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); ggml_tensor * new_magnitude = calc_magnitude(altup_added);
altup_added = ggml_div(ctx0, ggml_mul(ctx0, altup_added, target_magnitude), new_magnitude);
inpL = ggml_concat(ctx0, inpL, altup_added, 2); cb(inpL, "inp_stacked", -1);
}
for (int il = 0; il < n_layer; ++il) {
const float freq_base_l = model.get_rope_freq_base(cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
ggml_tensor * cur = inpL; ggml_tensor * predictions = altup_predict(cur, il);
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); cur = active_prediction;
cb(cur, "active_prediction", il);
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_tensor * laurel_out = laurel(cur, il);
if (hparams.has_kv(il)) {
auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
cb(Qcur, "Qcur_normed", il);
cb(Kcur, "Kcur_normed", il);
cb(Vcur, "Vcur_normed", il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cb(Kcur, "Kcur_pos", il);
cur = build_attn(inp_attn, model.layers[il].wo,
NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr,
hparams.f_attention_scale, il);
} else {
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cur = build_attn(inp_attn,
model.layers[il].wo, NULL, model.layers[il].wo_s,
Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
}
cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);
cur = ggml_add(ctx0, cur, active_prediction); cb(cur, "attn_gated", il);
ggml_tensor * attn_laurel = ggml_scale(ctx0, ggml_add(ctx0, cur, laurel_out),
1.0f / sqrtf(2.0f)); cb(attn_laurel, "attn_laurel", il);
cur = build_norm(attn_laurel, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
{
ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur);
ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
if (il < n_layer_sparsity) {
gate_proj = gaussian_topk(gate_proj);
}
gate_proj = ggml_gelu(ctx0, gate_proj);
cur = ggml_mul(ctx0, up_proj, gate_proj);
cur = build_lora_mm(model.layers[il].ffn_down, cur);
cb(cur, "ffn_out", il);
}
cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "ffn_post_norm", il);
ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il);
ggml_tensor * first_prediction; {
first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act); first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
first_prediction = ggml_gelu(ctx0, first_prediction); cb(first_prediction, "first_prediction_gated", il);
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); cb(first_prediction, "first_prediction_scaled", il);
first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); first_prediction =
build_norm(first_prediction, model.layers[il].per_layer_post_norm, NULL, LLM_NORM_RMS, il);
cb(first_prediction, "first_prediction_out", il);
}
{
ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0);
ggml_tensor * slice_rest = ggml_view_3d(
ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd),
ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected));
ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); corrected = ggml_concat(ctx0, slice_first, tmp, 2); }
cur = corrected; cur = build_cvec(cur, il);
cb(cur, "l_out", il);
inpL = cur;
}
cur = inpL;
{
ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act)); ggml_tensor * alt_slice =
ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd),
ggml_row_size(cur->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(cur));
ggml_tensor * altup_unembd =
ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
altup_unembd = ggml_div(ctx0, ggml_mul(ctx0, altup_unembd, target_magnitude), new_magnitude);
cb(altup_unembd, "altup_unembd", -1);
cur = ggml_view_2d_slice(ctx0, cur, 0); for (int i = 0; i < n_altup - 1; ++i) {
cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i));
}
cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); cb(cur, "unembd_merged", -1);
}
{
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
{
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
}
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) {
return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
}
ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() {
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
ggml_tensor * inp_per_layer;
float tok_embd_scale = sqrtf((float) n_embd_altup);
if (ubatch.token) {
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
ggml_set_input(inp->tokens);
res->t_inp_tokens = inp->tokens;
inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens);
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale);
cb(inp_per_layer, "inp_per_layer_selected", -1);
res->add_input(std::move(inp));
} else {
const int64_t embd_size = model.per_layer_tok_embd->ne[0];
ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0);
inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32);
inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale);
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1);
cb(inp_per_layer, "inp_per_layer_multimodal", -1);
}
return inp_per_layer;
}
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) {
const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
ggml_tensor * per_layer_proj;
per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch);
per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1);
cb(per_layer_proj, "per_layer_proj", -1);
inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
cb(inp_per_layer, "inp_per_layer", -1);
inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
return inp_per_layer;
}
ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) {
ggml_tensor * tmp = cur;
tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
tmp = ggml_add(ctx0, tmp, cur);
cb(tmp, "laurel_out", il);
return tmp;
}
ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) {
ggml_tensor * mean = ggml_mean(ctx0, x);
ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
1.0f / (float) (x->ne[0] - 1)));
ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
}
ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) {
ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il);
router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float) n_embd);
ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
return ggml_tanh(ctx0, output); }
ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) {
ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); ggml_tensor * modalities = altup_compute_router_modalities(activated, il); cb(modalities, "modalities", il);
ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
cb(all_coefs, "all_coefs", il);
all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs);
predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
predictions = ggml_add(ctx0, predictions, cur);
cb(predictions, "predictions", il);
return predictions;
}
ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); cb(modalities, "modalities", il);
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act);
ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); cb(innovation, "innovation", il);
ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); cb(all_coefs, "all_coefs", il);
all_coefs = ggml_transpose(ctx0, all_coefs); all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup);
innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); corrected = ggml_add(ctx0, corrected, predictions); cb(corrected, "corrected", il);
return corrected;
}