Skip to main content

suture_driver/
registry.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use crate::{DriverError, SutureDriver};
5
6/// Registry of file format drivers, dispatching by file extension.
7pub struct DriverRegistry {
8    extension_map: HashMap<String, String>,
9    drivers: HashMap<String, Box<dyn SutureDriver>>,
10}
11
12impl DriverRegistry {
13    pub fn new() -> Self {
14        Self {
15            extension_map: HashMap::new(),
16            drivers: HashMap::new(),
17        }
18    }
19
20    /// Register a driver for its supported extensions.
21    pub fn register(&mut self, driver: Box<dyn SutureDriver>) {
22        let name = driver.name().to_string();
23        for ext in driver.supported_extensions() {
24            self.extension_map.insert(ext.to_lowercase(), name.clone());
25        }
26        self.drivers.insert(name, driver);
27    }
28
29    /// Get a driver for the given file path (by extension).
30    pub fn get_for_path(&self, path: &Path) -> Result<&dyn SutureDriver, DriverError> {
31        let ext = path
32            .extension()
33            .and_then(|e| e.to_str())
34            .map(|e| format!(".{}", e.to_lowercase()))
35            .ok_or_else(|| DriverError::UnsupportedExtension(path.to_string_lossy().to_string()))?;
36
37        self.get(&ext)
38    }
39
40    /// Get a driver for a specific extension string (e.g., ".json").
41    pub fn get(&self, extension: &str) -> Result<&dyn SutureDriver, DriverError> {
42        let ext = extension.to_lowercase();
43
44        let driver_name = self
45            .extension_map
46            .get(&ext)
47            .ok_or_else(|| DriverError::DriverNotFound(ext.clone()))?;
48
49        self.drivers
50            .get(driver_name)
51            .map(|d| d.as_ref())
52            .ok_or(DriverError::DriverNotFound(ext))
53    }
54
55    /// List all registered drivers with their extensions.
56    pub fn list(&self) -> Vec<(&str, Vec<&str>)> {
57        let mut result: Vec<(&str, Vec<&str>)> = self
58            .drivers
59            .values()
60            .map(|d| {
61                let exts: Vec<&str> = d.supported_extensions().to_vec();
62                (d.name(), exts)
63            })
64            .collect();
65        result.sort_by_key(|(name, _)| *name);
66        result
67    }
68}
69
70impl Default for DriverRegistry {
71    fn default() -> Self {
72        Self::new()
73    }
74}