tree-sitter-postgres 1.1.6

Postgres grammar for tree-sitter
Documentation
/**
 * External scanner for PL/pgSQL tree-sitter grammar.
 *
 * Provides the _sql_expression token type that captures SQL expression
 * fragments. In PL/pgSQL, SQL expressions are consumed until a
 * context-specific delimiter is found (;, THEN, LOOP, etc.). Since
 * tree-sitter can't dynamically change the delimiter set, we use a simple
 * heuristic: consume everything that looks like SQL, respecting balanced
 * parentheses/brackets and string literals, stopping at tokens that are
 * unambiguously PL/pgSQL structure.
 */
#include "tree_sitter/parser.h"

#include <string.h>

enum TokenType {
  SQL_BODY,
};

void *tree_sitter_plpgsql_external_scanner_create(void) { return NULL; }
void tree_sitter_plpgsql_external_scanner_destroy(void *payload) { (void)payload; }
unsigned tree_sitter_plpgsql_external_scanner_serialize(void *payload, char *buffer) { (void)payload; (void)buffer; return 0; }
void tree_sitter_plpgsql_external_scanner_deserialize(void *payload, const char *buffer, unsigned length) { (void)payload; (void)buffer; (void)length; }

static void skip_whitespace(TSLexer *lexer) {
  while (lexer->lookahead == ' ' || lexer->lookahead == '\t' ||
         lexer->lookahead == '\n' || lexer->lookahead == '\r') {
    lexer->advance(lexer, true);
  }
}

/*
 * Keep the scanner self-contained for WebAssembly builds.
 *
 * Zed loads tree-sitter grammars as Wasm modules. If this scanner calls
 * libc ctype helpers like isalpha/isalnum/tolower, the compiled Wasm can
 * contain imports named "isalpha", "isalnum", or "tolower". Zed's grammar
 * runtime does not provide those imports, causing errors like:
 *
 *   Failed to instantiate Wasm module: invalid import 'isalnum'
 *
 * For delimiter keyword detection we only need ASCII PL/pgSQL keywords, so
 * small local helpers are sufficient and avoid external libc imports.
 */
static bool is_ascii_alpha(int c) {
  return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
}

static bool is_ascii_digit(int c) {
  return c >= '0' && c <= '9';
}

static bool is_ascii_alnum(int c) {
  return is_ascii_alpha(c) || is_ascii_digit(c);
}

static char ascii_tolower(int c) {
  return (c >= 'A' && c <= 'Z') ? (char)(c + ('a' - 'A')) : (char)c;
}

bool tree_sitter_plpgsql_external_scanner_scan(
  void *payload, TSLexer *lexer, const bool *valid_symbols
) {
  (void)payload;
  if (!valid_symbols[SQL_BODY]) return false;

  skip_whitespace(lexer);

  if (lexer->lookahead == 0) return false;

  /* Don't start on a semicolon — that's a delimiter, not SQL content */
  if (lexer->lookahead == ';') return false;

  int depth = 0;
  bool has_content = false;

  /*
   * Track whether the next token at depth 0 is expected to be a value.
   * `null` is a PL/pgSQL delimiter only as a fresh statement (`NULL;`), but
   * inside an expression — after IS, IS NOT, =, <>, !=, IN, AND, OR, LIKE,
   * BETWEEN, NOT, AS, or a binary operator — `NULL` is the SQL literal and
   * must be consumed as part of `sql_expression`. We default to `false` so
   * a bare leading `NULL` still falls through to the kw_null token (which
   * `stmt_null` consumes); the flag flips to `true` only after the scanner
   * has seen something that requires a value next.
   */
  bool expecting_value = false;

  while (lexer->lookahead != 0) {
    /* At depth 0, semicolon terminates */
    if (depth == 0 && lexer->lookahead == ';') break;

    /* At depth 0, << terminates (block/loop label) */
    if (depth == 0 && lexer->lookahead == '<') {
      lexer->mark_end(lexer);
      lexer->advance(lexer, false);
      if (lexer->lookahead == '<') {
        /* << found — stop before it */
        if (has_content) {
          lexer->result_symbol = SQL_BODY;
          return true;
        }
        return false;
      }
      /* Single < — part of a SQL comparison operator. The catch-all branch
       * at the bottom of the loop would set expecting_value=true for '<',
       * but this earlier branch consumes the char and `continue`s, so it
       * must update the flag itself or `IF x < NULL THEN` would truncate
       * at NULL (the next token wouldn't be recognized as a value). */
      has_content = true;
      expecting_value = true;
      continue;
    }

    /* At depth 0, := terminates (assignment operator) */
    if (depth == 0 && lexer->lookahead == ':') {
      lexer->mark_end(lexer);
      lexer->advance(lexer, false);
      if (lexer->lookahead == '=') {
        /* := found — stop before it */
        if (has_content) {
          lexer->result_symbol = SQL_BODY;
          return true;
        }
        return false;
      }
      /* Just a colon, not :=  — continue (it's part of SQL like ::) */
      if (lexer->lookahead == ':') {
        /* :: typecast — consume both */
        lexer->advance(lexer, false);
        has_content = true;
        continue;
      }
      has_content = true;
      continue;
    }

    /* At depth 0, .. terminates (range operator in FOR loops) */
    if (depth == 0 && lexer->lookahead == '.') {
      lexer->mark_end(lexer);
      lexer->advance(lexer, false);
      if (lexer->lookahead == '.') {
        /* .. found — stop before it */
        if (has_content) {
          lexer->result_symbol = SQL_BODY;
          return true;
        }
        return false;
      }
      /* Just a single dot — part of dotted name, continue */
      has_content = true;
      continue;
    }

    /* Track balanced parens/brackets */
    if (lexer->lookahead == '(' || lexer->lookahead == '[') {
      depth++;
      lexer->advance(lexer, false);
      has_content = true;
      expecting_value = true;
      continue;
    }
    if (lexer->lookahead == ')' || lexer->lookahead == ']') {
      if (depth > 0) {
        depth--;
        lexer->advance(lexer, false);
        has_content = true;
        expecting_value = false;
        continue;
      }
      /* Unbalanced close — stop */
      break;
    }

    /* String literals — consume whole */
    if (lexer->lookahead == '\'') {
      lexer->advance(lexer, false);
      while (lexer->lookahead != 0) {
        if (lexer->lookahead == '\'') {
          lexer->advance(lexer, false);
          if (lexer->lookahead != '\'') break;  /* doubled quote */
          lexer->advance(lexer, false);
        } else {
          lexer->advance(lexer, false);
        }
      }
      has_content = true;
      expecting_value = false;
      continue;
    }

    /* Dollar-quoted strings */
    if (lexer->lookahead == '$') {
      /* Just consume the $ and let it be part of the expression */
      lexer->advance(lexer, false);
      has_content = true;
      expecting_value = false;
      continue;
    }

    /* Comments */
    if (lexer->lookahead == '-') {
      lexer->advance(lexer, false);
      if (lexer->lookahead == '-') {
        /* Line comment — consume to end of line */
        while (lexer->lookahead != 0 && lexer->lookahead != '\n') {
          lexer->advance(lexer, false);
        }
        has_content = true;
        continue;
      }
      has_content = true;
      continue;
    }
    if (lexer->lookahead == '/') {
      lexer->advance(lexer, false);
      if (lexer->lookahead == '*') {
        /* Block comment */
        lexer->advance(lexer, false);
        int comment_depth = 1;
        while (lexer->lookahead != 0 && comment_depth > 0) {
          if (lexer->lookahead == '/') {
            lexer->advance(lexer, false);
            if (lexer->lookahead == '*') {
              comment_depth++;
              lexer->advance(lexer, false);
            }
          } else if (lexer->lookahead == '*') {
            lexer->advance(lexer, false);
            if (lexer->lookahead == '/') {
              comment_depth--;
              lexer->advance(lexer, false);
            }
          } else {
            lexer->advance(lexer, false);
          }
        }
        has_content = true;
        continue;
      }
      has_content = true;
      continue;
    }

    /* At depth 0, check for PL/pgSQL delimiter keywords.
     * We mark the position before checking, and if it's a delimiter, we stop. */
    if (depth == 0 && is_ascii_alpha(lexer->lookahead)) {
      lexer->mark_end(lexer);
      /* Read the identifier */
      char word[32];
      int len = 0;
      while (is_ascii_alnum(lexer->lookahead) || lexer->lookahead == '_') {
        if (len < 30) word[len++] = ascii_tolower(lexer->lookahead);
        lexer->advance(lexer, false);
      }
      word[len] = '\0';

      /* `null` is a PL/pgSQL delimiter only as a bare NULL statement.
       * Inside an expression — after IS, IS NOT, =, <>, !=, IN, AND, OR,
       * NOT, LIKE, etc., or any binary operator — NULL is the SQL literal
       * and must be consumed as part of the expression. We approximate
       * "inside an expression" with the `expecting_value` flag, which is
       * set by operators, opening parens, comma, and value-expecting
       * keywords like IS/NOT/AND/OR/IN/LIKE/BETWEEN. */
      if (strcmp(word, "null") == 0 && expecting_value) {
        has_content = true;
        expecting_value = false;
        continue;
      }

      /* Check if this word is a PL/pgSQL structural delimiter.
       * These are keywords that, in context, indicate the end of a SQL
       * expression in PL/pgSQL. We stop BEFORE consuming them.
       *
       * Note: This is a heuristic. The real parser knows the exact
       * delimiter from context. We err on the side of stopping too
       * early — the grammar rules will then match the keyword. */
      if (/* Expression terminators */
          strcmp(word, "then") == 0 ||
          strcmp(word, "loop") == 0 ||
          strcmp(word, "into") == 0 ||
          strcmp(word, "using") == 0 ||
          strcmp(word, "when") == 0 ||
          strcmp(word, "elsif") == 0 ||
          strcmp(word, "elseif") == 0 ||
          strcmp(word, "else") == 0 ||
          strcmp(word, "end") == 0 ||
          strcmp(word, "declare") == 0 ||
          strcmp(word, "begin") == 0 ||
          strcmp(word, "exception") == 0 ||
          /* Statement-starting keywords — must not be swallowed */
          strcmp(word, "if") == 0 ||
          strcmp(word, "case") == 0 ||
          strcmp(word, "for") == 0 ||
          strcmp(word, "foreach") == 0 ||
          strcmp(word, "while") == 0 ||
          strcmp(word, "return") == 0 ||
          strcmp(word, "raise") == 0 ||
          strcmp(word, "assert") == 0 ||
          strcmp(word, "execute") == 0 ||
          strcmp(word, "perform") == 0 ||
          strcmp(word, "call") == 0 ||
          strcmp(word, "open") == 0 ||
          strcmp(word, "fetch") == 0 ||
          strcmp(word, "move") == 0 ||
          strcmp(word, "close") == 0 ||
          strcmp(word, "null") == 0 ||
          strcmp(word, "exit") == 0 ||
          strcmp(word, "continue") == 0 ||
          strcmp(word, "commit") == 0 ||
          strcmp(word, "rollback") == 0 ||
          strcmp(word, "get") == 0 ||
          strcmp(word, "do") == 0 ||
          /* Additional context-sensitive delimiters */
          strcmp(word, "next") == 0 ||
          strcmp(word, "query") == 0 ||
          strcmp(word, "reverse") == 0 ||
          strcmp(word, "by") == 0 ||
          strcmp(word, "alias") == 0 ||
          strcmp(word, "strict") == 0 ||
          strcmp(word, "cursor") == 0 ||
          strcmp(word, "slice") == 0 ||
          strcmp(word, "array") == 0 ||
          strcmp(word, "all") == 0) {
        /* Stop before this keyword — it's a PL/pgSQL delimiter */
        if (has_content) {
          lexer->result_symbol = SQL_BODY;
          return true;
        }
        return false;
      }

      /* Non-delimiter word: update expecting_value based on whether the
       * word naturally precedes a value (binary operators, IS, NOT, etc.) */
      if (strcmp(word, "is") == 0 || strcmp(word, "not") == 0 ||
          strcmp(word, "and") == 0 || strcmp(word, "or") == 0 ||
          strcmp(word, "in") == 0 || strcmp(word, "like") == 0 ||
          strcmp(word, "ilike") == 0 || strcmp(word, "between") == 0 ||
          strcmp(word, "similar") == 0 || strcmp(word, "as") == 0) {
        expecting_value = true;
      } else {
        expecting_value = false;
      }

      has_content = true;
      continue;
    }

    /* Identifiers starting with underscore or non-ASCII */
    if (depth == 0 && (lexer->lookahead == '_' || (lexer->lookahead >= 0x80))) {
      while (is_ascii_alnum(lexer->lookahead) || lexer->lookahead == '_' ||
             lexer->lookahead == '$' || lexer->lookahead >= 0x80) {
        lexer->advance(lexer, false);
      }

      has_content = true;
      expecting_value = false;
      continue;
    }
    /* Inside parens, consume identifiers without keyword checking */
    if (depth > 0 && (is_ascii_alpha(lexer->lookahead) || lexer->lookahead == '_')) {
      while (is_ascii_alnum(lexer->lookahead) || lexer->lookahead == '_' ||
             lexer->lookahead == '$') {
        lexer->advance(lexer, false);
      }

      has_content = true;
      continue;
    }

    /* Everything else (operators, digits, etc.) — just consume */
    int c = lexer->lookahead;
    lexer->advance(lexer, false);
    has_content = true;
    if (depth == 0) {
      if (c == '+' || c == '-' || c == '*' || c == '/' || c == '%' ||
          c == '<' || c == '>' || c == '=' || c == '~' || c == '!' ||
          c == '@' || c == '#' || c == '^' || c == '&' || c == '|' ||
          c == '?' || c == ',') {
        expecting_value = true;
      } else if (c >= '0' && c <= '9') {
        expecting_value = false;
      }
      /* '.' and other punctuation: leave expecting_value unchanged */
    }
  }

  if (has_content) {
    lexer->mark_end(lexer);
    lexer->result_symbol = SQL_BODY;
    return true;
  }

  return false;
}