Skip to main content

rlx_sam/
profile.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tier-1 compile profiles for SAM v1 / SAM2 / SAM3 loaders.
17
18use std::path::Path;
19
20use rlx_flow::CompileProfile;
21
22use rlx_core::flow_bridge::profile_near_weights as load_profile_near_weights;
23
24/// Colocated with safetensors weights (`sam.rlx.toml`).
25pub const SAM_PROFILE_FILE: &str = "sam.rlx.toml";
26
27/// Load `sam.rlx.toml` next to weights, or [`CompileProfile::sam_encoder`].
28pub fn sam_profile_near_weights(weights: &Path) -> CompileProfile {
29    load_profile_near_weights(weights, SAM_PROFILE_FILE, CompileProfile::sam_encoder())
30}
31
32/// SAM3 detector graphs — same defaults as SAM encoder unless overridden on disk.
33pub fn sam3_profile_near_weights(weights: &Path) -> CompileProfile {
34    load_profile_near_weights(weights, SAM_PROFILE_FILE, CompileProfile::sam3())
35}
36
37/// SAM2 checkpoint graphs — loads `sam.rlx.toml` next to weights when present.
38pub fn sam2_profile_near_weights(weights: &Path) -> CompileProfile {
39    load_profile_near_weights(weights, SAM_PROFILE_FILE, CompileProfile::sam2())
40}
41
42pub fn sam2_profile_default() -> CompileProfile {
43    CompileProfile::sam2()
44}
45
46/// Profile search for tests (no weights directory).
47pub fn sam_profile_default() -> CompileProfile {
48    CompileProfile::sam_encoder()
49}
50
51/// Default SAM3 profile (tests / inline builds).
52pub fn sam3_profile_default() -> CompileProfile {
53    CompileProfile::sam3()
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use rlx_flow::FusionPolicyKind;
60
61    #[test]
62    fn sam_rlx_toml_loads_for_sam3() {
63        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src/sam.rlx.toml");
64        let p = CompileProfile::from_toml_path(&path).unwrap();
65        assert_eq!(p.fusion.policy, FusionPolicyKind::Direct);
66        assert!(p.passes.dce);
67        assert!(p.backend.cpu.unfuse_regions);
68    }
69}