use std::collections::BTreeMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use mabi_core::Protocol;
use crate::service::{ManagedService, RuntimeResult};
use crate::session::RuntimeExtensions;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ProtocolDescriptor {
pub key: &'static str,
pub display_name: &'static str,
pub protocol: Protocol,
pub default_port: u16,
pub description: &'static str,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ProtocolLaunchSpec {
pub protocol: String,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub config: JsonValue,
}
impl ProtocolLaunchSpec {
pub fn key(&self) -> &str {
&self.protocol
}
pub fn service_name(&self, descriptor: &ProtocolDescriptor) -> String {
self.name
.clone()
.unwrap_or_else(|| descriptor.key.to_string())
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ProtocolCatalogEntry {
pub descriptor: ProtocolDescriptor,
pub features: Vec<&'static str>,
}
#[async_trait]
pub trait ProtocolDriver: Send + Sync {
fn descriptor(&self) -> ProtocolDescriptor;
fn features(&self) -> &'static [&'static str] {
&[]
}
fn schema(&self) -> Option<JsonValue> {
None
}
async fn build(
&self,
spec: ProtocolLaunchSpec,
extensions: RuntimeExtensions,
) -> RuntimeResult<Arc<dyn ManagedService>>;
}
#[derive(Default, Clone)]
pub struct ProtocolDriverRegistry {
drivers: BTreeMap<String, Arc<dyn ProtocolDriver>>,
}
impl ProtocolDriverRegistry {
pub fn new() -> Self {
Self {
drivers: BTreeMap::new(),
}
}
pub fn register(&mut self, driver: impl ProtocolDriver + 'static) {
let descriptor = driver.descriptor();
self.drivers
.insert(descriptor.key.to_string(), Arc::new(driver));
}
pub fn extend(&mut self, other: &Self) {
for (key, driver) in &other.drivers {
self.drivers.insert(key.clone(), Arc::clone(driver));
}
}
pub fn get(&self, key: &str) -> Option<Arc<dyn ProtocolDriver>> {
self.drivers.get(key).cloned()
}
pub fn contains(&self, key: &str) -> bool {
self.drivers.contains_key(key)
}
pub fn descriptors(&self) -> Vec<ProtocolDescriptor> {
self.drivers
.values()
.map(|driver| driver.descriptor())
.collect()
}
pub fn catalog(&self) -> Vec<ProtocolCatalogEntry> {
self.drivers
.values()
.map(|driver| ProtocolCatalogEntry {
descriptor: driver.descriptor(),
features: driver.features().to_vec(),
})
.collect()
}
pub fn schema(&self, key: &str) -> Option<JsonValue> {
self.get(key).and_then(|driver| driver.schema())
}
pub fn len(&self) -> usize {
self.drivers.len()
}
pub fn is_empty(&self) -> bool {
self.drivers.is_empty()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use mabi_core::Protocol;
use crate::driver::{
ProtocolDescriptor, ProtocolDriver, ProtocolDriverRegistry, ProtocolLaunchSpec,
};
use crate::service::{
ManagedService, RuntimeResult, ServiceContext, ServiceSnapshot, ServiceStatus,
};
use crate::session::RuntimeExtensions;
struct NullService;
#[async_trait]
impl ManagedService for NullService {
async fn start(&self, _context: &ServiceContext) -> RuntimeResult<()> {
Ok(())
}
async fn stop(&self, _context: &ServiceContext) -> RuntimeResult<()> {
Ok(())
}
async fn serve(&self, _context: ServiceContext) -> RuntimeResult<()> {
Ok(())
}
fn status(&self) -> ServiceStatus {
ServiceStatus::new("null")
}
async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
Ok(ServiceSnapshot::new("null"))
}
}
struct NullDriver;
#[async_trait]
impl ProtocolDriver for NullDriver {
fn descriptor(&self) -> ProtocolDescriptor {
ProtocolDescriptor {
key: "null",
display_name: "Null",
protocol: Protocol::ModbusTcp,
default_port: 0,
description: "test driver",
}
}
fn features(&self) -> &'static [&'static str] {
&["feature-a"]
}
async fn build(
&self,
_spec: ProtocolLaunchSpec,
_extensions: RuntimeExtensions,
) -> RuntimeResult<Arc<dyn ManagedService>> {
Ok(Arc::new(NullService))
}
}
#[test]
fn registry_returns_descriptors() {
let mut registry = ProtocolDriverRegistry::new();
registry.register(NullDriver);
assert!(registry.contains("null"));
assert_eq!(registry.len(), 1);
assert_eq!(registry.descriptors()[0].key, "null");
}
#[test]
fn registry_returns_catalog_entries() {
let mut registry = ProtocolDriverRegistry::new();
registry.register(NullDriver);
let catalog = registry.catalog();
assert_eq!(catalog.len(), 1);
assert_eq!(catalog[0].descriptor.key, "null");
assert_eq!(catalog[0].features, vec!["feature-a"]);
}
#[tokio::test]
async fn launch_spec_keeps_service_name_override() {
let spec = ProtocolLaunchSpec {
protocol: "null".into(),
name: Some("custom".into()),
config: json!({"ok": true}),
};
let descriptor = NullDriver.descriptor();
assert_eq!(spec.service_name(&descriptor), "custom");
}
}