1pub mod cli;
50pub mod config;
51pub mod flow;
52pub mod image_encoder;
53pub mod mask_decoder;
54pub mod mlp_ir;
55pub mod preprocess;
56pub mod profile;
57pub mod prompt_encoder;
58pub mod prompt_mask_ir;
59#[allow(clippy::module_inception)]
60pub mod sam;
61pub mod transformer;
62pub mod transformer_ir;
63pub mod upscale_ir;
64
65pub use config::{
66 EncoderKind, SAM_EMBED_HW, SAM_IMG_SIZE, SAM_PATCH_SIZE, SAM_PIXEL_MEAN, SAM_PIXEL_STD,
67 SAM_PROMPT_EMBED_DIM, SamConfig, SamDecoderConfig, SamEncoderConfig,
68};
69pub use flow::{SamEncoderBuilt, SamEncoderFlow, build_sam_encoder_built};
70pub use image_encoder::{
71 NeckWeights, apply_neck_host, build_sam_encoder_graph, build_sam_encoder_hir,
72};
73pub use mask_decoder::{MaskDecoderWeights, mask_decoder_forward};
74pub use preprocess::{SamPreprocessWeights, assemble_patch_tokens, preprocess_image};
75pub use profile::{
76 SAM_PROFILE_FILE, sam_profile_default, sam_profile_near_weights, sam2_profile_default,
77 sam2_profile_near_weights, sam3_profile_default, sam3_profile_near_weights,
78};
79pub use prompt_encoder::{PromptEncoderOutput, PromptEncoderWeights, prompt_encoder_forward};
80pub use sam::{MaskPrediction, SAM_MASK_IN_CHANS, Sam, sam_vit_b_config};
81
82pub use rlx_runtime::Device;
85pub use transformer::{TwoWayTransformerWeights, attention_forward, two_way_transformer_forward};
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use rlx_core::weight_map::WeightMap;
91 use std::collections::HashMap;
92
93 fn synthetic_vit_b_weights() -> WeightMap {
97 let cfg = SamEncoderConfig::vit_b();
98 let e = cfg.embed_dim;
99 let dh = cfg.head_dim();
100 let int_dim = e * 4;
101 let hw = SAM_EMBED_HW;
102 let ws = cfg.window_size;
103 let ps = SAM_PATCH_SIZE;
104 let pd = 3 * ps * ps;
105
106 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
107 let z = |n: usize| vec![0.0f32; n];
108
109 t.insert(
110 "image_encoder.patch_embed.proj.weight".into(),
111 (z(e * pd), vec![e, 3, ps, ps]),
112 );
113 t.insert(
114 "image_encoder.patch_embed.proj.bias".into(),
115 (z(e), vec![e]),
116 );
117 t.insert(
118 "image_encoder.pos_embed".into(),
119 (z(hw * hw * e), vec![1, hw, hw, e]),
120 );
121
122 for i in 0..cfg.depth {
123 let lp = format!("image_encoder.blocks.{i}");
124 let is_global = cfg.global_attn_indexes.contains(&i);
125 let rel_size = if is_global { hw } else { ws };
126
127 t.insert(format!("{lp}.norm1.weight"), (z(e), vec![e]));
128 t.insert(format!("{lp}.norm1.bias"), (z(e), vec![e]));
129 t.insert(
130 format!("{lp}.attn.qkv.weight"),
131 (z(3 * e * e), vec![3 * e, e]),
132 );
133 t.insert(format!("{lp}.attn.qkv.bias"), (z(3 * e), vec![3 * e]));
134 t.insert(format!("{lp}.attn.proj.weight"), (z(e * e), vec![e, e]));
135 t.insert(format!("{lp}.attn.proj.bias"), (z(e), vec![e]));
136 t.insert(
137 format!("{lp}.attn.rel_pos_h"),
138 (z((2 * rel_size - 1) * dh), vec![2 * rel_size - 1, dh]),
139 );
140 t.insert(
141 format!("{lp}.attn.rel_pos_w"),
142 (z((2 * rel_size - 1) * dh), vec![2 * rel_size - 1, dh]),
143 );
144 t.insert(format!("{lp}.norm2.weight"), (z(e), vec![e]));
145 t.insert(format!("{lp}.norm2.bias"), (z(e), vec![e]));
146 t.insert(
147 format!("{lp}.mlp.lin1.weight"),
148 (z(int_dim * e), vec![int_dim, e]),
149 );
150 t.insert(format!("{lp}.mlp.lin1.bias"), (z(int_dim), vec![int_dim]));
151 t.insert(
152 format!("{lp}.mlp.lin2.weight"),
153 (z(e * int_dim), vec![e, int_dim]),
154 );
155 t.insert(format!("{lp}.mlp.lin2.bias"), (z(e), vec![e]));
156 }
157 t.insert(
159 "image_encoder.neck.0.weight".into(),
160 (z(cfg.out_chans * e), vec![cfg.out_chans, e, 1, 1]),
161 );
162 t.insert(
163 "image_encoder.neck.0.bias".into(),
164 (z(cfg.out_chans), vec![cfg.out_chans]),
165 );
166 t.insert(
167 "image_encoder.neck.1.weight".into(),
168 (z(cfg.out_chans), vec![cfg.out_chans]),
169 );
170 t.insert(
171 "image_encoder.neck.1.bias".into(),
172 (z(cfg.out_chans), vec![cfg.out_chans]),
173 );
174 t.insert(
175 "image_encoder.neck.2.weight".into(),
176 (
177 z(cfg.out_chans * cfg.out_chans * 9),
178 vec![cfg.out_chans, cfg.out_chans, 3, 3],
179 ),
180 );
181 t.insert(
182 "image_encoder.neck.3.weight".into(),
183 (z(cfg.out_chans), vec![cfg.out_chans]),
184 );
185 t.insert(
186 "image_encoder.neck.3.bias".into(),
187 (z(cfg.out_chans), vec![cfg.out_chans]),
188 );
189
190 WeightMap::from_tensors(t)
191 }
192
193 #[test]
194 fn encoder_graph_builds() {
195 let cfg = SamEncoderConfig::vit_b();
196 let mut wm = synthetic_vit_b_weights();
197 let (g, _params, _pre) = build_sam_encoder_graph(&cfg, &mut wm).unwrap();
198 assert_eq!(g.outputs.len(), 1);
199 let s = g.shape(g.outputs[0]);
201 let dims: Vec<usize> = s.dims().iter().map(|d| d.unwrap_static()).collect();
202 assert_eq!(dims, vec![1, cfg.out_chans, SAM_EMBED_HW, SAM_EMBED_HW]);
203 let leftovers: Vec<&str> = wm.keys().collect();
205 assert!(leftovers.is_empty(), "leftover weights: {leftovers:?}");
206 }
207
208 #[test]
209 fn sam_rlx_toml_profile_loads() {
210 use rlx_flow::CompileProfile;
211 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src/sam.rlx.toml");
212 let p = CompileProfile::from_toml_path(&path).unwrap();
213 assert_eq!(p.fusion.policy, rlx_flow::FusionPolicyKind::Direct);
214 }
215
216 #[test]
217 fn preprocess_round_trip_shapes() {
218 let img = vec![128u8; 100 * 80 * 3];
221 let (nchw, (h, w)) = preprocess_image(&img, 100, 80);
222 assert_eq!(nchw.len(), 3 * 1024 * 1024);
223 assert_eq!(h, 1024);
224 assert_eq!(w, (80.0_f32 * (1024.0 / 100.0)).round() as usize);
225 }
226}