pub mod cli;
pub mod config;
pub mod flow;
pub mod image_encoder;
pub mod mask_decoder;
pub mod mlp_ir;
pub mod preprocess;
pub mod profile;
pub mod prompt_encoder;
pub mod prompt_mask_ir;
#[allow(clippy::module_inception)]
pub mod sam;
pub mod transformer;
pub mod transformer_ir;
pub mod upscale_ir;
pub use config::{
EncoderKind, SAM_EMBED_HW, SAM_IMG_SIZE, SAM_PATCH_SIZE, SAM_PIXEL_MEAN, SAM_PIXEL_STD,
SAM_PROMPT_EMBED_DIM, SamConfig, SamDecoderConfig, SamEncoderConfig,
};
pub use flow::{SamEncoderBuilt, SamEncoderFlow, build_sam_encoder_built};
pub use image_encoder::{
NeckWeights, apply_neck_host, build_sam_encoder_graph, build_sam_encoder_hir,
};
pub use mask_decoder::{MaskDecoderWeights, mask_decoder_forward};
pub use preprocess::{SamPreprocessWeights, assemble_patch_tokens, preprocess_image};
pub use profile::{
SAM_PROFILE_FILE, sam_profile_default, sam_profile_near_weights, sam2_profile_default,
sam2_profile_near_weights, sam3_profile_default, sam3_profile_near_weights,
};
pub use prompt_encoder::{PromptEncoderOutput, PromptEncoderWeights, prompt_encoder_forward};
pub use sam::{MaskPrediction, SAM_MASK_IN_CHANS, Sam, sam_vit_b_config};
pub use rlx_runtime::Device;
pub use transformer::{TwoWayTransformerWeights, attention_forward, two_way_transformer_forward};
#[cfg(test)]
mod tests {
use super::*;
use rlx_core::weight_map::WeightMap;
use std::collections::HashMap;
fn synthetic_vit_b_weights() -> WeightMap {
let cfg = SamEncoderConfig::vit_b();
let e = cfg.embed_dim;
let dh = cfg.head_dim();
let int_dim = e * 4;
let hw = SAM_EMBED_HW;
let ws = cfg.window_size;
let ps = SAM_PATCH_SIZE;
let pd = 3 * ps * ps;
let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
let z = |n: usize| vec![0.0f32; n];
t.insert(
"image_encoder.patch_embed.proj.weight".into(),
(z(e * pd), vec![e, 3, ps, ps]),
);
t.insert(
"image_encoder.patch_embed.proj.bias".into(),
(z(e), vec![e]),
);
t.insert(
"image_encoder.pos_embed".into(),
(z(hw * hw * e), vec![1, hw, hw, e]),
);
for i in 0..cfg.depth {
let lp = format!("image_encoder.blocks.{i}");
let is_global = cfg.global_attn_indexes.contains(&i);
let rel_size = if is_global { hw } else { ws };
t.insert(format!("{lp}.norm1.weight"), (z(e), vec![e]));
t.insert(format!("{lp}.norm1.bias"), (z(e), vec![e]));
t.insert(
format!("{lp}.attn.qkv.weight"),
(z(3 * e * e), vec![3 * e, e]),
);
t.insert(format!("{lp}.attn.qkv.bias"), (z(3 * e), vec![3 * e]));
t.insert(format!("{lp}.attn.proj.weight"), (z(e * e), vec![e, e]));
t.insert(format!("{lp}.attn.proj.bias"), (z(e), vec![e]));
t.insert(
format!("{lp}.attn.rel_pos_h"),
(z((2 * rel_size - 1) * dh), vec![2 * rel_size - 1, dh]),
);
t.insert(
format!("{lp}.attn.rel_pos_w"),
(z((2 * rel_size - 1) * dh), vec![2 * rel_size - 1, dh]),
);
t.insert(format!("{lp}.norm2.weight"), (z(e), vec![e]));
t.insert(format!("{lp}.norm2.bias"), (z(e), vec![e]));
t.insert(
format!("{lp}.mlp.lin1.weight"),
(z(int_dim * e), vec![int_dim, e]),
);
t.insert(format!("{lp}.mlp.lin1.bias"), (z(int_dim), vec![int_dim]));
t.insert(
format!("{lp}.mlp.lin2.weight"),
(z(e * int_dim), vec![e, int_dim]),
);
t.insert(format!("{lp}.mlp.lin2.bias"), (z(e), vec![e]));
}
t.insert(
"image_encoder.neck.0.weight".into(),
(z(cfg.out_chans * e), vec![cfg.out_chans, e, 1, 1]),
);
t.insert(
"image_encoder.neck.0.bias".into(),
(z(cfg.out_chans), vec![cfg.out_chans]),
);
t.insert(
"image_encoder.neck.1.weight".into(),
(z(cfg.out_chans), vec![cfg.out_chans]),
);
t.insert(
"image_encoder.neck.1.bias".into(),
(z(cfg.out_chans), vec![cfg.out_chans]),
);
t.insert(
"image_encoder.neck.2.weight".into(),
(
z(cfg.out_chans * cfg.out_chans * 9),
vec![cfg.out_chans, cfg.out_chans, 3, 3],
),
);
t.insert(
"image_encoder.neck.3.weight".into(),
(z(cfg.out_chans), vec![cfg.out_chans]),
);
t.insert(
"image_encoder.neck.3.bias".into(),
(z(cfg.out_chans), vec![cfg.out_chans]),
);
WeightMap::from_tensors(t)
}
#[test]
fn encoder_graph_builds() {
let cfg = SamEncoderConfig::vit_b();
let mut wm = synthetic_vit_b_weights();
let (g, _params, _pre) = build_sam_encoder_graph(&cfg, &mut wm).unwrap();
assert_eq!(g.outputs.len(), 1);
let s = g.shape(g.outputs[0]);
let dims: Vec<usize> = s.dims().iter().map(|d| d.unwrap_static()).collect();
assert_eq!(dims, vec![1, cfg.out_chans, SAM_EMBED_HW, SAM_EMBED_HW]);
let leftovers: Vec<&str> = wm.keys().collect();
assert!(leftovers.is_empty(), "leftover weights: {leftovers:?}");
}
#[test]
fn sam_rlx_toml_profile_loads() {
use rlx_flow::CompileProfile;
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src/sam.rlx.toml");
let p = CompileProfile::from_toml_path(&path).unwrap();
assert_eq!(p.fusion.policy, rlx_flow::FusionPolicyKind::Direct);
}
#[test]
fn preprocess_round_trip_shapes() {
let img = vec![128u8; 100 * 80 * 3];
let (nchw, (h, w)) = preprocess_image(&img, 100, 80);
assert_eq!(nchw.len(), 3 * 1024 * 1024);
assert_eq!(h, 1024);
assert_eq!(w, (80.0_f32 * (1024.0 / 100.0)).round() as usize);
}
}