from __future__ import annotations
import os
import re
import abc
import sys
import json
import email
import types
import inspect
import pathlib
import zipfile
import operator
import textwrap
import warnings
import functools
import itertools
import posixpath
import collections
from . import _meta
from ._collections import FreezableDefaultDict, Pair
from ._functools import method_cache, pass_none
from ._itertools import always_iterable, unique_everseen
from ._meta import PackageMetadata, SimplePath
from contextlib import suppress
from importlib import import_module
from importlib.abc import MetaPathFinder
from itertools import starmap
from typing import Any, Iterable, List, Mapping, Match, Optional, Set, cast
__all__ = [
'Distribution',
'DistributionFinder',
'PackageMetadata',
'PackageNotFoundError',
'distribution',
'distributions',
'entry_points',
'files',
'metadata',
'packages_distributions',
'requires',
'version',
]
class PackageNotFoundError(ModuleNotFoundError):
def __str__(self) -> str:
return f"No package metadata was found for {self.name}"
@property
def name(self) -> str: (name,) = self.args
return name
class Sectioned:
_sample = textwrap.dedent(
"""
[sec1]
# comments ignored
a = 1
b = 2
[sec2]
a = 2
"""
).lstrip()
@classmethod
def section_pairs(cls, text):
return (
section._replace(value=Pair.parse(section.value))
for section in cls.read(text, filter_=cls.valid)
if section.name is not None
)
@staticmethod
def read(text, filter_=None):
lines = filter(filter_, map(str.strip, text.splitlines()))
name = None
for value in lines:
section_match = value.startswith('[') and value.endswith(']')
if section_match:
name = value.strip('[]')
continue
yield Pair(name, value)
@staticmethod
def valid(line: str):
return line and not line.startswith('#')
class EntryPoint:
pattern = re.compile(
r'(?P<module>[\w.]+)\s*'
r'(:\s*(?P<attr>[\w.]+)\s*)?'
r'((?P<extras>\[.*\])\s*)?$'
)
name: str
value: str
group: str
dist: Optional[Distribution] = None
def __init__(self, name: str, value: str, group: str) -> None:
vars(self).update(name=name, value=value, group=group)
def load(self) -> Any:
match = cast(Match, self.pattern.match(self.value))
module = import_module(match.group('module'))
attrs = filter(None, (match.group('attr') or '').split('.'))
return functools.reduce(getattr, attrs, module)
@property
def module(self) -> str:
match = self.pattern.match(self.value)
assert match is not None
return match.group('module')
@property
def attr(self) -> str:
match = self.pattern.match(self.value)
assert match is not None
return match.group('attr')
@property
def extras(self) -> List[str]:
match = self.pattern.match(self.value)
assert match is not None
return re.findall(r'\w+', match.group('extras') or '')
def _for(self, dist):
vars(self).update(dist=dist)
return self
def matches(self, **params):
attrs = (getattr(self, param) for param in params)
return all(map(operator.eq, params.values(), attrs))
def _key(self):
return self.name, self.value, self.group
def __lt__(self, other):
return self._key() < other._key()
def __eq__(self, other):
return self._key() == other._key()
def __setattr__(self, name, value):
raise AttributeError("EntryPoint objects are immutable.")
def __repr__(self):
return (
f'EntryPoint(name={self.name!r}, value={self.value!r}, '
f'group={self.group!r})'
)
def __hash__(self) -> int:
return hash(self._key())
class EntryPoints(tuple):
__slots__ = ()
def __getitem__(self, name: str) -> EntryPoint:
try:
return next(iter(self.select(name=name)))
except StopIteration:
raise KeyError(name)
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, tuple(self))
def select(self, **params) -> EntryPoints:
return EntryPoints(ep for ep in self if ep.matches(**params))
@property
def names(self) -> Set[str]:
return {ep.name for ep in self}
@property
def groups(self) -> Set[str]:
return {ep.group for ep in self}
@classmethod
def _from_text_for(cls, text, dist):
return cls(ep._for(dist) for ep in cls._from_text(text))
@staticmethod
def _from_text(text):
return (
EntryPoint(name=item.value.name, value=item.value.value, group=item.name)
for item in Sectioned.section_pairs(text or '')
)
class PackagePath(pathlib.PurePosixPath):
hash: Optional[FileHash]
size: int
dist: Distribution
def read_text(self, encoding: str = 'utf-8') -> str: return self.locate().read_text(encoding=encoding)
def read_binary(self) -> bytes:
return self.locate().read_bytes()
def locate(self) -> SimplePath:
return self.dist.locate_file(self)
class FileHash:
def __init__(self, spec: str) -> None:
self.mode, _, self.value = spec.partition('=')
def __repr__(self) -> str:
return f'<FileHash mode: {self.mode} value: {self.value}>'
class DeprecatedNonAbstract:
def __new__(cls, *args, **kwargs):
all_names = {
name for subclass in inspect.getmro(cls) for name in vars(subclass)
}
abstract = {
name
for name in all_names
if getattr(getattr(cls, name), '__isabstractmethod__', False)
}
if abstract:
warnings.warn(
f"Unimplemented abstract methods {abstract}",
DeprecationWarning,
stacklevel=2,
)
return super().__new__(cls)
class Distribution(DeprecatedNonAbstract):
@abc.abstractmethod
def read_text(self, filename) -> Optional[str]:
@abc.abstractmethod
def locate_file(self, path: str | os.PathLike[str]) -> SimplePath:
@classmethod
def from_name(cls, name: str) -> Distribution:
if not name:
raise ValueError("A distribution name is required.")
try:
return next(iter(cls.discover(name=name)))
except StopIteration:
raise PackageNotFoundError(name)
@classmethod
def discover(
cls, *, context: Optional[DistributionFinder.Context] = None, **kwargs
) -> Iterable[Distribution]:
if context and kwargs:
raise ValueError("cannot accept context and kwargs")
context = context or DistributionFinder.Context(**kwargs)
return itertools.chain.from_iterable(
resolver(context) for resolver in cls._discover_resolvers()
)
@staticmethod
def at(path: str | os.PathLike[str]) -> Distribution:
return PathDistribution(pathlib.Path(path))
@staticmethod
def _discover_resolvers():
declared = (
getattr(finder, 'find_distributions', None) for finder in sys.meta_path
)
return filter(None, declared)
@property
def metadata(self) -> _meta.PackageMetadata:
from . import _adapters
opt_text = (
self.read_text('METADATA')
or self.read_text('PKG-INFO')
or self.read_text('')
)
text = cast(str, opt_text)
return _adapters.Message(email.message_from_string(text))
@property
def name(self) -> str:
return self.metadata['Name']
@property
def _normalized_name(self):
return Prepared.normalize(self.name)
@property
def version(self) -> str:
return self.metadata['Version']
@property
def entry_points(self) -> EntryPoints:
return EntryPoints._from_text_for(self.read_text('entry_points.txt'), self)
@property
def files(self) -> Optional[List[PackagePath]]:
def make_file(name, hash=None, size_str=None):
result = PackagePath(name)
result.hash = FileHash(hash) if hash else None
result.size = int(size_str) if size_str else None
result.dist = self
return result
@pass_none
def make_files(lines):
import csv
return starmap(make_file, csv.reader(lines))
@pass_none
def skip_missing_files(package_paths):
return list(filter(lambda path: path.locate().exists(), package_paths))
return skip_missing_files(
make_files(
self._read_files_distinfo()
or self._read_files_egginfo_installed()
or self._read_files_egginfo_sources()
)
)
def _read_files_distinfo(self):
text = self.read_text('RECORD')
return text and text.splitlines()
def _read_files_egginfo_installed(self):
text = self.read_text('installed-files.txt')
subdir = getattr(self, '_path', None)
if not text or not subdir:
return
paths = (
(subdir / name)
.resolve()
.relative_to(self.locate_file('').resolve(), walk_up=True)
.as_posix()
for name in text.splitlines()
)
return map('"{}"'.format, paths)
def _read_files_egginfo_sources(self):
text = self.read_text('SOURCES.txt')
return text and map('"{}"'.format, text.splitlines())
@property
def requires(self) -> Optional[List[str]]:
reqs = self._read_dist_info_reqs() or self._read_egg_info_reqs()
return reqs and list(reqs)
def _read_dist_info_reqs(self):
return self.metadata.get_all('Requires-Dist')
def _read_egg_info_reqs(self):
source = self.read_text('requires.txt')
return pass_none(self._deps_from_requires_text)(source)
@classmethod
def _deps_from_requires_text(cls, source):
return cls._convert_egg_info_reqs_to_simple_reqs(Sectioned.read(source))
@staticmethod
def _convert_egg_info_reqs_to_simple_reqs(sections):
def make_condition(name):
return name and f'extra == "{name}"'
def quoted_marker(section):
section = section or ''
extra, sep, markers = section.partition(':')
if extra and markers:
markers = f'({markers})'
conditions = list(filter(None, [markers, make_condition(extra)]))
return '; ' + ' and '.join(conditions) if conditions else ''
def url_req_space(req):
return ' ' * ('@' in req)
for section in sections:
space = url_req_space(section.value)
yield section.value + space + quoted_marker(section.name)
@property
def origin(self):
return self._load_json('direct_url.json')
def _load_json(self, filename):
return pass_none(json.loads)(
self.read_text(filename),
object_hook=lambda data: types.SimpleNamespace(**data),
)
class DistributionFinder(MetaPathFinder):
class Context:
name = None
def __init__(self, **kwargs):
vars(self).update(kwargs)
@property
def path(self) -> List[str]:
return vars(self).get('path', sys.path)
@abc.abstractmethod
def find_distributions(self, context=Context()) -> Iterable[Distribution]:
class FastPath:
@functools.lru_cache() def __new__(cls, root):
return super().__new__(cls)
def __init__(self, root):
self.root = root
def joinpath(self, child):
return pathlib.Path(self.root, child)
def children(self):
with suppress(Exception):
return os.listdir(self.root or '.')
with suppress(Exception):
return self.zip_children()
return []
def zip_children(self):
zip_path = zipfile.Path(self.root)
names = zip_path.root.namelist()
self.joinpath = zip_path.joinpath
return dict.fromkeys(child.split(posixpath.sep, 1)[0] for child in names)
def search(self, name):
return self.lookup(self.mtime).search(name)
@property
def mtime(self):
with suppress(OSError):
return os.stat(self.root).st_mtime
self.lookup.cache_clear()
@method_cache
def lookup(self, mtime):
return Lookup(self)
class Lookup:
def __init__(self, path: FastPath):
base = os.path.basename(path.root).lower()
base_is_egg = base.endswith(".egg")
self.infos = FreezableDefaultDict(list)
self.eggs = FreezableDefaultDict(list)
for child in path.children():
low = child.lower()
if low.endswith((".dist-info", ".egg-info")):
name = low.rpartition(".")[0].partition("-")[0]
normalized = Prepared.normalize(name)
self.infos[normalized].append(path.joinpath(child))
elif base_is_egg and low == "egg-info":
name = base.rpartition(".")[0].partition("-")[0]
legacy_normalized = Prepared.legacy_normalize(name)
self.eggs[legacy_normalized].append(path.joinpath(child))
self.infos.freeze()
self.eggs.freeze()
def search(self, prepared: Prepared):
infos = (
self.infos[prepared.normalized]
if prepared
else itertools.chain.from_iterable(self.infos.values())
)
eggs = (
self.eggs[prepared.legacy_normalized]
if prepared
else itertools.chain.from_iterable(self.eggs.values())
)
return itertools.chain(infos, eggs)
class Prepared:
normalized = None
legacy_normalized = None
def __init__(self, name: Optional[str]):
self.name = name
if name is None:
return
self.normalized = self.normalize(name)
self.legacy_normalized = self.legacy_normalize(name)
@staticmethod
def normalize(name):
return re.sub(r"[-_.]+", "-", name).lower().replace('-', '_')
@staticmethod
def legacy_normalize(name):
return name.lower().replace('-', '_')
def __bool__(self):
return bool(self.name)
class MetadataPathFinder(DistributionFinder):
@classmethod
def find_distributions(
cls, context=DistributionFinder.Context()
) -> Iterable[PathDistribution]:
found = cls._search_paths(context.name, context.path)
return map(PathDistribution, found)
@classmethod
def _search_paths(cls, name, paths):
prepared = Prepared(name)
return itertools.chain.from_iterable(
path.search(prepared) for path in map(FastPath, paths)
)
@classmethod
def invalidate_caches(cls) -> None:
FastPath.__new__.cache_clear()
class PathDistribution(Distribution):
def __init__(self, path: SimplePath) -> None:
self._path = path
def read_text(self, filename: str | os.PathLike[str]) -> Optional[str]:
with suppress(
FileNotFoundError,
IsADirectoryError,
KeyError,
NotADirectoryError,
PermissionError,
):
return self._path.joinpath(filename).read_text(encoding='utf-8')
return None
read_text.__doc__ = Distribution.read_text.__doc__
def locate_file(self, path: str | os.PathLike[str]) -> SimplePath:
return self._path.parent / path
@property
def _normalized_name(self):
stem = os.path.basename(str(self._path))
return (
pass_none(Prepared.normalize)(self._name_from_stem(stem))
or super()._normalized_name
)
@staticmethod
def _name_from_stem(stem):
filename, ext = os.path.splitext(stem)
if ext not in ('.dist-info', '.egg-info'):
return
name, sep, rest = filename.partition('-')
return name
def distribution(distribution_name: str) -> Distribution:
return Distribution.from_name(distribution_name)
def distributions(**kwargs) -> Iterable[Distribution]:
return Distribution.discover(**kwargs)
def metadata(distribution_name: str) -> _meta.PackageMetadata:
return Distribution.from_name(distribution_name).metadata
def version(distribution_name: str) -> str:
return distribution(distribution_name).version
_unique = functools.partial(
unique_everseen,
key=operator.attrgetter('_normalized_name'),
)
def entry_points(**params) -> EntryPoints:
eps = itertools.chain.from_iterable(
dist.entry_points for dist in _unique(distributions())
)
return EntryPoints(eps).select(**params)
def files(distribution_name: str) -> Optional[List[PackagePath]]:
return distribution(distribution_name).files
def requires(distribution_name: str) -> Optional[List[str]]:
return distribution(distribution_name).requires
def packages_distributions() -> Mapping[str, List[str]]:
pkg_to_dist = collections.defaultdict(list)
for dist in distributions():
for pkg in _top_level_declared(dist) or _top_level_inferred(dist):
pkg_to_dist[pkg].append(dist.metadata['Name'])
return dict(pkg_to_dist)
def _top_level_declared(dist):
return (dist.read_text('top_level.txt') or '').split()
def _topmost(name: PackagePath) -> Optional[str]:
top, *rest = name.parts
return top if rest else None
def _get_toplevel_name(name: PackagePath) -> str:
return _topmost(name) or (
inspect.getmodulename(name) or str(name)
)
def _top_level_inferred(dist):
opt_names = set(map(_get_toplevel_name, always_iterable(dist.files)))
def importable_name(name):
return '.' not in name
return filter(importable_name, opt_names)