use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use anyhow::{Result, anyhow};
use serde::{Deserialize, Serialize};
use tracing::{debug, error, info, warn};
use super::traits::{CustomHardwareProvider, HardwareDevice, HardwareDiscovery};
#[derive(Debug)]
pub struct CustomProviderRegistry {
providers: Arc<RwLock<HashMap<String, Arc<dyn CustomHardwareProvider>>>>,
plugin_metadata: HashMap<String, PluginMetadata>,
plugin_paths: Vec<PathBuf>,
hardware_discovery: Option<Box<dyn HardwareDiscovery>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginMetadata {
pub name: String,
pub version: String,
pub author: String,
pub description: String,
pub license: String,
pub min_ronn_version: String,
pub supported_hardware: Vec<String>,
pub abi_version: u32,
pub plugin_path: PathBuf,
pub status: PluginStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum PluginStatus {
Loaded,
LoadError(String),
Disabled,
Incompatible,
}
pub trait ProviderPlugin: Send + Sync {
fn get_metadata(&self) -> PluginMetadata;
fn create_provider(&self, config: &str) -> Result<Box<dyn CustomHardwareProvider>>;
fn is_hardware_available(&self) -> bool;
fn get_abi_version(&self) -> u32 {
1 }
}
impl CustomProviderRegistry {
pub fn new() -> Self {
let default_paths = vec![
PathBuf::from("./plugins"),
PathBuf::from("/usr/local/lib/ronn/plugins"),
PathBuf::from("/opt/ronn/plugins"),
];
Self {
providers: Arc::new(RwLock::new(HashMap::new())),
plugin_metadata: HashMap::new(),
plugin_paths: default_paths,
hardware_discovery: None,
}
}
pub fn add_plugin_path<P: AsRef<Path>>(&mut self, path: P) {
let path = path.as_ref().to_path_buf();
if !self.plugin_paths.contains(&path) {
self.plugin_paths.push(path);
}
}
pub fn set_hardware_discovery(&mut self, discovery: Box<dyn HardwareDiscovery>) {
self.hardware_discovery = Some(discovery);
}
pub fn discover_plugins(&mut self) -> Result<Vec<PluginMetadata>> {
let mut discovered_plugins = Vec::new();
for plugin_path in &self.plugin_paths.clone() {
if plugin_path.exists() {
match self.scan_plugin_directory(plugin_path) {
Ok(mut plugins) => {
discovered_plugins.append(&mut plugins);
}
Err(e) => {
warn!("Failed to scan plugin directory {:?}: {}", plugin_path, e);
}
}
}
}
info!("Discovered {} plugins", discovered_plugins.len());
Ok(discovered_plugins)
}
fn scan_plugin_directory(&mut self, dir: &Path) -> Result<Vec<PluginMetadata>> {
let mut plugins = Vec::new();
if !dir.is_dir() {
return Ok(plugins);
}
for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if let Some(extension) = path.extension() {
let is_plugin = match extension.to_str() {
Some("so") => true, Some("dylib") => true, Some("dll") => true, _ => false,
};
if is_plugin {
match self.load_plugin_metadata(&path) {
Ok(metadata) => {
plugins.push(metadata);
}
Err(e) => {
error!("Failed to load plugin metadata from {:?}: {}", path, e);
}
}
}
}
}
Ok(plugins)
}
fn load_plugin_metadata(&mut self, plugin_path: &Path) -> Result<PluginMetadata> {
let plugin_name = plugin_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
let metadata = PluginMetadata {
name: plugin_name.clone(),
version: "1.0.0".to_string(),
author: "Unknown".to_string(),
description: format!("Custom hardware provider plugin: {}", plugin_name),
license: "MIT".to_string(),
min_ronn_version: "0.1.0".to_string(),
supported_hardware: vec!["Custom".to_string()],
abi_version: 1,
plugin_path: plugin_path.to_path_buf(),
status: PluginStatus::Loaded,
};
self.plugin_metadata.insert(plugin_name, metadata.clone());
Ok(metadata)
}
pub fn register_provider(
&mut self,
name: String,
provider: Arc<dyn CustomHardwareProvider>,
) -> Result<()> {
let mut providers = self
.providers
.write()
.map_err(|_| anyhow!("Lock poisoned"))?;
if providers.contains_key(&name) {
return Err(anyhow!("Provider {} already registered", name));
}
providers.insert(name.clone(), provider);
info!("Registered custom hardware provider: {}", name);
Ok(())
}
pub fn get_provider(&self, name: &str) -> Option<Arc<dyn CustomHardwareProvider>> {
let providers = self.providers.read().ok()?;
providers.get(name).cloned()
}
pub fn list_providers(&self) -> Vec<String> {
self.providers
.read()
.map(|providers| providers.keys().cloned().collect())
.unwrap_or_default()
}
pub fn get_plugin_metadata(&self, name: &str) -> Option<&PluginMetadata> {
self.plugin_metadata.get(name)
}
pub fn get_all_plugin_metadata(&self) -> &HashMap<String, PluginMetadata> {
&self.plugin_metadata
}
pub fn unregister_provider(&mut self, name: &str) -> Result<()> {
let mut providers = self
.providers
.write()
.map_err(|_| anyhow!("Lock poisoned"))?;
if providers.remove(name).is_some() {
self.plugin_metadata.remove(name);
info!("Unregistered provider: {}", name);
Ok(())
} else {
Err(anyhow!("Provider {} not found", name))
}
}
pub fn discover_hardware(&self) -> Result<Vec<HardwareDevice>> {
if let Some(ref discovery) = self.hardware_discovery {
discovery.discover_devices()
} else {
warn!("No hardware discovery service configured");
Ok(Vec::new())
}
}
pub fn get_statistics(&self) -> RegistryStatistics {
let provider_count = self
.providers
.read()
.map(|providers| providers.len())
.unwrap_or(0);
let plugin_status_counts =
self.plugin_metadata
.values()
.fold(HashMap::new(), |mut acc, metadata| {
let status_key = match &metadata.status {
PluginStatus::Loaded => "loaded",
PluginStatus::LoadError(_) => "error",
PluginStatus::Disabled => "disabled",
PluginStatus::Incompatible => "incompatible",
};
*acc.entry(status_key.to_string()).or_insert(0) += 1;
acc
});
RegistryStatistics {
registered_providers: provider_count,
discovered_plugins: self.plugin_metadata.len(),
plugin_paths: self.plugin_paths.clone(),
plugin_status_counts,
has_hardware_discovery: self.hardware_discovery.is_some(),
}
}
pub fn load_plugin_from_file<P: AsRef<Path>>(&mut self, plugin_path: P) -> Result<()> {
let plugin_path = plugin_path.as_ref();
if !plugin_path.exists() {
return Err(anyhow!("Plugin file does not exist: {:?}", plugin_path));
}
let metadata = self.load_plugin_metadata(plugin_path)?;
info!("Loaded plugin: {} v{}", metadata.name, metadata.version);
Ok(())
}
fn validate_plugin(&self, metadata: &PluginMetadata) -> Result<()> {
const CURRENT_ABI_VERSION: u32 = 1;
if metadata.abi_version != CURRENT_ABI_VERSION {
return Err(anyhow!(
"Plugin {} has incompatible ABI version: {} (expected {})",
metadata.name,
metadata.abi_version,
CURRENT_ABI_VERSION
));
}
if metadata.min_ronn_version.is_empty() {
warn!(
"Plugin {} does not specify minimum RONN version",
metadata.name
);
}
Ok(())
}
pub fn shutdown(&mut self) -> Result<()> {
let providers = self
.providers
.read()
.map_err(|_| anyhow!("Lock poisoned"))?;
for (name, _provider) in providers.iter() {
debug!("Shutting down provider: {}", name);
}
info!("Custom provider registry shutdown complete");
Ok(())
}
}
impl Default for CustomProviderRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct RegistryStatistics {
pub registered_providers: usize,
pub discovered_plugins: usize,
pub plugin_paths: Vec<PathBuf>,
pub plugin_status_counts: HashMap<String, usize>,
pub has_hardware_discovery: bool,
}
#[derive(Debug)]
pub struct DefaultHardwareDiscovery;
impl HardwareDiscovery for DefaultHardwareDiscovery {
fn discover_devices(&self) -> Result<Vec<HardwareDevice>> {
Ok(Vec::new())
}
fn is_device_available(&self, _device_id: &str) -> bool {
false
}
fn get_device_info(&self, _device_id: &str) -> Option<HardwareDevice> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[derive(Debug)]
struct MockProvider {
name: String,
}
impl CustomHardwareProvider for MockProvider {
fn provider_name(&self) -> &str {
&self.name
}
fn get_hardware_capability(&self) -> super::super::traits::HardwareCapability {
super::super::traits::HardwareCapability {
vendor: "Mock".to_string(),
model: "TestDevice".to_string(),
architecture_version: "1.0".to_string(),
supported_data_types: vec![ronn_core::DataType::F32],
max_memory_bytes: 1024 * 1024 * 1024,
peak_tops: 10.0,
memory_bandwidth_gbps: 100.0,
supported_operations: vec!["Add".to_string()],
features: HashMap::new(),
power_profile: super::super::traits::PowerProfile {
idle_power_watts: 1.0,
peak_power_watts: 10.0,
tdp_watts: 5.0,
efficiency_tops_per_watt: 2.0,
},
}
}
fn is_hardware_available(&self) -> bool {
true
}
fn initialize(&mut self) -> Result<()> {
Ok(())
}
fn compile_subgraph(
&self,
_subgraph: &ronn_core::SubGraph,
) -> Result<Box<dyn super::super::traits::CustomKernel>> {
Err(anyhow!("Not implemented"))
}
fn get_device_memory(&self) -> &dyn super::super::traits::DeviceMemory {
panic!("Not implemented")
}
fn get_performance_stats(&self) -> super::super::traits::ProviderStats {
super::super::traits::ProviderStats {
total_operations: 0,
average_execution_time_us: 0.0,
memory_usage_bytes: 0,
peak_memory_bytes: 0,
hardware_utilization: 0.0,
current_power_watts: 0.0,
total_energy_joules: 0.0,
}
}
fn shutdown(&mut self) -> Result<()> {
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
#[test]
fn test_registry_creation() {
let registry = CustomProviderRegistry::new();
assert!(registry.list_providers().is_empty());
}
#[test]
fn test_provider_registration() -> Result<()> {
let mut registry = CustomProviderRegistry::new();
let provider = Arc::new(MockProvider {
name: "test_provider".to_string(),
});
registry.register_provider("test_provider".to_string(), provider.clone())?;
let providers = registry.list_providers();
assert_eq!(providers.len(), 1);
assert!(providers.contains(&"test_provider".to_string()));
let retrieved_provider = registry.get_provider("test_provider");
assert!(retrieved_provider.is_some());
Ok(())
}
#[test]
fn test_duplicate_registration() {
let mut registry = CustomProviderRegistry::new();
let provider1 = Arc::new(MockProvider {
name: "test_provider".to_string(),
});
let provider2 = Arc::new(MockProvider {
name: "test_provider".to_string(),
});
registry
.register_provider("test_provider".to_string(), provider1)
.unwrap();
let result = registry.register_provider("test_provider".to_string(), provider2);
assert!(result.is_err());
}
#[test]
fn test_provider_unregistration() -> Result<()> {
let mut registry = CustomProviderRegistry::new();
let provider = Arc::new(MockProvider {
name: "test_provider".to_string(),
});
registry.register_provider("test_provider".to_string(), provider)?;
assert_eq!(registry.list_providers().len(), 1);
registry.unregister_provider("test_provider")?;
assert_eq!(registry.list_providers().len(), 0);
Ok(())
}
#[test]
fn test_registry_statistics() -> Result<()> {
let mut registry = CustomProviderRegistry::new();
let provider = Arc::new(MockProvider {
name: "test_provider".to_string(),
});
registry.register_provider("test_provider".to_string(), provider)?;
let stats = registry.get_statistics();
assert_eq!(stats.registered_providers, 1);
assert!(!stats.plugin_paths.is_empty());
assert!(!stats.has_hardware_discovery);
Ok(())
}
#[test]
fn test_hardware_discovery() {
let discovery = DefaultHardwareDiscovery;
let devices = discovery.discover_devices().unwrap();
assert!(devices.is_empty());
assert!(!discovery.is_device_available("test_device"));
assert!(discovery.get_device_info("test_device").is_none());
}
#[test]
fn test_plugin_metadata() {
let metadata = PluginMetadata {
name: "test_plugin".to_string(),
version: "1.0.0".to_string(),
author: "Test Author".to_string(),
description: "Test plugin".to_string(),
license: "MIT".to_string(),
min_ronn_version: "0.1.0".to_string(),
supported_hardware: vec!["TestHW".to_string()],
abi_version: 1,
plugin_path: PathBuf::from("/test/plugin.so"),
status: PluginStatus::Loaded,
};
assert_eq!(metadata.name, "test_plugin");
assert_eq!(metadata.abi_version, 1);
assert_eq!(metadata.status, PluginStatus::Loaded);
}
}