from __future__ import annotations
import collections
import contextlib
import dataclasses
import enum
import functools
import itertools
import pathlib
import re
_INCLUDES = re.compile(r'#\s*(?:include|import)(_next)?\s*["<]([^>"]+)[>"]')
class CompileStatus(enum.Enum):
NotCompiled = 1
Success = 2
Failure = 3
class IncludeDir(enum.Enum):
LibCxx = 1
Builtin = 2
SysrootModule = 3
Sysroot = 4
Framework = 5
def __lt__(self, other):
return self.value < other.value
@dataclasses.dataclass
class Header:
include_dir: IncludeDir
rel: str
abs: pathlib.Path = None
prev: None | Header = None
next: None | Header = None
root_module: None | str = None
textual: bool = False
umbrella: bool = False
compile_status: CompileStatus = CompileStatus.NotCompiled
deps: list[Header] = dataclasses.field(default_factory=list)
direct_deps: set[Header] = dataclasses.field(default_factory=set)
exports: None | list[str] = dataclasses.field(default_factory=list)
kwargs: dict[str, list[str]] = dataclasses.field(
default_factory=lambda: collections.defaultdict(list))
def __hash__(self):
return hash((self.include_dir, self.rel))
def __eq__(self, other):
if isinstance(other, Header):
return (self.include_dir, self.rel) == (other.include_dir, other.rel)
else:
return (self.include_dir, self.rel) == other
def __lt__(self, other):
return (self.include_dir, self.rel) < (other.include_dir, other.rel)
@property
def pretty_name(self):
return f'{self.include_dir.name}/{self.rel}'
def __repr__(self):
return self.pretty_name
@property
def submodule_name(self):
return self.rel.replace('.', '_').replace('/', '_').replace('-', '_')
@functools.cached_property
def content(self) -> str:
return self.abs.read_text(errors='ignore')
def calculate_direct_deps(self, includes: dict[str, Header],
sysroot: pathlib.Path) -> set[Header]:
direct = set()
found_includes = _INCLUDES.findall(self.content)
def find_include(is_next, include) -> bool:
header = None
first = includes.get(include, None)
if first is not None:
if not is_next or (is_next and self.rel != include):
header = first
elif self.next is not None:
header = self.next
if header is not None and header in self.deps:
direct.add(header)
return True
return False
for is_next, include in found_includes:
if not find_include(is_next, include):
with contextlib.suppress(OSError, FileNotFoundError):
find_include(is_next, str((sysroot / include).readlink()))
return direct
@functools.cache
def _required_deps(self) -> tuple[set[Header], set[Header]]:
nontextual = set()
textual = set()
todo = [self]
while todo:
hdr = todo.pop()
for dep in hdr.direct_deps:
if dep.textual and dep not in textual:
todo.append(dep)
textual.add(dep)
elif not dep.textual:
nontextual.add(dep)
return nontextual, textual
@property
def required_deps(self) -> set[Header]:
return self._required_deps()[0]
@property
def required_textual_deps(self) -> set[Header]:
return self._required_deps()[1]
def find_loop(self) -> list[Header] | None:
chain = [self]
has_chain = True
while has_chain:
has_chain = False
if self in chain[-1].direct_deps:
return chain + [self]
for dep in chain[-1].direct_deps:
if dep not in chain and self in dep.deps:
chain.append(dep)
has_chain = True
break
assert len(chain) == 1
def calculate_rdeps(headers: list[Header]) -> dict[Header, list[Header]]:
rdeps = collections.defaultdict(list)
for header in headers:
for dep in header.deps:
rdeps[dep].append(header)
return rdeps
def all_headers(graph: dict[str, Header]):
for header in graph.values():
while header is not None:
yield header
header = header.next
@dataclasses.dataclass
class Target:
include_dir: IncludeDir
name: str
headers: list[Header] = dataclasses.field(default_factory=list)
def __lt__(self, other):
return self.name < other.name
def __eq__(self, other):
return self.name == other.name
def __hash__(self):
return hash(self.name)
@property
def kwargs(self) -> dict[str, set[str]]:
kwargs = collections.defaultdict(set)
for header in self.headers:
for single in header.group:
for dep in {single} | single.required_textual_deps:
for k, v in dep.kwargs.items():
kwargs[k].update(v)
return kwargs
@property
def header_deps(self) -> set[Header]:
direct_deps = set()
for hdr in self.headers:
direct_deps.update(hdr.required_deps)
return direct_deps
@property
def public_deps(self) -> list[str]:
return sorted(
set([
hdr.root_module for hdr in self.header_deps
if hdr.root_module is not None and hdr.root_module != self.name
]))
def run_build(graph: dict[str, Header]) -> list[Target]:
unbuilt_modules: dict[str, list[Header]] = collections.defaultdict(list)
unbuilt_headers: set[Header] = set()
for header in all_headers(graph):
if not header.textual:
if header.root_module is None:
unbuilt_headers.add(header)
else:
unbuilt_modules[header.root_module].append(header)
header.rdeps = set()
header.group = [header]
header.mod_deps = set()
header.unbuilt_deps = set(
dep for dep in header.required_deps
if (header.root_module is None or header.root_module != dep.root_module)
and dep != header)
parents = {}
def find(header):
if header in parents:
parents[header] = find(parents[header])
return parents[header]
else:
return header
for header in sorted(unbuilt_headers):
for dep in header.required_deps:
if dep > header and header in dep.deps:
assert header.include_dir == IncludeDir.Sysroot and dep.include_dir == IncludeDir.Sysroot, (
header, dep)
x, y = sorted([find(header), find(dep)])
if x != y:
parents[y] = x
loops = collections.defaultdict(list)
for header in unbuilt_headers:
loops[find(header)].append(header)
for headers in loops.values():
if len(headers) == 1:
continue
headers.sort()
headers[0].group = headers
headers[0].unbuilt_deps = set.union(
*[header.unbuilt_deps for header in headers]) - set(headers)
for header in headers[1:]:
unbuilt_headers.remove(header)
for header in all_headers(graph):
for dep in header.unbuilt_deps:
dep.rdeps.add(header)
build_gn = []
for i in itertools.count():
while True:
n_remaining = len(unbuilt_modules)
for mod, headers in list(unbuilt_modules.items()):
if not any(header.unbuilt_deps for header in headers):
build_gn.append(
Target(
include_dir=headers[0].include_dir,
name=mod,
headers=sorted(headers),
))
del unbuilt_modules[mod]
for header in headers:
for rdep in header.rdeps:
rdep.mod_deps.add(header.root_module)
rdep.unbuilt_deps.remove(header)
if n_remaining == len(unbuilt_modules):
break
sysroot_mod = f'sys_stage{i + 1}'
build_gn.append(Target(
include_dir=IncludeDir.Sysroot,
name=sysroot_mod,
))
while True:
n_remaining = len(unbuilt_headers)
for header in list(unbuilt_headers):
if not header.unbuilt_deps:
build_gn[-1].headers.append(header)
unbuilt_headers.remove(header)
for header in header.group:
header.root_module = sysroot_mod
for rdep in header.rdeps:
rdep.mod_deps.add(header.root_module)
rdep.unbuilt_deps.remove(header)
if n_remaining == len(unbuilt_headers):
break
build_gn[-1].headers.sort()
if not build_gn[-1].headers:
break
build_gn.pop()
if not unbuilt_modules and not unbuilt_headers:
return build_gn
else:
print(
"Dependency loop in sysroot. You probably want to make one of them textual."
)
print("The following headers are in a dependency loop:")
seen = set()
for header in unbuilt_headers:
if header not in seen:
chain = header.find_loop()
if chain is not None:
print(' -> '.join([header.pretty_name for header in chain]))
seen.update(chain)
breakpoint()
exit(1)