def truncate_domain(
domain: DslPlanDomain,
truncation: Truncation,
) -> DslPlanDomain:
match truncation:
case Truncation.Filter(_):
for m in domain.margins:
m.invariant = None return domain
case Truncation.GroupBy(keys, aggs):
for agg in aggs:
check_infallible(agg, True)
def with_truncation(lf):
return lf.group_by(keys).agg(aggs)
def without_invariant(m):
m.invariant = None
return m
return FrameDomain.new_with_margins(
[ Seriesdomain.new_from_field(f)
for f in domain.simulate_schema(with_truncation)
],
margins=[
without_invariant(m.clone())
for m in domain.margins
if m.by.is_subset(HashSet.from_iter(keys))
],
)