#ifndef LIBRARYFUNCS_H_
#define LIBRARYFUNCS_H_
#include <llvm/ADT/StringMap.h>
#include <llvm/Analysis/AliasAnalysis.h>
#include <llvm/Analysis/TargetLibraryInfo.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/InlineAsm.h>
#include <llvm/IR/Instructions.h>
#include "Utils.h"
class GradientUtils;
extern llvm::StringMap<std::function<llvm::Value *(
llvm::IRBuilder<> &, llvm::CallInst *, llvm::ArrayRef<llvm::Value *>,
GradientUtils *)>>
shadowHandlers;
extern llvm::StringMap<
std::function<llvm::CallInst *(llvm::IRBuilder<> &, llvm::Value *)>>
shadowErasers;
static inline bool isAllocationFunction(const llvm::StringRef name,
const llvm::TargetLibraryInfo &TLI) {
if (name == "enzyme_allocator")
return true;
if (name == "calloc" || name == "malloc")
return true;
if (name == "_mlir_memref_to_llvm_alloc")
return true;
if (name == "swift_allocObject")
return true;
if (name == "__size_returning_new_experiment")
return true;
if (name == "__rust_alloc" || name == "__rust_alloc_zeroed")
return true;
if (name == "julia.gc_alloc_obj" || name == "jl_gc_alloc_typed" ||
name == "ijl_gc_alloc_typed")
return true;
if (shadowHandlers.find(name) != shadowHandlers.end())
return true;
using namespace llvm;
llvm::LibFunc libfunc;
if (!TLI.getLibFunc(name, libfunc))
return false;
switch (libfunc) {
case LibFunc_malloc: case LibFunc_valloc:
case LibFunc_Znwj: case LibFunc_ZnwjRKSt9nothrow_t: case LibFunc_ZnwjSt11align_val_t: case LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t:
case LibFunc_Znwm: case LibFunc_ZnwmRKSt9nothrow_t: case LibFunc_ZnwmSt11align_val_t: case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t:
case LibFunc_Znaj: case LibFunc_ZnajRKSt9nothrow_t: case LibFunc_ZnajSt11align_val_t: case LibFunc_ZnajSt11align_val_tRKSt9nothrow_t:
case LibFunc_Znam: case LibFunc_ZnamRKSt9nothrow_t: case LibFunc_ZnamSt11align_val_t: case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t:
case LibFunc_msvc_new_int: case LibFunc_msvc_new_int_nothrow: case LibFunc_msvc_new_longlong: case LibFunc_msvc_new_longlong_nothrow: case LibFunc_msvc_new_array_int: case LibFunc_msvc_new_array_int_nothrow: case LibFunc_msvc_new_array_longlong: case LibFunc_msvc_new_array_longlong_nothrow:
return true;
default:
return false;
}
}
static inline bool isDeallocationFunction(const llvm::StringRef name,
const llvm::TargetLibraryInfo &TLI) {
using namespace llvm;
llvm::LibFunc libfunc;
if (name == "_ZdlPvmSt11align_val_t")
return true;
if (!TLI.getLibFunc(name, libfunc)) {
if (name == "free")
return true;
if (name == "_mlir_memref_to_llvm_free")
return true;
if (name == "__rust_dealloc")
return true;
if (name == "swift_release")
return true;
return false;
}
switch (libfunc) {
case LibFunc_free:
case LibFunc_ZdaPv:
case LibFunc_ZdlPv:
case LibFunc_msvc_delete_array_ptr32:
case LibFunc_msvc_delete_array_ptr64:
case LibFunc_msvc_delete_ptr32:
case LibFunc_msvc_delete_ptr64:
case LibFunc_ZdaPvRKSt9nothrow_t:
case LibFunc_ZdaPvj:
case LibFunc_ZdaPvm:
case LibFunc_ZdlPvRKSt9nothrow_t:
case LibFunc_ZdlPvj:
case LibFunc_ZdlPvm:
case LibFunc_ZdlPvSt11align_val_t:
case LibFunc_ZdaPvSt11align_val_t:
case LibFunc_msvc_delete_array_ptr32_int:
case LibFunc_msvc_delete_array_ptr32_nothrow:
case LibFunc_msvc_delete_array_ptr64_longlong:
case LibFunc_msvc_delete_array_ptr64_nothrow:
case LibFunc_msvc_delete_ptr32_int:
case LibFunc_msvc_delete_ptr32_nothrow:
case LibFunc_msvc_delete_ptr64_longlong:
case LibFunc_msvc_delete_ptr64_nothrow:
case LibFunc_ZdlPvSt11align_val_tRKSt9nothrow_t:
case LibFunc_ZdaPvSt11align_val_tRKSt9nothrow_t:
return true;
default:
return false;
}
}
static inline void zeroKnownAllocation(llvm::IRBuilder<> &bb,
llvm::Value *toZero,
llvm::ArrayRef<llvm::Value *> argValues,
const llvm::StringRef funcName,
const llvm::TargetLibraryInfo &TLI,
llvm::CallInst *orig) {
using namespace llvm;
assert(isAllocationFunction(funcName, TLI));
if (funcName == "calloc" || funcName == "__rust_alloc_zeroed")
return;
Value *allocSize = argValues[0];
if (funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" ||
funcName == "ijl_gc_alloc_typed") {
allocSize = argValues[1];
}
if (funcName == "enzyme_allocator") {
auto index = getAllocationIndexFromCall(orig);
allocSize = argValues[*index];
}
Value *dst_arg = toZero;
if (funcName == "__size_returning_new_experiment")
dst_arg = bb.CreateExtractValue(dst_arg, 0);
if (dst_arg->getType()->isIntegerTy())
dst_arg = bb.CreateIntToPtr(dst_arg, getInt8PtrTy(toZero->getContext()));
else
dst_arg = bb.CreateBitCast(
dst_arg, getInt8PtrTy(toZero->getContext(),
toZero->getType()->getPointerAddressSpace()));
auto val_arg = ConstantInt::get(Type::getInt8Ty(toZero->getContext()), 0);
auto len_arg =
bb.CreateZExtOrTrunc(allocSize, Type::getInt64Ty(toZero->getContext()));
auto memset = bb.CreateMemSet(dst_arg, val_arg, len_arg, MaybeAlign());
memset->addParamAttr(0, Attribute::NonNull);
if (auto CI = dyn_cast<ConstantInt>(allocSize)) {
auto derefBytes = CI->getLimitedValue();
#if LLVM_VERSION_MAJOR >= 14
memset->addDereferenceableParamAttr(0, derefBytes);
memset->setAttributes(
memset->getAttributes().addDereferenceableOrNullParamAttr(
memset->getContext(), 0, derefBytes));
#else
memset->addDereferenceableAttr(llvm::AttributeList::FirstArgIndex,
derefBytes);
memset->addDereferenceableOrNullAttr(llvm::AttributeList::FirstArgIndex,
derefBytes);
#endif
}
}
llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder,
llvm::Value *tofree,
llvm::StringRef allocationfn,
const llvm::DebugLoc &debuglocation,
const llvm::TargetLibraryInfo &TLI,
llvm::CallInst *orig,
GradientUtils *gutils);
static inline bool isAllocationCall(const llvm::Value *TmpOrig,
llvm::TargetLibraryInfo &TLI) {
if (auto *CI = llvm::dyn_cast<llvm::CallBase>(TmpOrig)) {
auto AttrList =
CI->getAttributes().getAttributes(llvm::AttributeList::FunctionIndex);
if (AttrList.hasAttribute("enzyme_allocation"))
return true;
if (auto Fn = getFunctionFromCall(CI))
if (Fn->hasFnAttribute("enzyme_allocation"))
return true;
return isAllocationFunction(getFuncNameFromCall(CI), TLI);
}
return false;
}
static inline bool isDeallocationCall(const llvm::Value *TmpOrig,
llvm::TargetLibraryInfo &TLI) {
if (auto *CI = llvm::dyn_cast<llvm::CallBase>(TmpOrig)) {
return isDeallocationFunction(getFuncNameFromCall(CI), TLI);
}
return false;
}
#endif