from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class QueryEngine(Protocol):
def execute(self, store: Any, query: str) -> list[dict[str, Any]]:
...
class Query:
def __init__(self, store: Any, engine: QueryEngine | None = None):
self._store = store
self._engine = engine
self._working_set: list[dict[str, Any]] | None = None
self._steps: list[tuple[str, Any]] = []
def raw(self, query_str: str) -> list[dict[str, Any]]:
if self._engine is None:
raise RuntimeError(
"No query engine registered. Pass engine= to Query(), "
"or use the fluent API (.nodes(), .where(), .follow())."
)
return self._engine.execute(self._store, query_str)
def nodes(self, node_type: str | None = None, subtype: str | None = None) -> Query:
if node_type is not None:
working = self._store.query_nodes_by_type(node_type)
else:
working = self._store.all_nodes()
if subtype is not None:
working = [n for n in working if n.get("subtype") == subtype]
self._working_set = working
return self
def edges(self, edge_type: str | None = None) -> Query:
all_edges = self._store.all_edges()
if edge_type is not None:
all_edges = [e for e in all_edges if e.get("edge_type") == edge_type]
self._working_set = all_edges
return self
def where(self, **kwargs: Any) -> Query:
if self._working_set is None:
raise RuntimeError("Call .nodes() or .edges() before .where()")
filtered = []
for item in self._working_set:
props = item.get("properties", {})
if all(props.get(k) == v for k, v in kwargs.items()):
filtered.append(item)
self._working_set = filtered
return self
def where_fn(self, predicate: Any) -> Query:
if self._working_set is None:
raise RuntimeError("Call .nodes() or .edges() before .where_fn()")
self._working_set = [item for item in self._working_set if predicate(item)]
return self
def follow(self, edge_type: str | None = None, direction: str = "out") -> Query:
if self._working_set is None:
raise RuntimeError("Call .nodes() before .follow()")
targets = []
seen = set()
for node in self._working_set:
node_id = node.get("node_id")
if node_id is None:
continue
edge_pairs: list[tuple[dict, str]] = []
if direction in ("out", "both"):
for e in self._store.outgoing_edges(node_id):
edge_pairs.append((e, e.get("target_id", "")))
if direction in ("in", "both"):
for e in self._store.incoming_edges(node_id):
edge_pairs.append((e, e.get("source_id", "")))
for edge, target_id in edge_pairs:
if edge_type is not None and edge.get("edge_type") != edge_type:
continue
if target_id and target_id not in seen:
target = self._store.get_node(target_id)
if target is not None:
targets.append(target)
seen.add(target_id)
self._working_set = targets
return self
def limit(self, n: int) -> Query:
if self._working_set is not None:
self._working_set = self._working_set[:n]
return self
def collect(self) -> list[dict[str, Any]]:
return self._working_set or []
def collect_ids(self) -> list[str]:
if self._working_set is None:
return []
return [
item.get("node_id") or item.get("edge_id") or ""
for item in self._working_set
]
def count(self) -> int:
return len(self._working_set) if self._working_set is not None else 0
def first(self) -> dict[str, Any] | None:
if self._working_set and len(self._working_set) > 0:
return self._working_set[0]
return None
def __len__(self) -> int:
return self.count()
def __iter__(self):
return iter(self._working_set or [])