Skip to main content

gam_terms/analytic_penalties/
manifest.rs

1//! Analytic penalty registry manifests.
2//!
3//! Add a primitive by implementing [`PenaltyManifest`] for its concrete
4//! penalty type here and registering it in [`analytic_penalty_registry`].
5//!
6//! Inlined into `gam-terms` during the #1521 crate carve. This module was the
7//! last `analytic_penalties` submodule still bridged to the old monolith via a
8//! `#[path = "../../../../src/terms/analytic_penalties/manifest.rs"]` shim; when
9//! `83dd4411c` deleted the monolith `src/terms/` tree the shim's target vanished
10//! and the crate stopped compiling (the trait + every `impl` + the
11//! `analytic_penalty_registry!` macro consumed by `registry.rs` live here). The
12//! content is unchanged from the monolith version except the explicit
13//! `crate::terms::analytic_penalties::{…}` import list is replaced by the
14//! crate-idiomatic `use super::*;` (every sibling submodule pulls the shared
15//! penalty types this way — see the `mod.rs` re-export block).
16
17use super::*;
18
19pub trait PenaltyManifest: AnalyticPenalty {
20    const KIND_TAG: &'static str;
21    const PYTHON_WRAPPER: &'static str;
22    const ROW_BLOCK_DIAGONAL: bool;
23
24    fn dispatch_tier(&self) -> PenaltyTier {
25        self.tier()
26    }
27}
28
29impl PenaltyManifest for ARDPenalty {
30    const KIND_TAG: &'static str = "ard";
31    const PYTHON_WRAPPER: &'static str = "ARDPenalty";
32    const ROW_BLOCK_DIAGONAL: bool = true;
33}
34
35impl PenaltyManifest for BlockOrthogonalityPenalty {
36    const KIND_TAG: &'static str = "block_orthogonality";
37    const PYTHON_WRAPPER: &'static str = "BlockOrthogonalityPenalty";
38    const ROW_BLOCK_DIAGONAL: bool = false;
39}
40
41impl PenaltyManifest for BlockSparsityPenalty {
42    const KIND_TAG: &'static str = "block_sparsity";
43    const PYTHON_WRAPPER: &'static str = "BlockSparsityPenalty";
44    const ROW_BLOCK_DIAGONAL: bool = false;
45}
46
47impl PenaltyManifest for DecoderIncoherencePenalty {
48    const KIND_TAG: &'static str = "decoder_incoherence";
49    const PYTHON_WRAPPER: &'static str = "DecoderIncoherencePenalty";
50    const ROW_BLOCK_DIAGONAL: bool = false;
51}
52
53impl PenaltyManifest for IBPAssignmentPenalty {
54    const KIND_TAG: &'static str = "ibp_assignment";
55    const PYTHON_WRAPPER: &'static str = "IBPAssignmentPenalty";
56    const ROW_BLOCK_DIAGONAL: bool = true;
57}
58
59impl PenaltyManifest for IsometryPenalty {
60    const KIND_TAG: &'static str = "isometry";
61    const PYTHON_WRAPPER: &'static str = "IsometryPenalty";
62    const ROW_BLOCK_DIAGONAL: bool = false;
63}
64
65impl PenaltyManifest for IvaeRidgeMeanGauge {
66    const KIND_TAG: &'static str = "ivae_ridge_mean_gauge";
67    const PYTHON_WRAPPER: &'static str = "IvaeRidgeMeanGauge";
68    const ROW_BLOCK_DIAGONAL: bool = false;
69}
70
71impl PenaltyManifest for JumpReLUPenalty {
72    const KIND_TAG: &'static str = "jumprelu";
73    const PYTHON_WRAPPER: &'static str = "JumpReLUPenalty";
74    const ROW_BLOCK_DIAGONAL: bool = true;
75}
76
77impl PenaltyManifest for MechanismSparsityPenalty {
78    const KIND_TAG: &'static str = "mechanism_sparsity";
79    const PYTHON_WRAPPER: &'static str = "MechanismSparsityPenalty";
80    const ROW_BLOCK_DIAGONAL: bool = false;
81}
82
83impl PenaltyManifest for ShapeMonotonicityPenalty {
84    const KIND_TAG: &'static str = "monotonicity";
85    const PYTHON_WRAPPER: &'static str = "MonotonicityPenalty";
86    const ROW_BLOCK_DIAGONAL: bool = false;
87}
88
89impl PenaltyManifest for NestedPrefixPenalty {
90    const KIND_TAG: &'static str = "nested_prefix";
91    const PYTHON_WRAPPER: &'static str = "NestedPrefixPenalty";
92    const ROW_BLOCK_DIAGONAL: bool = true;
93}
94
95impl PenaltyManifest for NuclearNormPenalty {
96    const KIND_TAG: &'static str = "nuclear_norm";
97    const PYTHON_WRAPPER: &'static str = "NuclearNormPenalty";
98    const ROW_BLOCK_DIAGONAL: bool = false;
99}
100
101impl PenaltyManifest for OrthogonalityPenalty {
102    const KIND_TAG: &'static str = "orthogonality";
103    const PYTHON_WRAPPER: &'static str = "OrthogonalityPenalty";
104    const ROW_BLOCK_DIAGONAL: bool = false;
105}
106
107impl PenaltyManifest for ParametricRowPrecisionPriorPenalty {
108    const KIND_TAG: &'static str = "parametric_row_precision_prior";
109    const PYTHON_WRAPPER: &'static str = "ParametricAuxConditionalPriorPenalty";
110    const ROW_BLOCK_DIAGONAL: bool = true;
111}
112
113impl PenaltyManifest for RowPrecisionPriorPenalty {
114    const KIND_TAG: &'static str = "row_precision_prior";
115    const PYTHON_WRAPPER: &'static str = "AuxConditionalPriorPenalty";
116    const ROW_BLOCK_DIAGONAL: bool = true;
117}
118
119impl PenaltyManifest for ScadMcpPenalty {
120    const KIND_TAG: &'static str = "scad_mcp";
121    const PYTHON_WRAPPER: &'static str = "ScadMcpPenalty";
122    const ROW_BLOCK_DIAGONAL: bool = true;
123}
124
125impl PenaltyManifest for SheafConsistencyPenalty {
126    const KIND_TAG: &'static str = "sheaf_consistency";
127    const PYTHON_WRAPPER: &'static str = "SheafConsistencyPenalty";
128    const ROW_BLOCK_DIAGONAL: bool = false;
129}
130
131impl PenaltyManifest for SoftmaxAssignmentSparsityPenalty {
132    const KIND_TAG: &'static str = "softmax_assignment_sparsity";
133    const PYTHON_WRAPPER: &'static str = "SoftmaxAssignmentSparsityPenalty";
134    const ROW_BLOCK_DIAGONAL: bool = true;
135}
136
137impl PenaltyManifest for SparsityPenalty {
138    const KIND_TAG: &'static str = "sparsity";
139    const PYTHON_WRAPPER: &'static str = "SparsityPenalty";
140    const ROW_BLOCK_DIAGONAL: bool = true;
141}
142
143impl PenaltyManifest for TopKActivationPenalty {
144    const KIND_TAG: &'static str = "topk_activation";
145    const PYTHON_WRAPPER: &'static str = "TopKActivationPenalty";
146    const ROW_BLOCK_DIAGONAL: bool = true;
147}
148
149impl PenaltyManifest for TotalVariationPenalty {
150    const KIND_TAG: &'static str = "total_variation";
151    const PYTHON_WRAPPER: &'static str = "TotalVariationPenalty";
152    const ROW_BLOCK_DIAGONAL: bool = false;
153}
154
155#[macro_export]
156macro_rules! analytic_penalty_registry {
157    ($macro:ident) => {
158        $macro! {
159            register!(Isometry, IsometryPenalty);
160            register!(Sparsity, SparsityPenalty);
161            register!(SoftmaxAssignmentSparsity, SoftmaxAssignmentSparsityPenalty);
162            register!(IBPAssignment, IBPAssignmentPenalty);
163            register!(Ard, ARDPenalty);
164            register!(TopKActivation, TopKActivationPenalty);
165            register!(JumpReLU, JumpReLUPenalty);
166            register!(TotalVariation, TotalVariationPenalty);
167            register!(NuclearNorm, NuclearNormPenalty);
168            register!(BlockSparsity, BlockSparsityPenalty);
169            register!(MechanismSparsity, MechanismSparsityPenalty);
170            register!(Monotonicity, ShapeMonotonicityPenalty);
171            register!(NestedPrefix, NestedPrefixPenalty);
172            register!(RowPrecisionPrior, RowPrecisionPriorPenalty);
173            register!(IvaeRidgeMeanGauge, IvaeRidgeMeanGauge);
174            register!(ParametricRowPrecisionPrior, ParametricRowPrecisionPriorPenalty);
175            register!(ScadMcp, ScadMcpPenalty);
176            register!(BlockOrthogonality, BlockOrthogonalityPenalty);
177            register!(DecoderIncoherence, DecoderIncoherencePenalty);
178            register!(Orthogonality, OrthogonalityPenalty);
179            register!(SheafConsistency, SheafConsistencyPenalty);
180        }
181    };
182}