#ifndef __COMMON_HPP__
#define __COMMON_HPP__
#include "ggml_extend.hpp"
class DownSampleBlock : public GGMLBlock {
protected:
int channels;
int out_channels;
bool vae_downsample;
public:
DownSampleBlock(int channels,
int out_channels,
bool vae_downsample = false)
: channels(channels),
out_channels(out_channels),
vae_downsample(vae_downsample) {
if (vae_downsample) {
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {0, 0}));
} else {
blocks["op"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {1, 1}));
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
if (vae_downsample) {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_pad(ctx, x, 1, 1, 0, 0);
x = conv->forward(ctx, x);
} else {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);
x = conv->forward(ctx, x);
}
return x; }
};
class UpSampleBlock : public GGMLBlock {
protected:
int channels;
int out_channels;
public:
UpSampleBlock(int channels,
int out_channels)
: channels(channels),
out_channels(out_channels) {
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); x = conv->forward(ctx, x); return x;
}
};
class ResBlock : public GGMLBlock {
protected:
int64_t channels; int64_t emb_channels; int64_t out_channels; std::pair<int, int> kernel_size;
int dims;
bool skip_t_emb;
bool exchange_temb_dims;
std::shared_ptr<GGMLBlock> conv_nd(int dims,
int64_t in_channels,
int64_t out_channels,
std::pair<int, int> kernel_size,
std::pair<int, int> padding) {
GGML_ASSERT(dims == 2 || dims == 3);
if (dims == 3) {
return std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(in_channels, out_channels, kernel_size.first, 1, padding.first));
} else {
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, {1, 1}, padding));
}
}
public:
ResBlock(int64_t channels,
int64_t emb_channels,
int64_t out_channels,
std::pair<int, int> kernel_size = {3, 3},
int dims = 2,
bool exchange_temb_dims = false,
bool skip_t_emb = false)
: channels(channels),
emb_channels(emb_channels),
out_channels(out_channels),
kernel_size(kernel_size),
dims(dims),
skip_t_emb(skip_t_emb),
exchange_temb_dims(exchange_temb_dims) {
std::pair<int, int> padding = {kernel_size.first / 2, kernel_size.second / 2};
blocks["in_layers.0"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(channels));
blocks["in_layers.2"] = conv_nd(dims, channels, out_channels, kernel_size, padding);
if (!skip_t_emb) {
blocks["emb_layers.1"] = std::shared_ptr<GGMLBlock>(new Linear(emb_channels, out_channels));
}
blocks["out_layers.0"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
blocks["out_layers.3"] = conv_nd(dims, out_channels, out_channels, kernel_size, padding);
if (out_channels != channels) {
blocks["skip_connection"] = conv_nd(dims, channels, out_channels, {1, 1}, {0, 0});
}
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb = NULL) {
auto in_layers_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["in_layers.0"]);
auto in_layers_2 = std::dynamic_pointer_cast<UnaryBlock>(blocks["in_layers.2"]);
auto out_layers_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["out_layers.0"]);
auto out_layers_3 = std::dynamic_pointer_cast<UnaryBlock>(blocks["out_layers.3"]);
if (emb == NULL) {
GGML_ASSERT(skip_t_emb);
}
auto h = in_layers_0->forward(ctx, x);
h = ggml_silu_inplace(ctx, h);
h = in_layers_2->forward(ctx, h);
if (!skip_t_emb) {
auto emb_layer_1 = std::dynamic_pointer_cast<Linear>(blocks["emb_layers.1"]);
auto emb_out = ggml_silu(ctx, emb);
emb_out = emb_layer_1->forward(ctx, emb_out);
if (dims == 2) {
emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); } else {
emb_out = ggml_reshape_4d(ctx, emb_out, 1, emb_out->ne[0], emb_out->ne[1], emb_out->ne[2]); if (exchange_temb_dims) {
emb_out = ggml_cont(ctx, ggml_permute(ctx, emb_out, 0, 2, 1, 3)); }
}
h = ggml_add(ctx, h, emb_out); }
h = out_layers_0->forward(ctx, h);
h = ggml_silu_inplace(ctx, h);
h = out_layers_3->forward(ctx, h);
if (out_channels != channels) {
auto skip_connection = std::dynamic_pointer_cast<UnaryBlock>(blocks["skip_connection"]);
x = skip_connection->forward(ctx, x); }
h = ggml_add(ctx, h, x);
return h; }
};
class GEGLU : public GGMLBlock {
protected:
int64_t dim_in;
int64_t dim_out;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_types, GGML_TYPE_F32);
enum ggml_type bias_wtype = GGML_TYPE_F32;
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
}
public:
GEGLU(int64_t dim_in, int64_t dim_out)
: dim_in(dim_in), dim_out(dim_out) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["proj.weight"];
struct ggml_tensor* b = params["proj.bias"];
auto x_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); auto x_b = ggml_view_1d(ctx, b, b->ne[0] / 2, 0); auto gate_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); auto gate_b = ggml_view_1d(ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2);
auto x_in = x;
x = ggml_nn_linear(ctx, x_in, x_w, x_b); auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b);
gate = ggml_gelu_inplace(ctx, gate);
x = ggml_mul(ctx, x, gate);
return x;
}
};
class FeedForward : public GGMLBlock {
public:
FeedForward(int64_t dim,
int64_t dim_out,
int64_t mult = 4) {
int64_t inner_dim = dim * mult;
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GEGLU(dim, inner_dim));
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
auto net_0 = std::dynamic_pointer_cast<GEGLU>(blocks["net.0"]);
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]);
x = net_0->forward(ctx, x); x = net_2->forward(ctx, x); return x;
}
};
class CrossAttention : public GGMLBlock {
protected:
int64_t query_dim;
int64_t context_dim;
int64_t n_head;
int64_t d_head;
bool flash_attn;
public:
CrossAttention(int64_t query_dim,
int64_t context_dim,
int64_t n_head,
int64_t d_head,
bool flash_attn = false)
: n_head(n_head),
d_head(d_head),
query_dim(query_dim),
context_dim(context_dim),
flash_attn(flash_attn) {
int64_t inner_dim = d_head * n_head;
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
blocks["to_k"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
blocks["to_v"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, query_dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* context) {
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
auto to_k = std::dynamic_pointer_cast<Linear>(blocks["to_k"]);
auto to_v = std::dynamic_pointer_cast<Linear>(blocks["to_v"]);
auto to_out_0 = std::dynamic_pointer_cast<Linear>(blocks["to_out.0"]);
int64_t n = x->ne[2];
int64_t n_token = x->ne[1];
int64_t n_context = context->ne[1];
int64_t inner_dim = d_head * n_head;
auto q = to_q->forward(ctx, x); auto k = to_k->forward(ctx, context); auto v = to_v->forward(ctx, context);
x = ggml_nn_attention_ext(ctx, backend, q, k, v, n_head, NULL, false, false, flash_attn);
x = to_out_0->forward(ctx, x); return x;
}
};
class BasicTransformerBlock : public GGMLBlock {
protected:
int64_t n_head;
int64_t d_head;
bool ff_in;
public:
BasicTransformerBlock(int64_t dim,
int64_t n_head,
int64_t d_head,
int64_t context_dim,
bool ff_in = false,
bool flash_attn = false)
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head, flash_attn));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn));
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
if (ff_in) {
blocks["norm_in"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["ff_in"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
}
}
struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* context) {
auto attn1 = std::dynamic_pointer_cast<CrossAttention>(blocks["attn1"]);
auto attn2 = std::dynamic_pointer_cast<CrossAttention>(blocks["attn2"]);
auto ff = std::dynamic_pointer_cast<FeedForward>(blocks["ff"]);
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
auto norm3 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm3"]);
if (ff_in) {
auto norm_in = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_in"]);
auto ff_in = std::dynamic_pointer_cast<FeedForward>(blocks["ff_in"]);
auto x_skip = x;
x = norm_in->forward(ctx, x);
x = ff_in->forward(ctx, x);
x = ggml_add(ctx, x, x_skip);
}
auto r = x;
x = norm1->forward(ctx, x);
x = attn1->forward(ctx, backend, x, x); x = ggml_add(ctx, x, r);
r = x;
x = norm2->forward(ctx, x);
x = attn2->forward(ctx, backend, x, context); x = ggml_add(ctx, x, r);
r = x;
x = norm3->forward(ctx, x);
x = ff->forward(ctx, x);
x = ggml_add(ctx, x, r);
return x;
}
};
class SpatialTransformer : public GGMLBlock {
protected:
int64_t in_channels; int64_t n_head;
int64_t d_head;
int64_t depth = 1; int64_t context_dim = 768;
public:
SpatialTransformer(int64_t in_channels,
int64_t n_head,
int64_t d_head,
int64_t depth,
int64_t context_dim,
bool flash_attn = false)
: in_channels(in_channels),
n_head(n_head),
d_head(d_head),
depth(depth),
context_dim(context_dim) {
int64_t inner_dim = n_head * d_head; blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1}));
for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn));
}
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* x,
struct ggml_tensor* context) {
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
auto proj_in = std::dynamic_pointer_cast<Conv2d>(blocks["proj_in"]);
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);
auto x_in = x;
int64_t n = x->ne[3];
int64_t h = x->ne[1];
int64_t w = x->ne[0];
int64_t inner_dim = n_head * d_head;
x = norm->forward(ctx, x);
x = proj_in->forward(ctx, x);
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n);
for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]);
x = transformer_block->forward(ctx, backend, x, context);
}
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); x = ggml_reshape_4d(ctx, x, w, h, inner_dim, n);
x = proj_out->forward(ctx, x);
x = ggml_add(ctx, x, x_in);
return x;
}
};
class AlphaBlender : public GGMLBlock {
protected:
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32;
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
}
float get_alpha() {
float alpha = ggml_backend_tensor_get_f32(params["mix_factor"]);
return sigmoid(alpha);
}
public:
AlphaBlender() {
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x_spatial,
struct ggml_tensor* x_temporal) {
float alpha = get_alpha();
auto x = ggml_add(ctx,
ggml_scale(ctx, x_spatial, alpha),
ggml_scale(ctx, x_temporal, 1.0f - alpha));
return x;
}
};
class VideoResBlock : public ResBlock {
public:
VideoResBlock(int channels,
int emb_channels,
int out_channels,
std::pair<int, int> kernel_size = {3, 3},
int64_t video_kernel_size = 3,
int dims = 2) : ResBlock(channels, emb_channels, out_channels, kernel_size, dims) {
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, emb_channels, out_channels, kernel_size, 3, true));
blocks["time_mixer"] = std::shared_ptr<GGMLBlock>(new AlphaBlender());
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* emb,
int num_video_frames) {
auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]);
auto time_mixer = std::dynamic_pointer_cast<AlphaBlender>(blocks["time_mixer"]);
x = ResBlock::forward(ctx, x, emb);
int64_t T = num_video_frames;
int64_t B = x->ne[3] / T;
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); auto x_mix = x;
emb = ggml_reshape_4d(ctx, emb, emb->ne[0], T, B, emb->ne[3]);
x = time_stack->forward(ctx, x, emb);
x = time_mixer->forward(ctx, x_mix, x);
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); x = ggml_reshape_4d(ctx, x, W, H, C, T * B);
return x;
}
};
#endif