#include "logger_bridge.hpp"
#include <NvOnnxParser.h>
#include <cstdint>
#include <cstring>
class RustLoggerImpl : public nvinfer1::ILogger {
public:
RustLoggerImpl(RustLogCallback callback, void* user_data)
: callback_(callback), user_data_(user_data) {}
void log(Severity severity, const char* msg) noexcept override {
if (callback_) {
callback_(user_data_, static_cast<int32_t>(severity), msg);
}
}
private:
RustLogCallback callback_;
void* user_data_;
};
struct RustLoggerBridge {
RustLoggerImpl* impl;
};
extern "C" {
RustLoggerBridge* create_rust_logger_bridge(RustLogCallback callback, void* user_data) {
if (!callback) {
return nullptr;
}
try {
auto* bridge = new RustLoggerBridge();
bridge->impl = new RustLoggerImpl(callback, user_data);
return bridge;
} catch (...) {
return nullptr;
}
}
void destroy_rust_logger_bridge(RustLoggerBridge* logger) {
if (logger) {
delete logger->impl;
delete logger;
}
}
nvinfer1::ILogger* get_logger_interface(RustLoggerBridge* logger) {
return logger ? logger->impl : nullptr;
}
#ifdef TRTX_LINK_TENSORRT_RTX
void* create_infer_builder(void* logger) {
if (!logger) {
return nullptr;
}
try {
auto* ilogger = static_cast<nvinfer1::ILogger*>(logger);
return nvinfer1::createInferBuilder(*ilogger);
} catch (...) {
return nullptr;
}
}
void* create_infer_runtime(void* logger) {
if (!logger) {
return nullptr;
}
try {
auto* ilogger = static_cast<nvinfer1::ILogger*>(logger);
return nvinfer1::createInferRuntime(*ilogger);
} catch (...) {
return nullptr;
}
}
#endif
#ifdef TRTX_LINK_TENSORRT_ONNXPARSER
void* create_onnx_parser(void* network, void* logger) {
if (!network || !logger) {
return nullptr;
}
try {
auto* inetwork = static_cast<nvinfer1::INetworkDefinition*>(network);
auto* ilogger = static_cast<nvinfer1::ILogger*>(logger);
return nvonnxparser::createParser(*inetwork, *ilogger);
} catch (...) {
return nullptr;
}
}
#endif
void builder_config_set_memory_pool_limit(void* config, int32_t pool_type, size_t limit) {
if (!config) return;
try {
auto* iconfig = static_cast<nvinfer1::IBuilderConfig*>(config);
iconfig->setMemoryPoolLimit(static_cast<nvinfer1::MemoryPoolType>(pool_type), limit);
} catch (...) {
}
}
void* network_add_concatenation(void* network, void** inputs, int32_t nb_inputs) {
if (!network || !inputs || nb_inputs <= 0) return nullptr;
try {
auto* inetwork = static_cast<nvinfer1::INetworkDefinition*>(network);
std::vector<nvinfer1::ITensor*> tensors;
tensors.reserve(nb_inputs);
for (int32_t i = 0; i < nb_inputs; ++i) {
tensors.push_back(static_cast<nvinfer1::ITensor*>(inputs[i]));
}
auto* layer = inetwork->addConcatenation(tensors.data(), nb_inputs);
return layer; } catch (...) {
return nullptr;
}
}
void* network_add_assertion(void* network, void* condition, const char* message) {
if (!network || !condition) return nullptr;
try {
auto* inetwork = static_cast<nvinfer1::INetworkDefinition*>(network);
auto* condition_tensor = static_cast<nvinfer1::ITensor*>(condition);
auto* layer = inetwork->addAssertion(*condition_tensor, message ? message : "");
return layer;
} catch (...) {
return nullptr;
}
}
void* network_add_loop(void* network) {
if (!network) return nullptr;
try {
auto* inetwork = static_cast<nvinfer1::INetworkDefinition*>(network);
return inetwork->addLoop();
} catch (...) {
return nullptr;
}
}
void* network_add_if_conditional(void* network) {
if (!network) return nullptr;
try {
auto* inetwork = static_cast<nvinfer1::INetworkDefinition*>(network);
return inetwork->addIfConditional();
} catch (...) {
return nullptr;
}
}
void* tensor_get_dimensions(void* tensor, int32_t* dims, int32_t* nb_dims) {
if (!tensor || !dims || !nb_dims) return nullptr;
try {
auto* itensor = static_cast<nvinfer1::ITensor*>(tensor);
nvinfer1::Dims dimensions = itensor->getDimensions();
*nb_dims = dimensions.nbDims;
for (int32_t i = 0; i < dimensions.nbDims && i < nvinfer1::Dims::MAX_DIMS; ++i) {
dims[i] = dimensions.d[i];
}
return tensor; } catch (...) {
return nullptr;
}
}
int32_t tensor_get_type(void* tensor) {
if (!tensor) return -1;
try {
auto* itensor = static_cast<nvinfer1::ITensor*>(tensor);
return static_cast<int32_t>(itensor->getType());
} catch (...) {
return -1;
}
}
void* builder_build_serialized_network(void* builder, void* network, void* config, size_t* out_size) {
if (!builder || !network || !config || !out_size) return nullptr;
try {
auto* ibuilder = static_cast<nvinfer1::IBuilder*>(builder);
auto* inetwork = static_cast<nvinfer1::INetworkDefinition*>(network);
auto* iconfig = static_cast<nvinfer1::IBuilderConfig*>(config);
auto* serialized = ibuilder->buildSerializedNetwork(*inetwork, *iconfig);
if (!serialized) return nullptr;
*out_size = serialized->size();
void* data = malloc(*out_size);
if (data) {
memcpy(data, serialized->data(), *out_size);
}
delete serialized;
return data;
} catch (...) {
return nullptr;
}
}
void* runtime_deserialize_cuda_engine(void* runtime, const void* data, size_t size) {
if (!runtime || !data) return nullptr;
try {
auto* iruntime = static_cast<nvinfer1::IRuntime*>(runtime);
return iruntime->deserializeCudaEngine(data, size);
} catch (...) {
return nullptr;
}
}
bool parser_parse(void* parser, const void* data, size_t size) {
if (!parser || !data) return false;
try {
auto* iparser = static_cast<nvonnxparser::IParser*>(parser);
return iparser->parse(data, size);
} catch (...) {
return false;
}
}
int32_t parser_get_nb_errors(void* parser) {
if (!parser) return 0;
try {
auto* iparser = static_cast<nvonnxparser::IParser*>(parser);
return iparser->getNbErrors();
} catch (...) {
return 0;
}
}
void* parser_get_error(void* parser, int32_t index) {
if (!parser) return nullptr;
try {
auto* iparser = static_cast<nvonnxparser::IParser*>(parser);
return const_cast<nvonnxparser::IParserError*>(iparser->getError(index));
} catch (...) {
return nullptr;
}
}
const char* parser_error_desc(void* error) {
if (!error) return nullptr;
try {
auto* ierror = static_cast<nvonnxparser::IParserError*>(error);
return ierror->desc();
} catch (...) {
return nullptr;
}
}
void delete_builder(void* builder) {
if (builder) {
delete static_cast<nvinfer1::IBuilder*>(builder);
}
}
void delete_network(void* network) {
if (network) {
delete static_cast<nvinfer1::INetworkDefinition*>(network);
}
}
void delete_config(void* config) {
if (config) {
delete static_cast<nvinfer1::IBuilderConfig*>(config);
}
}
void delete_runtime(void* runtime) {
if (runtime) {
delete static_cast<nvinfer1::IRuntime*>(runtime);
}
}
void delete_engine(void* engine) {
if (engine) {
delete static_cast<nvinfer1::ICudaEngine*>(engine);
}
}
void delete_context(void* context) {
if (context) {
delete static_cast<nvinfer1::IExecutionContext*>(context);
}
}
void delete_parser(void* parser) {
if (parser) {
delete static_cast<nvonnxparser::IParser*>(parser);
}
}
uint32_t get_tensorrt_version() {
return NV_TENSORRT_VERSION;
}
}