#include <stdio.h>
#include "pg_query.h"
#include "pg_query_internal.h"
#include "pg_query_json_plpgsql.h"
#include "pg_query_proctup_attrs.h"
#include <assert.h>
#include <catalog/pg_type.h>
#include <catalog/objectaddress.h>
#include <catalog/pg_language_d.h>
#include <catalog/pg_proc.h>
#include <commands/defrem.h>
#include <nodes/parsenodes.h>
#include <nodes/nodeFuncs.h>
#include <parser/parse_type.h>
#include <utils/builtins.h>
#include <utils/lsyscache.h>
#include <utils/syscache.h>
typedef struct {
PLpgSQL_function *func;
PgQueryError* error;
} PgQueryInternalPlpgsqlFuncAndError;
static PgQueryInternalPlpgsqlFuncAndError pg_query_raw_parse_plpgsql(Node* stmt);
static void
compute_return_type(TypeName *returnType, Oid languageOid,
Oid *prorettype_p, bool *returnsSet_p)
{
Oid rettype;
Type typtup;
typtup = LookupTypeName(NULL, returnType, NULL, false);
if (typtup)
{
if (!((Form_pg_type) GETSTRUCT(typtup))->typisdefined)
{
if (languageOid == SQLlanguageId)
ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("SQL function cannot return shell type %s",
TypeNameToString(returnType))));
else
ereport(NOTICE,
(errcode(ERRCODE_WRONG_OBJECT_TYPE),
errmsg("return type %s is only a shell",
TypeNameToString(returnType))));
}
rettype = typeTypeId(typtup);
ReleaseSysCache(typtup);
}
else
{
char *typnam = TypeNameToString(returnType);
Oid namespaceId;
char *typname;
ObjectAddress address;
if (languageOid != INTERNALlanguageId &&
languageOid != ClanguageId)
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_OBJECT),
errmsg("type \"%s\" does not exist", typnam)));
if (returnType->typmods != NIL)
ereport(ERROR,
(errcode(ERRCODE_SYNTAX_ERROR),
errmsg("type modifier cannot be specified for shell type \"%s\"",
typnam)));
elog(ERROR, "Not implemented");
}
*prorettype_p = rettype;
*returnsSet_p = returnType->setof;
}
static PLpgSQL_function *compile_do_stmt(DoStmt* stmt)
{
char *proc_source = NULL;
const ListCell *lc;
char *language = "plpgsql";
assert(IsA(stmt, DoStmt));
foreach(lc, stmt->args)
{
DefElem* elem = (DefElem*) lfirst(lc);
if (strcmp(elem->defname, "as") == 0) {
assert(IsA(elem->arg, String));
proc_source = strVal(elem->arg);
} else if (strcmp(elem->defname, "language") == 0) {
language = strVal(elem->arg);
}
}
assert(proc_source != NULL);
if(strcmp(language, "plpgsql") != 0) {
return (PLpgSQL_function *) palloc0(sizeof(PLpgSQL_function));
}
return plpgsql_compile_inline(proc_source);
}
static ProcTupWithAttrs *
build_fake_proc_tuple(Form_pg_proc procStruct)
{
ProcTupWithAttrs *wrapper;
HeapTupleHeader td;
Size len,
data_len;
int hoff;
len = offsetof(HeapTupleHeaderData, t_bits);
hoff = len = MAXALIGN(len);
data_len = MAXALIGN(sizeof(FormData_pg_proc));
len += data_len;
wrapper = (ProcTupWithAttrs *) palloc0(sizeof(ProcTupWithAttrs) + len);
wrapper->tup.t_data = td = (HeapTupleHeader) ((char *) wrapper + sizeof(ProcTupWithAttrs));
wrapper->tup.t_len = len;
ItemPointerSetInvalid(&(wrapper->tup.t_self));
wrapper->tup.t_tableOid = InvalidOid;
HeapTupleHeaderSetDatumLength(td, len);
ItemPointerSetInvalid(&(td->t_ctid));
HeapTupleHeaderSetNatts(td, Natts_pg_proc);
td->t_hoff = hoff;
memcpy((char *) td + hoff, procStruct, sizeof(FormData_pg_proc));
return wrapper;
}
static ProcTupWithAttrs *
pg_query_procedure_create(const char *procedureName,
bool returnsSet,
Oid returnType,
char prokind,
char volatility,
oidvector *parameterTypes,
Datum allParameterTypes,
Datum parameterModes,
Datum parameterNames,
const char *prosrc)
{
ProcTupWithAttrs *wrapper;
FormData_pg_proc procStruct = {0};
Datum *values;
bool *nulls;
int parameterCount;
Assert(PointerIsValid(prosrc));
parameterCount = parameterTypes->dim1;
namestrcpy(&procStruct.proname, procedureName);
procStruct.prokind = prokind;
procStruct.provolatile = volatility;
procStruct.pronargs = parameterCount;
procStruct.prorettype = returnType;
procStruct.proretset = returnsSet;
wrapper = build_fake_proc_tuple(&procStruct);
values = wrapper->values;
nulls = wrapper->nulls;
values[Anum_pg_proc_proname - 1] = NameGetDatum(&procStruct.proname);
values[Anum_pg_proc_prokind - 1] = CharGetDatum(prokind);
values[Anum_pg_proc_provolatile - 1] = CharGetDatum(volatility);
values[Anum_pg_proc_pronargs - 1] = UInt16GetDatum(parameterCount);
values[Anum_pg_proc_prorettype - 1] = ObjectIdGetDatum(returnType);
values[Anum_pg_proc_proretset - 1] = BoolGetDatum(returnsSet);
values[Anum_pg_proc_proargtypes - 1] = PointerGetDatum(parameterTypes);
if (allParameterTypes != PointerGetDatum(NULL))
values[Anum_pg_proc_proallargtypes - 1] = allParameterTypes;
else
nulls[Anum_pg_proc_proallargtypes - 1] = true;
if (parameterModes != PointerGetDatum(NULL))
values[Anum_pg_proc_proargmodes - 1] = parameterModes;
else
nulls[Anum_pg_proc_proargmodes - 1] = true;
if (parameterNames != PointerGetDatum(NULL))
values[Anum_pg_proc_proargnames - 1] = parameterNames;
else
nulls[Anum_pg_proc_proargnames - 1] = true;
values[Anum_pg_proc_prosrc - 1] = CStringGetTextDatum(prosrc);
return wrapper;
}
static ProcTupWithAttrs *
pg_query_create_function(CreateFunctionStmt *stmt,
const char *language,
const char *proc_source,
bool *is_dml_trigger,
bool *is_event_trigger)
{
char *funcname;
Oid prorettype;
bool returnsSet;
char prokind;
oidvector *parameterTypes;
List *parameterTypes_list = NIL;
ArrayType *allParameterTypes;
ArrayType *parameterModes;
ArrayType *parameterNames;
List *inParameterNames_list = NIL;
List *parameterDefaults;
Oid variadicArgType;
Oid requiredResultType;
funcname = strVal(linitial(stmt->funcname));
*is_dml_trigger = false;
*is_event_trigger = false;
if (stmt->returnType != NULL)
{
foreach_ptr(String, val, stmt->returnType->names)
{
if (strcmp(val->sval, "trigger") == 0)
*is_dml_trigger = true;
else if (strcmp(val->sval, "event_trigger") == 0)
*is_event_trigger = true;
}
}
interpret_function_parameter_list(palloc0(sizeof(ParseState)),
stmt->parameters,
InvalidOid,
stmt->is_procedure ? OBJECT_PROCEDURE : OBJECT_FUNCTION,
¶meterTypes,
¶meterTypes_list,
&allParameterTypes,
¶meterModes,
¶meterNames,
&inParameterNames_list,
¶meterDefaults,
&variadicArgType,
&requiredResultType);
if (stmt->is_procedure)
{
Assert(!stmt->returnType);
prorettype = requiredResultType ? requiredResultType : VOIDOID;
returnsSet = false;
}
else if (stmt->returnType)
{
compute_return_type(stmt->returnType, InvalidOid,
&prorettype, &returnsSet);
if (OidIsValid(requiredResultType) && prorettype != requiredResultType)
ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("function result type must be %s because of OUT parameters",
format_type_be(requiredResultType))));
}
else if (OidIsValid(requiredResultType))
{
prorettype = requiredResultType;
returnsSet = false;
}
else
{
ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("function result type must be specified")));
}
prokind = stmt->is_procedure ? PROKIND_PROCEDURE : PROKIND_FUNCTION;
if (allParameterTypes == NULL && parameterTypes->dim1 > 0)
{
Datum *typeDatums = palloc(parameterTypes->dim1 * sizeof(Datum));
int i;
for (i = 0; i < parameterTypes->dim1; i++)
typeDatums[i] = ObjectIdGetDatum(parameterTypes->values[i]);
allParameterTypes = construct_array_builtin(typeDatums,
parameterTypes->dim1,
OIDOID);
}
return pg_query_procedure_create(funcname,
returnsSet,
prorettype,
prokind,
PROVOLATILE_VOLATILE,
parameterTypes,
allParameterTypes ? PointerGetDatum(allParameterTypes) : PointerGetDatum(NULL),
parameterModes ? PointerGetDatum(parameterModes) : PointerGetDatum(NULL),
parameterNames ? PointerGetDatum(parameterNames) : PointerGetDatum(NULL),
proc_source);
}
static PLpgSQL_function *
compile_create_function_stmt_via_callback(CreateFunctionStmt *stmt)
{
PLpgSQL_function *function;
ProcTupWithAttrs *wrapper;
FmgrInfo flinfo = {0};
LOCAL_FCINFO(fcinfo, 0);
TriggerData triggerData = {0};
EventTriggerData eventTriggerData = {0};
char *proc_source = NULL;
char *language = "plpgsql";
bool is_dml_trigger;
bool is_event_trigger;
assert(IsA(stmt, CreateFunctionStmt));
foreach_ptr(DefElem, elem, stmt->options)
{
if (strcmp(elem->defname, "as") == 0)
{
assert(IsA(elem->arg, List));
foreach_ptr(String, proc_source_str, (List *) elem->arg)
proc_source = proc_source_str->sval;
}
else if (strcmp(elem->defname, "language") == 0)
{
language = strVal(elem->arg);
}
}
assert(proc_source != NULL);
if (strcmp(language, "plpgsql") != 0)
return (PLpgSQL_function *) palloc0(sizeof(PLpgSQL_function));
wrapper = pg_query_create_function(stmt, language, proc_source,
&is_dml_trigger, &is_event_trigger);
flinfo.fn_oid = InvalidOid;
flinfo.fn_expr = NULL;
fcinfo->flinfo = &flinfo;
fcinfo->fncollation = InvalidOid;
fcinfo->context = NULL;
fcinfo->resultinfo = NULL;
fcinfo->isnull = false;
fcinfo->nargs = 0;
if (is_dml_trigger)
{
triggerData.type = T_TriggerData;
fcinfo->context = (Node *) &triggerData;
}
else if (is_event_trigger)
{
eventTriggerData.type = T_EventTriggerData;
fcinfo->context = (Node *) &eventTriggerData;
}
function = (PLpgSQL_function *) palloc0(sizeof(PLpgSQL_function));
plpgsql_compile_callback(fcinfo, &wrapper->tup, NULL,
(CachedFunction *) function, true);
return function;
}
PgQueryInternalPlpgsqlFuncAndError pg_query_raw_parse_plpgsql(Node* stmt)
{
PgQueryInternalPlpgsqlFuncAndError result = {0};
MemoryContext cctx = CurrentMemoryContext;
char stderr_buffer[STDERR_BUFFER_LEN + 1] = {0};
#ifndef DEBUG
int stderr_global;
int stderr_pipe[2];
#endif
#ifndef DEBUG
if (pipe(stderr_pipe) != 0) {
PgQueryError* error = malloc(sizeof(PgQueryError));
error->message = strdup("Failed to open pipe, too many open file descriptors")
result.error = error;
return result;
}
fcntl(stderr_pipe[0], F_SETFL, fcntl(stderr_pipe[0], F_GETFL) | O_NONBLOCK);
stderr_global = dup(STDERR_FILENO);
dup2(stderr_pipe[1], STDERR_FILENO);
close(stderr_pipe[1]);
#endif
PG_TRY();
{
if (IsA(stmt, CreateFunctionStmt)) {
result.func = compile_create_function_stmt_via_callback((CreateFunctionStmt *) stmt);
} else if (IsA(stmt, DoStmt)){
result.func = compile_do_stmt((DoStmt *) stmt);
} else {
elog(ERROR, "Unexpected node type for PL/pgSQL parsing: %d", nodeTag(stmt));
}
#ifndef DEBUG
read(stderr_pipe[0], stderr_buffer, STDERR_BUFFER_LEN);
#endif
if (strlen(stderr_buffer) > 0) {
PgQueryError* error = malloc(sizeof(PgQueryError));
error->message = strdup(stderr_buffer);
error->filename = "";
error->funcname = "";
error->context = "";
result.error = error;
}
}
PG_CATCH();
{
ErrorData* error_data;
PgQueryError* error;
MemoryContextSwitchTo(cctx);
error_data = CopyErrorData();
error = malloc(sizeof(PgQueryError));
error->message = strdup(error_data->message);
error->filename = strdup(error_data->filename);
error->funcname = strdup(error_data->funcname);
error->context = error_data->context ? strdup(error_data->context) : NULL;
error->lineno = error_data->lineno;
error->cursorpos = error_data->cursorpos;
result.error = error;
FlushErrorState();
}
PG_END_TRY();
#ifndef DEBUG
dup2(stderr_global, STDERR_FILENO);
close(stderr_pipe[0]);
close(stderr_global);
#endif
return result;
}
typedef struct plStmts
{
Node **stmts;
int stmts_buf_size;
int stmts_count;
} plStmts;
static bool stmts_walker(Node *node, plStmts *state)
{
bool result;
MemoryContext ccxt = CurrentMemoryContext;
if (node == NULL) return false;
if (IsA(node, CreateFunctionStmt) || IsA(node, DoStmt))
{
if (state->stmts_count >= state->stmts_buf_size)
{
state->stmts_buf_size *= 2;
state->stmts = (Node**) repalloc(state->stmts, state->stmts_buf_size * sizeof(Node*));
}
state->stmts[state->stmts_count] = (Node *) node;
state->stmts_count++;
} else if (IsA(node, RawStmt)) {
return stmts_walker((Node *) ((RawStmt *) node)->stmt, state);
}
PG_TRY();
{
result = raw_expression_tree_walker(node, stmts_walker, (void*) state);
}
PG_CATCH();
{
MemoryContextSwitchTo(ccxt);
FlushErrorState();
result = false;
}
PG_END_TRY();
return result;
}
PgQueryPlpgsqlParseResult pg_query_parse_plpgsql(const char* input)
{
MemoryContext ctx = NULL;
PgQueryPlpgsqlParseResult result = {0};
PgQueryInternalParsetreeAndError parse_result;
plStmts statements;
size_t i;
ctx = pg_query_enter_memory_context();
parse_result = pg_query_raw_parse(input, PG_QUERY_PARSE_DEFAULT);
result.error = parse_result.error;
if (result.error != NULL) {
pg_query_exit_memory_context(ctx);
return result;
}
statements.stmts_buf_size = 100;
statements.stmts = (Node**) palloc(statements.stmts_buf_size * sizeof(Node*));
statements.stmts_count = 0;
stmts_walker((Node*) parse_result.tree, &statements);
if (statements.stmts_count == 0) {
result.plpgsql_funcs = strdup("[]");
pg_query_exit_memory_context(ctx);
return result;
}
result.plpgsql_funcs = strdup("[\n");
for (i = 0; i < statements.stmts_count; i++) {
PgQueryInternalPlpgsqlFuncAndError func_and_error;
func_and_error = pg_query_raw_parse_plpgsql(statements.stmts[i]);
result.error = func_and_error.error;
if (result.error != NULL) {
pg_query_exit_memory_context(ctx);
return result;
}
if (func_and_error.func != NULL) {
char *func_json;
char *new_out;
size_t new_out_len;
func_json = plpgsqlToJSON(func_and_error.func);
plpgsql_free_function_memory(func_and_error.func);
new_out_len = strlen(result.plpgsql_funcs) + strlen(func_json) + 3;
new_out = malloc(new_out_len);
int n = snprintf(new_out, new_out_len, "%s%s,\n", result.plpgsql_funcs, func_json);
if (n < 0 || n >= new_out_len) {
PgQueryError* error = malloc(sizeof(PgQueryError));
error->message = strdup("Failed to output PL/pgSQL functions due to snprintf failure");
result.error = error;
} else {
free(result.plpgsql_funcs);
result.plpgsql_funcs = new_out;
}
pfree(func_json);
}
}
result.plpgsql_funcs[strlen(result.plpgsql_funcs) - 2] = '\n';
result.plpgsql_funcs[strlen(result.plpgsql_funcs) - 1] = ']';
free(parse_result.stderr_buffer);
pg_query_exit_memory_context(ctx);
return result;
}
void pg_query_free_plpgsql_parse_result(PgQueryPlpgsqlParseResult result)
{
if (result.error) {
pg_query_free_error(result.error);
}
free(result.plpgsql_funcs);
}