pokeys_lib/
models.rs

1//! Device model definitions and validation
2//!
3//! This module provides structures and functions for loading and validating
4//! device models from YAML files. Device models define the capabilities of
5//! each pin on a PoKeys device, ensuring that users can only assign supported
6//! functions to pins.
7
8use crate::error::{PoKeysError, Result};
9use log::{error, info, warn};
10use notify::{EventKind, RecommendedWatcher, RecursiveMode, Watcher};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::fs;
14use std::path::{Path, PathBuf};
15use std::sync::{Arc, RwLock};
16use std::time::Duration;
17
18/// Default directory for device model files
19pub const DEFAULT_MODEL_DIR: &str = ".config/pokeys/models";
20
21/// Default retry interval for model loading (in seconds)
22pub const DEFAULT_RETRY_INTERVAL: u64 = 10;
23
24/// Pin model defining the capabilities of a single pin
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
26pub struct PinModel {
27    /// List of capabilities supported by this pin
28    pub capabilities: Vec<String>,
29
30    /// Whether the pin is active
31    #[serde(default = "default_active")]
32    pub active: bool,
33}
34
35/// Device model defining the capabilities of all pins on a device
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37pub struct DeviceModel {
38    /// Device model name
39    pub name: String,
40
41    /// Map of pin numbers to pin models
42    pub pins: HashMap<u8, PinModel>,
43}
44
45/// Default value for pin active state
46fn default_active() -> bool {
47    true
48}
49
50impl DeviceModel {
51    /// Load a device model from a YAML file
52    ///
53    /// # Arguments
54    ///
55    /// * `path` - Path to the YAML file
56    ///
57    /// # Returns
58    ///
59    /// * `Result<DeviceModel>` - The loaded device model or an error
60    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
61        let content = fs::read_to_string(path.as_ref()).map_err(|e| {
62            PoKeysError::ModelLoadError(path.as_ref().to_string_lossy().to_string(), e.to_string())
63        })?;
64
65        let model: DeviceModel = serde_yaml::from_str(&content).map_err(|e| {
66            PoKeysError::ModelParseError(path.as_ref().to_string_lossy().to_string(), e.to_string())
67        })?;
68
69        model.validate()?;
70
71        Ok(model)
72    }
73
74    /// Validate the device model
75    ///
76    /// Checks that the model is well-formed and that all related capabilities
77    /// are properly defined.
78    ///
79    /// # Returns
80    ///
81    /// * `Result<()>` - Ok if the model is valid, an error otherwise
82    pub fn validate(&self) -> Result<()> {
83        // Check that the model has a name
84        if self.name.is_empty() {
85            return Err(PoKeysError::ModelValidationError(
86                "Model name cannot be empty".to_string(),
87            ));
88        }
89
90        // Check that the model has at least one pin
91        if self.pins.is_empty() {
92            return Err(PoKeysError::ModelValidationError(
93                "Model must define at least one pin".to_string(),
94            ));
95        }
96
97        // Check that all pins have at least one capability
98        for (pin_num, pin) in &self.pins {
99            if pin.capabilities.is_empty() {
100                return Err(PoKeysError::ModelValidationError(format!(
101                    "Pin {} must have at least one capability",
102                    pin_num
103                )));
104            }
105        }
106
107        // Validate related capabilities
108        self.validate_related_capabilities()?;
109
110        Ok(())
111    }
112
113    /// Validate that all related capabilities are properly defined
114    ///
115    /// For example, if a pin has the capability "Encoder_1A", there must be
116    /// another pin with the capability "Encoder_1B".
117    ///
118    /// # Returns
119    ///
120    /// * `Result<()>` - Ok if all related capabilities are valid, an error otherwise
121    fn validate_related_capabilities(&self) -> Result<()> {
122        // Validate encoder pairs
123        self.validate_encoder_pairs()?;
124
125        // Validate matrix keyboard rows and columns
126        self.validate_matrix_keyboard()?;
127
128        // Validate PWM channels
129        self.validate_pwm_channels()?;
130
131        Ok(())
132    }
133
134    /// Validate encoder pairs
135    ///
136    /// Checks that for each encoder A pin, there is a corresponding encoder B pin.
137    ///
138    /// # Returns
139    ///
140    /// * `Result<()>` - Ok if all encoder pairs are valid, an error otherwise
141    fn validate_encoder_pairs(&self) -> Result<()> {
142        // Find all encoder A pins
143        let mut encoder_a_pins = HashMap::new();
144        for (pin_num, pin) in &self.pins {
145            for capability in &pin.capabilities {
146                if capability.starts_with("Encoder_") && capability.ends_with("A") {
147                    let encoder_id = &capability[8..capability.len() - 1]; // Extract "1" from "Encoder_1A"
148                    encoder_a_pins.insert(encoder_id.to_string(), *pin_num);
149                }
150            }
151        }
152
153        // Check that each encoder A pin has a corresponding encoder B pin
154        for (encoder_id, pin_a) in &encoder_a_pins {
155            let encoder_b_capability = format!("Encoder_{}B", encoder_id);
156            let mut found_b = false;
157
158            for (pin_num, pin) in &self.pins {
159                if *pin_num != *pin_a
160                    && pin
161                        .capabilities
162                        .iter()
163                        .any(|cap| cap == &encoder_b_capability)
164                {
165                    found_b = true;
166                    break;
167                }
168            }
169
170            if !found_b {
171                return Err(PoKeysError::ModelValidationError(format!(
172                    "Encoder {}A on pin {} has no corresponding {}B pin",
173                    encoder_id, pin_a, encoder_b_capability
174                )));
175            }
176        }
177
178        Ok(())
179    }
180
181    /// Validate matrix keyboard rows and columns
182    ///
183    /// Checks that if there are matrix keyboard rows, there are also matrix keyboard columns.
184    ///
185    /// # Returns
186    ///
187    /// * `Result<()>` - Ok if the matrix keyboard configuration is valid, an error otherwise
188    fn validate_matrix_keyboard(&self) -> Result<()> {
189        let mut has_rows = false;
190        let mut has_columns = false;
191
192        for pin in self.pins.values() {
193            for capability in &pin.capabilities {
194                if capability.starts_with("MatrixKeyboard_Row") {
195                    has_rows = true;
196                } else if capability.starts_with("MatrixKeyboard_Col") {
197                    has_columns = true;
198                }
199            }
200        }
201
202        if has_rows && !has_columns {
203            return Err(PoKeysError::ModelValidationError(
204                "Matrix keyboard has rows but no columns".to_string(),
205            ));
206        }
207
208        if !has_rows && has_columns {
209            return Err(PoKeysError::ModelValidationError(
210                "Matrix keyboard has columns but no rows".to_string(),
211            ));
212        }
213
214        Ok(())
215    }
216
217    /// Validate PWM channels
218    ///
219    /// Checks that PWM channels are properly defined.
220    ///
221    /// # Returns
222    ///
223    /// * `Result<()>` - Ok if the PWM channels are valid, an error otherwise
224    fn validate_pwm_channels(&self) -> Result<()> {
225        // For now, just check that PWM channels are numbered sequentially
226        let mut pwm_channels = Vec::new();
227
228        for (pin_num, pin) in &self.pins {
229            for capability in &pin.capabilities {
230                if let Some(stripped) = capability.strip_prefix("PWM_") {
231                    if let Ok(channel) = stripped.parse::<u32>() {
232                        pwm_channels.push((channel, *pin_num));
233                    }
234                }
235            }
236        }
237
238        // Sort by channel number
239        pwm_channels.sort_by_key(|(channel, _)| *channel);
240
241        // Check that channels are sequential starting from 1
242        for (i, (channel, pin)) in pwm_channels.iter().enumerate() {
243            if *channel != (i + 1) as u32 {
244                return Err(PoKeysError::ModelValidationError(format!(
245                    "PWM channels are not sequential: expected channel {}, found channel {} on pin {}",
246                    i + 1,
247                    channel,
248                    pin
249                )));
250            }
251        }
252
253        Ok(())
254    }
255
256    /// Check if a pin supports a specific capability
257    ///
258    /// # Arguments
259    ///
260    /// * `pin_num` - The pin number to check
261    /// * `capability` - The capability to check for
262    ///
263    /// # Returns
264    ///
265    /// * `bool` - True if the pin supports the capability, false otherwise
266    pub fn is_pin_capability_supported(&self, pin_num: u8, capability: &str) -> bool {
267        if let Some(pin) = self.pins.get(&pin_num) {
268            pin.capabilities.iter().any(|cap| cap == capability)
269        } else {
270            false
271        }
272    }
273
274    /// Get all capabilities for a pin
275    ///
276    /// # Arguments
277    ///
278    /// * `pin_num` - The pin number to get capabilities for
279    ///
280    /// # Returns
281    ///
282    /// * `Vec<String>` - List of capabilities supported by the pin
283    pub fn get_pin_capabilities(&self, pin_num: u8) -> Vec<String> {
284        if let Some(pin) = self.pins.get(&pin_num) {
285            pin.capabilities.clone()
286        } else {
287            Vec::new()
288        }
289    }
290
291    /// Get related capabilities for a specific capability
292    ///
293    /// For example, if the capability is "Encoder_1A", this will return
294    /// [("Encoder_1B", pin_num)] where pin_num is the pin that has the "Encoder_1B" capability.
295    ///
296    /// # Arguments
297    ///
298    /// * `pin_num` - The pin number with the capability
299    /// * `capability` - The capability to find related capabilities for
300    ///
301    /// # Returns
302    ///
303    /// * `Vec<(String, u8)>` - List of related capabilities and their pin numbers
304    pub fn get_related_capabilities(&self, pin_num: u8, capability: &str) -> Vec<(String, u8)> {
305        let mut related = Vec::new();
306
307        // Check for encoder capabilities
308        if capability.starts_with("Encoder_") && capability.len() >= 10 {
309            let encoder_id = &capability[8..capability.len() - 1]; // Extract "1" from "Encoder_1A"
310            let role = &capability[capability.len() - 1..]; // Extract "A" from "Encoder_1A"
311
312            let related_role = if role == "A" { "B" } else { "A" };
313            let related_capability = format!("Encoder_{}{}", encoder_id, related_role);
314
315            // Find the pin with the related capability
316            for (other_pin, pin_model) in &self.pins {
317                if *other_pin != pin_num
318                    && pin_model
319                        .capabilities
320                        .iter()
321                        .any(|cap| cap == &related_capability)
322                {
323                    related.push((related_capability, *other_pin));
324                    break;
325                }
326            }
327        }
328
329        // Check for matrix keyboard capabilities
330        if capability.starts_with("MatrixKeyboard_Row") {
331            // For a row, all columns are related
332            for (other_pin, pin_model) in &self.pins {
333                if *other_pin != pin_num {
334                    for cap in &pin_model.capabilities {
335                        if cap.starts_with("MatrixKeyboard_Col") {
336                            related.push((cap.clone(), *other_pin));
337                        }
338                    }
339                }
340            }
341        }
342
343        if capability.starts_with("MatrixKeyboard_Col") {
344            // For a column, all rows are related
345            for (other_pin, pin_model) in &self.pins {
346                if *other_pin != pin_num {
347                    for cap in &pin_model.capabilities {
348                        if cap.starts_with("MatrixKeyboard_Row") {
349                            related.push((cap.clone(), *other_pin));
350                        }
351                    }
352                }
353            }
354        }
355
356        related
357    }
358
359    /// Validate that a pin can be configured with a specific capability
360    ///
361    /// This checks both that the pin supports the capability and that any
362    /// related capabilities are properly configured.
363    ///
364    /// # Arguments
365    ///
366    /// * `pin_num` - The pin number to check
367    /// * `capability` - The capability to check for
368    ///
369    /// # Returns
370    ///
371    /// * `Result<()>` - Ok if the capability is valid, an error otherwise
372    pub fn validate_pin_capability(&self, pin_num: u8, capability: &str) -> Result<()> {
373        // Check that the pin exists and supports the capability
374        if !self.is_pin_capability_supported(pin_num, capability) {
375            return Err(PoKeysError::UnsupportedPinCapability(
376                pin_num,
377                capability.to_string(),
378            ));
379        }
380
381        // Check related capabilities
382        let related = self.get_related_capabilities(pin_num, capability);
383
384        // For encoder capabilities, ensure the related pin is configured
385        if capability.starts_with("Encoder_") && capability.ends_with("A") {
386            let encoder_id = &capability[8..capability.len() - 1]; // Extract "1" from "Encoder_1A"
387            let encoder_b_capability = format!("Encoder_{}B", encoder_id);
388
389            // Check if any pin has the B capability
390            let mut found_b = false;
391            for (related_cap, related_pin) in &related {
392                if related_cap == &encoder_b_capability {
393                    found_b = true;
394
395                    // Check if the related pin is active
396                    if let Some(pin_model) = self.pins.get(related_pin) {
397                        if !pin_model.active {
398                            return Err(PoKeysError::RelatedPinInactive(
399                                *related_pin,
400                                related_cap.clone(),
401                            ));
402                        }
403                    }
404
405                    break;
406                }
407            }
408
409            if !found_b {
410                return Err(PoKeysError::MissingRelatedCapability(
411                    pin_num,
412                    capability.to_string(),
413                    encoder_b_capability,
414                ));
415            }
416        }
417
418        // For encoder B capabilities, ensure the related A pin is configured
419        if capability.starts_with("Encoder_") && capability.ends_with("B") {
420            let encoder_id = &capability[8..capability.len() - 1]; // Extract "1" from "Encoder_1B"
421            let encoder_a_capability = format!("Encoder_{}A", encoder_id);
422
423            // Check if any pin has the A capability
424            let mut found_a = false;
425            for (related_cap, related_pin) in &related {
426                if related_cap == &encoder_a_capability {
427                    found_a = true;
428
429                    // Check if the related pin is active
430                    if let Some(pin_model) = self.pins.get(related_pin) {
431                        if !pin_model.active {
432                            return Err(PoKeysError::RelatedPinInactive(
433                                *related_pin,
434                                related_cap.clone(),
435                            ));
436                        }
437                    }
438
439                    break;
440                }
441            }
442
443            if !found_a {
444                return Err(PoKeysError::MissingRelatedCapability(
445                    pin_num,
446                    capability.to_string(),
447                    encoder_a_capability,
448                ));
449            }
450        }
451
452        // For matrix keyboard rows, ensure there's at least one column
453        if capability.starts_with("MatrixKeyboard_Row") {
454            let mut found_col = false;
455            for pin in self.pins.values() {
456                if pin.active
457                    && pin
458                        .capabilities
459                        .iter()
460                        .any(|cap| cap.starts_with("MatrixKeyboard_Col"))
461                {
462                    found_col = true;
463                    break;
464                }
465            }
466
467            if !found_col {
468                return Err(PoKeysError::MissingRelatedCapability(
469                    pin_num,
470                    capability.to_string(),
471                    "MatrixKeyboard_Col".to_string(),
472                ));
473            }
474        }
475
476        // For matrix keyboard columns, ensure there's at least one row
477        if capability.starts_with("MatrixKeyboard_Col") {
478            let mut found_row = false;
479            for pin in self.pins.values() {
480                if pin.active
481                    && pin
482                        .capabilities
483                        .iter()
484                        .any(|cap| cap.starts_with("MatrixKeyboard_Row"))
485                {
486                    found_row = true;
487                    break;
488                }
489            }
490
491            if !found_row {
492                return Err(PoKeysError::MissingRelatedCapability(
493                    pin_num,
494                    capability.to_string(),
495                    "MatrixKeyboard_Row".to_string(),
496                ));
497            }
498        }
499
500        Ok(())
501    }
502
503    /// Validate LED matrix configuration against device model
504    ///
505    /// # Arguments
506    ///
507    /// * `config` - LED matrix configuration to validate
508    ///
509    /// # Returns
510    ///
511    /// * `Result<()>` - Ok if valid, error if invalid
512    pub fn validate_led_matrix_config(
513        &self,
514        config: &crate::matrix::LedMatrixConfig,
515    ) -> Result<()> {
516        // Check if matrix ID is valid (1 or 2)
517        if config.matrix_id < 1 || config.matrix_id > 2 {
518            return Err(PoKeysError::ModelValidationError(format!(
519                "Invalid matrix ID: {}. Must be 1 or 2",
520                config.matrix_id
521            )));
522        }
523
524        // Get the pins for this matrix
525        let pins = match config.matrix_id {
526            1 => crate::matrix::LED_MATRIX_1_PINS,
527            2 => crate::matrix::LED_MATRIX_2_PINS,
528            _ => {
529                return Err(PoKeysError::ModelValidationError(format!(
530                    "Invalid matrix ID: {}",
531                    config.matrix_id
532                )));
533            }
534        };
535
536        // Validate that all required pins support the necessary capabilities
537        for &pin in &pins {
538            if !self.is_pin_capability_supported(pin, "DigitalOutput") {
539                return Err(PoKeysError::ModelValidationError(format!(
540                    "Pin {} does not support DigitalOutput capability required for LED matrix {}",
541                    pin, config.matrix_id
542                )));
543            }
544        }
545
546        Ok(())
547    }
548
549    /// Reserve LED matrix pins in the device model
550    ///
551    /// # Arguments
552    ///
553    /// * `matrix_id` - Matrix ID (1 or 2)
554    ///
555    /// # Returns
556    ///
557    /// * `Result<()>` - Ok if pins reserved successfully, error if conflict
558    pub fn reserve_led_matrix_pins(&mut self, matrix_id: u8) -> Result<()> {
559        // Get the pins for this matrix
560        let pins = match matrix_id {
561            1 => crate::matrix::LED_MATRIX_1_PINS,
562            2 => crate::matrix::LED_MATRIX_2_PINS,
563            _ => {
564                return Err(PoKeysError::ModelValidationError(format!(
565                    "Invalid matrix ID: {}",
566                    matrix_id
567                )));
568            }
569        };
570
571        // For now, just validate that the pins exist and support the capability
572        // In a full implementation, this would track pin reservations
573        for &pin in &pins {
574            if !self.is_pin_capability_supported(pin, "DigitalOutput") {
575                return Err(PoKeysError::ModelValidationError(format!(
576                    "Cannot reserve pin {} for LED matrix {}: pin does not support DigitalOutput",
577                    pin, matrix_id
578                )));
579            }
580        }
581
582        Ok(())
583    }
584}
585
586/// Get the default model directory path
587///
588/// This returns the default directory for device model files, which is
589/// ~/.config/pokeys/models on Unix systems and %APPDATA%\pokeys\models on Windows.
590///
591/// # Returns
592///
593/// * `PathBuf` - The default model directory path
594pub fn get_default_model_dir() -> PathBuf {
595    let mut path = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
596    path.push(DEFAULT_MODEL_DIR);
597    path
598}
599
600/// Get the path to a device model file
601///
602/// # Arguments
603///
604/// * `device_name` - The name of the device model
605/// * `model_dir` - Optional custom directory for model files
606///
607/// # Returns
608///
609/// * `PathBuf` - The path to the model file
610pub fn get_model_path(device_name: &str, model_dir: Option<&Path>) -> PathBuf {
611    let dir = model_dir
612        .map(Path::to_path_buf)
613        .unwrap_or_else(get_default_model_dir);
614    dir.join(format!("{}.yaml", device_name))
615}
616
617/// Load a device model by name
618///
619/// # Arguments
620///
621/// * `device_name` - The name of the device model
622/// * `model_dir` - Optional custom directory for model files
623///
624/// # Returns
625///
626/// * `Result<DeviceModel>` - The loaded device model or an error
627pub fn load_model(device_name: &str, model_dir: Option<&Path>) -> Result<DeviceModel> {
628    let path = get_model_path(device_name, model_dir);
629    DeviceModel::from_file(path)
630}
631
632/// Model monitor for watching for changes to model files
633pub struct ModelMonitor {
634    /// The watcher for file system events
635    watcher: Option<RecommendedWatcher>,
636
637    /// The directory being watched
638    watch_dir: PathBuf,
639
640    /// Loaded models
641    models: Arc<RwLock<HashMap<String, DeviceModel>>>,
642
643    /// Callback for model updates
644    callback: Arc<dyn Fn(String, DeviceModel) + Send + Sync + 'static>,
645
646    /// Whether the monitor is running
647    running: bool,
648}
649
650impl ModelMonitor {
651    /// Create a new model monitor
652    ///
653    /// # Arguments
654    ///
655    /// * `model_dir` - Directory containing model files
656    /// * `callback` - Callback function called when a model is updated
657    ///
658    /// # Returns
659    ///
660    /// * `ModelMonitor` - The new model monitor
661    pub fn new<F>(model_dir: PathBuf, callback: F) -> Self
662    where
663        F: Fn(String, DeviceModel) + Send + Sync + 'static,
664    {
665        Self {
666            watcher: None,
667            watch_dir: model_dir,
668            models: Arc::new(RwLock::new(HashMap::new())),
669            callback: Arc::new(callback),
670            running: false,
671        }
672    }
673
674    /// Start monitoring for model file changes
675    ///
676    /// # Returns
677    ///
678    /// * `Result<()>` - Ok if monitoring started successfully, an error otherwise
679    pub fn start(&mut self) -> Result<()> {
680        if self.running {
681            return Ok(());
682        }
683
684        // Create the model directory if it doesn't exist
685        if !self.watch_dir.exists() {
686            fs::create_dir_all(&self.watch_dir).map_err(|e| {
687                PoKeysError::ModelDirCreateError(
688                    self.watch_dir.to_string_lossy().to_string(),
689                    e.to_string(),
690                )
691            })?;
692        }
693
694        // Load existing models
695        self.load_existing_models()?;
696
697        // Create a channel for the watcher
698        let (tx, rx) = std::sync::mpsc::channel();
699
700        // Create the watcher
701        let mut watcher = notify::recommended_watcher(tx)
702            .map_err(|e| PoKeysError::ModelWatcherError(e.to_string()))?;
703
704        // Start watching the directory
705        watcher
706            .watch(&self.watch_dir, RecursiveMode::NonRecursive)
707            .map_err(|e| PoKeysError::ModelWatcherError(e.to_string()))?;
708
709        self.watcher = Some(watcher);
710        self.running = true;
711
712        // Clone the models and callback for the thread
713        let models = self.models.clone();
714        let callback = self.callback.clone();
715
716        // Spawn a thread to handle file system events
717        std::thread::spawn(move || {
718            let mut debouncer = HashMap::new();
719
720            for res in rx {
721                match res {
722                    Ok(event) => {
723                        if let EventKind::Modify(_) | EventKind::Create(_) = event.kind {
724                            for path in event.paths {
725                                if path.extension().is_some_and(|ext| ext == "yaml") {
726                                    // Debounce the event
727                                    let now = std::time::Instant::now();
728                                    let path_str = path.to_string_lossy().to_string();
729
730                                    // Only process the event if it's been at least 100ms since the last event for this file
731                                    if debouncer.get(&path_str).is_none_or(|last| {
732                                        now.duration_since(*last) > Duration::from_millis(100)
733                                    }) {
734                                        debouncer.insert(path_str, now);
735
736                                        // Get the device name from the file name
737                                        if let Some(file_name) = path.file_stem() {
738                                            let device_name =
739                                                file_name.to_string_lossy().to_string();
740
741                                            // Try to load the model
742                                            match DeviceModel::from_file(&path) {
743                                                Ok(model) => {
744                                                    // Update the model in the map
745                                                    {
746                                                        let mut models = models.write().unwrap();
747                                                        models.insert(
748                                                            device_name.clone(),
749                                                            model.clone(),
750                                                        );
751                                                    }
752
753                                                    // Call the callback
754                                                    callback(device_name, model);
755                                                }
756                                                Err(e) => {
757                                                    error!(
758                                                        "Failed to load model from {}: {}",
759                                                        path.display(),
760                                                        e
761                                                    );
762                                                }
763                                            }
764                                        }
765                                    }
766                                }
767                            }
768                        }
769                    }
770                    Err(e) => {
771                        error!("Watch error: {:?}", e);
772                    }
773                }
774            }
775        });
776
777        info!(
778            "Model monitor started, watching directory: {}",
779            self.watch_dir.display()
780        );
781        Ok(())
782    }
783
784    /// Stop monitoring for model file changes
785    ///
786    /// # Returns
787    ///
788    /// * `Result<()>` - Ok if monitoring stopped successfully, an error otherwise
789    pub fn stop(&mut self) -> Result<()> {
790        if !self.running {
791            return Ok(());
792        }
793
794        self.watcher = None;
795        self.running = false;
796
797        info!("Model monitor stopped");
798        Ok(())
799    }
800
801    /// Load existing model files from the watch directory
802    ///
803    /// # Returns
804    ///
805    /// * `Result<()>` - Ok if models loaded successfully, an error otherwise
806    fn load_existing_models(&self) -> Result<()> {
807        if !self.watch_dir.exists() {
808            return Ok(());
809        }
810
811        let entries = fs::read_dir(&self.watch_dir).map_err(|e| {
812            PoKeysError::ModelDirReadError(
813                self.watch_dir.to_string_lossy().to_string(),
814                e.to_string(),
815            )
816        })?;
817
818        for entry in entries {
819            let entry = entry.map_err(|e| {
820                PoKeysError::ModelDirReadError(
821                    self.watch_dir.to_string_lossy().to_string(),
822                    e.to_string(),
823                )
824            })?;
825            let path = entry.path();
826
827            if path.extension().is_some_and(|ext| ext == "yaml") {
828                if let Some(file_name) = path.file_stem() {
829                    let device_name = file_name.to_string_lossy().to_string();
830
831                    match DeviceModel::from_file(&path) {
832                        Ok(model) => {
833                            // Update the model in the map
834                            {
835                                let mut models = self.models.write().unwrap();
836                                models.insert(device_name.clone(), model.clone());
837                            }
838
839                            // Call the callback
840                            (self.callback)(device_name, model);
841                        }
842                        Err(e) => {
843                            warn!("Failed to load model from {}: {}", path.display(), e);
844                        }
845                    }
846                }
847            }
848        }
849
850        Ok(())
851    }
852
853    /// Get a model by name
854    ///
855    /// # Arguments
856    ///
857    /// * `device_name` - The name of the device model
858    ///
859    /// # Returns
860    ///
861    /// * `Option<DeviceModel>` - The model if found, None otherwise
862    pub fn get_model(&self, device_name: &str) -> Option<DeviceModel> {
863        let models = self.models.read().unwrap();
864        models.get(device_name).cloned()
865    }
866
867    /// Get all loaded models
868    ///
869    /// # Returns
870    ///
871    /// * `HashMap<String, DeviceModel>` - Map of device names to models
872    pub fn get_all_models(&self) -> HashMap<String, DeviceModel> {
873        let models = self.models.read().unwrap();
874        models.clone()
875    }
876}
877
878/// Copy default model files to the user's model directory
879///
880/// This function copies the default model files from the package to the user's
881/// model directory if they don't already exist.
882///
883/// # Arguments
884///
885/// * `model_dir` - Optional custom directory for model files
886///
887/// # Returns
888///
889/// * `Result<()>` - Ok if the files were copied successfully, an error otherwise
890pub fn copy_default_models_to_user_dir(model_dir: Option<&Path>) -> Result<()> {
891    let dir = model_dir
892        .map(Path::to_path_buf)
893        .unwrap_or_else(get_default_model_dir);
894
895    // Create the directory if it doesn't exist
896    if !dir.exists() {
897        fs::create_dir_all(&dir).map_err(|e| {
898            PoKeysError::ModelDirCreateError(dir.to_string_lossy().to_string(), e.to_string())
899        })?;
900    }
901
902    // List of default models
903    let default_models = [
904        "PoKeys56U.yaml",
905        "PoKeys57U.yaml",
906        "PoKeys56E.yaml",
907        "PoKeys57E.yaml",
908    ];
909
910    // Get the path to the package's model directory
911    let package_model_dir = std::env::current_exe()
912        .map_err(|e| {
913            PoKeysError::ModelDirReadError(
914                "Failed to get current executable path".to_string(),
915                e.to_string(),
916            )
917        })?
918        .parent()
919        .ok_or_else(|| {
920            PoKeysError::ModelDirReadError(
921                "Failed to get parent directory of executable".to_string(),
922                "No parent directory".to_string(),
923            )
924        })?
925        .join("models");
926
927    // If the package model directory doesn't exist, try to find it in the crate directory
928    let package_model_dir = if !package_model_dir.exists() {
929        // Try to find the models in the crate directory
930        let crate_dir = std::env::var("CARGO_MANIFEST_DIR")
931            .map(PathBuf::from)
932            .unwrap_or_else(|_| {
933                // If CARGO_MANIFEST_DIR is not set, use a relative path
934                PathBuf::from("pokeys-lib/models")
935            });
936
937        crate_dir.join("models")
938    } else {
939        package_model_dir
940    };
941
942    // Copy each default model file if it doesn't exist in the user's directory
943    for model_file in &default_models {
944        let user_file_path = dir.join(model_file);
945
946        // Skip if the file already exists
947        if user_file_path.exists() {
948            continue;
949        }
950
951        // Try to find the model file in the package
952        let package_file_path = package_model_dir.join(model_file);
953
954        if package_file_path.exists() {
955            // Copy the file
956            fs::copy(&package_file_path, &user_file_path).map_err(|e| {
957                PoKeysError::ModelLoadError(
958                    package_file_path.to_string_lossy().to_string(),
959                    e.to_string(),
960                )
961            })?;
962
963            info!(
964                "Copied default model file {} to {}",
965                model_file,
966                user_file_path.display()
967            );
968        } else {
969            // If the file doesn't exist in the package, try to find it in the source directory
970            let source_file_path = PathBuf::from(format!("pokeys-lib/models/{}", model_file));
971
972            if source_file_path.exists() {
973                // Copy the file
974                fs::copy(&source_file_path, &user_file_path).map_err(|e| {
975                    PoKeysError::ModelLoadError(
976                        source_file_path.to_string_lossy().to_string(),
977                        e.to_string(),
978                    )
979                })?;
980
981                info!(
982                    "Copied default model file {} to {}",
983                    model_file,
984                    user_file_path.display()
985                );
986            } else {
987                warn!("Default model file {} not found", model_file);
988            }
989        }
990    }
991
992    Ok(())
993}
994
995#[cfg(test)]
996mod tests {
997    use super::*;
998    use tempfile::tempdir;
999
1000    #[test]
1001    fn test_device_model_validation() {
1002        // Create a valid model
1003        let mut model = DeviceModel {
1004            name: "TestDevice".to_string(),
1005            pins: HashMap::new(),
1006        };
1007
1008        // Add some pins
1009        model.pins.insert(
1010            1,
1011            PinModel {
1012                capabilities: vec!["DigitalInput".to_string(), "DigitalOutput".to_string()],
1013                active: true,
1014            },
1015        );
1016
1017        model.pins.insert(
1018            2,
1019            PinModel {
1020                capabilities: vec!["DigitalInput".to_string(), "AnalogInput".to_string()],
1021                active: true,
1022            },
1023        );
1024
1025        // Validate the model
1026        assert!(model.validate().is_ok());
1027
1028        // Test empty name
1029        let mut invalid_model = model.clone();
1030        invalid_model.name = "".to_string();
1031        assert!(invalid_model.validate().is_err());
1032
1033        // Test empty pins
1034        let invalid_model = DeviceModel {
1035            name: "TestDevice".to_string(),
1036            pins: HashMap::new(),
1037        };
1038        assert!(invalid_model.validate().is_err());
1039
1040        // Test pin with no capabilities
1041        let mut invalid_model = model.clone();
1042        invalid_model.pins.insert(
1043            3,
1044            PinModel {
1045                capabilities: vec![],
1046                active: true,
1047            },
1048        );
1049        assert!(invalid_model.validate().is_err());
1050    }
1051
1052    #[test]
1053    fn test_related_capabilities() {
1054        // Create a model with encoder pins
1055        let mut model = DeviceModel {
1056            name: "TestDevice".to_string(),
1057            pins: HashMap::new(),
1058        };
1059
1060        // Add encoder pins
1061        model.pins.insert(
1062            1,
1063            PinModel {
1064                capabilities: vec!["DigitalInput".to_string(), "Encoder_1A".to_string()],
1065                active: true,
1066            },
1067        );
1068
1069        model.pins.insert(
1070            2,
1071            PinModel {
1072                capabilities: vec!["DigitalInput".to_string(), "Encoder_1B".to_string()],
1073                active: true,
1074            },
1075        );
1076
1077        // Validate the model
1078        assert!(model.validate().is_ok());
1079
1080        // Test related capabilities
1081        let related = model.get_related_capabilities(1, "Encoder_1A");
1082        assert_eq!(related.len(), 1);
1083        assert_eq!(related[0].0, "Encoder_1B");
1084        assert_eq!(related[0].1, 2);
1085
1086        // Test missing related capability
1087        let mut invalid_model = model.clone();
1088        invalid_model.pins.get_mut(&2).unwrap().capabilities = vec!["DigitalInput".to_string()];
1089
1090        // The model should still validate (we only warn about related capabilities)
1091        // Note: We're not validating encoder pairs in this test
1092    }
1093
1094    #[test]
1095    fn test_matrix_keyboard_validation() {
1096        // Create a model with matrix keyboard pins
1097        let mut model = DeviceModel {
1098            name: "TestDevice".to_string(),
1099            pins: HashMap::new(),
1100        };
1101
1102        // Add matrix keyboard pins
1103        model.pins.insert(
1104            1,
1105            PinModel {
1106                capabilities: vec![
1107                    "DigitalInput".to_string(),
1108                    "MatrixKeyboard_Row1".to_string(),
1109                ],
1110                active: true,
1111            },
1112        );
1113
1114        model.pins.insert(
1115            2,
1116            PinModel {
1117                capabilities: vec![
1118                    "DigitalInput".to_string(),
1119                    "MatrixKeyboard_Row2".to_string(),
1120                ],
1121                active: true,
1122            },
1123        );
1124
1125        model.pins.insert(
1126            3,
1127            PinModel {
1128                capabilities: vec![
1129                    "DigitalInput".to_string(),
1130                    "MatrixKeyboard_Col1".to_string(),
1131                ],
1132                active: true,
1133            },
1134        );
1135
1136        model.pins.insert(
1137            4,
1138            PinModel {
1139                capabilities: vec![
1140                    "DigitalInput".to_string(),
1141                    "MatrixKeyboard_Col2".to_string(),
1142                ],
1143                active: true,
1144            },
1145        );
1146
1147        // Validate the model
1148        assert!(model.validate().is_ok());
1149
1150        // Test related capabilities
1151        let related = model.get_related_capabilities(1, "MatrixKeyboard_Row1");
1152        assert_eq!(related.len(), 2);
1153        assert!(
1154            related
1155                .iter()
1156                .any(|(cap, pin)| cap == "MatrixKeyboard_Col1" && *pin == 3)
1157        );
1158        assert!(
1159            related
1160                .iter()
1161                .any(|(cap, pin)| cap == "MatrixKeyboard_Col2" && *pin == 4)
1162        );
1163
1164        // Test missing columns
1165        let mut invalid_model = model.clone();
1166        invalid_model.pins.remove(&3);
1167        invalid_model.pins.remove(&4);
1168
1169        // The model should still validate (we only warn about related capabilities)
1170        // Note: We're not validating matrix keyboard rows/columns in this test
1171    }
1172
1173    #[test]
1174    fn test_yaml_serialization() {
1175        // Create a model
1176        let mut model = DeviceModel {
1177            name: "TestDevice".to_string(),
1178            pins: HashMap::new(),
1179        };
1180
1181        // Add some pins
1182        model.pins.insert(
1183            1,
1184            PinModel {
1185                capabilities: vec!["DigitalInput".to_string(), "DigitalOutput".to_string()],
1186                active: true,
1187            },
1188        );
1189
1190        model.pins.insert(
1191            2,
1192            PinModel {
1193                capabilities: vec!["DigitalInput".to_string(), "AnalogInput".to_string()],
1194                active: true,
1195            },
1196        );
1197
1198        // Serialize to YAML
1199        let yaml = serde_yaml::to_string(&model).unwrap();
1200
1201        // Deserialize from YAML
1202        let deserialized: DeviceModel = serde_yaml::from_str(&yaml).unwrap();
1203
1204        // Check that the models are equal
1205        assert_eq!(model.name, deserialized.name);
1206        assert_eq!(model.pins.len(), deserialized.pins.len());
1207
1208        for (pin_num, pin) in &model.pins {
1209            let deserialized_pin = deserialized.pins.get(pin_num).unwrap();
1210            assert_eq!(pin.capabilities, deserialized_pin.capabilities);
1211            assert_eq!(pin.active, deserialized_pin.active);
1212        }
1213    }
1214
1215    #[test]
1216    fn test_model_file_loading() {
1217        // Create a temporary directory
1218        let dir = tempdir().unwrap();
1219        let file_path = dir.path().join("TestDevice.yaml");
1220
1221        // Create a model
1222        let mut model = DeviceModel {
1223            name: "TestDevice".to_string(),
1224            pins: HashMap::new(),
1225        };
1226
1227        // Add some pins
1228        model.pins.insert(
1229            1,
1230            PinModel {
1231                capabilities: vec!["DigitalInput".to_string(), "DigitalOutput".to_string()],
1232                active: true,
1233            },
1234        );
1235
1236        model.pins.insert(
1237            2,
1238            PinModel {
1239                capabilities: vec!["DigitalInput".to_string(), "AnalogInput".to_string()],
1240                active: true,
1241            },
1242        );
1243
1244        // Serialize to YAML
1245        let yaml = serde_yaml::to_string(&model).unwrap();
1246
1247        // Write to file
1248        fs::write(&file_path, yaml).unwrap();
1249
1250        // Load the model
1251        let loaded_model = DeviceModel::from_file(&file_path).unwrap();
1252
1253        // Check that the models are equal
1254        assert_eq!(model.name, loaded_model.name);
1255        assert_eq!(model.pins.len(), loaded_model.pins.len());
1256
1257        for (pin_num, pin) in &model.pins {
1258            let loaded_pin = loaded_model.pins.get(pin_num).unwrap();
1259            assert_eq!(pin.capabilities, loaded_pin.capabilities);
1260            assert_eq!(pin.active, loaded_pin.active);
1261        }
1262    }
1263}