#include <algorithm>
#include <assert.h>
#include <atomic>
#include <cinttypes>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <float.h>
#include <limits>
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include <cmath>
#include <iostream>
#include <fstream>
#include <stdio.h>
#include <stdlib.h>
#include <regex>
#include <sycl/sycl.hpp>
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
# include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
#endif
#include <sycl/half_type.hpp>
#include "ggml.h"
#include "ggml-sycl.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include "ggml-sycl/add-id.hpp"
#include "ggml-sycl/backend.hpp"
#include "ggml-sycl/common.hpp"
#include "ggml-sycl/element_wise.hpp"
#include "ggml-sycl/gemm.hpp"
#include "ggml-sycl/getrows.hpp"
#include "ggml-sycl/norm.hpp"
#include "ggml-sycl/presets.hpp"
#include "ggml-sycl/quantize.hpp"
#include "ggml-sycl/repeat_back.hpp"
#include "ggml-sycl/set_rows.hpp"
#include "ggml-sycl/set.hpp"
#include "ggml-sycl/ssm_conv.hpp"
#include "ggml-sycl/sycl_hw.hpp"
static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0;
int g_ggml_sycl_disable_optimize = 0;
int g_ggml_sycl_disable_graph = 0;
int g_ggml_sycl_disable_dnn = 0;
int g_ggml_sycl_prioritize_dmmv = 0;
int g_ggml_sycl_use_async_mem_op = 0;
int g_ggml_sycl_enable_flash_attention = 1;
static ggml_sycl_device_info ggml_sycl_init() {
ggml_sycl_device_info info = {};
info.device_count = dpct::dev_mgr::instance().device_count();
if (info.device_count == 0) {
GGML_LOG_ERROR("%s: failed to initialize: %s\n", GGML_SYCL_NAME, __func__);
return info;
}
GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
int64_t total_vram = 0;
for (int i = 0; i < info.device_count; ++i) {
info.devices[i].vmm = 0;
dpct::device_info prop;
sycl::device device = dpct::dev_mgr::instance().get_device(i);
SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
prop, device)));
info.default_tensor_split[i] = total_vram;
total_vram += prop.get_global_mem_size();
info.devices[i].cc =
100 * prop.get_major_version() + 10 * prop.get_minor_version();
info.devices[i].nsm = prop.get_max_compute_units() / 16; info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.devices[i].smpbo = prop.get_local_mem_size();
info.devices[i].warp_size = WARP_SIZE;
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
}
for (int id = 0; id < info.device_count; ++id) {
info.default_tensor_split[id] /= total_vram;
}
return info;
}
const ggml_sycl_device_info & ggml_sycl_info() {
static ggml_sycl_device_info info = ggml_sycl_init();
return info;
}
static void print_device_detail(int id, sycl::device &device, std::string device_type) {
dpct::device_info prop;
SYCL_CHECK(CHECK_TRY_ERROR(
dpct::get_device_info(prop, device)));
std::string version;
version += std::to_string(prop.get_major_version());
version += ".";
version += std::to_string(prop.get_minor_version());
device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
std::string name = std::string(prop.get_name());
name = std::regex_replace(name, std::regex("\\(R\\)"), "");
name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
auto global_mem_size = prop.get_global_mem_size()/1000000;
GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
name.c_str(), version.c_str(), prop.get_max_compute_units(),
prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
}
static void print_device_opt_feature(int device_count) {
GGML_LOG_INFO("SYCL Optimization Feature:\n");
GGML_LOG_INFO(
"|ID| Device Type|Reorder|\n");
GGML_LOG_INFO(
"|--|-------------------|-------|\n");
std::map<std::string, size_t> DeviceNums;
for (int id = 0; id < device_count; ++id) {
sycl::device device = dpct::dev_mgr::instance().get_device(id);
std::string backend_type = get_device_backend_and_type(device);
int type_id = DeviceNums[backend_type]++;
std::stringstream device_type;
device_type << "[" << backend_type << ":" << std::to_string(type_id)
<< "]";
std::string device_type_s = device_type.str();
device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), "");
GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(),
ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N");
}
}
void ggml_backend_sycl_print_sycl_devices() {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
int device_count = dpct::dev_mgr::instance().device_count();
std::map<std::string, size_t> DeviceNums;
GGML_LOG_INFO("Found %d SYCL devices:\n", device_count);
GGML_LOG_INFO(
"| | | | "
" |Max | |Max |Global | |\n");
GGML_LOG_INFO(
"| | | | "
" |compute|Max work|sub |mem | |\n");
GGML_LOG_INFO(
"|ID| Device Type| "
"Name|Version|units |group |group|size | Driver version|\n");
GGML_LOG_INFO(
"|--|-------------------|---------------------------------------|------"
"-|-------|--------|-----|-------|---------------------|\n");
for (int id = 0; id < device_count; ++id) {
sycl::device device = dpct::dev_mgr::instance().get_device(id);
std::string backend_type = get_device_backend_and_type(device);
int type_id = DeviceNums[backend_type]++;
std::stringstream device_type;
device_type << "[" << backend_type << ":" << std::to_string(type_id)
<< "]";
print_device_detail(id, device, device_type.str());
}
print_device_opt_feature(device_count);
}
static inline int get_sycl_env(const char *env_name, int default_val) {
char *user_device_string = getenv(env_name);
int user_number = default_val;
unsigned n;
if (user_device_string != NULL &&
sscanf(user_device_string, " %u", &n) == 1) {
user_number = (int)n;
} else {
user_number = default_val;
}
return user_number;
}
static void ggml_check_sycl() try {
static bool initialized = false;
if (!initialized) {
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
#ifdef SYCL_FLASH_ATTN
g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1);
#else
g_ggml_sycl_enable_flash_attention = 0;
#endif
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
GGML_LOG_INFO("Build with Macros:\n");
#if defined(GGML_SYCL_FORCE_MMQ)
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
#else
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
#endif
#if defined(GGML_SYCL_F16)
GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
#else
GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
#endif
#if defined(GGML_SYCL_GRAPH)
GGML_LOG_INFO(" GGML_SYCL_GRAPH: yes\n");
#else
GGML_LOG_INFO(" GGML_SYCL_GRAPH: no\n");
#endif
#if defined(GGML_SYCL_DNNL)
GGML_LOG_INFO(" GGML_SYCL_DNNL: yes\n");
#else
GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n");
#endif
GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
#ifdef GGML_SYCL_GRAPH
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
#endif
#if GGML_SYCL_DNNL
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
#endif
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
#ifdef SYCL_FLASH_ATTN
GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention);
#else
GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\n",
g_ggml_sycl_enable_flash_attention);
#endif
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
if (g_ggml_sycl_use_async_mem_op) {
for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) {
if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {
g_ggml_sycl_use_async_mem_op = 0;
break;
}
}
}
#endif
if (CHECK_TRY_ERROR(g_all_sycl_device_count =
dpct::dev_mgr::instance().device_count()) != 0) {
initialized = true;
g_sycl_loaded = false;
return;
}
GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
initialized = true;
g_sycl_loaded = true;
ggml_backend_sycl_print_sycl_devices();
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
inline void check_allow_gpu_index(const int device_index) {
if (device_index >= ggml_sycl_info().device_count) {
char error_buf[256];
snprintf(
error_buf,
sizeof(error_buf),
"%s error: device_index:%d is out of range: [0-%d]",
__func__,
device_index,
ggml_sycl_info().device_count - 1);
GGML_LOG_ERROR("%s\n", error_buf);
assert(false);
}
}
GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len) try {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_gpu_list\n");
for(int i=0;i<max_len;i++) id_list[i] = -1;
for (int i=0;i< ggml_sycl_info().device_count;i++){
if (i>=max_len) break;
id_list[i] = i;
}
return;
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
struct ggml_backend_sycl_buffer_context {
int device;
void * dev_ptr = nullptr;
queue_ptr stream;
std::string name;
optimize_feature opt_feature;
std::vector<ggml_tensor_extra_gpu *> tensor_extras;
ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
device(device), dev_ptr(dev_ptr), stream(stream) {
check_allow_gpu_index(device);
name = (GGML_SYCL_NAME + std::to_string(device));
opt_feature = ggml_sycl_info().devices[device].opt_feature;
}
~ggml_backend_sycl_buffer_context() {
if (dev_ptr != nullptr) {
ggml_sycl_set_device(device);
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
}
for (ggml_tensor_extra_gpu * extra : tensor_extras) {
release_extra_gpu(extra);
}
}
};
static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft);
static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) {
return buffer->buft->iface.get_name == ggml_backend_sycl_buffer_type_get_name;
}
static void
ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
ggml_sycl_set_device(ctx->device);
delete ctx;
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
return ctx->dev_ptr;
}
static enum ggml_status
ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
ggml_tensor *tensor) try {
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
if (tensor->view_src != NULL) {
assert(tensor->view_src->buffer->buft == buffer->buft);
return GGML_STATUS_SUCCESS;
}
if (!g_ggml_sycl_disable_optimize) {
switch (tensor->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:{
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
tensor->extra = extra;
ctx->tensor_extras.push_back(extra);
break;
}
default:
break;
}
}
if (ggml_is_quantized(tensor->type)) {
size_t original_size = ggml_nbytes(tensor);
size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
if (padded_size > original_size && tensor->view_src == nullptr) {
SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset(
(char *)tensor->data + original_size, 0,
padded_size - original_size).wait()));
}
}
return GGML_STATUS_SUCCESS;
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
ggml_tensor *tensor,
const void *data, size_t offset,
size_t size) try {
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
ggml_sycl_set_device(ctx->device);
auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
#ifndef _WIN32
char * host_buf = (char *) malloc(size);
memcpy(host_buf, data, size);
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait()));
free(host_buf);
#else
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait()));
#endif
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
const ggml_tensor *tensor,
void *data, size_t offset,
size_t size) try {
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
ggml_sycl_set_device(ctx->device);
auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue();
SYCL_CHECK(CHECK_TRY_ERROR(
stream.memcpy(data, (const char *)tensor->data + offset, size)
.wait()));
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
const void *ptr_src, size_t size) {
char *host_buf = (char *)malloc(size);
q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
q_dst.memcpy((char *)ptr_dst, host_buf, size).wait();
free(host_buf);
}
static bool
ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
const ggml_tensor *src,
ggml_tensor *dst) try {
bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer);
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
if (is_cpy_supported) {
ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;
ggml_sycl_set_device(src_ctx->device);
SYCL_CHECK(CHECK_TRY_ERROR(
dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw()));
ggml_sycl_set_device(dst_ctx->device);
SYCL_CHECK(CHECK_TRY_ERROR(
dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
queue_ptr stream_dst = dst_ctx->stream;
queue_ptr stream_src = src_ctx->stream;
size_t size = ggml_nbytes(src);
dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size);
#if 0#endif
return true;
}
return false;
GGML_UNUSED(buffer);
} catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
uint8_t value) try {
GGML_SYCL_DEBUG("[SYCL] call %s: size=%zu\n", __func__, buffer->size);
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
ggml_sycl_set_device(ctx->device);
queue_ptr stream = ctx->stream;
SYCL_CHECK(
CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));
constexpr size_t MAX_CHUNK = 2ULL << 30; for (size_t off = 0; off < buffer->size; off += MAX_CHUNK) {
size_t chunk = std::min(buffer->size - off, MAX_CHUNK);
SYCL_CHECK(CHECK_TRY_ERROR(
(*stream)
.memset(static_cast<char*>(ctx->dev_ptr) + off, value, chunk)
.wait()
));
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
size_t offset, size_t size) {
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value);
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
SYCL_CHECK(ggml_sycl_set_device(ctx->device));
auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
if (size == 0) {
return; }
if (tensor->data == nullptr) {
GGML_ABORT("Error: Tensor data pointer is null.\n");
}
void * target_ptr = static_cast<char *>(tensor->data) + offset;
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size)));
SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait()));
}
static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
if (buffer == nullptr) {
return;
}
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
if (ctx != nullptr) {
for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {
release_extra_gpu(extra);
}
ctx->tensor_extras.clear(); }
}
static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
ggml_backend_sycl_buffer_free_buffer,
ggml_backend_sycl_buffer_get_base,
ggml_backend_sycl_buffer_init_tensor,
ggml_backend_sycl_buffer_memset_tensor,
ggml_backend_sycl_buffer_set_tensor,
ggml_backend_sycl_buffer_get_tensor,
NULL,
NULL,
ggml_backend_sycl_buffer_cpy_tensor,
ggml_backend_sycl_buffer_clear,
ggml_backend_sycl_buffer_reset,
};
struct ggml_backend_sycl_buffer_type_context {
int device;
std::string name;
queue_ptr stream = nullptr;
};
static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
return ctx->name.c_str();
}
static ggml_backend_buffer_t
ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
size_t size) try {
ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
ggml_sycl_set_device(buft_ctx->device);
const queue_ptr stream = buft_ctx->stream;
size = std::max(size, (size_t)1);
void * dev_ptr;
SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device(
size, *stream)));
if (!dev_ptr) {
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, size);
return nullptr;
}
ggml_backend_sycl_buffer_context * ctx = new ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream);
return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size);
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return 128;
GGML_UNUSED(buft);
}
static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
return dpct::get_current_device().get_max_mem_alloc_size();
GGML_UNUSED(buft);
}
static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
size_t size = ggml_nbytes(tensor);
int64_t ne0 = tensor->ne[0];
if (ggml_is_quantized(tensor->type)) {
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
}
return size;
GGML_UNUSED(buft);
}
static const ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
ggml_backend_sycl_buffer_type_get_name,
ggml_backend_sycl_buffer_type_alloc_buffer,
ggml_backend_sycl_buffer_type_get_alignment,
ggml_backend_sycl_buffer_type_get_max_size,
ggml_backend_sycl_buffer_type_get_alloc_size,
NULL,
};
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
auto dev_count = ggml_backend_sycl_get_device_count();
if (device>=dev_count or device<0) {
GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
device, dev_count-1);
GGML_ASSERT(device<dev_count);
}
static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
static bool ggml_backend_sycl_buffer_type_initialized = false;
if (!ggml_backend_sycl_buffer_type_initialized) {
for (int i = 0; i < dev_count; i++) {
auto & device_i = dpct::dev_mgr::instance().get_device(i);
queue_ptr stream = &(device_i.default_queue());
ggml_backend_sycl_buffer_types[i] = {
ggml_backend_sycl_buffer_type_interface,
ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), i),
new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},
};
}
ggml_backend_sycl_buffer_type_initialized = true;
}
return &ggml_backend_sycl_buffer_types[device];
}
static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
int device = ctx->device;
if (device>=ggml_sycl_info().device_count or device<0) {
GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
device, ggml_sycl_info().device_count-1);
GGML_ASSERT(device<ggml_sycl_info().device_count);
}
static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
static bool ggml_backend_sycl_buffer_type_initialized = false;
if (!ggml_backend_sycl_buffer_type_initialized) {
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
ggml_backend_sycl_buffer_types[i] = {
ggml_backend_sycl_buffer_type_interface,
nullptr,
new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},
};
}
ggml_backend_sycl_buffer_type_initialized = true;
}
return &ggml_backend_sycl_buffer_types[device];
}
static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split) {
int64_t min_compute_capability = INT_MAX;
int64_t max_compute_capability = INT_MIN;
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) {
if (min_compute_capability > ggml_sycl_info().devices[i].cc) {
min_compute_capability = ggml_sycl_info().devices[i].cc;
}
if (max_compute_capability < ggml_sycl_info().devices[i].cc) {
max_compute_capability = ggml_sycl_info().devices[i].cc;
}
}
}
switch(type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return max_compute_capability >= VER_GEN9 ? 128 : 64;
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return 64;
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return 1;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
return max_compute_capability >= VER_GEN9 ? 128 : 64;
case GGML_TYPE_IQ3_S:
return max_compute_capability >= VER_GEN9 ? 128 : 64;
case GGML_TYPE_Q6_K:
return 64;
default:
GGML_ABORT("fatal error");
}
}
static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split, int id) {
const int64_t nrows = ggml_nrows(tensor);
const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
*row_low = id == 0 ? 0 : nrows*tensor_split[id];
*row_low -= *row_low % rounding;
if (id == ggml_sycl_info().device_count - 1) {
*row_high = nrows;
} else {
*row_high = nrows*tensor_split[id + 1];
*row_high -= *row_high % rounding;
}
}
static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
}
struct ggml_backend_sycl_split_buffer_type_context {
std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
};
struct ggml_backend_sycl_split_buffer_context {
~ggml_backend_sycl_split_buffer_context() try {
for (ggml_tensor_extra_gpu * extra : tensor_extras) {
release_extra_gpu(extra, streams);
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
std::vector<ggml_tensor_extra_gpu *> tensor_extras;
std::vector<queue_ptr> streams;
};
static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
delete ctx;
}
static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
return (void *)0x1000;
GGML_UNUSED(buffer);
}
static enum ggml_status
ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
ggml_tensor *tensor) try {
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str());
GGML_ASSERT(tensor->view_src == nullptr);
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
const int64_t ne0 = tensor->ne[0];
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
ctx->tensor_extras.push_back(extra);
ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
int64_t row_low, row_high;
get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
int64_t nrows_split = row_high - row_low;
if (nrows_split == 0) {
continue;
}
size_t size = ggml_nbytes_split(tensor, nrows_split);
const size_t original_size = size;
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
ggml_sycl_set_device(i);
const queue_ptr stream = ctx->streams[i];
char * buf;
SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device(
size, *stream)));
if (!buf) {
char err_buf[1024];
snprintf(err_buf, 1023, "%s: can't allocate %lu Bytes of memory on device\n", __func__, size);
throw std::runtime_error(err_buf);
}
if (size > original_size) {
SYCL_CHECK(CHECK_TRY_ERROR(
(*stream)
.memset(buf + original_size, 0, size - original_size)
.wait()));
}
extra->data_device[i] = buf;
for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
SYCL_CHECK(
CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
}
}
tensor->extra = extra;
return GGML_STATUS_SUCCESS;
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void
ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
ggml_tensor *tensor, const void *data,
size_t offset, size_t size) try {
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
const int64_t ne0 = tensor->ne[0];
const size_t nb1 = tensor->nb[1];
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
int64_t row_low, row_high;
get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
int64_t nrows_split = row_high - row_low;
if (nrows_split == 0) {
continue;
}
const size_t offset_split = row_low*nb1;
size_t size = ggml_nbytes_split(tensor, nrows_split);
const size_t original_size = size;
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
const char * buf_host = (const char *)data + offset_split;
ggml_sycl_set_device(i);
const queue_ptr stream = ctx->streams[i];
SYCL_CHECK(CHECK_TRY_ERROR(
(*stream)
.memcpy(extra->data_device[i], buf_host, original_size)
.wait()));
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void
ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
const ggml_tensor *tensor, void *data,
size_t offset, size_t size) try {
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
const int64_t ne0 = tensor->ne[0];
const size_t nb1 = tensor->nb[1];
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
int64_t row_low, row_high;
get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
int64_t nrows_split = row_high - row_low;
if (nrows_split == 0) {
continue;
}
const size_t offset_split = row_low*nb1;
size_t size = ggml_nbytes_split(tensor, nrows_split);
const size_t original_size = size;
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
char * buf_host = (char *)data + offset_split;
ggml_sycl_set_device(i);
const queue_ptr stream = ctx->streams[i];
SYCL_CHECK(CHECK_TRY_ERROR(
(*stream)
.memcpy(buf_host, extra->data_device[i], original_size)
.wait()));
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
GGML_UNUSED(buffer);
GGML_UNUSED(value);
}
static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = {
ggml_backend_sycl_split_buffer_free_buffer,
ggml_backend_sycl_split_buffer_get_base,
ggml_backend_sycl_split_buffer_init_tensor,
NULL,
ggml_backend_sycl_split_buffer_set_tensor,
ggml_backend_sycl_split_buffer_get_tensor,
NULL,
NULL,
NULL,
ggml_backend_sycl_split_buffer_clear,
NULL,
};
static const char * ggml_backend_sycl_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return GGML_SYCL_NAME "_Split";
GGML_UNUSED(buft);
}
static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
return buffer->buft->iface.get_name == ggml_backend_sycl_split_buffer_type_get_name;
}
static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context();
return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size);
}
static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return 128;
GGML_UNUSED(buft);
}
static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context;
size_t total_size = 0;
const int64_t ne0 = tensor->ne[0];
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
int64_t row_low, row_high;
get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i);
int64_t nrows_split = row_high - row_low;
if (nrows_split == 0) {
continue;
}
total_size += ggml_nbytes_split(tensor, nrows_split);
if (ne0 % MATRIX_ROW_PADDING != 0) {
total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
}
}
return total_size;
}
static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
return false;
GGML_UNUSED(buft);
}
static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = {
ggml_backend_sycl_split_buffer_type_get_name,
ggml_backend_sycl_split_buffer_type_alloc_buffer,
ggml_backend_sycl_split_buffer_type_get_alignment,
NULL, ggml_backend_sycl_split_buffer_type_get_alloc_size,
ggml_backend_sycl_split_buffer_type_is_host,
};
ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
ggml_check_sycl();
static std::map<std::array<float, GGML_SYCL_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split_arr = {};
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; });
if (all_zero) {
tensor_split_arr = ggml_sycl_info().default_tensor_split;
} else {
float split_sum = 0.0f;
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
tensor_split_arr[i] = split_sum;
split_sum += tensor_split[i];
}
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
tensor_split_arr[i] /= split_sum;
}
}
auto it = buft_map.find(tensor_split_arr);
if (it != buft_map.end()) {
return &it->second;
}
struct ggml_backend_buffer_type buft {
ggml_backend_sycl_split_buffer_type_interface,
ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr},
};
auto result = buft_map.emplace(tensor_split_arr, buft);
return &result.first->second;
}
static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
return GGML_SYCL_NAME "_Host";
GGML_UNUSED(buft);
}
inline void * aligned_malloc_host(size_t alignment, size_t size) {
#ifdef _WIN32
return _aligned_malloc(size, alignment);
#else
return aligned_alloc(alignment, size);
#endif
}
inline void free_aligned_mem_host(void * memblock) {
#ifdef _WIN32
_aligned_free(memblock);
#else
free(memblock);
#endif
}
static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
free_aligned_mem_host((void *)buffer->context);
}
static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size);
if (ptr == nullptr) {
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
}
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
buffer->buft = buft;
buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer;
return buffer;
}
ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n");
static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = {
{
ggml_backend_sycl_host_buffer_type_name,
ggml_backend_sycl_host_buffer_type_alloc_buffer,
ggml_backend_cpu_buffer_type()->iface.get_alignment,
NULL, ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
ggml_backend_cpu_buffer_type()->iface.is_host,
},
ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
nullptr,
};
return &ggml_backend_sycl_buffer_type_host;
}
struct ggml_sycl_pool_leg : public ggml_sycl_pool {
static const int MAX_SYCL_BUFFERS = 256;
int device;
queue_ptr qptr;
struct ggml_sycl_buffer {
void * ptr = nullptr;
size_t size = 0;
};
ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {};
size_t pool_size = 0;
explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {}
~ggml_sycl_pool_leg() {
for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
ggml_sycl_buffer & b = buffer_pool[i];
if (b.ptr != nullptr) {
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
pool_size -= b.size;
}
}
GGML_ASSERT(pool_size == 0);
}
void * alloc(size_t size, size_t * actual_size) override {
#ifdef DEBUG_sycl_MALLOC
int nnz = 0;
size_t max_size = 0;
#endif
size_t best_diff = 1ull << 36;
int ibest = -1;
for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
ggml_sycl_buffer& b = buffer_pool[i];
if (b.ptr != nullptr) {
#ifdef DEBUG_sycl_MALLOC
++nnz;
if (b.size > max_size) max_size = b.size;
#endif
if (b.size >= size) {
size_t diff = b.size - size;
if (diff < best_diff) {
best_diff = diff;
ibest = i;
if (!best_diff) {
void * ptr = b.ptr;
*actual_size = b.size;
b.ptr = nullptr;
b.size = 0;
return ptr;
}
}
}
}
}
if (ibest >= 0) {
ggml_sycl_buffer& b = buffer_pool[ibest];
void * ptr = b.ptr;
*actual_size = b.size;
b.ptr = nullptr;
b.size = 0;
return ptr;
}
void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size);
SYCL_CHECK(
CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device(
look_ahead_size, *qptr)));
if (!ptr) {
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device/GPU\n", __func__, look_ahead_size);
return nullptr;
}
*actual_size = look_ahead_size;
pool_size += look_ahead_size;
#ifdef DEBUG_SYCL_MALLOC
GGML_LOG_DEBUG("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
#endif
return ptr;
}
void free(void * ptr, size_t size) override {
for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
ggml_sycl_buffer& b = buffer_pool[i];
if (b.ptr == nullptr) {
b.ptr = ptr;
b.size = size;
return;
}
}
GGML_LOG_WARN("WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n");
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
pool_size -= size;
}
};
struct ggml_sycl_pool_host : public ggml_sycl_pool {
queue_ptr qptr;
int device;
inline static int counter{ 0 };
struct ggml_sycl_buffer {
void * ptr = nullptr;
size_t size = 0;
};
static constexpr int MAX_POOL_SIZE{ 64 };
std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
size_t pool_size = 0;
explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
~ggml_sycl_pool_host() {
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
ggml_sycl_buffer & b = buffer_pool[i];
if (b.ptr != nullptr) {
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
b.ptr = nullptr;
pool_size -= b.size;
b.size = 0;
}
}
counter = 0;
}
void * alloc(size_t size, size_t * actual_size) override {
if (counter == MAX_POOL_SIZE) {
ggml_sycl_buffer b = buffer_pool[0];
void * ptr = b.ptr;
*actual_size = b.size;
counter = 1;
return ptr;
}
ggml_sycl_buffer & b = buffer_pool[counter];
if (b.ptr == nullptr) {
void * ptr;
SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
if (!ptr) {
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
return nullptr;
}
pool_size += size;
*actual_size = size;
counter = counter + 1;
return ptr;
} else {
++counter;
b.size = size;
return b.ptr;
}
}
void free(void * ptr, size_t size) override {
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
ggml_sycl_buffer & b = buffer_pool[i];
if (b.ptr == nullptr) {
b.ptr = ptr;
b.size = size;
return;
}
}
}
};
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
}
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
}
typedef void (*ggml_sycl_op_mul_mat_t)(
ggml_backend_sycl_context & ctx,
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
const int64_t src1_ncols, const int64_t src1_padded_row_size,
const queue_ptr &stream);
static void mul_mat_p021_f16_f32(
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
const sycl::nd_item<3> &item_ct1) {
const sycl::half *x = (const sycl::half *)vx;
const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1);
const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
item_ct1.get_local_id(0);
const int channel_x = channel / (nchannels_y / nchannels_x);
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
const int row_dst = row_x;
float tmp = 0.0f;
for (int col_x0 = 0; col_x0 < ncols_x;
col_x0 += item_ct1.get_local_range(2)) {
const int col_x = col_x0 + item_ct1.get_local_id(2);
if (col_x >= ncols_x) {
break;
}
const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
const float xi =
sycl::vec<sycl::half, 1>(x[ix])
.convert<float, sycl::rounding_mode::automatic>()[0];
const int row_y = col_x;
const int iy = channel*nrows_y + row_y;
tmp += xi * y[iy];
}
const int idst = channel*nrows_dst + row_dst;
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
if (item_ct1.get_local_id(2) == 0) {
dst[idst] = tmp;
}
}
static void mul_mat_vec_nc_f16_f32( const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
const sycl::nd_item<3> &item_ct1) {
const sycl::half *x = (const sycl::half *)vx;
const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1);
const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
item_ct1.get_local_id(0);
const int channel_x = channel / channel_x_divisor;
const int nrows_dst = nrows_x;
const int row_dst = row_x;
const int idst = channel*nrows_dst + row_dst;
float tmp = 0.0f;
for (int col_x0 = 0; col_x0 < ncols_x;
col_x0 += item_ct1.get_local_range(2)) {
const int col_x = col_x0 + item_ct1.get_local_id(2);
if (col_x >= ncols_x) {
break;
}
const int row_y = col_x;
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const int iy = channel * channel_stride_y + row_y;
const float xi =
sycl::vec<sycl::half, 1>(x[ix])
.convert<float, sycl::rounding_mode::automatic>()[0];
tmp += xi * y[iy];
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
if (item_ct1.get_local_id(2) == 0) {
dst[idst] = tmp;
}
}
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
const sycl::nd_item<3> &item_ct1) {
const int row = item_ct1.get_group(1);
const int col = item_ct1.get_local_id(2);
float sum = 0.0f;
for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum, item_ct1);
if (col == 0) {
dst[row] = sum;
}
}
template<typename T>
static inline void ggml_sycl_swap(T & a, T & b) {
T tmp = a;
a = b;
b = tmp;
}
template <ggml_sort_order order>
__dpct_inline__ static void
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
uint8_t *dpct_local) {
int col_index = item_ct1.get_local_id(2);
int row = item_ct1.get_group(1);
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
if (col >= ncols_pad) {
return;
}
}
const float * x_row = x + row * ncols;
auto dst_row = (int *)dpct_local;
for (int i=0;i<tasks_per_thread;i++){
int col = col_index*tasks_per_thread+i;
dst_row[col] = col;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols &&
(order == GGML_SORT_ORDER_ASC
? x_row[dst_row[col]] > x_row[dst_row[ixj]]
: x_row[dst_row[col]] <
x_row[dst_row[ixj]]))) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols &&
(order == GGML_SORT_ORDER_ASC
? x_row[dst_row[col]] < x_row[dst_row[ixj]]
: x_row[dst_row[col]] >
x_row[dst_row[ixj]]))) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
}
}
}
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
}
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
const sycl::nd_item<3> &item_ct1) {
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1);
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (col >= ncols) {
return;
}
const int i = row*ncols + col;
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
}
static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (i >= k) {
return;
}
dst[i] = scale * x[i] + bias;
}
template <typename Ti, typename To>
static void pool2d_nchw_kernel(
const int ih, const int iw, const int oh, const int ow,
const int kh, const int kw, const int sh, const int sw,
const int ph, const int pw, const int parallel_elements,
const Ti* src, To* dst, const enum ggml_op_pool op,
const sycl::nd_item<3> &item_ct1) {
int idx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (idx >= parallel_elements) {
return;
}
const int I_HW = ih * iw;
const int O_HW = oh * ow;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / ow;
const int cur_ow = idx % O_HW % ow;
const Ti* i_ptr = src + nc * I_HW;
To* o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * sh - ph;
const int bh = sycl::max(0, start_h);
const int eh = sycl::min(ih, start_h + kh);
const int start_w = cur_ow * sw - pw;
const int bw = sycl::max(0, start_w);
const int ew = sycl::min(iw, start_w + kw);
To res = 0;
switch (op) {
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
default:
res = (To) sycl::nan(uint32_t(0));
break;
}
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
#if DPCT_COMPATIBILITY_TEMP >= 350
Ti cur = *(i_ptr + i * iw + j);
#else
Ti cur = i_ptr[i * iw + j];
#endif
switch (op) {
case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
default:
res = (To) sycl::nan(uint32_t(0));
break;
}
}
}
o_ptr[cur_oh * ow + cur_ow] = res;
}
static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
float *dst, const int ncols_x,
const int nrows_x,
const int nchannels_x,
const int nchannels_y,
queue_ptr stream) {
const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
nchannels_y, item_ct1);
});
}
}
static void ggml_mul_mat_vec_nc_f16_f32_sycl(
const void *vx, const float *y, float *dst, const int ncols_x,
const int nrows_x, const int row_stride_x, const int nchannels_x,
const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
row_stride_x, channel_stride_x, channel_stride_y,
nchannels_y / nchannels_x, item_ct1);
});
}
}
static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
const int k, queue_ptr stream) {
const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
scale_f32(x, dst, scale, bias, k, item_ct1);
});
}
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
const int nrows, queue_ptr stream) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
const sycl::range<3> block_nums(1, nrows, 1);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
k_sum_rows_f32(x, dst, ncols, item_ct1);
});
}
static int next_power_of_2(int x) {
int n = 1;
while (n < x) {
n *= 2;
}
return n;
}
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
const int nrows, ggml_sort_order order,
queue_ptr stream, int device) {
const int ncols_pad = next_power_of_2(ncols);
int nth = 1;
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
while (nth < ncols_pad && nth < max_block_size)
nth *= 2;
if (nth > max_block_size)
nth = max_block_size;
const int tasks_per_thread = ncols_pad / nth;
const sycl::range<3> block_dims(1, 1, nth);
const sycl::range<3> block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo);
if (order == GGML_SORT_ORDER_ASC) {
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
sycl::range<1>(shared_mem), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
dpct_local_acc_ct1
.get_multi_ptr<sycl::access::decorated::no>()
.get());
});
});
} else if (order == GGML_SORT_ORDER_DESC) {
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
sycl::range<1>(shared_mem), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
dpct_local_acc_ct1
.get_multi_ptr<sycl::access::decorated::no>()
.get());
});
});
} else {
GGML_ABORT("fatal error");
}
}
static void top_k_f32_sycl(
const float * src,
int32_t * dst_indices,
const int64_t ncols,
const int64_t nrows,
const int k,
dpct::queue_ptr main_stream
) {
const int block_size = 128;
const sycl::range<1> block_dims(block_size);
const sycl::range<1> grid_dims(nrows);
main_stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<float, 1> shared_vals(sycl::range<1>(block_size * k), cgh);
sycl::local_accessor<int, 1> shared_idx(sycl::range<1>(block_size * k), cgh);
cgh.parallel_for(
sycl::nd_range<1>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<1> item_ct1) {
const int row = item_ct1.get_group(0);
const int tid = item_ct1.get_local_id(0);
if (row >= nrows) return;
const float * src_row = src + row * ncols;
int32_t * dst_idx_row = dst_indices + row * k;
float local_vals[32];
int local_idx[32];
for (int i = 0; i < k; i++) {
local_vals[i] = -FLT_MAX;
local_idx[i] = -1;
}
for (int col = tid; col < ncols; col += block_size) {
float val = src_row[col];
if (val > local_vals[k-1]) {
int pos = k - 1;
while (pos > 0 && val > local_vals[pos - 1]) {
pos--;
}
for (int i = k - 1; i > pos; i--) {
local_vals[i] = local_vals[i - 1];
local_idx[i] = local_idx[i - 1];
}
local_vals[pos] = val;
local_idx[pos] = col;
}
}
for (int i = 0; i < k; i++) {
shared_vals[tid * k + i] = local_vals[i];
shared_idx[tid * k + i] = local_idx[i];
}
item_ct1.barrier(sycl::access::fence_space::local_space);
if (tid == 0) {
float final_vals[32];
int final_idx[32];
for (int i = 0; i < k; i++) {
final_vals[i] = -FLT_MAX;
final_idx[i] = -1;
}
for (int t = 0; t < block_size; t++) {
for (int i = 0; i < k; i++) {
float val = shared_vals[t * k + i];
int idx = shared_idx[t * k + i];
if (val > final_vals[k-1]) {
int pos = k - 1;
while (pos > 0 && val > final_vals[pos - 1]) {
pos--;
}
for (int j = k - 1; j > pos; j--) {
final_vals[j] = final_vals[j - 1];
final_idx[j] = final_idx[j - 1];
}
final_vals[pos] = val;
final_idx[pos] = idx;
}
}
}
for (int i = 0; i < k; i++) {
dst_idx_row[i] = final_idx[i];
}
if (k > 1) {
int32_t temp = dst_idx_row[0];
dst_idx_row[0] = dst_idx_row[1];
dst_idx_row[1] = temp;
}
}
});
});
}
static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
const int nrows, queue_ptr stream) {
const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
const sycl::range<3> block_nums(1, nrows, 1);
const size_t shared_mem = 256 * sizeof(float);
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<float, 1> shared_data(
sycl::range<1>(shared_mem/sizeof(float)), cgh);
sycl::local_accessor<int, 1> shared_indices(
sycl::range<1>(shared_mem/sizeof(float)), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
const int tid = item_ct1.get_local_id(2);
const int row = item_ct1.get_global_id(1);
float max_val = -INFINITY;
int max_idx = -1;
for (int col = tid; col < ncols; col += 256) {
float val = x[row * ncols + col];
if (val > max_val) {
max_val = val;
max_idx = col;
}
}
shared_data[tid] = max_val;
shared_indices[tid] = max_idx;
item_ct1.barrier(sycl::access::fence_space::local_space);
for (int stride = 256/2; stride > 0; stride >>= 1) {
if (tid < stride) {
float val1 = shared_data[tid];
float val2 = shared_data[tid + stride];
if (val2 > val1) {
shared_data[tid] = val2;
shared_indices[tid] = shared_indices[tid + stride];
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
}
if (tid == 0) {
dst[row] = shared_indices[0];
}
});
});
}
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
const int ncols_x, const int nrows_x,
const int rows_per_channel, const int n_past,
queue_ptr stream) {
const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1);
const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;
const sycl::range<3> block_nums(1, block_num_x, nrows_x);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
diag_mask_inf_f32(x, dst, ncols_x,
rows_per_channel, n_past,
item_ct1);
});
}
static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
const struct ggml_tensor *src,
int64_t i3, int64_t i2,
int64_t i1_low, int64_t i1_high,
queue_ptr stream) try {
dpct::memcpy_direction kind;
char * src_ptr;
if (ggml_backend_buffer_is_host(src->buffer)) {
kind = dpct::host_to_device;
src_ptr = (char *) src->data;
} else if (ggml_backend_buffer_is_sycl(src->buffer)) {
kind = dpct::device_to_device;
src_ptr = (char *) src->data;
} else if (ggml_backend_buffer_is_sycl_split(src->buffer)) {
GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]);
kind = dpct::device_to_device;
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
int id;
SYCL_CHECK(CHECK_TRY_ERROR(
id = get_current_device_id()));
src_ptr = (char *) extra->data_device[id];
} else {
GGML_ABORT("fatal error");
}
char * dst_ptr = (char *) dst;
GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
const enum ggml_type type = src->type;
const int64_t ts = ggml_type_size(type);
const int64_t bs = ggml_blck_size(type);
int64_t i1_diff = i1_high - i1_low;
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
if (nb0 == ts && nb1 == ts*ne0/bs) {
return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
kind, *stream));
} else if (nb0 == ts) {
return CHECK_TRY_ERROR(
dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
ts * ne0 / bs, i1_diff, kind, *stream));
} else {
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
const void * rx = (const void *) ((const char *) x + i1*nb1);
void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
if (r != 0) return r;
}
return 0;
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
inline void ggml_sycl_op_mul_mat_sycl(
ggml_backend_sycl_context & ctx,
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
const int64_t src1_ncols, const int64_t src1_padded_row_size,
const queue_ptr &stream) try {
GGML_ASSERT(src0_dd_i != nullptr);
GGML_ASSERT(src1_ddf_i != nullptr);
GGML_ASSERT(dst_dd_i != nullptr);
const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0];
GGML_ASSERT(ne00 == ne10);
const int64_t row_diff = row_high - row_low;
int id;
SYCL_CHECK(
CHECK_TRY_ERROR(id = get_current_device_id()));
const int64_t ne0 = dst->ne[0]; int ldc = id == ctx.device ? ne0 : row_diff;
#ifdef GGML_SYCL_F16
bool use_fp16 = true; #else
bool use_fp16 = false;
#endif
if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) &&
row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
if (src0->type != GGML_TYPE_F16) {
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, 2,
" : converting src0 to fp16");
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst);
GGML_ASSERT(to_fp16_sycl != nullptr);
size_t ne = row_diff*ne00;
src0_as_f16.alloc(ne);
to_fp16_sycl(src0_dd_i, src0_as_f16.get(), ne, stream);
}
const sycl::half *src0_ptr = src0->type == GGML_TYPE_F16
? (const sycl::half *)src0_dd_i
: src0_as_f16.get();
ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
if (src1->type != GGML_TYPE_F16) {
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, 2,
" : converting src1 to fp16");
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
GGML_ASSERT(to_fp16_sycl != nullptr);
size_t ne = src1_ncols*ne10;
src1_as_f16.alloc(ne);
to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream);
}
const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
? (const sycl::half *)src1->data + src1_padded_row_size
: src1_as_f16.get();
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
}
else
#endif
{
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
*stream, oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
dst_f16.get(), dpct::library_data_t::real_half, ldc,
dpct::library_data_t::real_half)));
scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, 2,
" : converting dst to fp32");
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
}
} else {
ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
if (src0->type != GGML_TYPE_F32) {
scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, 2,
" : converting src0 to fp32");
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);
GGML_ASSERT(to_fp32_sycl != nullptr);
src0_ddq_as_f32.alloc(row_diff*ne00);
to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
}
if (src1->type != GGML_TYPE_F32) {
scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, 2,
" : converting src1 to fp32");
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst);
GGML_ASSERT(to_fp32_sycl != nullptr);
src1_ddq_as_f32.alloc(src1_ncols*ne10);
to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
}
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
}
else
#endif
{
const float alpha = 1.0f;
const float beta = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff,
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
}
}
GGML_UNUSED(dst);
GGML_UNUSED(src1_ddq_i);
GGML_UNUSED(src1_padded_row_size);
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
const int k0 = opts[1];
const int k1 = opts[2];
const int s0 = opts[3];
const int s1 = opts[4];
const int p0 = opts[5];
const int p1 = opts[6];
const int64_t IH = dst->src[0]->ne[1];
const int64_t IW = dst->src[0]->ne[0];
const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];
const int parallel_elements = N * OC * OH * OW;
const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
sycl::range<3> block_nums(1, 1, num_blocks);
main_stream->parallel_for(
sycl::nd_range<3>(block_nums *
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,
parallel_elements, src0_dd, dst_dd, op,
item_ct1);
});
}
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int64_t ne = ggml_nelements(dst->src[0]);
sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
}
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
}
inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
main_stream->parallel_for(
sycl::range<1>(nrows),
[=](sycl::id<1> row) {
dst_dd[row] /= ncols;
}
);
}
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_I32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
main_stream, ctx.device);
}
static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(src0->data);
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
const int k = dst->ne[0];
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
GGML_ASSERT(k > 0 && k <= 32);
GGML_ASSERT(k <= ncols);
top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream);
}
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
}
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t ne01 = dst->src[0]->ne[1];
const int nrows0 = ggml_nrows(dst->src[0]);
const int n_past = ((int32_t *) dst->op_params)[0];
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
}
static void tri_f32_sycl(
const float * src,
float * dst,
const int64_t ne0,
const int64_t ne1,
const int64_t ne2,
const int64_t ne3,
const ggml_tri_type ttype,
dpct::queue_ptr main_stream
) {
const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
const int64_t idx = (int64_t) tid[0];
const int64_t i0 = idx % ne0;
const int64_t t1 = idx / ne0;
const int64_t i1 = t1 % ne1;
bool keep = false;
switch (ttype) {
case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break;
case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break;
case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
default: keep = false; break;
}
dst[idx] = keep ? src[idx] : 0.0f;
});
}
static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(src0->data);
float * dst_dd = static_cast<float *>(dst->data);
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
const int64_t ne0 = src0->ne[0];
const int64_t ne1 = src0->ne[1];
const int64_t ne2 = src0->ne[2];
const int64_t ne3 = src0->ne[3];
tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
}
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
float scale;
float bias;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
SYCL_CHECK(0);
}
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
static bool peer_access_enabled = false;
const bool enable_peer_access = n_tokens <= GGML_SYCL_PEER_MAX_BATCH_SIZE;
if (peer_access_enabled == enable_peer_access) {
return;
}
#ifdef NDEBUG
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
SYCL_CHECK(ggml_sycl_set_device(i));
}
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
SYCL_CHECK(ggml_sycl_set_device(i));
for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) {
if (i == id_other) {
continue;
}
if (i != main_device && id_other != main_device) {
continue;
}
}
}
#endif
peer_access_enabled = enable_peer_access;
}
template <template <int> typename quantize_f>
static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
ggml_sycl_op_mul_mat_t op) try {
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
const int64_t nrows1 = ggml_nrows(src1);
GGML_ASSERT(ne03 == ne13);
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src1->buffer));
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
const int64_t i02_divisor = ne12 / ne02;
const size_t src0_ts = ggml_type_size(src0->type);
const size_t src0_bs = ggml_blck_size(src0->type);
const size_t q8_1_ts = sizeof(block_q8_1);
const size_t q8_1_bs = QK8_1;
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
const bool src0_is_contiguous = ggml_is_contiguous(src0);
const bool src1_is_contiguous = ggml_is_contiguous(src1);
int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
GGML_ASSERT(!(split && ne02 > 1));
GGML_ASSERT(!(split && ne03 > 1));
GGML_ASSERT(!(split && ne02 < ne12));
std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
if (split) {
ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
tensor_split = buft_ctx->tensor_split;
}
struct dev_data {
ggml_sycl_pool_alloc<char> src0_dd_alloc;
ggml_sycl_pool_alloc<float> src1_ddf_alloc;
ggml_sycl_pool_alloc<char> src1_ddq_alloc;
ggml_sycl_pool_alloc<float> dst_dd_alloc;
char *src0_dd = nullptr;
float *src1_ddf = nullptr; char *src1_ddq = nullptr; float *dst_dd = nullptr;
int64_t row_low;
int64_t row_high;
};
dev_data dev[GGML_SYCL_MAX_DEVICES];
int used_devices = 0;
queue_ptr main_stream = ctx.stream();
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
dev[i].row_low = 0;
dev[i].row_high = ne01;
if (split) {
const int64_t rounding = get_row_rounding(src0->type, tensor_split);
if (i != 0) {
dev[i].row_low = ne01*tensor_split[i];
if (dev[i].row_low < ne01) {
dev[i].row_low -= dev[i].row_low % rounding;
}
}
if (i != ggml_sycl_info().device_count - 1) {
dev[i].row_high = ne01*tensor_split[i + 1];
if (dev[i].row_high < ne01) {
dev[i].row_high -= dev[i].row_high % rounding;
}
}
}
}
constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
continue;
}
used_devices++;
const bool src1_on_device = i == ctx.device;
const bool dst_on_device = i == ctx.device;
ggml_sycl_set_device(i);
queue_ptr stream = ctx.stream(i, 0);
if (src0_is_contiguous) {
dev[i].src0_dd = (char *) src0->data;
} else {
dev[i].src0_dd = dev[i].src0_dd_alloc.alloc(ctx.pool(i), ggml_nbytes(src0));
}
if (src1_on_device && src1_is_contiguous) {
dev[i].src1_ddf = (float *) src1->data;
} else {
dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
}
if constexpr(quantize_enabled) {
dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
if (src1_on_device && src1_is_contiguous) {
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2, " : converting src1 to Q8_1");
try {
quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
} catch (sycl::exception const &exc) {
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
}
}
if (dst_on_device) {
dev[i].dst_dd = (float *) dst->data;
} else {
const size_t size_dst_ddf = split ? (dev[i].row_high - dev[i].row_low)*ne1 : ggml_nelements(dst);
dev[i].dst_dd = dev[i].dst_dd_alloc.alloc(ctx.pool(i), size_dst_ddf);
}
}
if (split && used_devices > 1) {
ggml_sycl_set_device(ctx.device);
SYCL_CHECK(CHECK_TRY_ERROR(
*src0_extra->events[ctx.device][0] =
ctx.stream()->ext_oneapi_submit_barrier()));
}
const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
continue;
}
const bool src1_on_device = i == ctx.device;
const bool dst_on_device = i == ctx.device;
const int64_t row_diff = dev[i].row_high - dev[i].row_low;
ggml_sycl_set_device(i);
queue_ptr stream = ctx.stream(i, is);
if (split && (i != ctx.device || is != 0)) {
SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
{*src0_extra->events[ctx.device][0]})));
}
for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
const int64_t i03 = i0 / ne12;
const int64_t i02 = i0 % ne12;
const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
char * src0_dd_i = dev[i].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
float * src1_ddf_i = dev[i].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
char * src1_ddq_i = dev[i].src1_ddq + src1_ddq_i_offset;
float * dst_dd_i = dev[i].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
if (i == ctx.device) {
dst_dd_i += dev[i].row_low; }
if (src1_is_contiguous) {
if (i != ctx.device) {
if constexpr (quantize_enabled) {
char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
SYCL_CHECK(
CHECK_TRY_ERROR(stream
->memcpy(src1_ddq_i, src1_ddq_i_source,
src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
.wait()));
} else {
float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
SYCL_CHECK(
CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
src1_ncols * ne10 * sizeof(float))));
}
}
} else {
if (src1_on_device) {
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
src1_col_0 + src1_ncols, stream));
} else {
GGML_ABORT("src1 is non-contiguous and not on device");
}
if constexpr (quantize_enabled) {
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2, " : converting src1 to Q8_1");
try {
quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
src1_padded_col_size, stream);
} catch (const sycl::exception & exc) {
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
<< "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1);
}
}
}
if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream));
}
if (src1->type == GGML_TYPE_F16) {
src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
}
SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
if (!dst_on_device) {
void * dst_off_device = dst->data;
if (split) {
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
dhf_dst_i += src1_col_0*ne0 + dev[i].row_low;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
dhf_dst_i, ne0 * sizeof(float), dst_dd_i,
row_diff * sizeof(float), row_diff * sizeof(float),
src1_ncols, dpct::device_to_device, *stream)));
} else {
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
dhf_dst_i += src1_col_0*ne0;
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(dhf_dst_i, dst_dd_i,
src1_ncols * ne0 * sizeof(float)).wait()));
}
}
if (split && (i != ctx.device || is != 0)) {
SYCL_CHECK(CHECK_TRY_ERROR(
*src0_extra->events[i][is] =
stream->ext_oneapi_submit_barrier()));
}
}
}
}
if (split && ggml_sycl_info().device_count > 1) {
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS;
ggml_sycl_set_device(ctx.device);
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
if (dev[i].row_low == dev[i].row_high) {
continue;
}
for (int64_t is = 0; is < is_max; ++is) {
SYCL_CHECK(CHECK_TRY_ERROR(
ctx.stream()->ext_oneapi_submit_barrier(
{*src0_extra->events[i][is]})));
}
}
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
ggml_sycl_op_repeat_back(ctx, dst);
}
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 2);
ggml_sycl_op_get_rows(ctx, dst);
}
static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
ggml_sycl_op_norm(ctx, dst);
}
static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
ggml_sycl_op_rms_norm(ctx, dst);
}
static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 2);
ggml_sycl_op_rms_norm_back(ctx, dst);
}
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
ggml_sycl_op_l2_norm(ctx, dst);
}
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
ggml_sycl_op_group_norm(ctx, dst);
}
static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst) try {
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne12 = src1->ne[2];
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
queue_ptr main_stream = ctx.stream();
void * src0_ddq = src0->data;
float * src1_ddf = (float *) src1->data;
float * dst_ddf = (float *) dst->data;
ggml_mul_mat_p021_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst) try {
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_is_permuted(src0));
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src1->ne[1] == 1);
GGML_ASSERT(src1->ne[3] == 1);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];
const int64_t ne12 = src1->ne[2];
const int64_t nb11 = src1->nb[1];
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
queue_ptr main_stream = ctx.stream();
void * src0_ddq = src0->data;
float * src1_ddf = (float *) src1->data;
float * dst_ddf = (float *) dst->data;
const int64_t row_stride_x = nb01 / sizeof(sycl::half);
const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
const int64_t channel_stride_y = nb11 / sizeof(float);
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
if (i13 >= ne13 || i12 >= ne12) {
return;
}
const int64_t i03 = i13 / r3;
const int64_t i02 = i12 / r2;
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
}
static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst) try {
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(ggml_is_contiguous(dst));
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
queue_ptr queue = ctx.stream();
dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
float * dst_ddf = static_cast<float *>(dst->data);
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
const size_t type_size_src0 = ggml_type_size(src0->type);
const size_t type_size_src1 = ggml_type_size(src1->type);
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
int64_t s11 = nb11 / type_size_src1;
int64_t s12 = nb12 / type_size_src1;
int64_t s13 = nb13 / type_size_src1;
ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
if (src1->type != GGML_TYPE_F16) {
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, 2,
" : converting src1 to fp16");
int last_dim=0;
int last_str=0;
size_t largest_str=0;
for(int i = 0; i< 4; i++){
if(src1->nb[i] == largest_str){
if(src1->ne[last_dim] == 1){
last_str = i;
last_dim = i;
}
}
if(src1->nb[i] > largest_str){
largest_str = src1->nb[i];
last_str = i;
last_dim = i;
}
}
#if GGML_SYCL_DNNL
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
src1_f16_alloc.alloc(ne_src1);
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
GGML_ASSERT(to_fp16_sycl != nullptr);
to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
# else
const int64_t ne_src1 = ggml_nelements(src1);
src1_f16_alloc.alloc(ne_src1);
const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type);
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
#endif
src1_f16 = src1_f16_alloc.get();
s11 = ne10;
s12 = ne11 * s11;
s13 = ne12 * s12;
is_src1_cont_2 = true;
}
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];
const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;
const void * alpha = &alpha_f32;
const void * beta = &beta_f32;
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
GGML_ASSERT(ne10 == ne00);
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
int64_t str_a0 = nb00 / type_size_src0;
int64_t str_a1 = nb01 / type_size_src0;
int64_t str_a2 = nb02 / type_size_src0;
int64_t str_b0 = nb10 / type_size_src1;
int64_t str_b1 = nb11 / type_size_src1;
int64_t str_b2 = nb12 / type_size_src1;
auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
const sycl::half *src1, float *dst,
int64_t a0, int64_t a1, int64_t batcha,
int64_t , int64_t b1, int64_t batchb,
int64_t sa0, int64_t sa1, int64_t sa2,
int64_t sb0, int64_t sb1, int64_t sb2,
int64_t sd2) {
bool supported_broadcast = batchb == batcha ? true
: batchb == 1 || batcha == 1 ? true
: false;
if (supported_broadcast) {
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
} else {
int64_t batches0 = batcha;
int64_t batches1 = batchb;
if (batches0 > batches1) {
int64_t num_mul_mats = batches1;
int64_t sub_batch = batches0 / num_mul_mats;
for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
const sycl::half *src1_shifted = src1 + (sb2 * i0);
float *dst_shifted = dst + (sd2 * i0 * sub_batch);
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
queue, sub_batch, 1);
}
} else {
int64_t num_mul_mats = batches0;
int64_t sub_batch = batches1 / num_mul_mats;
for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
const sycl::half *src0_shifted = src0 + (sa2 * i1);
const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
float *dst_shifted = dst + (sd2 * i1 * sub_batch);
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
queue, 1, sub_batch);
}
}
}
};
const bool cont_batches_dim2_a = nb02 * ne02 == nb03;
const bool cont_batches_dim2_b = nb12 * ne12 == nb13;
const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;
const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;
if (cont_batches_dim2_a && cont_batches_dim2_b) {
int64_t batches0 = ne02 * ne03;
int64_t batches1 = ne12 * ne13;
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
str_b2, nb2 / sizeof(float));
} else if (cont_batches_dim3_a && cont_batches_dim3_b) {
int64_t batches0 = ne02 * ne03;
int64_t batches1 = ne12 * ne13;
int64_t str_a3 = nb03 / type_size_src0;
int64_t str_b3 = nb13 / type_size_src1;
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,
str_b3, nb2 / sizeof(float));
} else {
for (int64_t b_a = 0; b_a < ne03; b_a++) {
const sycl::half *src0_f16_shifted
= src0_f16 + (nb03 * b_a / type_size_src0);
const sycl::half *src1_f16_shifted
= src1_f16 + (nb13 * b_a / type_size_src1);
float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
int64_t batches0 = ne02;
int64_t batches1 = ne12;
launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
}
}
}
else
#endif
{
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
const int64_t smb = ne12 == 1 ? s13 : s12;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
} else {
const int ne23 = ne12 * ne13;
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
sycl::range<3> block_dims(1, ne12, ne13);
queue->submit([&](sycl::handler & cgh) {
const void ** ptrs_src_get = ptrs_src.get();
void ** ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
});
});
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
}
}
} catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1);
}
enum class mul_mat_algo {
DMMV = 0,
MMVQ = 1,
MUL_MAT_SYCL = 2,
};
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
GGML_UNUSED(type);
return false;
}
inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
return true;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
return !g_ggml_sycl_prioritize_dmmv;
default:
return false;
}
}
inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
return true;
default:
return false;
}
}
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
return true;
default:
return false;
}
}
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_F16:
return true;
default:
return false;
}
}
static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) {
bool use_async = g_ggml_sycl_use_async_mem_op;
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
if (use_async) {
return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);
}
#else
GGML_ASSERT(!use_async);
#endif
return sycl::malloc(size, *stream, sycl::usm::alloc::device);
}
static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {
bool use_async = g_ggml_sycl_use_async_mem_op;
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
if (use_async) {
syclex::async_free(*stream, ptr);
return;
}
#else
GGML_ASSERT(!use_async);
#endif
sycl::free(ptr, *stream);
}
struct sycl_reorder_temp_buffer {
void * ptr = nullptr;
dpct::queue_ptr stream;
sycl_reorder_temp_buffer(dpct::queue_ptr stream, size_t size) : stream(stream) {
ptr = sycl_ext_malloc_device(stream, size);
#ifdef GGML_SYCL_HOST_MEM_FALLBACK
if (!ptr) {
ptr = sycl::malloc_host(size, *stream);
if (ptr) {
host_fallback = true;
GGML_LOG_WARN("%s: device alloc of %zu bytes failed, using host memory fallback\n", __func__, size);
}
}
#endif
}
~sycl_reorder_temp_buffer() {
if (!ptr) {
return;
}
if (host_fallback) {
sycl::free(ptr, *stream);
} else {
sycl_ext_free(stream, ptr);
}
}
explicit operator bool() const { return ptr != nullptr; }
sycl_reorder_temp_buffer(const sycl_reorder_temp_buffer &) = delete;
sycl_reorder_temp_buffer & operator=(const sycl_reorder_temp_buffer &) = delete;
private:
bool host_fallback = false;
};
static bool reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
dpct::queue_ptr stream) {
sycl_reorder_temp_buffer tmp(stream, size);
if (!tmp) {
GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size);
return false;
}
uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr);
sycl::event copy_event;
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
if (!g_ggml_sycl_use_async_mem_op) {
copy_event.wait();
}
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
int offset_blks = offset / sizeof(block_q4_0);
auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
auto reorder_event = stream->parallel_for(
size / sizeof(block_q4_0),
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
const block_q4_0* x = (const block_q4_0*)tmp_buf;
const int ib = i;
for (int j = 0; j < QK4_0/2; j ++)
{
*(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
}
*(d_ptr + ib) = x[ib].d;
});
if (!g_ggml_sycl_use_async_mem_op) {
reorder_event.wait_and_throw();
}
return true;
}
static bool reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
dpct::queue_ptr stream) {
sycl_reorder_temp_buffer tmp(stream, size);
if (!tmp) {
GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size);
return false;
}
uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr);
sycl::event copy_event;
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
if (!g_ggml_sycl_use_async_mem_op) {
copy_event.wait();
}
GGML_ASSERT((size % sizeof(block_q8_0) == 0));
GGML_ASSERT((offset % sizeof(block_q8_0) == 0));
int offset_blks = offset / sizeof(block_q8_0);
auto qs_ptr = data_device + offset_blks * QK8_0;
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows) + offset_blks;
auto reorder_event = stream->parallel_for(
size / sizeof(block_q8_0),
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
const block_q8_0* x = (const block_q8_0*)tmp_buf;
const int ib = i;
for (int j = 0; j < QK8_0; j++)
{
*((int8_t*)qs_ptr + ib * QK8_0 + j) = x[ib].qs[j];
}
*(d_ptr + ib) = x[ib].d;
});
if (!g_ggml_sycl_use_async_mem_op) {
reorder_event.wait_and_throw();
}
return true;
}
static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
const int nblocks = size / sizeof(block_q4_K);
sycl_reorder_temp_buffer tmp(stream, size);
if (!tmp) {
GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size);
return false;
}
uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr);
sycl::event copy_event;
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
if (!g_ggml_sycl_use_async_mem_op) {
copy_event.wait();
}
auto * qs_ptr = data_device;
auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
const block_q4_K * x = (const block_q4_K *) tmp_buf;
const int ib = i;
for (int j = 0; j < QK_K / 2; ++j) {
qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
}
for (int j = 0; j < K_SCALE_SIZE; ++j) {
scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
}
dm_ptr[ib] = x[ib].dm;
});
if (!g_ggml_sycl_use_async_mem_op) {
reorder_event.wait_and_throw();
}
return true;
}
static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
GGML_ASSERT(size % sizeof(block_q6_K) == 0);
GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
const int nblocks = size / sizeof(block_q6_K);
sycl_reorder_temp_buffer tmp(stream, size);
if (!tmp) {
GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size);
return false;
}
uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr);
sycl::event copy_event;
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
if (!g_ggml_sycl_use_async_mem_op) {
copy_event.wait();
}
auto * ql_ptr = data_device;
auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
const block_q6_K * x = (const block_q6_K *) tmp_buf;
const int ib = i;
const uint8_t * ql = x[ib].ql;
const uint8_t * qh = x[ib].qh;
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
for (int j = 0; j < QK_K / 2; ++j) {
base_ql_ptr[j] = ql[j];
}
for (int j = 0; j < QK_K / 4; ++j) {
base_qh_ptr[j] = qh[j];
}
for (int j = 0; j < QK_K / 16; ++j) {
base_scales_ptr[j] = x[ib].scales[j];
}
dm_ptr[ib] = x[ib].d;
});
if (!g_ggml_sycl_use_async_mem_op) {
reorder_event.wait_and_throw();
}
return true;
}
static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
uint8_t * data_device = (uint8_t *) src0->data;
size_t ncols = src0->ne[0];
size_t nrows = src0->ne[1];
size_t size = ggml_nbytes(src0);
switch (src0->type) {
case GGML_TYPE_Q4_0:
return reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
case GGML_TYPE_Q8_0:
return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream);
case GGML_TYPE_Q4_K:
return reorder_qw_q4_k(data_device, size, 0, stream);
case GGML_TYPE_Q6_K:
return reorder_qw_q6_k(data_device, size, 0, stream);
default:
GGML_ABORT("reorder_qw() called with unsupported type");
return false;
}
}
static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
return !g_ggml_sycl_disable_optimize && ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT && dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
}
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * ,
ggml_tensor * dst, mul_mat_algo mm_algorithm) {
if (!should_reorder_tensor(*ctx, dst)) {
return;
}
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
if (!extra || extra->optimized_feature.reorder) {
return; }
switch (mm_algorithm) {
case mul_mat_algo::DMMV:
if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
return;
}
break;
case mul_mat_algo::MMVQ:
if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
return;
}
break;
case mul_mat_algo::MUL_MAT_SYCL:
if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
return;
}
break;
}
if (reorder_qw(src0, ctx->stream())) {
extra->optimized_feature.reorder = true; }
}
static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
}
static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
}
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 2);
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
int64_t min_compute_capability = INT_MAX;
if (split) {
ggml_backend_sycl_split_buffer_type_context * buft_ctx =
(ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
auto & tensor_split = buft_ctx->tensor_split;
for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
if (tensor_split[id] >= (id + 1 < ggml_sycl_info().device_count ? tensor_split[id + 1] : 1.0f)) {
continue;
}
if (min_compute_capability > ggml_sycl_info().devices[id].cc) {
min_compute_capability = ggml_sycl_info().devices[id].cc;
}
}
} else {
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
}
bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
#ifdef SYCL_USE_XMX
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
#endif
if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) &&
ggml_sycl_supports_reorder_mmvq(src0->type)))) {
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
}
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
if (src0->ne[3] == 1 && src1->ne[3] == 1) {
ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
} else {
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
}
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) {
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
} else if (use_mul_mat_vec_q) {
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
if (extra && extra->optimized_feature.reorder) {
ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
} else {
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
}
} else if (use_mul_mat_q) {
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
} else {
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
}
}
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
__dpct_inline__ static void k_copy_src1_to_contiguous(
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
const sycl::nd_item<3> &item_ct1, int &src1_row) {
int32_t iid1 = item_ct1.get_group(2);
int32_t id = item_ct1.get_group(1);
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
if (row_id_i != i02) {
return;
}
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
if (item_ct1.get_local_id(2) == 0) {
src1_row =
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
cur_src1_row, 1);
row_mapping[src1_row] = {id, iid1};
}
item_ct1.barrier();
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
#pragma unroll
for (int i = item_ct1.get_local_id(2); i < ne10;
i += item_ct1.get_local_range(2)) {
src1_row_contiguous[i] = src1_row_original[i];
}
}
__dpct_inline__ static void k_copy_dst_from_contiguous(
char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
size_t nb2, const sycl::nd_item<3> &item_ct1) {
int32_t i = item_ct1.get_group(2);
const int32_t i1 = row_mapping[i].i1;
const int32_t i2 = row_mapping[i].i2;
const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
#pragma unroll
for (int j = item_ct1.get_local_id(2); j < ne0;
j += item_ct1.get_local_range(2)) {
dst_row_original[j] = dst_row_contiguous[j];
}
}
static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
ggml_tensor *dst) try {
scope_op_debug_print scope_dbg_print(__func__, dst, 3);
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
const ggml_tensor *ids = dst->src[2];
GGML_TENSOR_BINARY_OP_LOCALS
const queue_ptr stream = ctx.stream();
const int64_t n_as = ne02;
const int64_t n_ids = ids->ne[0];
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
char *src0_original = (char *)src0->data;
char *src1_original = (char *)src1->data;
char *dst_original = (char *)dst->data;
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[3] = nb02;
src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
src1_row.ne[3] = 1;
src1_row.nb[2] = nb11;
src1_row.nb[3] = nb11;
dst_row.ne[1] = 1;
dst_row.ne[2] = 1;
dst_row.ne[3] = 1;
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
if (ne12 == 1) {
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
const int64_t i1 = id;
const int64_t i2 = i12;
src0_row.data = src0_original + i02*nb02;
src1_row.data = src1_original + i11*nb11 + i12*nb12;
dst_row.data = dst_original + i1*nb1 + i2*nb2;
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
}
}
} else {
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
src1_row.data = src1_contiguous.get();
dst_row.data = dst_contiguous.get();
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0;
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
if (row_id_i != i02) {
continue;
}
num_src1_rows++;
}
}
if (num_src1_rows == 0) {
continue;
}
ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<int, 0> src1_row_acc(cgh);
char *__restrict src1_contiguous_get =
src1_contiguous.get();
int *__restrict dev_cur_src1_row_get =
dev_cur_src1_row.get();
mmid_row_mapping *__restrict dev_row_mapping_get =
dev_row_mapping.get();
size_t ids_nb_ct6 = ids->nb[1];
size_t ids_nb_ct7 = ids->nb[0];
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_copy_src1_to_contiguous(
src1_original, src1_contiguous_get,
dev_cur_src1_row_get,
dev_row_mapping_get, ids_dev, i02,
ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
item_ct1, src1_row_acc);
});
});
}
src0_row.data = src0_original + i02*nb02;
GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0);
src1_row.ne[1] = num_src1_rows;
src1_row.nb[1] = nb11;
src1_row.nb[2] = num_src1_rows*nb11;
src1_row.nb[3] = num_src1_rows*nb11;
dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
sycl::range<3> grid_dims(1, 1, num_src1_rows);
stream->submit([&](sycl::handler &cgh) {
const char *__restrict dst_contiguous_get =
dst_contiguous.get();
const mmid_row_mapping *__restrict dev_row_mapping_get =
dev_row_mapping.get();
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_copy_dst_from_contiguous(dst_original,
dst_contiguous_get,
dev_row_mapping_get,
ne0, nb1, nb2, item_ct1);
});
});
}
}
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
ggml_sycl_op_scale(ctx, dst);
}
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
ggml_sycl_op_diag_mask_inf(ctx, dst);
}
static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
ggml_sycl_op_pool2d(ctx, dst);
}
static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 2);
ggml_sycl_op_im2col(ctx, dst);
}
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_sum(ctx, dst);
}
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_sum_rows(ctx, dst);
}
static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_mean(ctx, dst);
}
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_argsort(ctx, dst);
}
static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 1);
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_argmax(ctx, dst);
}
static void ggml_sycl_set_main_device(const int main_device) try {
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
return;
}
check_allow_gpu_index(main_device);
dpct::select_device(main_device);
if (g_ggml_sycl_debug) {
dpct::device_info prop;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
prop, dpct::dev_mgr::instance().get_device(main_device))));
GGML_LOG_INFO("Using device %d (%s) as main device\n",
main_device, prop.get_name());
}
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {
if (!g_sycl_loaded) return false;
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device);
}
switch (dst->op) {
case GGML_OP_ARGMAX:
ggml_sycl_argmax(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_sycl_op_conv_transpose_1d(ctx, dst);
break;
case GGML_OP_REPEAT:
ggml_sycl_repeat(ctx, dst);
break;
case GGML_OP_REPEAT_BACK:
ggml_sycl_repeat_back(ctx, dst);
break;
case GGML_OP_GET_ROWS:
ggml_sycl_get_rows(ctx, dst);
break;
case GGML_OP_SET:
ggml_sycl_op_set(ctx, dst);
break;
case GGML_OP_SET_ROWS:
ggml_sycl_op_set_rows(ctx, dst);
break;
case GGML_OP_DUP:
ggml_sycl_dup(ctx, dst);
break;
case GGML_OP_ADD:
case GGML_OP_ADD1: ggml_sycl_add(ctx, dst);
break;
case GGML_OP_ADD_ID:
ggml_sycl_add_id(ctx, dst);
break;
case GGML_OP_SUB:
ggml_sycl_sub(ctx, dst);
break;
case GGML_OP_COUNT_EQUAL:
ggml_sycl_count_equal(ctx, dst);
break;
case GGML_OP_ACC:
ggml_sycl_acc(ctx, dst);
break;
case GGML_OP_MUL:
ggml_sycl_mul(ctx, dst);
break;
case GGML_OP_LOG:
ggml_sycl_log(ctx, dst);
break;
case GGML_OP_DIV:
ggml_sycl_div(ctx, dst);
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_NEG:
ggml_sycl_neg(ctx, dst);
break;
case GGML_UNARY_OP_STEP:
ggml_sycl_step(ctx, dst);
break;
case GGML_UNARY_OP_GELU:
ggml_sycl_gelu(ctx, dst);
break;
case GGML_UNARY_OP_SILU:
ggml_sycl_silu(ctx, dst);
break;
case GGML_UNARY_OP_GELU_QUICK:
ggml_sycl_gelu_quick(ctx, dst);
break;
case GGML_UNARY_OP_GELU_ERF:
ggml_sycl_gelu_erf(ctx, dst);
break;
case GGML_UNARY_OP_TANH:
ggml_sycl_tanh(ctx, dst);
break;
case GGML_UNARY_OP_RELU:
ggml_sycl_relu(ctx, dst);
break;
case GGML_UNARY_OP_SIGMOID:
ggml_sycl_sigmoid(ctx, dst);
break;
case GGML_UNARY_OP_HARDSIGMOID:
ggml_sycl_hardsigmoid(ctx, dst);
break;
case GGML_UNARY_OP_HARDSWISH:
ggml_sycl_hardswish(ctx, dst);
break;
case GGML_UNARY_OP_EXP:
ggml_sycl_exp(ctx, dst);
break;
case GGML_UNARY_OP_SOFTPLUS:
ggml_sycl_softplus(ctx, dst);
break;
case GGML_UNARY_OP_SGN:
ggml_sycl_sgn(ctx, dst);
break;
case GGML_UNARY_OP_ABS:
ggml_sycl_abs(ctx, dst);
break;
case GGML_UNARY_OP_ELU:
ggml_sycl_elu(ctx, dst);
break;
case GGML_UNARY_OP_FLOOR:
ggml_sycl_floor(ctx, dst);
break;
case GGML_UNARY_OP_CEIL:
ggml_sycl_ceil(ctx, dst);
break;
case GGML_UNARY_OP_ROUND:
ggml_sycl_round(ctx, dst);
break;
case GGML_UNARY_OP_TRUNC:
ggml_sycl_trunc(ctx, dst);
break;
default:
return false;
}
break;
case GGML_OP_GLU:
switch (ggml_get_glu_op(dst)) {
case GGML_GLU_OP_REGLU:
ggml_sycl_reglu(ctx, dst);
break;
case GGML_GLU_OP_GEGLU:
ggml_sycl_geglu(ctx, dst);
break;
case GGML_GLU_OP_SWIGLU:
ggml_sycl_swiglu(ctx, dst);
break;
case GGML_GLU_OP_SWIGLU_OAI:
ggml_sycl_swiglu_oai(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_ERF:
ggml_sycl_geglu_erf(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_QUICK:
ggml_sycl_geglu_quick(ctx, dst);
break;
default:
return false;
}
break;
case GGML_OP_NORM:
ggml_sycl_norm(ctx, dst);
break;
case GGML_OP_GROUP_NORM:
ggml_sycl_group_norm(ctx, dst);
break;
case GGML_OP_CONCAT:
ggml_sycl_op_concat(ctx, dst);
break;
case GGML_OP_PAD_REFLECT_1D:
ggml_sycl_op_pad_reflect_1d(ctx,dst);
break;
case GGML_OP_UPSCALE:
ggml_sycl_upscale(ctx, dst);
break;
case GGML_OP_PAD:
ggml_sycl_pad(ctx, dst);
break;
case GGML_OP_LEAKY_RELU:
ggml_sycl_leaky_relu(ctx, dst);
break;
case GGML_OP_RMS_NORM_BACK:
ggml_sycl_rms_norm_back(ctx, dst);
break;
case GGML_OP_RMS_NORM:
ggml_sycl_rms_norm(ctx, dst);
break;
case GGML_OP_L2_NORM:
ggml_sycl_l2_norm(ctx, dst);
break;
case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
return false;
}
ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst);
break;
case GGML_OP_MUL_MAT_ID:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
return false;
}
ggml_sycl_mul_mat_id(ctx, dst);
break;
case GGML_OP_OUT_PROD:
ggml_sycl_op_out_prod(ctx, dst);
break;
case GGML_OP_SCALE:
ggml_sycl_scale(ctx, dst);
break;
case GGML_OP_SQR:
ggml_sycl_sqr(ctx, dst);
break;
case GGML_OP_SQRT:
ggml_sycl_sqrt(ctx, dst);
break;
case GGML_OP_SIN:
ggml_sycl_sin(ctx, dst);
break;
case GGML_OP_COS:
ggml_sycl_cos(ctx, dst);
break;
case GGML_OP_CLAMP:
ggml_sycl_clamp(ctx, dst);
break;
case GGML_OP_CPY:
ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]);
break;
case GGML_OP_CONT:
ggml_sycl_dup(ctx, dst);
break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
break;
case GGML_OP_TRI:
ggml_sycl_op_tri(ctx, dst);
break;
case GGML_OP_DIAG_MASK_INF:
ggml_sycl_diag_mask_inf(ctx, dst);
break;
case GGML_OP_SOFT_MAX:
ggml_sycl_op_soft_max(ctx, dst);
break;
case GGML_OP_SOFT_MAX_BACK:
ggml_sycl_op_soft_max_back(ctx, dst);
break;
case GGML_OP_ROPE:
ggml_sycl_rope(ctx, dst);
break;
case GGML_OP_ROPE_BACK:
ggml_sycl_rope_back(ctx, dst);
break;
case GGML_OP_IM2COL:
ggml_sycl_im2col(ctx, dst);
break;
case GGML_OP_POOL_2D:
ggml_sycl_pool2d(ctx, dst);
break;
case GGML_OP_SUM:
ggml_sycl_sum(ctx, dst);
break;
case GGML_OP_SUM_ROWS:
ggml_sycl_sum_rows(ctx, dst);
break;
case GGML_OP_MEAN:
ggml_sycl_mean(ctx, dst);
break;
case GGML_OP_ARGSORT:
ggml_sycl_argsort(ctx, dst);
break;
case GGML_OP_TOP_K:
ggml_sycl_op_top_k(ctx, dst);
break;
case GGML_OP_TIMESTEP_EMBEDDING:
ggml_sycl_op_timestep_embedding(ctx, dst);
break;
case GGML_OP_RWKV_WKV6:
ggml_sycl_op_rwkv_wkv6(ctx, dst);
break;
case GGML_OP_RWKV_WKV7:
ggml_sycl_op_rwkv_wkv7(ctx, dst);
break;
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_GATED_DELTA_NET:
ggml_sycl_gated_delta_net(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_sycl_ssm_conv(ctx, dst);
break;
case GGML_OP_ROLL:
ggml_sycl_roll(ctx, dst);
break;
case GGML_OP_ARANGE:
ggml_sycl_arange(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_sycl_flash_attn_ext(ctx, dst);
break;
default:
return false;
}
return true;
} catch (sycl::exception & e) {
std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl;
std::exit(1);
}
GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
size_t description_size) try {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_description\n");
dpct::device_info prop;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
prop, dpct::dev_mgr::instance().get_device(device))));
snprintf(description, description_size, "%s", prop.get_name());
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
void ggml_backend_sycl_get_device_memory(int device, size_t *free,
size_t *total) try {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
ggml_sycl_set_device(device);
SYCL_CHECK(CHECK_TRY_ERROR(
dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static const char * ggml_backend_sycl_get_name(ggml_backend_t backend) {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
return sycl_ctx->name.c_str();
}
static void ggml_backend_sycl_free(ggml_backend_t backend) {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
delete sycl_ctx;
delete backend;
}
static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
ggml_tensor *tensor,
const void *data, size_t offset,
size_t size) try {
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR(
(stream)->memcpy((char *)tensor->data + offset, data, size)));
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
const ggml_tensor *tensor,
void *data, size_t offset,
size_t size) try {
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str());
GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset);
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
data, (const char *)tensor->data + offset, size)));
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
const ggml_tensor *src,
ggml_tensor *dst) try {
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
bool is_cpy_supported = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) &&
ggml_backend_buffer_is_sycl(src->buffer);
GGML_SYCL_DEBUG("[SYCL] call %s", __func__);
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str());
GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str());
GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported);
if (is_cpy_supported) {
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
dst->data, src->data, ggml_nbytes(dst))));
return true;
}
return false;
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));
GGML_UNUSED(backend);
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
ggml_sycl_set_main_device(sycl_ctx->device);
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
}
}
#endif
bool ok = ggml_sycl_compute_forward(*sycl_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
}
}
#ifdef GGML_SYCL_GRAPH
static bool check_graph_compatibility(ggml_cgraph * cgraph) {
if (ggml_sycl_info().device_count > 1) {
GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__);
return false;
}
for (int i = 0; i < cgraph->n_nodes; i++) {
const ggml_op node_op = cgraph->nodes[i]->op;
switch (node_op) {
default:
break;
case GGML_OP_CONCAT:
[[fallthrough]];
case GGML_OP_MUL_MAT_ID:
GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
ggml_op_name(node_op));
return false;
case GGML_OP_MUL_MAT:
if (!g_ggml_sycl_use_async_mem_op) {
GGML_LOG_INFO(
"%s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
"oneAPI async memory allocation extension "
"%s\n",
__func__, ggml_op_name(node_op));
return false;
}
}
}
return true;
}
#endif
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
#ifdef GGML_SYCL_GRAPH
bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);
if (use_sycl_graph) {
const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
if (!graph_support) {
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
return GGML_STATUS_SUCCESS;
}
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
model_sycl_graph.end_recording();
const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph);
if (!sycl_ctx->exec_graph || !graph_update_support) {
auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) :
model_sycl_graph.finalize();
sycl_ctx->exec_graph = std::make_unique<
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
} else {
try {
sycl_ctx->exec_graph->update(model_sycl_graph);
GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
} catch (sycl::exception const & e) {
GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
sycl_ctx->exec_graph = std::make_unique<
sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
}
}
sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
} else
#endif
{
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
}
return GGML_STATUS_SUCCESS;
}
static void ggml_backend_sycl_event_record(ggml_backend_t backend, ggml_backend_event_t event)
try
{
ggml_backend_sycl_context *sycl_ctx =
(ggml_backend_sycl_context *)backend->context;
sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
const queue_ptr &stream = sycl_ctx->stream(sycl_ctx->device, 0);
SYCL_CHECK(CHECK_TRY_ERROR(*sycl_event = stream->ext_oneapi_submit_barrier()));
}
catch (sycl::exception const &exc)
{
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try {
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
sycl::event* sycl_event = static_cast<sycl::event*>(event->context);
if (ggml_backend_is_sycl(backend)) {
SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
} else
GGML_ABORT("fatal error");
} catch (sycl::exception const& exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static ggml_backend_i ggml_backend_sycl_interface = {
ggml_backend_sycl_get_name,
ggml_backend_sycl_free,
ggml_backend_sycl_set_tensor_async,
ggml_backend_sycl_get_tensor_async,
NULL,
NULL,
NULL, ggml_backend_sycl_synchronize,
NULL,
NULL,
NULL,
NULL,
ggml_backend_sycl_graph_compute,
ggml_backend_sycl_event_record,
ggml_backend_sycl_event_wait,
NULL,
};
static ggml_guid_t ggml_backend_sycl_guid() {
static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 };
return &guid;
}
bool ggml_backend_is_sycl(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid());
}
int ggml_backend_sycl_get_device_count() {
return ggml_sycl_info().device_count;
}
struct ggml_backend_sycl_device_context {
int device;
std::string name;
std::string description;
int op_offload_min_batch_size;
};
static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
return ctx->name.c_str();
}
static const char * ggml_backend_sycl_device_get_description(ggml_backend_dev_t dev) {
ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
return ctx->description.c_str();
}
static void ggml_backend_sycl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
ggml_sycl_set_device(ctx->device);
SYCL_CHECK(CHECK_TRY_ERROR(
dpct::dev_mgr::instance().get_device(ctx->device).get_memory_info(*free, *total)));
}
static enum ggml_backend_dev_type ggml_backend_sycl_device_get_type(ggml_backend_dev_t dev) {
GGML_UNUSED(dev);
return GGML_BACKEND_DEVICE_TYPE_GPU;
}
static void ggml_backend_sycl_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
props->name = ggml_backend_sycl_device_get_name(dev);
props->description = ggml_backend_sycl_device_get_description(dev);
props->type = ggml_backend_sycl_device_get_type(dev);
ggml_backend_sycl_device_get_memory(dev, &props->memory_free, &props->memory_total);
bool host_buffer = getenv("GGML_SYCL_NO_PINNED") == nullptr;
#ifdef GGML_SYCL_NO_PEER_COPY
bool events = false;
#else
bool events = true;
#endif
props->caps = {
true,
host_buffer,
false,
events,
};
}
static ggml_backend_t ggml_backend_sycl_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
return ggml_backend_sycl_init(ctx->device);
}
static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_buffer_type(ggml_backend_dev_t dev) {
ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
return ggml_backend_sycl_buffer_type(ctx->device);
}
static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_host_buffer_type(ggml_backend_dev_t dev) {
GGML_UNUSED(dev);
return ggml_backend_sycl_host_buffer_type();
}
static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
GGML_UNUSED(dev);
GGML_UNUSED(ptr);
GGML_UNUSED(size);
GGML_UNUSED(max_tensor_size);
return nullptr;
}
static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
ggml_backend_sycl_device_context *sycl_ctx =
(ggml_backend_sycl_device_context *)dev->context;
int device = sycl_ctx->device;
switch (op->op) {
case GGML_OP_CONV_TRANSPOSE_1D:
{
ggml_type src0_type = op->src[0]->type;
ggml_type src1_type = op->src[1]->type;
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
return true;
}
return false;
}
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_CEIL:
return true;
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
#if defined (GGML_SYCL_F16)
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
#else
return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
#endif
default:
return false;
}
case GGML_OP_GLU:
switch (ggml_get_glu_op(op)) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]);
default:
return false;
}
break;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
if (a->type == GGML_TYPE_Q1_0 || b->type == GGML_TYPE_Q1_0) {
return false;
}
if (a->ne[3] != b->ne[3]) {
return false;
}
ggml_type src0_type = op->src[0]->type;
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
return false;
}
if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
return false;
}
return true;
}
case GGML_OP_OUT_PROD:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
case GGML_OP_GET_ROWS:
{
switch (op->src[0]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return true;
default:
return false;
}
}
case GGML_OP_SET:
return (op->type == GGML_TYPE_F32) &&
(op->src[0] && op->src[1]) &&
(op->src[0]->type == GGML_TYPE_F32) &&
(op->src[1]->type == GGML_TYPE_F32);
case GGML_OP_SET_ROWS:
{
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));
}
break;
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;
ggml_type src1_type = op->src[1]->type;
if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
return true;
}
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
return true;
}
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
return true;
}
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
return true;
}
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
return true;
}
if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) {
return true;
}
if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) {
return true;
}
if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) {
return true;
}
if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) {
return true;
}
if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) {
return true;
}
return false;
}
case GGML_OP_REPEAT_BACK:
{
ggml_type src0_type = op->src[0]->type;
return src0_type == GGML_TYPE_F32;
}
case GGML_OP_CONCAT:
case GGML_OP_DUP:
case GGML_OP_ARGMAX:
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_ADD_ID:
case GGML_OP_SUB:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_REPEAT:
return true;
case GGML_OP_PAD_REFLECT_1D:
return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_LOG:
#if defined (GGML_SYCL_F16)
return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
#else
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
#endif
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_SCALE:
return true;
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_TRI:
{
const ggml_tensor * src0 = op->src[0];
return src0 &&
op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(src0);
}
case GGML_OP_DIAG_MASK_INF:
return true;
case GGML_OP_SOFT_MAX:
return true;
case GGML_OP_SOFT_MAX_BACK: {
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
return max_bias == 0.0f;
}
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_IM2COL:
case GGML_OP_UPSCALE:
return true;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ARGSORT:
return op->src[0]->ne[0] * sizeof(int) <=
ggml_sycl_info().devices[device].smpbo;
case GGML_OP_TOP_K: {
const ggml_tensor * src0 = op->src[0];
const int k = op->ne[0];
return src0 &&
op->type == GGML_TYPE_I32 &&
src0->type == GGML_TYPE_F32 &&
ggml_is_contiguous(src0) &&
k > 0 && k <= 32;
}
case GGML_OP_POOL_2D:
return true;
case GGML_OP_ACC:
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
case GGML_OP_PAD:
if (ggml_get_op_params_i32(op, 8) != 0) {
return false;
}
return ggml_is_contiguous(op->src[0]);
case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_GATED_DELTA_NET:
return true;
case GGML_OP_SSM_CONV:
return op->type == GGML_TYPE_F32 &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_ROLL:
return op->type == GGML_TYPE_F32;
case GGML_OP_ARANGE:
return op->type == GGML_TYPE_F32;
case GGML_OP_FLASH_ATTN_EXT:
return ggml_sycl_flash_attn_ext_supported(device, op);
default:
return false;
}
GGML_UNUSED(dev);
}
static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
if (buft->iface.get_name != ggml_backend_sycl_buffer_type_get_name) {
return false;
}
ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
return buft_ctx->device == sycl_ctx->device;
}
static int64_t get_op_batch_size(const ggml_tensor * op) {
switch (op->op) {
case GGML_OP_GET_ROWS:
return 0;
case GGML_OP_MUL_MAT:
return op->ne[1];
case GGML_OP_MUL_MAT_ID:
case GGML_OP_ROPE:
return op->ne[2];
default:
return ggml_nrows(op);
}
}
static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size;
}
static ggml_backend_event_t
ggml_backend_sycl_device_event_new(ggml_backend_dev_t dev) {
#ifdef GGML_SYCL_NO_PEER_COPY
return nullptr;
#else
sycl::event *event_ptr = new sycl::event();
return new ggml_backend_event{
dev,
event_ptr,
};
#endif
}
static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
GGML_UNUSED(dev);
if (event == nullptr) {
return;
}
if (event->context != nullptr) {
sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
delete sycl_event;
event->context = nullptr;
}
delete event;
} catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
GGML_UNUSED(dev);
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
sycl::event *sycl_event = static_cast<sycl::event *>(event->context);
SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
} catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
static const ggml_backend_device_i ggml_backend_sycl_device_interface = {
ggml_backend_sycl_device_get_name,
ggml_backend_sycl_device_get_description,
ggml_backend_sycl_device_get_memory,
ggml_backend_sycl_device_get_type,
ggml_backend_sycl_device_get_props,
ggml_backend_sycl_device_init,
ggml_backend_sycl_device_get_buffer_type,
ggml_backend_sycl_device_get_host_buffer_type,
ggml_backend_sycl_device_buffer_from_host_ptr,
ggml_backend_sycl_device_supports_op,
ggml_backend_sycl_device_supports_buft,
ggml_backend_sycl_device_offload_op,
ggml_backend_sycl_device_event_new,
ggml_backend_sycl_device_event_free,
ggml_backend_sycl_device_event_synchronize,
};
struct ggml_backend_sycl_reg_context {
std::vector<ggml_backend_dev_t> devices;
};
static const char * ggml_backend_sycl_reg_get_name(ggml_backend_reg_t reg) {
GGML_UNUSED(reg);
return GGML_SYCL_NAME;
}
static size_t ggml_backend_sycl_reg_get_device_count(ggml_backend_reg_t reg) {
ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
return ctx->devices.size();
}
static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t reg, size_t index) {
ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
GGML_ASSERT(index < ctx->devices.size());
return ctx->devices[index];
}
static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) {
GGML_UNUSED(reg);
if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
return (void *)ggml_backend_sycl_split_buffer_type;
}
GGML_UNUSED(name);
return nullptr;
}
static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = {
ggml_backend_sycl_reg_get_name,
ggml_backend_sycl_reg_get_device_count,
ggml_backend_sycl_reg_get_device,
ggml_backend_sycl_reg_get_proc_address,
};
ggml_backend_reg_t ggml_backend_sycl_reg() {
static ggml_backend_reg reg;
static bool initialized = false;
{
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
dev_ctx->device = i;
dev_ctx->name = GGML_SYCL_NAME + std::to_string(i);
ggml_sycl_set_device(i);
dpct::device_info prop;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
prop, dpct::dev_mgr::instance().get_device(i))));
dev_ctx->description = prop.get_name();
dev_ctx->op_offload_min_batch_size = min_batch_size;
ggml_backend_dev_t dev = new ggml_backend_device {
ggml_backend_sycl_device_interface,
®,
dev_ctx
};
ctx->devices.push_back(dev);
}
reg = ggml_backend_reg {
GGML_BACKEND_API_VERSION,
ggml_backend_sycl_reg_interface,
ctx
};
}
initialized = true;
}
return ®
}
ggml_backend_t ggml_backend_sycl_init(int device) {
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_init\n");
ggml_check_sycl();
check_allow_gpu_index(device);
ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context(device);
if (ctx == nullptr) {
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
return nullptr;
};
ggml_backend_t sycl_backend = new ggml_backend {
ggml_backend_sycl_guid(),
ggml_backend_sycl_interface,
ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
ctx
};
return sycl_backend;
}
GGML_BACKEND_DL_IMPL(ggml_backend_sycl_reg)