use std::collections::HashMap;
#[allow(dead_code)]
const POSITIONS_URL: &str =
"https://huggingface.co/brain-bzh/reve-positions/resolve/main/positions.json";
pub struct PositionBank {
positions: HashMap<String, [f32; 3]>,
}
impl PositionBank {
pub fn from_json(path: &str) -> anyhow::Result<Self> {
let data = std::fs::read_to_string(path)?;
let map: HashMap<String, [f32; 3]> = serde_json::from_str(&data)?;
Ok(Self { positions: map })
}
pub fn from_json_str(json: &str) -> anyhow::Result<Self> {
let map: HashMap<String, [f32; 3]> = serde_json::from_str(json)?;
Ok(Self { positions: map })
}
#[cfg(feature = "hf-download")]
pub fn download_and_cache(cache_dir: Option<&str>) -> anyhow::Result<Self> {
let cache_path = match cache_dir {
Some(dir) => PathBuf::from(dir).join("reve_positions.json"),
None => {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
PathBuf::from(home).join(".cache").join("reve-rs").join("reve_positions.json")
}
};
if cache_path.exists() {
if let Ok(bank) = Self::from_json(cache_path.to_str().unwrap()) {
return Ok(bank);
}
}
let resp = ureq::get(POSITIONS_URL).call()?;
let body = resp.into_string()?;
if let Some(parent) = cache_path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(&cache_path, &body)?;
Self::from_json_str(&body)
}
pub fn get_positions(&self, channel_names: &[&str]) -> Vec<f32> {
let mut result = Vec::with_capacity(channel_names.len() * 3);
for name in channel_names {
if let Some(pos) = self.positions.get(*name) {
result.extend_from_slice(pos);
} else {
eprintln!("Warning: channel '{}' not found in position bank, using [0,0,0]", name);
result.extend_from_slice(&[0.0, 0.0, 0.0]);
}
}
result
}
pub fn channel_names(&self) -> Vec<&str> {
self.positions.keys().map(|s| s.as_str()).collect()
}
pub fn len(&self) -> usize {
self.positions.len()
}
pub fn is_empty(&self) -> bool {
self.positions.is_empty()
}
}