use crate::error::{PoKeysError, Result};
use log::{error, info, warn};
use notify::{EventKind, RecommendedWatcher, RecursiveMode, Watcher};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use std::time::Duration;
pub const DEFAULT_MODEL_DIR: &str = ".config/pokeys/models";
pub const DEFAULT_RETRY_INTERVAL: u64 = 10;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PinModel {
pub capabilities: Vec<String>,
#[serde(default = "default_active")]
pub active: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct DeviceModel {
pub name: String,
pub pins: HashMap<u8, PinModel>,
}
fn default_active() -> bool {
true
}
impl DeviceModel {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = fs::read_to_string(path.as_ref()).map_err(|e| {
PoKeysError::ModelLoadError(path.as_ref().to_string_lossy().to_string(), e.to_string())
})?;
let model: DeviceModel = serde_yaml::from_str(&content).map_err(|e| {
PoKeysError::ModelParseError(path.as_ref().to_string_lossy().to_string(), e.to_string())
})?;
model.validate()?;
Ok(model)
}
pub fn validate(&self) -> Result<()> {
if self.name.is_empty() {
return Err(PoKeysError::ModelValidationError(
"Model name cannot be empty".to_string(),
));
}
if self.pins.is_empty() {
return Err(PoKeysError::ModelValidationError(
"Model must define at least one pin".to_string(),
));
}
for (pin_num, pin) in &self.pins {
if pin.capabilities.is_empty() {
return Err(PoKeysError::ModelValidationError(format!(
"Pin {} must have at least one capability",
pin_num
)));
}
}
self.validate_related_capabilities()?;
Ok(())
}
fn validate_related_capabilities(&self) -> Result<()> {
self.validate_encoder_pairs()?;
self.validate_matrix_keyboard()?;
self.validate_pwm_channels()?;
Ok(())
}
fn validate_encoder_pairs(&self) -> Result<()> {
let mut encoder_a_pins = HashMap::new();
for (pin_num, pin) in &self.pins {
for capability in &pin.capabilities {
if capability.starts_with("Encoder_") && capability.ends_with("A") {
let encoder_id = &capability[8..capability.len() - 1]; encoder_a_pins.insert(encoder_id.to_string(), *pin_num);
}
}
}
for (encoder_id, pin_a) in &encoder_a_pins {
let encoder_b_capability = format!("Encoder_{}B", encoder_id);
let mut found_b = false;
for (pin_num, pin) in &self.pins {
if *pin_num != *pin_a
&& pin
.capabilities
.iter()
.any(|cap| cap == &encoder_b_capability)
{
found_b = true;
break;
}
}
if !found_b {
return Err(PoKeysError::ModelValidationError(format!(
"Encoder {}A on pin {} has no corresponding {}B pin",
encoder_id, pin_a, encoder_b_capability
)));
}
}
Ok(())
}
fn validate_matrix_keyboard(&self) -> Result<()> {
let mut has_rows = false;
let mut has_columns = false;
for pin in self.pins.values() {
for capability in &pin.capabilities {
if capability.starts_with("MatrixKeyboard_Row") {
has_rows = true;
} else if capability.starts_with("MatrixKeyboard_Col") {
has_columns = true;
}
}
}
if has_rows && !has_columns {
return Err(PoKeysError::ModelValidationError(
"Matrix keyboard has rows but no columns".to_string(),
));
}
if !has_rows && has_columns {
return Err(PoKeysError::ModelValidationError(
"Matrix keyboard has columns but no rows".to_string(),
));
}
Ok(())
}
fn validate_pwm_channels(&self) -> Result<()> {
let mut pwm_channels = Vec::new();
for (pin_num, pin) in &self.pins {
for capability in &pin.capabilities {
if let Some(stripped) = capability.strip_prefix("PWM_") {
if let Ok(channel) = stripped.parse::<u32>() {
pwm_channels.push((channel, *pin_num));
}
}
}
}
pwm_channels.sort_by_key(|(channel, _)| *channel);
for (i, (channel, pin)) in pwm_channels.iter().enumerate() {
if *channel != (i + 1) as u32 {
return Err(PoKeysError::ModelValidationError(format!(
"PWM channels are not sequential: expected channel {}, found channel {} on pin {}",
i + 1,
channel,
pin
)));
}
}
Ok(())
}
pub fn is_pin_capability_supported(&self, pin_num: u8, capability: &str) -> bool {
if let Some(pin) = self.pins.get(&pin_num) {
pin.capabilities.iter().any(|cap| cap == capability)
} else {
false
}
}
pub fn get_pin_capabilities(&self, pin_num: u8) -> Vec<String> {
if let Some(pin) = self.pins.get(&pin_num) {
pin.capabilities.clone()
} else {
Vec::new()
}
}
pub fn get_related_capabilities(&self, pin_num: u8, capability: &str) -> Vec<(String, u8)> {
let mut related = Vec::new();
if capability.starts_with("Encoder_") && capability.len() >= 10 {
let encoder_id = &capability[8..capability.len() - 1]; let role = &capability[capability.len() - 1..];
let related_role = if role == "A" { "B" } else { "A" };
let related_capability = format!("Encoder_{}{}", encoder_id, related_role);
for (other_pin, pin_model) in &self.pins {
if *other_pin != pin_num
&& pin_model
.capabilities
.iter()
.any(|cap| cap == &related_capability)
{
related.push((related_capability, *other_pin));
break;
}
}
}
if capability.starts_with("MatrixKeyboard_Row") {
for (other_pin, pin_model) in &self.pins {
if *other_pin != pin_num {
for cap in &pin_model.capabilities {
if cap.starts_with("MatrixKeyboard_Col") {
related.push((cap.clone(), *other_pin));
}
}
}
}
}
if capability.starts_with("MatrixKeyboard_Col") {
for (other_pin, pin_model) in &self.pins {
if *other_pin != pin_num {
for cap in &pin_model.capabilities {
if cap.starts_with("MatrixKeyboard_Row") {
related.push((cap.clone(), *other_pin));
}
}
}
}
}
related
}
pub fn validate_pin_capability(&self, pin_num: u8, capability: &str) -> Result<()> {
if !self.is_pin_capability_supported(pin_num, capability) {
return Err(PoKeysError::UnsupportedPinCapability(
pin_num,
capability.to_string(),
));
}
let related = self.get_related_capabilities(pin_num, capability);
if capability.starts_with("Encoder_") && capability.ends_with("A") {
let encoder_id = &capability[8..capability.len() - 1]; let encoder_b_capability = format!("Encoder_{}B", encoder_id);
let mut found_b = false;
for (related_cap, related_pin) in &related {
if related_cap == &encoder_b_capability {
found_b = true;
if let Some(pin_model) = self.pins.get(related_pin) {
if !pin_model.active {
return Err(PoKeysError::RelatedPinInactive(
*related_pin,
related_cap.clone(),
));
}
}
break;
}
}
if !found_b {
return Err(PoKeysError::MissingRelatedCapability(
pin_num,
capability.to_string(),
encoder_b_capability,
));
}
}
if capability.starts_with("Encoder_") && capability.ends_with("B") {
let encoder_id = &capability[8..capability.len() - 1]; let encoder_a_capability = format!("Encoder_{}A", encoder_id);
let mut found_a = false;
for (related_cap, related_pin) in &related {
if related_cap == &encoder_a_capability {
found_a = true;
if let Some(pin_model) = self.pins.get(related_pin) {
if !pin_model.active {
return Err(PoKeysError::RelatedPinInactive(
*related_pin,
related_cap.clone(),
));
}
}
break;
}
}
if !found_a {
return Err(PoKeysError::MissingRelatedCapability(
pin_num,
capability.to_string(),
encoder_a_capability,
));
}
}
if capability.starts_with("MatrixKeyboard_Row") {
let mut found_col = false;
for pin in self.pins.values() {
if pin.active
&& pin
.capabilities
.iter()
.any(|cap| cap.starts_with("MatrixKeyboard_Col"))
{
found_col = true;
break;
}
}
if !found_col {
return Err(PoKeysError::MissingRelatedCapability(
pin_num,
capability.to_string(),
"MatrixKeyboard_Col".to_string(),
));
}
}
if capability.starts_with("MatrixKeyboard_Col") {
let mut found_row = false;
for pin in self.pins.values() {
if pin.active
&& pin
.capabilities
.iter()
.any(|cap| cap.starts_with("MatrixKeyboard_Row"))
{
found_row = true;
break;
}
}
if !found_row {
return Err(PoKeysError::MissingRelatedCapability(
pin_num,
capability.to_string(),
"MatrixKeyboard_Row".to_string(),
));
}
}
Ok(())
}
pub fn validate_led_matrix_config(
&self,
config: &crate::matrix::LedMatrixConfig,
) -> Result<()> {
if config.matrix_id < 1 || config.matrix_id > 2 {
return Err(PoKeysError::ModelValidationError(format!(
"Invalid matrix ID: {}. Must be 1 or 2",
config.matrix_id
)));
}
let pins = match config.matrix_id {
1 => crate::matrix::LED_MATRIX_1_PINS,
2 => crate::matrix::LED_MATRIX_2_PINS,
_ => {
return Err(PoKeysError::ModelValidationError(format!(
"Invalid matrix ID: {}",
config.matrix_id
)));
}
};
for &pin in &pins {
if !self.is_pin_capability_supported(pin, "DigitalOutput") {
return Err(PoKeysError::ModelValidationError(format!(
"Pin {} does not support DigitalOutput capability required for LED matrix {}",
pin, config.matrix_id
)));
}
}
Ok(())
}
pub fn reserve_led_matrix_pins(&mut self, matrix_id: u8) -> Result<()> {
let pins = match matrix_id {
1 => crate::matrix::LED_MATRIX_1_PINS,
2 => crate::matrix::LED_MATRIX_2_PINS,
_ => {
return Err(PoKeysError::ModelValidationError(format!(
"Invalid matrix ID: {}",
matrix_id
)));
}
};
for &pin in &pins {
if !self.is_pin_capability_supported(pin, "DigitalOutput") {
return Err(PoKeysError::ModelValidationError(format!(
"Cannot reserve pin {} for LED matrix {}: pin does not support DigitalOutput",
pin, matrix_id
)));
}
}
Ok(())
}
}
pub fn get_default_model_dir() -> PathBuf {
let mut path = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
path.push(DEFAULT_MODEL_DIR);
path
}
pub fn get_model_path(device_name: &str, model_dir: Option<&Path>) -> PathBuf {
let dir = model_dir
.map(Path::to_path_buf)
.unwrap_or_else(get_default_model_dir);
dir.join(format!("{}.yaml", device_name))
}
pub fn load_model(device_name: &str, model_dir: Option<&Path>) -> Result<DeviceModel> {
let path = get_model_path(device_name, model_dir);
DeviceModel::from_file(path)
}
pub struct ModelMonitor {
watcher: Option<RecommendedWatcher>,
watch_dir: PathBuf,
models: Arc<RwLock<HashMap<String, DeviceModel>>>,
callback: Arc<dyn Fn(String, DeviceModel) + Send + Sync + 'static>,
running: bool,
}
impl ModelMonitor {
pub fn new<F>(model_dir: PathBuf, callback: F) -> Self
where
F: Fn(String, DeviceModel) + Send + Sync + 'static,
{
Self {
watcher: None,
watch_dir: model_dir,
models: Arc::new(RwLock::new(HashMap::new())),
callback: Arc::new(callback),
running: false,
}
}
pub fn start(&mut self) -> Result<()> {
if self.running {
return Ok(());
}
if !self.watch_dir.exists() {
fs::create_dir_all(&self.watch_dir).map_err(|e| {
PoKeysError::ModelDirCreateError(
self.watch_dir.to_string_lossy().to_string(),
e.to_string(),
)
})?;
}
self.load_existing_models()?;
let (tx, rx) = std::sync::mpsc::channel();
let mut watcher = notify::recommended_watcher(tx)
.map_err(|e| PoKeysError::ModelWatcherError(e.to_string()))?;
watcher
.watch(&self.watch_dir, RecursiveMode::NonRecursive)
.map_err(|e| PoKeysError::ModelWatcherError(e.to_string()))?;
self.watcher = Some(watcher);
self.running = true;
let models = self.models.clone();
let callback = self.callback.clone();
std::thread::spawn(move || {
let mut debouncer = HashMap::new();
for res in rx {
match res {
Ok(event) => {
if let EventKind::Modify(_) | EventKind::Create(_) = event.kind {
for path in event.paths {
if path.extension().is_some_and(|ext| ext == "yaml") {
let now = std::time::Instant::now();
let path_str = path.to_string_lossy().to_string();
if debouncer.get(&path_str).is_none_or(|last| {
now.duration_since(*last) > Duration::from_millis(100)
}) {
debouncer.insert(path_str, now);
if let Some(file_name) = path.file_stem() {
let device_name =
file_name.to_string_lossy().to_string();
match DeviceModel::from_file(&path) {
Ok(model) => {
{
let mut models = models.write().unwrap();
models.insert(
device_name.clone(),
model.clone(),
);
}
callback(device_name, model);
}
Err(e) => {
error!(
"Failed to load model from {}: {}",
path.display(),
e
);
}
}
}
}
}
}
}
}
Err(e) => {
error!("Watch error: {:?}", e);
}
}
}
});
info!(
"Model monitor started, watching directory: {}",
self.watch_dir.display()
);
Ok(())
}
pub fn stop(&mut self) -> Result<()> {
if !self.running {
return Ok(());
}
self.watcher = None;
self.running = false;
info!("Model monitor stopped");
Ok(())
}
fn load_existing_models(&self) -> Result<()> {
if !self.watch_dir.exists() {
return Ok(());
}
let entries = fs::read_dir(&self.watch_dir).map_err(|e| {
PoKeysError::ModelDirReadError(
self.watch_dir.to_string_lossy().to_string(),
e.to_string(),
)
})?;
for entry in entries {
let entry = entry.map_err(|e| {
PoKeysError::ModelDirReadError(
self.watch_dir.to_string_lossy().to_string(),
e.to_string(),
)
})?;
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "yaml") {
if let Some(file_name) = path.file_stem() {
let device_name = file_name.to_string_lossy().to_string();
match DeviceModel::from_file(&path) {
Ok(model) => {
{
let mut models = self.models.write().unwrap();
models.insert(device_name.clone(), model.clone());
}
(self.callback)(device_name, model);
}
Err(e) => {
warn!("Failed to load model from {}: {}", path.display(), e);
}
}
}
}
}
Ok(())
}
pub fn get_model(&self, device_name: &str) -> Option<DeviceModel> {
let models = self.models.read().unwrap();
models.get(device_name).cloned()
}
pub fn get_all_models(&self) -> HashMap<String, DeviceModel> {
let models = self.models.read().unwrap();
models.clone()
}
}
pub fn copy_default_models_to_user_dir(model_dir: Option<&Path>) -> Result<()> {
let dir = model_dir
.map(Path::to_path_buf)
.unwrap_or_else(get_default_model_dir);
if !dir.exists() {
fs::create_dir_all(&dir).map_err(|e| {
PoKeysError::ModelDirCreateError(dir.to_string_lossy().to_string(), e.to_string())
})?;
}
let default_models = [
"PoKeys56U.yaml",
"PoKeys57U.yaml",
"PoKeys56E.yaml",
"PoKeys57E.yaml",
];
let package_model_dir = std::env::current_exe()
.map_err(|e| {
PoKeysError::ModelDirReadError(
"Failed to get current executable path".to_string(),
e.to_string(),
)
})?
.parent()
.ok_or_else(|| {
PoKeysError::ModelDirReadError(
"Failed to get parent directory of executable".to_string(),
"No parent directory".to_string(),
)
})?
.join("models");
let package_model_dir = if !package_model_dir.exists() {
let crate_dir = std::env::var("CARGO_MANIFEST_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| {
PathBuf::from("pokeys-lib/models")
});
crate_dir.join("models")
} else {
package_model_dir
};
for model_file in &default_models {
let user_file_path = dir.join(model_file);
if user_file_path.exists() {
continue;
}
let package_file_path = package_model_dir.join(model_file);
if package_file_path.exists() {
fs::copy(&package_file_path, &user_file_path).map_err(|e| {
PoKeysError::ModelLoadError(
package_file_path.to_string_lossy().to_string(),
e.to_string(),
)
})?;
info!(
"Copied default model file {} to {}",
model_file,
user_file_path.display()
);
} else {
let source_file_path = PathBuf::from(format!("pokeys-lib/models/{}", model_file));
if source_file_path.exists() {
fs::copy(&source_file_path, &user_file_path).map_err(|e| {
PoKeysError::ModelLoadError(
source_file_path.to_string_lossy().to_string(),
e.to_string(),
)
})?;
info!(
"Copied default model file {} to {}",
model_file,
user_file_path.display()
);
} else {
warn!("Default model file {} not found", model_file);
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_device_model_validation() {
let mut model = DeviceModel {
name: "TestDevice".to_string(),
pins: HashMap::new(),
};
model.pins.insert(
1,
PinModel {
capabilities: vec!["DigitalInput".to_string(), "DigitalOutput".to_string()],
active: true,
},
);
model.pins.insert(
2,
PinModel {
capabilities: vec!["DigitalInput".to_string(), "AnalogInput".to_string()],
active: true,
},
);
assert!(model.validate().is_ok());
let mut invalid_model = model.clone();
invalid_model.name = "".to_string();
assert!(invalid_model.validate().is_err());
let invalid_model = DeviceModel {
name: "TestDevice".to_string(),
pins: HashMap::new(),
};
assert!(invalid_model.validate().is_err());
let mut invalid_model = model.clone();
invalid_model.pins.insert(
3,
PinModel {
capabilities: vec![],
active: true,
},
);
assert!(invalid_model.validate().is_err());
}
#[test]
fn test_related_capabilities() {
let mut model = DeviceModel {
name: "TestDevice".to_string(),
pins: HashMap::new(),
};
model.pins.insert(
1,
PinModel {
capabilities: vec!["DigitalInput".to_string(), "Encoder_1A".to_string()],
active: true,
},
);
model.pins.insert(
2,
PinModel {
capabilities: vec!["DigitalInput".to_string(), "Encoder_1B".to_string()],
active: true,
},
);
assert!(model.validate().is_ok());
let related = model.get_related_capabilities(1, "Encoder_1A");
assert_eq!(related.len(), 1);
assert_eq!(related[0].0, "Encoder_1B");
assert_eq!(related[0].1, 2);
let mut invalid_model = model.clone();
invalid_model.pins.get_mut(&2).unwrap().capabilities = vec!["DigitalInput".to_string()];
}
#[test]
fn test_matrix_keyboard_validation() {
let mut model = DeviceModel {
name: "TestDevice".to_string(),
pins: HashMap::new(),
};
model.pins.insert(
1,
PinModel {
capabilities: vec![
"DigitalInput".to_string(),
"MatrixKeyboard_Row1".to_string(),
],
active: true,
},
);
model.pins.insert(
2,
PinModel {
capabilities: vec![
"DigitalInput".to_string(),
"MatrixKeyboard_Row2".to_string(),
],
active: true,
},
);
model.pins.insert(
3,
PinModel {
capabilities: vec![
"DigitalInput".to_string(),
"MatrixKeyboard_Col1".to_string(),
],
active: true,
},
);
model.pins.insert(
4,
PinModel {
capabilities: vec![
"DigitalInput".to_string(),
"MatrixKeyboard_Col2".to_string(),
],
active: true,
},
);
assert!(model.validate().is_ok());
let related = model.get_related_capabilities(1, "MatrixKeyboard_Row1");
assert_eq!(related.len(), 2);
assert!(
related
.iter()
.any(|(cap, pin)| cap == "MatrixKeyboard_Col1" && *pin == 3)
);
assert!(
related
.iter()
.any(|(cap, pin)| cap == "MatrixKeyboard_Col2" && *pin == 4)
);
let mut invalid_model = model.clone();
invalid_model.pins.remove(&3);
invalid_model.pins.remove(&4);
}
#[test]
fn test_yaml_serialization() {
let mut model = DeviceModel {
name: "TestDevice".to_string(),
pins: HashMap::new(),
};
model.pins.insert(
1,
PinModel {
capabilities: vec!["DigitalInput".to_string(), "DigitalOutput".to_string()],
active: true,
},
);
model.pins.insert(
2,
PinModel {
capabilities: vec!["DigitalInput".to_string(), "AnalogInput".to_string()],
active: true,
},
);
let yaml = serde_yaml::to_string(&model).unwrap();
let deserialized: DeviceModel = serde_yaml::from_str(&yaml).unwrap();
assert_eq!(model.name, deserialized.name);
assert_eq!(model.pins.len(), deserialized.pins.len());
for (pin_num, pin) in &model.pins {
let deserialized_pin = deserialized.pins.get(pin_num).unwrap();
assert_eq!(pin.capabilities, deserialized_pin.capabilities);
assert_eq!(pin.active, deserialized_pin.active);
}
}
#[test]
fn test_model_file_loading() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("TestDevice.yaml");
let mut model = DeviceModel {
name: "TestDevice".to_string(),
pins: HashMap::new(),
};
model.pins.insert(
1,
PinModel {
capabilities: vec!["DigitalInput".to_string(), "DigitalOutput".to_string()],
active: true,
},
);
model.pins.insert(
2,
PinModel {
capabilities: vec!["DigitalInput".to_string(), "AnalogInput".to_string()],
active: true,
},
);
let yaml = serde_yaml::to_string(&model).unwrap();
fs::write(&file_path, yaml).unwrap();
let loaded_model = DeviceModel::from_file(&file_path).unwrap();
assert_eq!(model.name, loaded_model.name);
assert_eq!(model.pins.len(), loaded_model.pins.len());
for (pin_num, pin) in &model.pins {
let loaded_pin = loaded_model.pins.get(pin_num).unwrap();
assert_eq!(pin.capabilities, loaded_pin.capabilities);
assert_eq!(pin.active, loaded_pin.active);
}
}
}