1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use serde_json::Value as JsonValue;
7
8use mabi_core::Protocol;
9
10use crate::service::{ManagedService, RuntimeResult};
11use crate::session::RuntimeExtensions;
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
15pub struct ProtocolDescriptor {
16 pub key: &'static str,
18 pub display_name: &'static str,
20 pub protocol: Protocol,
22 pub default_port: u16,
24 pub description: &'static str,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30pub struct ProtocolLaunchSpec {
31 pub protocol: String,
33 #[serde(default)]
35 pub name: Option<String>,
36 #[serde(default)]
38 pub config: JsonValue,
39}
40
41impl ProtocolLaunchSpec {
42 pub fn key(&self) -> &str {
44 &self.protocol
45 }
46
47 pub fn service_name(&self, descriptor: &ProtocolDescriptor) -> String {
49 self.name
50 .clone()
51 .unwrap_or_else(|| descriptor.key.to_string())
52 }
53}
54
55#[derive(Debug, Clone, Serialize)]
57pub struct ProtocolCatalogEntry {
58 pub descriptor: ProtocolDescriptor,
59 pub features: Vec<&'static str>,
60}
61
62#[async_trait]
64pub trait ProtocolDriver: Send + Sync {
65 fn descriptor(&self) -> ProtocolDescriptor;
67
68 fn features(&self) -> &'static [&'static str] {
70 &[]
71 }
72
73 fn schema(&self) -> Option<JsonValue> {
75 None
76 }
77
78 async fn build(
80 &self,
81 spec: ProtocolLaunchSpec,
82 extensions: RuntimeExtensions,
83 ) -> RuntimeResult<Arc<dyn ManagedService>>;
84}
85
86#[derive(Default, Clone)]
88pub struct ProtocolDriverRegistry {
89 drivers: BTreeMap<String, Arc<dyn ProtocolDriver>>,
90}
91
92impl ProtocolDriverRegistry {
93 pub fn new() -> Self {
95 Self {
96 drivers: BTreeMap::new(),
97 }
98 }
99
100 pub fn register(&mut self, driver: impl ProtocolDriver + 'static) {
102 let descriptor = driver.descriptor();
103 self.drivers
104 .insert(descriptor.key.to_string(), Arc::new(driver));
105 }
106
107 pub fn extend(&mut self, other: &Self) {
109 for (key, driver) in &other.drivers {
110 self.drivers.insert(key.clone(), Arc::clone(driver));
111 }
112 }
113
114 pub fn get(&self, key: &str) -> Option<Arc<dyn ProtocolDriver>> {
116 self.drivers.get(key).cloned()
117 }
118
119 pub fn contains(&self, key: &str) -> bool {
121 self.drivers.contains_key(key)
122 }
123
124 pub fn descriptors(&self) -> Vec<ProtocolDescriptor> {
126 self.drivers
127 .values()
128 .map(|driver| driver.descriptor())
129 .collect()
130 }
131
132 pub fn catalog(&self) -> Vec<ProtocolCatalogEntry> {
134 self.drivers
135 .values()
136 .map(|driver| ProtocolCatalogEntry {
137 descriptor: driver.descriptor(),
138 features: driver.features().to_vec(),
139 })
140 .collect()
141 }
142
143 pub fn schema(&self, key: &str) -> Option<JsonValue> {
145 self.get(key).and_then(|driver| driver.schema())
146 }
147
148 pub fn len(&self) -> usize {
150 self.drivers.len()
151 }
152
153 pub fn is_empty(&self) -> bool {
155 self.drivers.is_empty()
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use std::sync::Arc;
162
163 use async_trait::async_trait;
164 use serde_json::json;
165
166 use mabi_core::Protocol;
167
168 use crate::driver::{
169 ProtocolDescriptor, ProtocolDriver, ProtocolDriverRegistry, ProtocolLaunchSpec,
170 };
171 use crate::service::{
172 ManagedService, RuntimeResult, ServiceContext, ServiceSnapshot, ServiceStatus,
173 };
174 use crate::session::RuntimeExtensions;
175
176 struct NullService;
177
178 #[async_trait]
179 impl ManagedService for NullService {
180 async fn start(&self, _context: &ServiceContext) -> RuntimeResult<()> {
181 Ok(())
182 }
183
184 async fn stop(&self, _context: &ServiceContext) -> RuntimeResult<()> {
185 Ok(())
186 }
187
188 async fn serve(&self, _context: ServiceContext) -> RuntimeResult<()> {
189 Ok(())
190 }
191
192 fn status(&self) -> ServiceStatus {
193 ServiceStatus::new("null")
194 }
195
196 async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
197 Ok(ServiceSnapshot::new("null"))
198 }
199 }
200
201 struct NullDriver;
202
203 #[async_trait]
204 impl ProtocolDriver for NullDriver {
205 fn descriptor(&self) -> ProtocolDescriptor {
206 ProtocolDescriptor {
207 key: "null",
208 display_name: "Null",
209 protocol: Protocol::ModbusTcp,
210 default_port: 0,
211 description: "test driver",
212 }
213 }
214
215 fn features(&self) -> &'static [&'static str] {
216 &["feature-a"]
217 }
218
219 async fn build(
220 &self,
221 _spec: ProtocolLaunchSpec,
222 _extensions: RuntimeExtensions,
223 ) -> RuntimeResult<Arc<dyn ManagedService>> {
224 Ok(Arc::new(NullService))
225 }
226 }
227
228 #[test]
229 fn registry_returns_descriptors() {
230 let mut registry = ProtocolDriverRegistry::new();
231 registry.register(NullDriver);
232 assert!(registry.contains("null"));
233 assert_eq!(registry.len(), 1);
234 assert_eq!(registry.descriptors()[0].key, "null");
235 }
236
237 #[test]
238 fn registry_returns_catalog_entries() {
239 let mut registry = ProtocolDriverRegistry::new();
240 registry.register(NullDriver);
241 let catalog = registry.catalog();
242 assert_eq!(catalog.len(), 1);
243 assert_eq!(catalog[0].descriptor.key, "null");
244 assert_eq!(catalog[0].features, vec!["feature-a"]);
245 }
246
247 #[tokio::test]
248 async fn launch_spec_keeps_service_name_override() {
249 let spec = ProtocolLaunchSpec {
250 protocol: "null".into(),
251 name: Some("custom".into()),
252 config: json!({"ok": true}),
253 };
254 let descriptor = NullDriver.descriptor();
255 assert_eq!(spec.service_name(&descriptor), "custom");
256 }
257}