#include "postgres.h"
#include "catalog/pg_collation.h"
#include "catalog/pg_type.h"
#include "miscadmin.h"
#include "nodes/execnodes.h"
#include "nodes/makefuncs.h"
#include "nodes/nodeFuncs.h"
#include "nodes/pathnodes.h"
#include "utils/builtins.h"
#include "utils/lsyscache.h"
static bool expression_returns_set_walker(Node *node, void *context);
static int leftmostLoc(int loc1, int loc2);
static bool fix_opfuncids_walker(Node *node, void *context);
static bool planstate_walk_subplans(List *plans, bool (*walker) (),
void *context);
static bool planstate_walk_members(PlanState **planstates, int nplans,
bool (*walker) (), void *context);
#ifdef USE_ASSERT_CHECKING
#endif
int
exprLocation(const Node *expr)
{
int loc;
if (expr == NULL)
return -1;
switch (nodeTag(expr))
{
case T_RangeVar:
loc = ((const RangeVar *) expr)->location;
break;
case T_TableFunc:
loc = ((const TableFunc *) expr)->location;
break;
case T_Var:
loc = ((const Var *) expr)->location;
break;
case T_Const:
loc = ((const Const *) expr)->location;
break;
case T_Param:
loc = ((const Param *) expr)->location;
break;
case T_Aggref:
loc = ((const Aggref *) expr)->location;
break;
case T_GroupingFunc:
loc = ((const GroupingFunc *) expr)->location;
break;
case T_WindowFunc:
loc = ((const WindowFunc *) expr)->location;
break;
case T_SubscriptingRef:
loc = exprLocation((Node *) ((const SubscriptingRef *) expr)->refexpr);
break;
case T_FuncExpr:
{
const FuncExpr *fexpr = (const FuncExpr *) expr;
loc = leftmostLoc(fexpr->location,
exprLocation((Node *) fexpr->args));
}
break;
case T_NamedArgExpr:
{
const NamedArgExpr *na = (const NamedArgExpr *) expr;
loc = leftmostLoc(na->location,
exprLocation((Node *) na->arg));
}
break;
case T_OpExpr:
case T_DistinctExpr:
case T_NullIfExpr:
{
const OpExpr *opexpr = (const OpExpr *) expr;
loc = leftmostLoc(opexpr->location,
exprLocation((Node *) opexpr->args));
}
break;
case T_ScalarArrayOpExpr:
{
const ScalarArrayOpExpr *saopexpr = (const ScalarArrayOpExpr *) expr;
loc = leftmostLoc(saopexpr->location,
exprLocation((Node *) saopexpr->args));
}
break;
case T_BoolExpr:
{
const BoolExpr *bexpr = (const BoolExpr *) expr;
loc = leftmostLoc(bexpr->location,
exprLocation((Node *) bexpr->args));
}
break;
case T_SubLink:
{
const SubLink *sublink = (const SubLink *) expr;
loc = leftmostLoc(exprLocation(sublink->testexpr),
sublink->location);
}
break;
case T_FieldSelect:
loc = exprLocation((Node *) ((const FieldSelect *) expr)->arg);
break;
case T_FieldStore:
loc = exprLocation((Node *) ((const FieldStore *) expr)->arg);
break;
case T_RelabelType:
{
const RelabelType *rexpr = (const RelabelType *) expr;
loc = leftmostLoc(rexpr->location,
exprLocation((Node *) rexpr->arg));
}
break;
case T_CoerceViaIO:
{
const CoerceViaIO *cexpr = (const CoerceViaIO *) expr;
loc = leftmostLoc(cexpr->location,
exprLocation((Node *) cexpr->arg));
}
break;
case T_ArrayCoerceExpr:
{
const ArrayCoerceExpr *cexpr = (const ArrayCoerceExpr *) expr;
loc = leftmostLoc(cexpr->location,
exprLocation((Node *) cexpr->arg));
}
break;
case T_ConvertRowtypeExpr:
{
const ConvertRowtypeExpr *cexpr = (const ConvertRowtypeExpr *) expr;
loc = leftmostLoc(cexpr->location,
exprLocation((Node *) cexpr->arg));
}
break;
case T_CollateExpr:
loc = exprLocation((Node *) ((const CollateExpr *) expr)->arg);
break;
case T_CaseExpr:
loc = ((const CaseExpr *) expr)->location;
break;
case T_CaseWhen:
loc = ((const CaseWhen *) expr)->location;
break;
case T_ArrayExpr:
loc = ((const ArrayExpr *) expr)->location;
break;
case T_RowExpr:
loc = ((const RowExpr *) expr)->location;
break;
case T_RowCompareExpr:
loc = exprLocation((Node *) ((const RowCompareExpr *) expr)->largs);
break;
case T_CoalesceExpr:
loc = ((const CoalesceExpr *) expr)->location;
break;
case T_MinMaxExpr:
loc = ((const MinMaxExpr *) expr)->location;
break;
case T_SQLValueFunction:
loc = ((const SQLValueFunction *) expr)->location;
break;
case T_XmlExpr:
{
const XmlExpr *xexpr = (const XmlExpr *) expr;
loc = leftmostLoc(xexpr->location,
exprLocation((Node *) xexpr->args));
}
break;
case T_NullTest:
{
const NullTest *nexpr = (const NullTest *) expr;
loc = leftmostLoc(nexpr->location,
exprLocation((Node *) nexpr->arg));
}
break;
case T_BooleanTest:
{
const BooleanTest *bexpr = (const BooleanTest *) expr;
loc = leftmostLoc(bexpr->location,
exprLocation((Node *) bexpr->arg));
}
break;
case T_CoerceToDomain:
{
const CoerceToDomain *cexpr = (const CoerceToDomain *) expr;
loc = leftmostLoc(cexpr->location,
exprLocation((Node *) cexpr->arg));
}
break;
case T_CoerceToDomainValue:
loc = ((const CoerceToDomainValue *) expr)->location;
break;
case T_SetToDefault:
loc = ((const SetToDefault *) expr)->location;
break;
case T_TargetEntry:
loc = exprLocation((Node *) ((const TargetEntry *) expr)->expr);
break;
case T_IntoClause:
loc = exprLocation((Node *) ((const IntoClause *) expr)->rel);
break;
case T_List:
{
ListCell *lc;
loc = -1;
foreach(lc, (const List *) expr)
{
loc = exprLocation((Node *) lfirst(lc));
if (loc >= 0)
break;
}
}
break;
case T_A_Expr:
{
const A_Expr *aexpr = (const A_Expr *) expr;
loc = leftmostLoc(aexpr->location,
exprLocation(aexpr->lexpr));
}
break;
case T_ColumnRef:
loc = ((const ColumnRef *) expr)->location;
break;
case T_ParamRef:
loc = ((const ParamRef *) expr)->location;
break;
case T_A_Const:
loc = ((const A_Const *) expr)->location;
break;
case T_FuncCall:
{
const FuncCall *fc = (const FuncCall *) expr;
loc = leftmostLoc(fc->location,
exprLocation((Node *) fc->args));
}
break;
case T_A_ArrayExpr:
loc = ((const A_ArrayExpr *) expr)->location;
break;
case T_ResTarget:
loc = ((const ResTarget *) expr)->location;
break;
case T_MultiAssignRef:
loc = exprLocation(((const MultiAssignRef *) expr)->source);
break;
case T_TypeCast:
{
const TypeCast *tc = (const TypeCast *) expr;
loc = exprLocation(tc->arg);
loc = leftmostLoc(loc, tc->typeName->location);
loc = leftmostLoc(loc, tc->location);
}
break;
case T_CollateClause:
loc = exprLocation(((const CollateClause *) expr)->arg);
break;
case T_SortBy:
loc = exprLocation(((const SortBy *) expr)->node);
break;
case T_WindowDef:
loc = ((const WindowDef *) expr)->location;
break;
case T_RangeTableSample:
loc = ((const RangeTableSample *) expr)->location;
break;
case T_TypeName:
loc = ((const TypeName *) expr)->location;
break;
case T_ColumnDef:
loc = ((const ColumnDef *) expr)->location;
break;
case T_Constraint:
loc = ((const Constraint *) expr)->location;
break;
case T_FunctionParameter:
loc = exprLocation((Node *) ((const FunctionParameter *) expr)->argType);
break;
case T_XmlSerialize:
loc = ((const XmlSerialize *) expr)->location;
break;
case T_GroupingSet:
loc = ((const GroupingSet *) expr)->location;
break;
case T_WithClause:
loc = ((const WithClause *) expr)->location;
break;
case T_InferClause:
loc = ((const InferClause *) expr)->location;
break;
case T_OnConflictClause:
loc = ((const OnConflictClause *) expr)->location;
break;
case T_CommonTableExpr:
loc = ((const CommonTableExpr *) expr)->location;
break;
case T_PlaceHolderVar:
loc = exprLocation((Node *) ((const PlaceHolderVar *) expr)->phexpr);
break;
case T_InferenceElem:
loc = exprLocation((Node *) ((const InferenceElem *) expr)->expr);
break;
case T_PartitionElem:
loc = ((const PartitionElem *) expr)->location;
break;
case T_PartitionSpec:
loc = ((const PartitionSpec *) expr)->location;
break;
case T_PartitionBoundSpec:
loc = ((const PartitionBoundSpec *) expr)->location;
break;
case T_PartitionRangeDatum:
loc = ((const PartitionRangeDatum *) expr)->location;
break;
default:
loc = -1;
break;
}
return loc;
}
static int
leftmostLoc(int loc1, int loc2)
{
if (loc1 < 0)
return loc2;
else if (loc2 < 0)
return loc1;
else
return Min(loc1, loc2);
}
#define FLATCOPY(newnode, node, nodetype) \
( (newnode) = (nodetype *) palloc(sizeof(nodetype)), \
memcpy((newnode), (node), sizeof(nodetype)) )
#define CHECKFLATCOPY(newnode, node, nodetype) \
( AssertMacro(IsA((node), nodetype)), \
(newnode) = (nodetype *) palloc(sizeof(nodetype)), \
memcpy((newnode), (node), sizeof(nodetype)) )
#define MUTATE(newfield, oldfield, fieldtype) \
( (newfield) = (fieldtype) mutator((Node *) (oldfield), context) )
bool
raw_expression_tree_walker(Node *node,
bool (*walker) (),
void *context)
{
ListCell *temp;
if (node == NULL)
return false;
check_stack_depth();
switch (nodeTag(node))
{
case T_SetToDefault:
case T_CurrentOfExpr:
case T_SQLValueFunction:
case T_Integer:
case T_Float:
case T_String:
case T_BitString:
case T_Null:
case T_ParamRef:
case T_A_Const:
case T_A_Star:
break;
case T_Alias:
break;
case T_RangeVar:
return walker(((RangeVar *) node)->alias, context);
case T_GroupingFunc:
return walker(((GroupingFunc *) node)->args, context);
case T_SubLink:
{
SubLink *sublink = (SubLink *) node;
if (walker(sublink->testexpr, context))
return true;
if (walker(sublink->subselect, context))
return true;
}
break;
case T_CaseExpr:
{
CaseExpr *caseexpr = (CaseExpr *) node;
if (walker(caseexpr->arg, context))
return true;
foreach(temp, caseexpr->args)
{
CaseWhen *when = lfirst_node(CaseWhen, temp);
if (walker(when->expr, context))
return true;
if (walker(when->result, context))
return true;
}
if (walker(caseexpr->defresult, context))
return true;
}
break;
case T_RowExpr:
return walker(((RowExpr *) node)->args, context);
case T_CoalesceExpr:
return walker(((CoalesceExpr *) node)->args, context);
case T_MinMaxExpr:
return walker(((MinMaxExpr *) node)->args, context);
case T_XmlExpr:
{
XmlExpr *xexpr = (XmlExpr *) node;
if (walker(xexpr->named_args, context))
return true;
if (walker(xexpr->args, context))
return true;
}
break;
case T_NullTest:
return walker(((NullTest *) node)->arg, context);
case T_BooleanTest:
return walker(((BooleanTest *) node)->arg, context);
case T_JoinExpr:
{
JoinExpr *join = (JoinExpr *) node;
if (walker(join->larg, context))
return true;
if (walker(join->rarg, context))
return true;
if (walker(join->quals, context))
return true;
if (walker(join->alias, context))
return true;
}
break;
case T_IntoClause:
{
IntoClause *into = (IntoClause *) node;
if (walker(into->rel, context))
return true;
if (walker(into->viewQuery, context))
return true;
}
break;
case T_List:
foreach(temp, (List *) node)
{
if (walker((Node *) lfirst(temp), context))
return true;
}
break;
case T_InsertStmt:
{
InsertStmt *stmt = (InsertStmt *) node;
if (walker(stmt->relation, context))
return true;
if (walker(stmt->cols, context))
return true;
if (walker(stmt->selectStmt, context))
return true;
if (walker(stmt->onConflictClause, context))
return true;
if (walker(stmt->returningList, context))
return true;
if (walker(stmt->withClause, context))
return true;
}
break;
case T_DeleteStmt:
{
DeleteStmt *stmt = (DeleteStmt *) node;
if (walker(stmt->relation, context))
return true;
if (walker(stmt->usingClause, context))
return true;
if (walker(stmt->whereClause, context))
return true;
if (walker(stmt->returningList, context))
return true;
if (walker(stmt->withClause, context))
return true;
}
break;
case T_UpdateStmt:
{
UpdateStmt *stmt = (UpdateStmt *) node;
if (walker(stmt->relation, context))
return true;
if (walker(stmt->targetList, context))
return true;
if (walker(stmt->whereClause, context))
return true;
if (walker(stmt->fromClause, context))
return true;
if (walker(stmt->returningList, context))
return true;
if (walker(stmt->withClause, context))
return true;
}
break;
case T_SelectStmt:
{
SelectStmt *stmt = (SelectStmt *) node;
if (walker(stmt->distinctClause, context))
return true;
if (walker(stmt->intoClause, context))
return true;
if (walker(stmt->targetList, context))
return true;
if (walker(stmt->fromClause, context))
return true;
if (walker(stmt->whereClause, context))
return true;
if (walker(stmt->groupClause, context))
return true;
if (walker(stmt->havingClause, context))
return true;
if (walker(stmt->windowClause, context))
return true;
if (walker(stmt->valuesLists, context))
return true;
if (walker(stmt->sortClause, context))
return true;
if (walker(stmt->limitOffset, context))
return true;
if (walker(stmt->limitCount, context))
return true;
if (walker(stmt->lockingClause, context))
return true;
if (walker(stmt->withClause, context))
return true;
if (walker(stmt->larg, context))
return true;
if (walker(stmt->rarg, context))
return true;
}
break;
case T_A_Expr:
{
A_Expr *expr = (A_Expr *) node;
if (walker(expr->lexpr, context))
return true;
if (walker(expr->rexpr, context))
return true;
}
break;
case T_BoolExpr:
{
BoolExpr *expr = (BoolExpr *) node;
if (walker(expr->args, context))
return true;
}
break;
case T_ColumnRef:
break;
case T_FuncCall:
{
FuncCall *fcall = (FuncCall *) node;
if (walker(fcall->args, context))
return true;
if (walker(fcall->agg_order, context))
return true;
if (walker(fcall->agg_filter, context))
return true;
if (walker(fcall->over, context))
return true;
}
break;
case T_NamedArgExpr:
return walker(((NamedArgExpr *) node)->arg, context);
case T_A_Indices:
{
A_Indices *indices = (A_Indices *) node;
if (walker(indices->lidx, context))
return true;
if (walker(indices->uidx, context))
return true;
}
break;
case T_A_Indirection:
{
A_Indirection *indir = (A_Indirection *) node;
if (walker(indir->arg, context))
return true;
if (walker(indir->indirection, context))
return true;
}
break;
case T_A_ArrayExpr:
return walker(((A_ArrayExpr *) node)->elements, context);
case T_ResTarget:
{
ResTarget *rt = (ResTarget *) node;
if (walker(rt->indirection, context))
return true;
if (walker(rt->val, context))
return true;
}
break;
case T_MultiAssignRef:
return walker(((MultiAssignRef *) node)->source, context);
case T_TypeCast:
{
TypeCast *tc = (TypeCast *) node;
if (walker(tc->arg, context))
return true;
if (walker(tc->typeName, context))
return true;
}
break;
case T_CollateClause:
return walker(((CollateClause *) node)->arg, context);
case T_SortBy:
return walker(((SortBy *) node)->node, context);
case T_WindowDef:
{
WindowDef *wd = (WindowDef *) node;
if (walker(wd->partitionClause, context))
return true;
if (walker(wd->orderClause, context))
return true;
if (walker(wd->startOffset, context))
return true;
if (walker(wd->endOffset, context))
return true;
}
break;
case T_RangeSubselect:
{
RangeSubselect *rs = (RangeSubselect *) node;
if (walker(rs->subquery, context))
return true;
if (walker(rs->alias, context))
return true;
}
break;
case T_RangeFunction:
{
RangeFunction *rf = (RangeFunction *) node;
if (walker(rf->functions, context))
return true;
if (walker(rf->alias, context))
return true;
if (walker(rf->coldeflist, context))
return true;
}
break;
case T_RangeTableSample:
{
RangeTableSample *rts = (RangeTableSample *) node;
if (walker(rts->relation, context))
return true;
if (walker(rts->args, context))
return true;
if (walker(rts->repeatable, context))
return true;
}
break;
case T_RangeTableFunc:
{
RangeTableFunc *rtf = (RangeTableFunc *) node;
if (walker(rtf->docexpr, context))
return true;
if (walker(rtf->rowexpr, context))
return true;
if (walker(rtf->namespaces, context))
return true;
if (walker(rtf->columns, context))
return true;
if (walker(rtf->alias, context))
return true;
}
break;
case T_RangeTableFuncCol:
{
RangeTableFuncCol *rtfc = (RangeTableFuncCol *) node;
if (walker(rtfc->colexpr, context))
return true;
if (walker(rtfc->coldefexpr, context))
return true;
}
break;
case T_TypeName:
{
TypeName *tn = (TypeName *) node;
if (walker(tn->typmods, context))
return true;
if (walker(tn->arrayBounds, context))
return true;
}
break;
case T_ColumnDef:
{
ColumnDef *coldef = (ColumnDef *) node;
if (walker(coldef->typeName, context))
return true;
if (walker(coldef->raw_default, context))
return true;
if (walker(coldef->collClause, context))
return true;
}
break;
case T_IndexElem:
{
IndexElem *indelem = (IndexElem *) node;
if (walker(indelem->expr, context))
return true;
}
break;
case T_GroupingSet:
return walker(((GroupingSet *) node)->content, context);
case T_LockingClause:
return walker(((LockingClause *) node)->lockedRels, context);
case T_XmlSerialize:
{
XmlSerialize *xs = (XmlSerialize *) node;
if (walker(xs->expr, context))
return true;
if (walker(xs->typeName, context))
return true;
}
break;
case T_WithClause:
return walker(((WithClause *) node)->ctes, context);
case T_InferClause:
{
InferClause *stmt = (InferClause *) node;
if (walker(stmt->indexElems, context))
return true;
if (walker(stmt->whereClause, context))
return true;
}
break;
case T_OnConflictClause:
{
OnConflictClause *stmt = (OnConflictClause *) node;
if (walker(stmt->infer, context))
return true;
if (walker(stmt->targetList, context))
return true;
if (walker(stmt->whereClause, context))
return true;
}
break;
case T_CommonTableExpr:
return walker(((CommonTableExpr *) node)->ctequery, context);
default:
elog(ERROR, "unrecognized node type: %d",
(int) nodeTag(node));
break;
}
return false;
}