use crate::sam3::{Sam3ImagePrediction, Sam3VideoFramePrediction, Sam3VideoState};
use anyhow::Result;
use rlx_core::weight_map::WeightMap;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct Sam3TrackerWeights {
pub loaded: bool,
pub raw: HashMap<String, (Vec<f32>, Vec<usize>)>,
}
pub fn extract_tracker_weights(weights: &mut WeightMap) -> Result<Sam3TrackerWeights> {
let prefixes = ["tracker.", "detector.tracker."];
let mut owned = HashMap::new();
let mut to_take: Vec<String> = Vec::new();
for key in weights.keys() {
for p in prefixes {
if let Some(suffix) = key.strip_prefix(p) {
to_take.push(key.to_string());
let _ = suffix;
break;
}
}
}
for full in to_take {
let suffix = full
.strip_prefix("detector.tracker.")
.or_else(|| full.strip_prefix("tracker."))
.unwrap()
.to_string();
if let Ok(tensor) = weights.take(&full) {
owned.insert(suffix, tensor);
}
}
Ok(Sam3TrackerWeights {
loaded: !owned.is_empty(),
raw: owned,
})
}
pub fn tracker_forward_native(
_weights: &Sam3TrackerWeights,
state: &mut Sam3VideoState,
image: Sam3ImagePrediction,
) -> Sam3VideoFramePrediction {
let frame_index = state.frame_index;
state.frame_index += 1;
state.last_prediction = Some(image.clone());
state.memory_tokens.push(image.scores.clone());
Sam3VideoFramePrediction {
frame_index,
image,
memory_len: state.memory_tokens.len(),
}
}