Skip to main content

mabi_runtime/
driver.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use serde_json::Value as JsonValue;
7
8use mabi_core::Protocol;
9
10use crate::service::{ManagedService, RuntimeResult};
11use crate::session::RuntimeExtensions;
12
13/// Descriptor for a protocol driver registered with the runtime.
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
15pub struct ProtocolDescriptor {
16    /// Stable identifier used by the CLI and docs.
17    pub key: &'static str,
18    /// Human-readable protocol name.
19    pub display_name: &'static str,
20    /// Shared protocol enum.
21    pub protocol: Protocol,
22    /// Default listening port for the protocol.
23    pub default_port: u16,
24    /// Short description shown in help and inspection output.
25    pub description: &'static str,
26}
27
28/// Generic launch request used by runtime sessions and the CLI.
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30pub struct ProtocolLaunchSpec {
31    /// Stable protocol key matching a registered descriptor.
32    pub protocol: String,
33    /// Optional service name override.
34    #[serde(default)]
35    pub name: Option<String>,
36    /// Protocol-specific configuration payload.
37    #[serde(default)]
38    pub config: JsonValue,
39}
40
41impl ProtocolLaunchSpec {
42    /// Returns the stable protocol key.
43    pub fn key(&self) -> &str {
44        &self.protocol
45    }
46
47    /// Returns the runtime service name using the descriptor as fallback.
48    pub fn service_name(&self, descriptor: &ProtocolDescriptor) -> String {
49        self.name
50            .clone()
51            .unwrap_or_else(|| descriptor.key.to_string())
52    }
53}
54
55/// Registry-facing catalog entry.
56#[derive(Debug, Clone, Serialize)]
57pub struct ProtocolCatalogEntry {
58    pub descriptor: ProtocolDescriptor,
59    pub features: Vec<&'static str>,
60}
61
62/// Compile-time extensibility point for protocol service creation.
63#[async_trait]
64pub trait ProtocolDriver: Send + Sync {
65    /// Returns the driver descriptor.
66    fn descriptor(&self) -> ProtocolDescriptor;
67
68    /// Returns a short feature list for CLI inspection surfaces.
69    fn features(&self) -> &'static [&'static str] {
70        &[]
71    }
72
73    /// Returns an optional schema summary for CLI inspection surfaces.
74    fn schema(&self) -> Option<JsonValue> {
75        None
76    }
77
78    /// Builds a managed service from the generic launch request.
79    async fn build(
80        &self,
81        spec: ProtocolLaunchSpec,
82        extensions: RuntimeExtensions,
83    ) -> RuntimeResult<Arc<dyn ManagedService>>;
84}
85
86/// Shared registry for protocol drivers.
87#[derive(Default, Clone)]
88pub struct ProtocolDriverRegistry {
89    drivers: BTreeMap<String, Arc<dyn ProtocolDriver>>,
90}
91
92impl ProtocolDriverRegistry {
93    /// Creates an empty registry.
94    pub fn new() -> Self {
95        Self {
96            drivers: BTreeMap::new(),
97        }
98    }
99
100    /// Registers a driver by its descriptor key.
101    pub fn register(&mut self, driver: impl ProtocolDriver + 'static) {
102        let descriptor = driver.descriptor();
103        self.drivers
104            .insert(descriptor.key.to_string(), Arc::new(driver));
105    }
106
107    /// Extends the registry with all entries from another registry.
108    pub fn extend(&mut self, other: &Self) {
109        for (key, driver) in &other.drivers {
110            self.drivers.insert(key.clone(), Arc::clone(driver));
111        }
112    }
113
114    /// Returns a driver by key.
115    pub fn get(&self, key: &str) -> Option<Arc<dyn ProtocolDriver>> {
116        self.drivers.get(key).cloned()
117    }
118
119    /// Returns whether the registry contains a driver.
120    pub fn contains(&self, key: &str) -> bool {
121        self.drivers.contains_key(key)
122    }
123
124    /// Returns the registered descriptors in stable order.
125    pub fn descriptors(&self) -> Vec<ProtocolDescriptor> {
126        self.drivers
127            .values()
128            .map(|driver| driver.descriptor())
129            .collect()
130    }
131
132    /// Returns catalog entries with driver features in stable order.
133    pub fn catalog(&self) -> Vec<ProtocolCatalogEntry> {
134        self.drivers
135            .values()
136            .map(|driver| ProtocolCatalogEntry {
137                descriptor: driver.descriptor(),
138                features: driver.features().to_vec(),
139            })
140            .collect()
141    }
142
143    /// Returns a schema summary for the provided protocol key, if available.
144    pub fn schema(&self, key: &str) -> Option<JsonValue> {
145        self.get(key).and_then(|driver| driver.schema())
146    }
147
148    /// Returns the number of registered drivers.
149    pub fn len(&self) -> usize {
150        self.drivers.len()
151    }
152
153    /// Returns true when the registry is empty.
154    pub fn is_empty(&self) -> bool {
155        self.drivers.is_empty()
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use std::sync::Arc;
162
163    use async_trait::async_trait;
164    use serde_json::json;
165
166    use mabi_core::Protocol;
167
168    use crate::driver::{
169        ProtocolDescriptor, ProtocolDriver, ProtocolDriverRegistry, ProtocolLaunchSpec,
170    };
171    use crate::service::{
172        ManagedService, RuntimeResult, ServiceContext, ServiceSnapshot, ServiceStatus,
173    };
174    use crate::session::RuntimeExtensions;
175
176    struct NullService;
177
178    #[async_trait]
179    impl ManagedService for NullService {
180        async fn start(&self, _context: &ServiceContext) -> RuntimeResult<()> {
181            Ok(())
182        }
183
184        async fn stop(&self, _context: &ServiceContext) -> RuntimeResult<()> {
185            Ok(())
186        }
187
188        async fn serve(&self, _context: ServiceContext) -> RuntimeResult<()> {
189            Ok(())
190        }
191
192        fn status(&self) -> ServiceStatus {
193            ServiceStatus::new("null")
194        }
195
196        async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
197            Ok(ServiceSnapshot::new("null"))
198        }
199    }
200
201    struct NullDriver;
202
203    #[async_trait]
204    impl ProtocolDriver for NullDriver {
205        fn descriptor(&self) -> ProtocolDescriptor {
206            ProtocolDescriptor {
207                key: "null",
208                display_name: "Null",
209                protocol: Protocol::ModbusTcp,
210                default_port: 0,
211                description: "test driver",
212            }
213        }
214
215        fn features(&self) -> &'static [&'static str] {
216            &["feature-a"]
217        }
218
219        async fn build(
220            &self,
221            _spec: ProtocolLaunchSpec,
222            _extensions: RuntimeExtensions,
223        ) -> RuntimeResult<Arc<dyn ManagedService>> {
224            Ok(Arc::new(NullService))
225        }
226    }
227
228    #[test]
229    fn registry_returns_descriptors() {
230        let mut registry = ProtocolDriverRegistry::new();
231        registry.register(NullDriver);
232        assert!(registry.contains("null"));
233        assert_eq!(registry.len(), 1);
234        assert_eq!(registry.descriptors()[0].key, "null");
235    }
236
237    #[test]
238    fn registry_returns_catalog_entries() {
239        let mut registry = ProtocolDriverRegistry::new();
240        registry.register(NullDriver);
241        let catalog = registry.catalog();
242        assert_eq!(catalog.len(), 1);
243        assert_eq!(catalog[0].descriptor.key, "null");
244        assert_eq!(catalog[0].features, vec!["feature-a"]);
245    }
246
247    #[tokio::test]
248    async fn launch_spec_keeps_service_name_override() {
249        let spec = ProtocolLaunchSpec {
250            protocol: "null".into(),
251            name: Some("custom".into()),
252            config: json!({"ok": true}),
253        };
254        let descriptor = NullDriver.descriptor();
255        assert_eq!(spec.service_name(&descriptor), "custom");
256    }
257}