#ifndef __LTXV_HPP__
#define __LTXV_HPP__
#include "common.hpp"
#include "ggml_extend.hpp"
namespace LTXV {
class CausalConv3d : public GGMLBlock {
protected:
int time_kernel_size;
public:
CausalConv3d(int64_t in_channels,
int64_t out_channels,
int kernel_size = 3,
std::tuple<int> stride = {1, 1, 1},
int dilation = 1,
bool bias = true) {
time_kernel_size = kernel_size / 2;
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(in_channels,
out_channels,
{kernel_size, kernel_size, kernel_size},
stride,
{0, kernel_size / 2, kernel_size / 2},
{dilation, 1, 1},
bias));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
bool causal = true) {
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
if (causal) {
auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); auto first_frame_pad = first_frame;
for (int i = 1; i < time_kernel_size - 1; i++) {
first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2);
}
x = ggml_concat(ctx, first_frame_pad, x, 2);
} else {
auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); int64_t offset = h->nb[2] * h->ne[2];
auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); auto first_frame_pad = first_frame;
for (int i = 1; i < (time_kernel_size - 1) / 2; i++) {
first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2);
}
auto last_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], offset * (h->ne[3] - 1)); last_frame = ggml_reshape_4d(ctx, last_frame, last_frame->ne[0], last_frame->ne[1], 1, last_frame->ne[2]); auto last_frame_pad = last_frame;
for (int i = 1; i < (time_kernel_size - 1) / 2; i++) {
last_frame_pad = ggml_concat(ctx, last_frame_pad, last_frame, 2);
}
x = ggml_concat(ctx, first_frame_pad, x, 2);
x = ggml_concat(ctx, x, last_frame_pad, 2);
}
x = conv->forward(ctx, x);
return x;
}
};
};
#endif