gam_terms/analytic_penalties/
manifest.rs1use super::*;
18
19pub trait PenaltyManifest: AnalyticPenalty {
20 const KIND_TAG: &'static str;
21 const PYTHON_WRAPPER: &'static str;
22 const ROW_BLOCK_DIAGONAL: bool;
23
24 fn dispatch_tier(&self) -> PenaltyTier {
25 self.tier()
26 }
27}
28
29impl PenaltyManifest for ARDPenalty {
30 const KIND_TAG: &'static str = "ard";
31 const PYTHON_WRAPPER: &'static str = "ARDPenalty";
32 const ROW_BLOCK_DIAGONAL: bool = true;
33}
34
35impl PenaltyManifest for BlockOrthogonalityPenalty {
36 const KIND_TAG: &'static str = "block_orthogonality";
37 const PYTHON_WRAPPER: &'static str = "BlockOrthogonalityPenalty";
38 const ROW_BLOCK_DIAGONAL: bool = false;
39}
40
41impl PenaltyManifest for BlockSparsityPenalty {
42 const KIND_TAG: &'static str = "block_sparsity";
43 const PYTHON_WRAPPER: &'static str = "BlockSparsityPenalty";
44 const ROW_BLOCK_DIAGONAL: bool = false;
45}
46
47impl PenaltyManifest for DecoderIncoherencePenalty {
48 const KIND_TAG: &'static str = "decoder_incoherence";
49 const PYTHON_WRAPPER: &'static str = "DecoderIncoherencePenalty";
50 const ROW_BLOCK_DIAGONAL: bool = false;
51}
52
53impl PenaltyManifest for IBPAssignmentPenalty {
54 const KIND_TAG: &'static str = "ibp_assignment";
55 const PYTHON_WRAPPER: &'static str = "IBPAssignmentPenalty";
56 const ROW_BLOCK_DIAGONAL: bool = true;
57}
58
59impl PenaltyManifest for IsometryPenalty {
60 const KIND_TAG: &'static str = "isometry";
61 const PYTHON_WRAPPER: &'static str = "IsometryPenalty";
62 const ROW_BLOCK_DIAGONAL: bool = false;
63}
64
65impl PenaltyManifest for IvaeRidgeMeanGauge {
66 const KIND_TAG: &'static str = "ivae_ridge_mean_gauge";
67 const PYTHON_WRAPPER: &'static str = "IvaeRidgeMeanGauge";
68 const ROW_BLOCK_DIAGONAL: bool = false;
69}
70
71impl PenaltyManifest for JumpReLUPenalty {
72 const KIND_TAG: &'static str = "jumprelu";
73 const PYTHON_WRAPPER: &'static str = "JumpReLUPenalty";
74 const ROW_BLOCK_DIAGONAL: bool = true;
75}
76
77impl PenaltyManifest for MechanismSparsityPenalty {
78 const KIND_TAG: &'static str = "mechanism_sparsity";
79 const PYTHON_WRAPPER: &'static str = "MechanismSparsityPenalty";
80 const ROW_BLOCK_DIAGONAL: bool = false;
81}
82
83impl PenaltyManifest for ShapeMonotonicityPenalty {
84 const KIND_TAG: &'static str = "monotonicity";
85 const PYTHON_WRAPPER: &'static str = "MonotonicityPenalty";
86 const ROW_BLOCK_DIAGONAL: bool = false;
87}
88
89impl PenaltyManifest for NestedPrefixPenalty {
90 const KIND_TAG: &'static str = "nested_prefix";
91 const PYTHON_WRAPPER: &'static str = "NestedPrefixPenalty";
92 const ROW_BLOCK_DIAGONAL: bool = true;
93}
94
95impl PenaltyManifest for NuclearNormPenalty {
96 const KIND_TAG: &'static str = "nuclear_norm";
97 const PYTHON_WRAPPER: &'static str = "NuclearNormPenalty";
98 const ROW_BLOCK_DIAGONAL: bool = false;
99}
100
101impl PenaltyManifest for OrthogonalityPenalty {
102 const KIND_TAG: &'static str = "orthogonality";
103 const PYTHON_WRAPPER: &'static str = "OrthogonalityPenalty";
104 const ROW_BLOCK_DIAGONAL: bool = false;
105}
106
107impl PenaltyManifest for ParametricRowPrecisionPriorPenalty {
108 const KIND_TAG: &'static str = "parametric_row_precision_prior";
109 const PYTHON_WRAPPER: &'static str = "ParametricAuxConditionalPriorPenalty";
110 const ROW_BLOCK_DIAGONAL: bool = true;
111}
112
113impl PenaltyManifest for RowPrecisionPriorPenalty {
114 const KIND_TAG: &'static str = "row_precision_prior";
115 const PYTHON_WRAPPER: &'static str = "AuxConditionalPriorPenalty";
116 const ROW_BLOCK_DIAGONAL: bool = true;
117}
118
119impl PenaltyManifest for ScadMcpPenalty {
120 const KIND_TAG: &'static str = "scad_mcp";
121 const PYTHON_WRAPPER: &'static str = "ScadMcpPenalty";
122 const ROW_BLOCK_DIAGONAL: bool = true;
123}
124
125impl PenaltyManifest for SheafConsistencyPenalty {
126 const KIND_TAG: &'static str = "sheaf_consistency";
127 const PYTHON_WRAPPER: &'static str = "SheafConsistencyPenalty";
128 const ROW_BLOCK_DIAGONAL: bool = false;
129}
130
131impl PenaltyManifest for SoftmaxAssignmentSparsityPenalty {
132 const KIND_TAG: &'static str = "softmax_assignment_sparsity";
133 const PYTHON_WRAPPER: &'static str = "SoftmaxAssignmentSparsityPenalty";
134 const ROW_BLOCK_DIAGONAL: bool = true;
135}
136
137impl PenaltyManifest for SparsityPenalty {
138 const KIND_TAG: &'static str = "sparsity";
139 const PYTHON_WRAPPER: &'static str = "SparsityPenalty";
140 const ROW_BLOCK_DIAGONAL: bool = true;
141}
142
143impl PenaltyManifest for TopKActivationPenalty {
144 const KIND_TAG: &'static str = "topk_activation";
145 const PYTHON_WRAPPER: &'static str = "TopKActivationPenalty";
146 const ROW_BLOCK_DIAGONAL: bool = true;
147}
148
149impl PenaltyManifest for TotalVariationPenalty {
150 const KIND_TAG: &'static str = "total_variation";
151 const PYTHON_WRAPPER: &'static str = "TotalVariationPenalty";
152 const ROW_BLOCK_DIAGONAL: bool = false;
153}
154
155#[macro_export]
156macro_rules! analytic_penalty_registry {
157 ($macro:ident) => {
158 $macro! {
159 register!(Isometry, IsometryPenalty);
160 register!(Sparsity, SparsityPenalty);
161 register!(SoftmaxAssignmentSparsity, SoftmaxAssignmentSparsityPenalty);
162 register!(IBPAssignment, IBPAssignmentPenalty);
163 register!(Ard, ARDPenalty);
164 register!(TopKActivation, TopKActivationPenalty);
165 register!(JumpReLU, JumpReLUPenalty);
166 register!(TotalVariation, TotalVariationPenalty);
167 register!(NuclearNorm, NuclearNormPenalty);
168 register!(BlockSparsity, BlockSparsityPenalty);
169 register!(MechanismSparsity, MechanismSparsityPenalty);
170 register!(Monotonicity, ShapeMonotonicityPenalty);
171 register!(NestedPrefix, NestedPrefixPenalty);
172 register!(RowPrecisionPrior, RowPrecisionPriorPenalty);
173 register!(IvaeRidgeMeanGauge, IvaeRidgeMeanGauge);
174 register!(ParametricRowPrecisionPrior, ParametricRowPrecisionPriorPenalty);
175 register!(ScadMcp, ScadMcpPenalty);
176 register!(BlockOrthogonality, BlockOrthogonalityPenalty);
177 register!(DecoderIncoherence, DecoderIncoherencePenalty);
178 register!(Orthogonality, OrthogonalityPenalty);
179 register!(SheafConsistency, SheafConsistencyPenalty);
180 }
181 };
182}