#ifndef __MMDIT_HPP__
#define __MMDIT_HPP__
#include "ggml_extend.hpp"
#include "model.h"
#define MMDIT_GRAPH_SIZE 10240
struct Mlp : public GGMLBlock {
public:
Mlp(int64_t in_features,
int64_t hidden_features = -1,
int64_t out_features = -1,
bool bias = true) {
if (hidden_features == -1) {
hidden_features = in_features;
}
if (out_features == -1) {
out_features = in_features;
}
blocks["fc1"] = std::shared_ptr<GGMLBlock>(new Linear(in_features, hidden_features, bias));
blocks["fc2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_features, out_features, bias));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]);
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
x = fc1->forward(ctx, x);
x = ggml_gelu_inplace(ctx, x);
x = fc2->forward(ctx, x);
return x;
}
};
struct PatchEmbed : public GGMLBlock {
protected:
bool flatten;
bool dynamic_img_pad;
int patch_size;
public:
PatchEmbed(int64_t img_size = 224,
int patch_size = 16,
int64_t in_chans = 3,
int64_t embed_dim = 1536,
bool bias = true,
bool flatten = true,
bool dynamic_img_pad = true)
: patch_size(patch_size),
flatten(flatten),
dynamic_img_pad(dynamic_img_pad) {
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_chans,
embed_dim,
{patch_size, patch_size},
{patch_size, patch_size},
{0, 0},
{1, 1},
bias));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
auto proj = std::dynamic_pointer_cast<Conv2d>(blocks["proj"]);
if (dynamic_img_pad) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); }
x = proj->forward(ctx, x);
if (flatten) {
x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3));
}
return x;
}
};
struct TimestepEmbedder : public GGMLBlock {
protected:
int64_t frequency_embedding_size;
public:
TimestepEmbedder(int64_t hidden_size,
int64_t frequency_embedding_size = 256)
: frequency_embedding_size(frequency_embedding_size) {
blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(frequency_embedding_size, hidden_size, true, true));
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) {
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
auto t_freq = ggml_nn_timestep_embedding(ctx, t, frequency_embedding_size);
auto t_emb = mlp_0->forward(ctx, t_freq);
t_emb = ggml_silu_inplace(ctx, t_emb);
t_emb = mlp_2->forward(ctx, t_emb);
return t_emb;
}
};
struct VectorEmbedder : public GGMLBlock {
public:
VectorEmbedder(int64_t input_dim,
int64_t hidden_size) {
blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(input_dim, hidden_size, true, true));
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
auto mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["mlp.0"]);
auto mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["mlp.2"]);
x = mlp_0->forward(ctx, x);
x = ggml_silu_inplace(ctx, x);
x = mlp_2->forward(ctx, x);
return x;
}
};
class SelfAttention : public GGMLBlock {
public:
int64_t num_heads;
bool pre_only;
std::string qk_norm;
public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false)
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) {
int64_t d_head = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
if (!pre_only) {
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim));
}
if (qk_norm == "rms") {
blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6));
blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(d_head, 1.0e-6));
} else if (qk_norm == "ln") {
blocks["ln_q"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6));
blocks["ln_k"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_head, 1.0e-6));
}
}
std::vector<struct ggml_tensor*> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks["qkv"]);
auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx, qkv);
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); auto v = qkv_vec[2];
if (qk_norm == "rms" || qk_norm == "ln") {
auto ln_q = std::dynamic_pointer_cast<UnaryBlock>(blocks["ln_q"]);
auto ln_k = std::dynamic_pointer_cast<UnaryBlock>(blocks["ln_k"]);
q = ln_q->forward(ctx, q);
k = ln_k->forward(ctx, k);
}
q = ggml_reshape_3d(ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); k = ggml_reshape_3d(ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]);
return {q, k, v};
}
struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
GGML_ASSERT(!pre_only);
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
x = proj->forward(ctx, x); return x;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x) {
auto qkv = pre_attention(ctx, x);
x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); x = post_attention(ctx, x); return x;
}
};
__STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* shift,
struct ggml_tensor* scale) {
scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
x = ggml_add(ctx, x, shift);
return x;
}
struct DismantledBlock : public GGMLBlock {
public:
int64_t num_heads;
bool pre_only;
bool self_attn;
public:
DismantledBlock(int64_t hidden_size,
int64_t num_heads,
float mlp_ratio = 4.0,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false,
bool self_attn = false)
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
if (self_attn) {
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false));
}
if (!pre_only) {
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio);
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new Mlp(hidden_size, mlp_hidden_dim));
}
int64_t n_mods = 6;
if (pre_only) {
n_mods = 2;
}
if (self_attn) {
n_mods = 9;
}
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, n_mods * hidden_size));
}
std::tuple<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention_x(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
GGML_ASSERT(self_attn);
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int64_t n_mods = 9;
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3));
int64_t offset = m->nb[1] * m->ne[1];
auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2);
auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5);
auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8);
auto x_norm = norm1->forward(ctx, x);
auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa);
auto qkv = attn->pre_attention(ctx, attn_in);
auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2);
auto qkv2 = attn2->pre_attention(ctx, attn2_in);
return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}};
}
std::pair<std::vector<struct ggml_tensor*>, std::vector<struct ggml_tensor*>> pre_attention(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
int64_t n_mods = 6;
if (pre_only) {
n_mods = 2;
}
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3));
int64_t offset = m->nb[1] * m->ne[1];
auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); if (!pre_only) {
auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5);
auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
auto qkv = attn->pre_attention(ctx, attn_in);
return {qkv, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp}};
} else {
auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
auto qkv = attn->pre_attention(ctx, attn_in);
return {qkv, {NULL, NULL, NULL, NULL, NULL}};
}
}
struct ggml_tensor* post_attention_x(struct ggml_context* ctx,
struct ggml_tensor* attn_out,
struct ggml_tensor* attn2_out,
struct ggml_tensor* x,
struct ggml_tensor* gate_msa,
struct ggml_tensor* shift_mlp,
struct ggml_tensor* scale_mlp,
struct ggml_tensor* gate_mlp,
struct ggml_tensor* gate_msa2) {
GGML_ASSERT(!pre_only);
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto attn2 = std::dynamic_pointer_cast<SelfAttention>(blocks["attn2"]);
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]);
attn_out = attn->post_attention(ctx, attn_out);
attn2_out = attn2->post_attention(ctx, attn2_out);
x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa));
x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2));
auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp));
return x;
}
struct ggml_tensor* post_attention(struct ggml_context* ctx,
struct ggml_tensor* attn_out,
struct ggml_tensor* x,
struct ggml_tensor* gate_msa,
struct ggml_tensor* shift_mlp,
struct ggml_tensor* scale_mlp,
struct ggml_tensor* gate_mlp) {
GGML_ASSERT(!pre_only);
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
auto mlp = std::dynamic_pointer_cast<Mlp>(blocks["mlp"]);
gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]);
attn_out = attn->post_attention(ctx, attn_out);
x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa));
auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp));
return x;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* c) {
auto attn = std::dynamic_pointer_cast<SelfAttention>(blocks["attn"]);
if (self_attn) {
auto qkv_intermediates = pre_attention_x(ctx, x, c);
auto qkv = std::get<0>(qkv_intermediates);
auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates);
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads); x = post_attention_x(ctx,
attn_out,
attn2_out,
intermediates[0],
intermediates[1],
intermediates[2],
intermediates[3],
intermediates[4],
intermediates[5]);
return x; } else {
auto qkv_intermediates = pre_attention(ctx, x, c);
auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); x = post_attention(ctx,
attn_out,
intermediates[0],
intermediates[1],
intermediates[2],
intermediates[3],
intermediates[4]);
return x; }
}
};
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
block_mixing(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* context,
struct ggml_tensor* x,
struct ggml_tensor* c,
std::shared_ptr<DismantledBlock> context_block,
std::shared_ptr<DismantledBlock> x_block) {
auto context_qkv_intermediates = context_block->pre_attention(ctx, context, c);
auto context_qkv = context_qkv_intermediates.first;
auto context_intermediates = context_qkv_intermediates.second;
std::vector<ggml_tensor*> x_qkv, x_qkv2, x_intermediates;
if (x_block->self_attn) {
auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c);
x_qkv = std::get<0>(x_qkv_intermediates);
x_qkv2 = std::get<1>(x_qkv_intermediates);
x_intermediates = std::get<2>(x_qkv_intermediates);
} else {
auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c);
x_qkv = x_qkv_intermediates.first;
x_intermediates = x_qkv_intermediates.second;
}
std::vector<struct ggml_tensor*> qkv;
for (int i = 0; i < 3; i++) {
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
}
auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads); attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); auto context_attn = ggml_view_3d(ctx,
attn,
attn->ne[0],
attn->ne[1],
context->ne[1],
attn->nb[1],
attn->nb[2],
0); context_attn = ggml_cont(ctx, ggml_permute(ctx, context_attn, 0, 2, 1, 3)); auto x_attn = ggml_view_3d(ctx,
attn,
attn->ne[0],
attn->ne[1],
x->ne[1],
attn->nb[1],
attn->nb[2],
attn->nb[2] * context->ne[1]); x_attn = ggml_cont(ctx, ggml_permute(ctx, x_attn, 0, 2, 1, 3));
if (!context_block->pre_only) {
context = context_block->post_attention(ctx,
context_attn,
context_intermediates[0],
context_intermediates[1],
context_intermediates[2],
context_intermediates[3],
context_intermediates[4]);
} else {
context = NULL;
}
if (x_block->self_attn) {
auto attn2 = ggml_nn_attention_ext(ctx, backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads);
x = x_block->post_attention_x(ctx,
x_attn,
attn2,
x_intermediates[0],
x_intermediates[1],
x_intermediates[2],
x_intermediates[3],
x_intermediates[4],
x_intermediates[5]);
} else {
x = x_block->post_attention(ctx,
x_attn,
x_intermediates[0],
x_intermediates[1],
x_intermediates[2],
x_intermediates[3],
x_intermediates[4]);
}
return {context, x};
}
struct JointBlock : public GGMLBlock {
public:
JointBlock(int64_t hidden_size,
int64_t num_heads,
float mlp_ratio = 4.0,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false,
bool self_attn_x = false) {
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x));
}
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* context,
struct ggml_tensor* x,
struct ggml_tensor* c) {
auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]);
auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]);
return block_mixing(ctx, backend, context, x, c, context_block, x_block);
}
};
struct FinalLayer : public GGMLBlock {
public:
FinalLayer(int64_t hidden_size,
int64_t patch_size,
int64_t out_channels) {
blocks["norm_final"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels, true, true));
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* c) {
auto norm_final = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_final"]);
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3));
int64_t offset = m->nb[1] * m->ne[1];
auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1);
x = modulate(ctx, norm_final->forward(ctx, x), shift, scale);
x = linear->forward(ctx, x);
return x;
}
};
struct MMDiT : public GGMLBlock {
protected:
int64_t input_size = -1;
int64_t patch_size = 2;
int64_t in_channels = 16;
int64_t d_self = -1; int64_t depth = 24;
float mlp_ratio = 4.0f;
int64_t adm_in_channels = 2048;
int64_t out_channels = 16;
int64_t pos_embed_max_size = 192;
int64_t num_patchs = 36864; int64_t context_size = 4096;
int64_t context_embedder_out_dim = 1536;
int64_t hidden_size;
std::string qk_norm;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32;
params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1);
}
public:
MMDiT(const String2GGMLType& tensor_types = {}) {
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
continue;
size_t jb = tensor_name.find("joint_blocks.");
if (jb != std::string::npos) {
tensor_name = tensor_name.substr(jb); int block_depth = atoi(tensor_name.substr(13, tensor_name.find(".", 13)).c_str());
if (block_depth + 1 > depth) {
depth = block_depth + 1;
}
if (tensor_name.find("attn.ln") != std::string::npos) {
if (tensor_name.find(".bias") != std::string::npos) {
qk_norm = "ln";
} else {
qk_norm = "rms";
}
}
if (tensor_name.find("attn2") != std::string::npos) {
if (block_depth > d_self) {
d_self = block_depth;
}
}
}
}
if (d_self >= 0) {
pos_embed_max_size *= 2;
num_patchs *= 4;
}
LOG_INFO("MMDiT layers: %d (including %d MMDiT-x layers)", depth, d_self + 1);
int64_t default_out_channels = in_channels;
hidden_size = 64 * depth;
context_embedder_out_dim = 64 * depth;
int64_t num_heads = depth;
blocks["x_embedder"] = std::shared_ptr<GGMLBlock>(new PatchEmbed(input_size, patch_size, in_channels, hidden_size, true));
blocks["t_embedder"] = std::shared_ptr<GGMLBlock>(new TimestepEmbedder(hidden_size));
if (adm_in_channels != -1) {
blocks["y_embedder"] = std::shared_ptr<GGMLBlock>(new VectorEmbedder(adm_in_channels, hidden_size));
}
blocks["context_embedder"] = std::shared_ptr<GGMLBlock>(new Linear(4096, context_embedder_out_dim, true, true));
for (int i = 0; i < depth; i++) {
blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new JointBlock(hidden_size,
num_heads,
mlp_ratio,
qk_norm,
true,
i == depth - 1,
i <= d_self));
}
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
}
struct ggml_tensor*
cropped_pos_embed(struct ggml_context* ctx,
int64_t h,
int64_t w) {
auto pos_embed = params["pos_embed"];
h = (h + 1) / patch_size;
w = (w + 1) / patch_size;
GGML_ASSERT(h <= pos_embed_max_size && h > 0);
GGML_ASSERT(w <= pos_embed_max_size && w > 0);
int64_t top = (pos_embed_max_size - h) / 2;
int64_t left = (pos_embed_max_size - w) / 2;
auto spatial_pos_embed = ggml_reshape_3d(ctx, pos_embed, hidden_size, pos_embed_max_size, pos_embed_max_size);
spatial_pos_embed = ggml_view_3d(ctx,
spatial_pos_embed,
hidden_size,
pos_embed_max_size,
h,
spatial_pos_embed->nb[1],
spatial_pos_embed->nb[2],
spatial_pos_embed->nb[2] * top); spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); spatial_pos_embed = ggml_view_3d(ctx,
spatial_pos_embed,
hidden_size,
h,
w,
spatial_pos_embed->nb[1],
spatial_pos_embed->nb[2],
spatial_pos_embed->nb[2] * left); spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); spatial_pos_embed = ggml_reshape_3d(ctx, spatial_pos_embed, hidden_size, h * w, 1); return spatial_pos_embed;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t h,
int64_t w) {
int64_t n = x->ne[2];
int64_t c = out_channels;
int64_t p = patch_size;
h = (h + 1) / p;
w = (w + 1) / p;
GGML_ASSERT(h * w == x->ne[1]);
x = ggml_reshape_4d(ctx, x, c, p * p, w * h, n); x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); x = ggml_reshape_4d(ctx, x, p, p, w, h * c * n); x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); x = ggml_reshape_4d(ctx, x, p * w, p * h, c, n); return x;
}
struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* c_mod,
struct ggml_tensor* context,
std::vector<int> skip_layers = std::vector<int>()) {
auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]);
for (int i = 0; i < depth; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
continue;
}
auto block = std::dynamic_pointer_cast<JointBlock>(blocks["joint_blocks." + std::to_string(i)]);
auto context_x = block->forward(ctx, backend, context, x, c_mod);
context = context_x.first;
x = context_x.second;
}
x = final_layer->forward(ctx, x, c_mod);
return x;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* t,
struct ggml_tensor* y = NULL,
struct ggml_tensor* context = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
auto x_embedder = std::dynamic_pointer_cast<PatchEmbed>(blocks["x_embedder"]);
auto t_embedder = std::dynamic_pointer_cast<TimestepEmbedder>(blocks["t_embedder"]);
int64_t w = x->ne[0];
int64_t h = x->ne[1];
auto patch_embed = x_embedder->forward(ctx, x); auto pos_embed = cropped_pos_embed(ctx, h, w); x = ggml_add(ctx, patch_embed, pos_embed);
auto c = t_embedder->forward(ctx, t); if (y != NULL && adm_in_channels != -1) {
auto y_embedder = std::dynamic_pointer_cast<VectorEmbedder>(blocks["y_embedder"]);
y = y_embedder->forward(ctx, y); c = ggml_add(ctx, c, y);
}
if (context != NULL) {
auto context_embedder = std::dynamic_pointer_cast<Linear>(blocks["context_embedder"]);
context = context_embedder->forward(ctx, context); }
x = forward_core_with_concat(ctx, backend, x, c, context, skip_layers);
x = unpatchify(ctx, x, h, w);
return x;
}
};
struct MMDiTRunner : public GGMLRunner {
MMDiT mmdit;
MMDiTRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "")
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) {
mmdit.init(params_ctx, tensor_types, prefix);
}
std::string get_desc() {
return "mmdit";
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
mmdit.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* y,
std::vector<int> skip_layers = std::vector<int>()) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, MMDIT_GRAPH_SIZE, false);
x = to_backend(x);
context = to_backend(context);
y = to_backend(y);
timesteps = to_backend(timesteps);
struct ggml_tensor* out = mmdit.forward(compute_ctx,
runtime_backend,
x,
timesteps,
y,
context,
skip_layers);
ggml_build_forward_expand(gf, out);
return gf;
}
void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* y,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, y, skip_layers);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); params.mem_buffer = NULL;
params.no_alloc = false;
struct ggml_context* work_ctx = ggml_init(params);
GGML_ASSERT(work_ctx != NULL);
{
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 128, 128, 16, 1);
std::vector<float> timesteps_vec(1, 999.f);
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
ggml_set_f32(x, 0.01f);
auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 154, 1);
ggml_set_f32(context, 0.01f);
auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 2048, 1);
ggml_set_f32(y, 0.01f);
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
compute(8, x, timesteps, context, y, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);
LOG_DEBUG("mmdit test done in %dms", t1 - t0);
}
}
static void load_from_file_and_test(const std::string& file_path) {
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_F16;
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false));
{
LOG_INFO("loading from '%s'", file_path.c_str());
mmdit->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;
mmdit->get_param_tensors(tensors, "model.diffusion_model");
ModelLoader model_loader;
if (!model_loader.init_from_file(file_path)) {
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
return;
}
bool success = model_loader.load_tensors(tensors);
if (!success) {
LOG_ERROR("load tensors from model loader failed");
return;
}
LOG_INFO("mmdit model loaded");
}
mmdit->test();
}
};
#endif