import ast
import logging
import re
from collections.abc import Iterable, Iterator
from string import Formatter
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Tuple,
)
templater_logger = logging.getLogger("sqlfluff.templater")
class FluffConfig(NamedTuple):
templater_unwrap_wrapped_queries: bool
class FormatterInterface:
pass
class SQLTemplaterError(Exception):
def __init__(self, message):
self.message = message
def zero_slice(i: int) -> slice:
return slice(i, i)
def offset_slice(start: int, offset: int) -> slice:
return slice(start, start + offset)
def findall(substr: str, in_str: str) -> Iterator[int]:
if not substr or not in_str:
return
idx = in_str.find(substr)
while idx != -1:
yield idx
idx = in_str.find(substr, idx + 1)
class TemplatedFileSlice(NamedTuple):
slice_type: str
source_slice: slice
templated_slice: slice
class RawFileSlice(NamedTuple):
raw: str slice_type: str
source_idx: int block_idx: int = 0
tag: Optional[str] = None
def end_source_idx(self) -> int:
return self.source_idx + len(self.raw)
def source_slice(self) -> slice:
return slice(self.source_idx, self.end_source_idx())
def is_source_only_slice(self) -> bool:
return self.slice_type in ("comment", "block_end", "block_start", "block_mid")
class TemplatedFile:
def __init__(
self,
source_str: str,
fname: str,
templated_str: Optional[str] = None,
sliced_file: Optional[List[TemplatedFileSlice]] = None,
raw_sliced: Optional[List[RawFileSlice]] = None,
):
self.source_str = source_str
self.fname = fname
self.templated_str = templated_str
self.sliced_file = sliced_file
self.raw_sliced = raw_sliced
class IntermediateFileSlice(NamedTuple):
intermediate_type: str
source_slice: slice
templated_slice: slice
slice_buffer: List[RawFileSlice]
def _trim_end(
self,
templated_str: str,
target_end: str = "head",
) -> Tuple["IntermediateFileSlice", List[TemplatedFileSlice]]:
target_idx = 0 if target_end == "head" else -1
terminator_types = ("block_start") if target_end == "head" else ("block_end")
main_source_slice = self.source_slice
main_templated_slice = self.templated_slice
slice_buffer = self.slice_buffer
end_buffer = []
while len(slice_buffer) > 0 and slice_buffer[target_idx].slice_type in (
"literal",
"block_start",
"block_end",
"comment",
):
focus = slice_buffer[target_idx]
templater_logger.debug(" %s Focus: %s", target_end, focus)
if focus.slice_type in ("block_start", "block_end", "comment"):
templated_len = 0
else:
templated_len = len(focus.raw)
if target_end == "head":
check_slice = offset_slice(
main_templated_slice.start,
templated_len,
)
else:
check_slice = slice(
main_templated_slice.stop - templated_len,
main_templated_slice.stop,
)
if templated_str[check_slice] != focus.raw:
templater_logger.debug(" Nope")
break
if target_end == "head":
division = (
main_source_slice.start + len(focus.raw),
main_templated_slice.start + templated_len,
)
new_slice = TemplatedFileSlice(
focus.slice_type,
slice(main_source_slice.start, division[0]),
slice(main_templated_slice.start, division[1]),
)
end_buffer.append(new_slice)
main_source_slice = slice(division[0], main_source_slice.stop)
main_templated_slice = slice(division[1], main_templated_slice.stop)
else:
division = (
main_source_slice.stop - len(focus.raw),
main_templated_slice.stop - templated_len,
)
new_slice = TemplatedFileSlice(
focus.slice_type,
slice(division[0], main_source_slice.stop),
slice(division[1], main_templated_slice.stop),
)
end_buffer.insert(0, new_slice)
main_source_slice = slice(main_source_slice.start, division[0])
main_templated_slice = slice(main_templated_slice.start, division[1])
slice_buffer.pop(target_idx)
if focus.slice_type in terminator_types:
break
new_intermediate = self.__class__(
"compound",
main_source_slice,
main_templated_slice,
slice_buffer,
)
return new_intermediate, end_buffer
def trim_ends(
self,
templated_str: str,
) -> Tuple[
List[TemplatedFileSlice],
"IntermediateFileSlice",
List[TemplatedFileSlice],
]:
new_slice, head_buffer = self._trim_end(
templated_str=templated_str,
target_end="head",
)
new_slice, tail_buffer = new_slice._trim_end(
templated_str=templated_str,
target_end="tail",
)
return head_buffer, new_slice, tail_buffer
def try_simple(self) -> TemplatedFileSlice:
if len(self.slice_buffer) == 1:
return TemplatedFileSlice(
self.slice_buffer[0].slice_type,
self.source_slice,
self.templated_slice,
)
raise ValueError("IntermediateFileSlice is not simple!")
def coalesce(self) -> TemplatedFileSlice:
return TemplatedFileSlice(
PythonTemplater._coalesce_types(self.slice_buffer),
self.source_slice,
self.templated_slice,
)
class PythonTemplater:
name = "python"
config_subsection = ("context",)
def __init__(self, override_context: Optional[Dict[str, Any]] = None) -> None:
self.default_context = dict(test_value="__test__")
self.override_context = override_context or {}
@staticmethod
def infer_type(s: Any) -> Any:
try:
return ast.literal_eval(s)
except (SyntaxError, ValueError):
return s
def get_context(self, context: Dict[str, str]) -> Dict[str, Any]:
live_context = context
for k in live_context:
live_context[k] = self.infer_type(live_context[k])
return live_context
def process(
self,
*,
in_str: str,
fname: str,
context: Dict[str, str],
config: Optional[FluffConfig] = None,
) -> Tuple[TemplatedFile, List[SQLTemplaterError]]:
live_context = self.get_context(context)
def render_func(raw_str: str) -> str:
try:
raw_str_with_dot_notation_hack = re.sub(
r"{([^:}]*\.[^:}]*)(:\S*)?}",
r"{sqlfluff[\1]\2}",
raw_str,
)
templater_logger.debug(
" Raw String with Dot Notation Hack: %r",
raw_str_with_dot_notation_hack,
)
rendered_str = raw_str_with_dot_notation_hack.format(**live_context)
except KeyError as err:
missing_key = err.args[0]
if missing_key == "sqlfluff":
raise SQLTemplaterError(
"Failure in Python templating: magic key 'sqlfluff' "
"missing from context. This key is required "
"for template variables containing '.'. "
"https://docs.sqlfluff.com/en/stable/"
"perma/python_templating.html",
)
if "." in missing_key:
raise SQLTemplaterError(
f"Failure in Python templating: {err} key missing from 'sqlfluff' "
"dict in context. Template variables containing '.' are "
"required to use the 'sqlfluff' magic fixed context key. "
"https://docs.sqlfluff.com/en/stable/"
"perma/python_templating.html",
)
raise SQLTemplaterError(
f"Failure in Python templating: {err}. Have you configured your "
"variables? https://docs.sqlfluff.com/en/stable/"
"perma/variables.html",
)
return rendered_str
raw_sliced, sliced_file, new_str = self.slice_file(
in_str,
render_func=render_func,
config=config,
)
return (
TemplatedFile(
source_str=in_str,
templated_str=new_str,
fname=fname,
sliced_file=sliced_file,
raw_sliced=raw_sliced,
),
[],
)
def slice_file(
self,
raw_str: str,
render_func: Callable[[str], str],
config: Optional[FluffConfig] = None,
append_to_templated: str = "",
) -> Tuple[List[RawFileSlice], List[TemplatedFileSlice], str]:
templater_logger.info("Slicing File Template")
templater_logger.debug(" Raw String: %r", raw_str)
templated_str = render_func(raw_str)
templater_logger.debug(" Templated String: %r", templated_str)
raw_sliced = list(self._slice_template(raw_str))
templater_logger.debug(" Raw Sliced:")
for idx, raw_slice in enumerate(raw_sliced):
templater_logger.debug(" %s: %r", idx, raw_slice)
literals = [
raw_slice.raw
for raw_slice in raw_sliced
if raw_slice.slice_type == "literal"
]
templater_logger.debug(" Literals: %s", literals)
for loop_idx in range(2):
templater_logger.debug(" # Slice Loop %s", loop_idx)
raw_occurrences = self._substring_occurrences(raw_str, literals)
templated_occurrences = self._substring_occurrences(templated_str, literals)
templater_logger.debug(
" Occurrences: Raw: %s, Templated: %s",
raw_occurrences,
templated_occurrences,
)
split_sliced = list(
self._split_invariants(
raw_sliced,
literals,
raw_occurrences,
templated_occurrences,
templated_str,
),
)
templater_logger.debug(" Split Sliced:")
for idx, split_slice in enumerate(split_sliced):
templater_logger.debug(" %s: %r", idx, split_slice)
sliced_file = list(
self._split_uniques_coalesce_rest(
split_sliced,
raw_occurrences,
templated_occurrences,
templated_str,
),
)
templater_logger.debug(" Fully Sliced:")
for idx, templ_slice in enumerate(sliced_file):
templater_logger.debug(" %s: %r", idx, templ_slice)
unwrap_wrapped = (
True if config is None else config.templater_unwrap_wrapped_queries
)
sliced_file, new_templated_str = self._check_for_wrapped(
sliced_file,
templated_str,
unwrap_wrapped=unwrap_wrapped,
)
if new_templated_str == templated_str:
break
templated_str = new_templated_str
return raw_sliced, sliced_file, new_templated_str
@classmethod
def _check_for_wrapped(
cls,
slices: List[TemplatedFileSlice],
templated_str: str,
unwrap_wrapped: bool = True,
) -> Tuple[List[TemplatedFileSlice], str]:
if not slices:
return slices, templated_str
first_slice = slices[0]
last_slice = slices[-1]
if unwrap_wrapped:
return (
slices,
templated_str[
first_slice.templated_slice.start : last_slice.templated_slice.stop
],
)
if (
first_slice.source_slice.start == 0
and first_slice.templated_slice.start != 0
):
slices.insert(
0,
TemplatedFileSlice(
"templated",
slice(0, 0),
slice(0, first_slice.templated_slice.start),
),
)
if last_slice.templated_slice.stop != len(templated_str):
slices.append(
TemplatedFileSlice(
"templated",
zero_slice(last_slice.source_slice.stop),
slice(last_slice.templated_slice.stop, len(templated_str)),
),
)
return slices, templated_str
@classmethod
def _substring_occurrences(
cls,
in_str: str,
substrings: Iterable[str],
) -> Dict[str, List[int]]:
occurrences = {}
for substring in substrings:
occurrences[substring] = list(findall(substring, in_str))
return occurrences
@staticmethod
def _sorted_occurrence_tuples(
occurrences: Dict[str, List[int]],
) -> List[Tuple[str, int]]:
return sorted(
((raw, idx) for raw in occurrences for idx in occurrences[raw]),
key=lambda x: (x[1], x[0]),
)
@classmethod
def _slice_template(cls, in_str: str) -> Iterator[RawFileSlice]:
fmt = Formatter()
in_idx = 0
for literal_text, field_name, format_spec, conversion in fmt.parse(in_str):
if literal_text:
escape_chars = cls._sorted_occurrence_tuples(
cls._substring_occurrences(literal_text, ["}", "{"]),
)
idx = 0
while escape_chars:
first_char = escape_chars.pop()
if first_char[1] > idx:
yield RawFileSlice(
literal_text[idx : first_char[1]],
"literal",
in_idx,
)
in_idx += first_char[1] - idx
idx = first_char[1] + len(first_char[0])
yield RawFileSlice(
literal_text[first_char[1] : idx] * 2,
"escaped",
in_idx,
)
in_idx += 2
if literal_text[idx:]:
yield RawFileSlice(literal_text[idx:], "literal", in_idx)
in_idx += len(literal_text) - idx
if field_name:
constructed_token = "{{{field_name}{conv}{spec}}}".format(
field_name=field_name,
conv=f"!{conversion}" if conversion else "",
spec=f":{format_spec}" if format_spec else "",
)
yield RawFileSlice(constructed_token, "templated", in_idx)
in_idx += len(constructed_token)
@classmethod
def _split_invariants(
cls,
raw_sliced: List[RawFileSlice],
literals: List[str],
raw_occurrences: Dict[str, List[int]],
templated_occurrences: Dict[str, List[int]],
templated_str: str,
) -> Iterator[IntermediateFileSlice]:
invariants = [
literal
for literal in literals
if len(raw_occurrences[literal]) == 1
and len(templated_occurrences[literal]) == 1
]
for linv in sorted(invariants, key=len, reverse=True):
if linv not in invariants:
continue
source_pos, templ_pos = raw_occurrences[linv], templated_occurrences[linv]
for tinv in invariants.copy():
if tinv != linv:
src_dir = source_pos > raw_occurrences[tinv]
tmp_dir = templ_pos > templated_occurrences[tinv]
if src_dir != tmp_dir: templater_logger.debug(
" Invariant found out of order: %r",
tinv,
)
invariants.remove(tinv)
buffer: List[RawFileSlice] = []
idx: Optional[int] = None
templ_idx = 0
for raw_file_slice in raw_sliced:
if raw_file_slice.raw in invariants:
if buffer:
yield IntermediateFileSlice(
"compound",
slice(idx, raw_file_slice.source_idx),
slice(templ_idx, templated_occurrences[raw_file_slice.raw][0]),
buffer,
)
buffer = []
idx = None
yield IntermediateFileSlice(
"invariant",
offset_slice(
raw_file_slice.source_idx,
len(raw_file_slice.raw),
),
offset_slice(
templated_occurrences[raw_file_slice.raw][0],
len(raw_file_slice.raw),
),
[
RawFileSlice(
raw_file_slice.raw,
raw_file_slice.slice_type,
templated_occurrences[raw_file_slice.raw][0],
),
],
)
templ_idx = templated_occurrences[raw_file_slice.raw][0] + len(
raw_file_slice.raw,
)
else:
buffer.append(
RawFileSlice(
raw_file_slice.raw,
raw_file_slice.slice_type,
raw_file_slice.source_idx,
),
)
if idx is None:
idx = raw_file_slice.source_idx
if buffer:
yield IntermediateFileSlice(
"compound",
slice((idx or 0), (idx or 0) + sum(len(slc.raw) for slc in buffer)),
slice(templ_idx, len(templated_str)),
buffer,
)
@staticmethod
def _filter_occurrences(
file_slice: slice,
occurrences: Dict[str, List[int]],
) -> Dict[str, List[int]]:
filtered = {
key: [
pos
for pos in occurrences[key]
if pos >= file_slice.start and pos < file_slice.stop
]
for key in occurrences
}
return {key: filtered[key] for key in filtered if filtered[key]}
@staticmethod
def _coalesce_types(elems: List[RawFileSlice]) -> str:
types = {elem.slice_type for elem in elems}
for typ in list(types):
if typ.startswith("block_"): types.remove(typ)
types.add("templated")
if len(types) == 1:
return types.pop()
priority = ["templated", "escaped", "literal"]
for p in priority:
if p in types:
return p
raise RuntimeError(
f"Exhausted priorities in _coalesce_types! {types!r}",
)
@classmethod
def _split_uniques_coalesce_rest(
cls,
split_file: List[IntermediateFileSlice],
raw_occurrences: Dict[str, List[int]],
templ_occurrences: Dict[str, List[int]],
templated_str: str,
) -> Iterator[TemplatedFileSlice]:
tail_buffer: List[TemplatedFileSlice] = []
templater_logger.debug(" _split_uniques_coalesce_rest: %s", split_file)
for int_file_slice in split_file:
if tail_buffer: templater_logger.debug(
" Yielding Tail Buffer [start]: %s",
tail_buffer,
)
yield from tail_buffer
tail_buffer = []
if (
int_file_slice.templated_slice.stop
- int_file_slice.templated_slice.start
== 0
): point_combo = int_file_slice.coalesce()
templater_logger.debug(
" Yielding Point Combination: %s",
point_combo,
)
yield point_combo
continue
try:
simple_elem = int_file_slice.try_simple()
templater_logger.debug(" Yielding Simple: %s", simple_elem)
yield simple_elem
continue
except ValueError:
pass
head_buffer, int_file_slice, tail_buffer = int_file_slice.trim_ends(
templated_str=templated_str,
)
if head_buffer:
yield from head_buffer if not int_file_slice.slice_buffer:
continue
try: simple_elem = int_file_slice.try_simple()
templater_logger.debug(" Yielding Simple: %s", simple_elem)
yield simple_elem
continue
except ValueError:
pass
templater_logger.debug(" Intermediate Slice: %s", int_file_slice)
coalesced = int_file_slice.coalesce()
raw_occs = cls._filter_occurrences(
int_file_slice.source_slice,
raw_occurrences,
)
templ_occs = cls._filter_occurrences(
int_file_slice.templated_slice,
templ_occurrences,
)
one_way_uniques = [
key
for key in raw_occs.keys()
if len(raw_occs[key]) == 1 and len(templ_occs.get(key, [])) >= 1
]
two_way_uniques = [
key for key in one_way_uniques if len(templ_occs[key]) == 1
]
if not raw_occs or not templ_occs or not one_way_uniques:
templater_logger.debug(
" No Anchors or Uniques. Yielding Whole: %s",
coalesced,
)
yield coalesced
continue
templater_logger.debug(
" Intermediate Slice [post trim]: %s: %r",
int_file_slice,
templated_str[int_file_slice.templated_slice],
)
templater_logger.debug(" One Way Uniques: %s", one_way_uniques)
templater_logger.debug(" Two Way Uniques: %s", two_way_uniques)
starts = (
int_file_slice.source_slice.start,
int_file_slice.templated_slice.start,
)
if two_way_uniques:
bookmark_idx = 0
for idx, raw_slice in enumerate(int_file_slice.slice_buffer):
pos = 0
unq: Optional[str] = None
for unique in two_way_uniques:
if unique in raw_slice.raw:
pos = raw_slice.raw.index(unique)
unq = unique
if unq:
unique_position = (
raw_occs[unq][0],
templ_occs[unq][0],
)
templater_logger.debug(
" Handling Unique: %r, %s, %s, %r",
unq,
pos,
unique_position,
raw_slice,
)
if idx > bookmark_idx:
yield from cls._split_uniques_coalesce_rest(
[
IntermediateFileSlice(
"compound",
slice(starts[0], unique_position[0] - pos),
slice(starts[1], unique_position[1] - pos),
int_file_slice.slice_buffer[bookmark_idx:idx],
),
],
raw_occs,
templ_occs,
templated_str,
)
if pos > 0:
yield TemplatedFileSlice(
raw_slice.slice_type,
slice(unique_position[0] - pos, unique_position[0]),
slice(unique_position[1] - pos, unique_position[1]),
)
starts = (
unique_position[0] + len(unq),
unique_position[1] + len(unq),
)
yield TemplatedFileSlice(
raw_slice.slice_type,
slice(unique_position[0], starts[0]),
slice(unique_position[1], starts[1]),
)
bookmark_idx = idx + 1
if raw_slice.raw[pos + len(unq) :]:
remnant_length = len(raw_slice.raw) - (len(unq) + pos)
_starts = starts
starts = (
starts[0] + remnant_length,
starts[1] + remnant_length,
)
yield TemplatedFileSlice(
raw_slice.slice_type,
slice(_starts[0], starts[0]),
slice(_starts[1], starts[1]),
)
if bookmark_idx == 0: templater_logger.info(
" Safety Value Info: %s, %r",
two_way_uniques,
templated_str[int_file_slice.templated_slice],
)
templater_logger.warning(
" Python templater safety value unexpectedly triggered. "
"Please report your raw and compiled query on github for "
"debugging.",
)
yield coalesced
continue
if len(int_file_slice.slice_buffer) > bookmark_idx:
yield from cls._split_uniques_coalesce_rest(
[
IntermediateFileSlice(
"compound",
slice(starts[0], int_file_slice.source_slice.stop),
slice(starts[1], int_file_slice.templated_slice.stop),
int_file_slice.slice_buffer[
bookmark_idx : len(int_file_slice.slice_buffer)
],
),
],
raw_occs,
templ_occs,
templated_str,
)
continue
owu_templ_tuples = cls._sorted_occurrence_tuples( {key: templ_occs[key] for key in one_way_uniques},
)
templater_logger.debug( " Handling One Way Uniques: %s",
owu_templ_tuples,
)
stops = ( int_file_slice.source_slice.stop,
int_file_slice.templated_slice.stop,
)
this_owu_idx: Optional[int] = None last_owu_idx: Optional[int] = None for raw, template_idx in owu_templ_tuples: raw_idx = raw_occs[raw][0]
raw_len = len(raw)
last_owu_idx = this_owu_idx
try:
this_owu_idx = next(
idx
for idx, slc in enumerate(int_file_slice.slice_buffer)
if slc.raw == raw
)
except StopIteration: templater_logger.info(
"One Way Unique %r not found in slice buffer. Skipping...",
raw,
)
continue
templater_logger.debug(
" Handling OWU: %r @%s (raw @%s) [this_owu_idx: %s, "
"last_owu_dx: %s]",
raw,
template_idx,
raw_idx,
this_owu_idx,
last_owu_idx,
)
if template_idx > starts[1]:
sub_section: Optional[List[RawFileSlice]] = None
if (
starts[1] == int_file_slice.templated_slice.stop
): sub_section = int_file_slice.slice_buffer[:this_owu_idx]
elif (
raw_idx > starts[0] and last_owu_idx != this_owu_idx
): if last_owu_idx:
sub_section = int_file_slice.slice_buffer[
last_owu_idx + 1 : this_owu_idx
]
else:
sub_section = int_file_slice.slice_buffer[:this_owu_idx]
if sub_section:
templater_logger.debug(
" Attempting Subsplit [pre]: %s, %r",
sub_section,
templated_str[slice(starts[1], template_idx)],
)
yield from cls._split_uniques_coalesce_rest(
[
IntermediateFileSlice(
"compound",
slice(starts[0], raw_idx),
slice(starts[1], template_idx),
sub_section,
),
],
raw_occs,
templ_occs,
templated_str,
)
else:
if last_owu_idx is None or last_owu_idx + 1 >= len(
int_file_slice.slice_buffer,
):
cur_idx = 0
else:
cur_idx = last_owu_idx + 1
block_ends = sum(
slc.slice_type == "block_end"
for slc in int_file_slice.slice_buffer[cur_idx:]
)
block_start_indices = [
idx
for idx, slc in enumerate(
int_file_slice.slice_buffer[:cur_idx],
)
if slc.slice_type == "block_start"
]
if len(block_start_indices) > block_ends: offset = block_start_indices[-1 - block_ends] + 1
elem_sub_buffer = int_file_slice.slice_buffer[offset:]
cur_idx -= offset
else:
elem_sub_buffer = int_file_slice.slice_buffer
include_start = raw_idx > elem_sub_buffer[0].source_idx
end_point = elem_sub_buffer[-1].end_source_idx()
if include_start:
start_point = elem_sub_buffer[0].source_idx
else: start_point = elem_sub_buffer[cur_idx].source_idx
tricky = TemplatedFileSlice(
"templated",
slice(start_point, end_point),
slice(starts[1], template_idx),
)
templater_logger.debug(
" Yielding Tricky Case : %s",
tricky,
)
yield tricky
owu_literal_slice = TemplatedFileSlice(
"literal",
offset_slice(raw_idx, raw_len),
offset_slice(template_idx, raw_len),
)
templater_logger.debug(
" Yielding Unique: %r, %s",
raw,
owu_literal_slice,
)
yield owu_literal_slice
starts = (
raw_idx + raw_len,
template_idx + raw_len,
)
if starts[1] < stops[1] and last_owu_idx is not None: templater_logger.debug(" Attempting Subsplit [post].")
yield from cls._split_uniques_coalesce_rest(
[
IntermediateFileSlice(
"compound",
slice(raw_idx + raw_len, stops[0]),
slice(starts[1], stops[1]),
int_file_slice.slice_buffer[last_owu_idx + 1 :],
),
],
raw_occs,
templ_occs,
templated_str,
)
if tail_buffer: templater_logger.debug(
" Yielding Tail Buffer [end]: %s",
tail_buffer,
)
yield from tail_buffer
def process_from_rust(
string: str,
fname: str,
live_context: Dict[str, Any],
) -> TemplatedFile:
templater = PythonTemplater(override_context=live_context)
(output, errors) = templater.process(
in_str=string,
fname=fname,
context=live_context,
)
if errors != []:
raise ValueError
return output