#pragma once
#include "shim.h"
#include <torch/torch.h>
#include <cstring>
#include <string>
#include <vector>
static inline char* make_error(const std::string& msg) {
char* err = (char*)malloc(msg.size() + 1);
if (err) {
memcpy(err, msg.c_str(), msg.size() + 1);
}
return err;
}
static inline torch::ScalarType to_scalar_type(int dtype) {
switch (dtype) {
case FLODL_FLOAT16: return torch::kFloat16;
case FLODL_BFLOAT16: return torch::kBFloat16;
case FLODL_FLOAT32: return torch::kFloat32;
case FLODL_FLOAT64: return torch::kFloat64;
case FLODL_INT32: return torch::kInt32;
case FLODL_INT64: return torch::kInt64;
default: return torch::kFloat32;
}
}
static inline int from_scalar_type(torch::ScalarType st) {
switch (st) {
case torch::kFloat16: return FLODL_FLOAT16;
case torch::kBFloat16: return FLODL_BFLOAT16;
case torch::kFloat32: return FLODL_FLOAT32;
case torch::kFloat64: return FLODL_FLOAT64;
case torch::kInt32: return FLODL_INT32;
case torch::kInt64: return FLODL_INT64;
default: return FLODL_FLOAT32;
}
}
static inline torch::Device to_device(int device_type, int device_index) {
if (device_type == FLODL_CUDA) {
return torch::Device(torch::kCUDA, (c10::DeviceIndex)device_index);
}
return torch::Device(torch::kCPU);
}
static inline int from_device_type(const torch::Device& dev) {
if (dev.is_cuda()) return FLODL_CUDA;
return FLODL_CPU;
}
static inline int from_device_index(const torch::Device& dev) {
if (dev.is_cuda()) return (int)dev.index();
return 0;
}
static inline FlodlTensor wrap(torch::Tensor t) {
auto* p = new torch::Tensor(std::move(t));
return (FlodlTensor)p;
}
static inline torch::Tensor& unwrap(FlodlTensor t) {
return *((torch::Tensor*)t);
}
static inline std::vector<at::Tensor> unwrap_list(FlodlTensor* tensors, int count) {
std::vector<at::Tensor> result;
result.reserve(count);
for (int i = 0; i < count; i++) {
result.push_back(unwrap(tensors[i]));
}
return result;
}
static inline torch::IntArrayRef make_shape(int64_t* shape, int ndim) {
return torch::IntArrayRef(shape, ndim);
}
static inline torch::ScalarType mask_dtype(const torch::Tensor& t) {
return t.is_floating_point() ? t.scalar_type() : torch::kFloat32;
}