Skip to main content

mabi_runtime/
session.rs

1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value as JsonValue;
5use tokio::time::Duration;
6
7use mabi_core::Protocol;
8
9use crate::device::{DeviceRegistry, DynDevicePort};
10use crate::driver::{ProtocolDriverRegistry, ProtocolLaunchSpec};
11use crate::service::{RuntimeError, RuntimeResult, ServiceHandle, ServiceSnapshot};
12
13/// Decorates controller-visible device ports at runtime.
14pub trait DevicePortLayer: Send + Sync {
15    fn decorate(&self, protocol: Option<Protocol>, port: DynDevicePort) -> DynDevicePort;
16}
17
18/// Shared runtime extensions consumed by sessions and protocol drivers.
19#[derive(Clone, Default)]
20pub struct RuntimeExtensions {
21    device_layers: Vec<Arc<dyn DevicePortLayer>>,
22    protocol_configs: std::collections::BTreeMap<String, JsonValue>,
23}
24
25impl RuntimeExtensions {
26    /// Creates an empty extension set.
27    pub fn new() -> Self {
28        Self::default()
29    }
30
31    /// Appends a device-layer decorator.
32    pub fn add_device_layer(&mut self, layer: Arc<dyn DevicePortLayer>) {
33        self.device_layers.push(layer);
34    }
35
36    /// Inserts a protocol-scoped configuration payload.
37    pub fn insert_protocol_config(&mut self, protocol: impl Into<String>, value: JsonValue) {
38        self.protocol_configs.insert(protocol.into(), value);
39    }
40
41    /// Returns a protocol-scoped configuration payload.
42    pub fn protocol_config(&self, protocol: &str) -> Option<&JsonValue> {
43        self.protocol_configs.get(protocol)
44    }
45
46    /// Applies all registered device layers in insertion order.
47    pub fn decorate_device_port(
48        &self,
49        protocol: Option<Protocol>,
50        mut port: DynDevicePort,
51    ) -> DynDevicePort {
52        for layer in &self.device_layers {
53            port = layer.decorate(protocol, port);
54        }
55        port
56    }
57}
58
59/// Runtime session configuration shared by CLI controllers.
60#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
61pub struct RuntimeSessionSpec {
62    /// Protocol services that should be launched in the session.
63    #[serde(default)]
64    pub services: Vec<ProtocolLaunchSpec>,
65    /// Optional readiness timeout override in milliseconds.
66    #[serde(default)]
67    pub readiness_timeout: Option<u64>,
68}
69
70impl RuntimeSessionSpec {
71    /// Returns the configured readiness timeout or a fallback duration.
72    pub fn readiness_duration(&self, fallback: Duration) -> Duration {
73        self.readiness_timeout
74            .map(Duration::from_millis)
75            .unwrap_or(fallback)
76    }
77}
78
79/// Same-process runtime session containing multiple managed services.
80pub struct RuntimeSession {
81    spec: RuntimeSessionSpec,
82    devices: DeviceRegistry,
83    handles: Vec<ServiceHandle>,
84}
85
86impl RuntimeSession {
87    /// Builds a runtime session from the provided spec and registry.
88    pub async fn new(
89        spec: RuntimeSessionSpec,
90        registry: &ProtocolDriverRegistry,
91        extensions: RuntimeExtensions,
92    ) -> RuntimeResult<Self> {
93        if spec.services.is_empty() {
94            return Err(RuntimeError::service(
95                "runtime session requires at least one service",
96            ));
97        }
98
99        let devices = DeviceRegistry::new();
100        let mut handles = Vec::with_capacity(spec.services.len());
101
102        for launch in &spec.services {
103            let driver = registry.get(launch.key()).ok_or_else(|| {
104                RuntimeError::service(format!("unknown protocol driver: {}", launch.key()))
105            })?;
106            let descriptor = driver.descriptor();
107            let service = driver.build(launch.clone(), extensions.clone()).await?;
108            let service_protocol = service.status().protocol.or(Some(descriptor.protocol));
109
110            let service_devices = DeviceRegistry::new();
111            service.register_devices(&service_devices)?;
112            for (device_id, port) in service_devices.entries() {
113                devices.register(
114                    device_id,
115                    extensions.decorate_device_port(service_protocol, port),
116                );
117            }
118
119            handles.push(ServiceHandle::named(
120                launch.service_name(&descriptor),
121                service_protocol,
122                service,
123            ));
124        }
125
126        Ok(Self {
127            spec,
128            devices,
129            handles,
130        })
131    }
132
133    /// Starts all managed services and waits for readiness.
134    pub async fn start(&self, fallback_readiness_timeout: Duration) -> RuntimeResult<()> {
135        let readiness_timeout = self.spec.readiness_duration(fallback_readiness_timeout);
136        let mut started = Vec::new();
137
138        for handle in &self.handles {
139            if let Err(error) = handle.spawn().await {
140                self.stop_started(&started).await;
141                return Err(error);
142            }
143
144            match handle.readiness(readiness_timeout).await {
145                Ok(status) if status.ready && !status.is_terminal() => started.push(handle),
146                Ok(status) => {
147                    self.stop_started(&started).await;
148                    return Err(RuntimeError::service(format!(
149                        "service failed to become ready: {} ({:?})",
150                        status.name, status.state
151                    )));
152                }
153                Err(error) => {
154                    self.stop_started(&started).await;
155                    return Err(error);
156                }
157            }
158        }
159
160        Ok(())
161    }
162
163    async fn stop_started(&self, started: &[&ServiceHandle]) {
164        for handle in started.iter().rev() {
165            let _ = handle.stop().await;
166        }
167    }
168
169    /// Stops all managed services in reverse order.
170    pub async fn stop(&self) -> RuntimeResult<()> {
171        let mut errors = Vec::new();
172        for handle in self.handles.iter().rev() {
173            if let Err(error) = handle.stop().await {
174                errors.push(error.to_string());
175            }
176        }
177        if errors.is_empty() {
178            Ok(())
179        } else {
180            Err(RuntimeError::service(errors.join("; ")))
181        }
182    }
183
184    /// Returns the shared controller-visible device registry.
185    pub fn devices(&self) -> DeviceRegistry {
186        self.devices.clone()
187    }
188
189    /// Returns the current service snapshots.
190    pub async fn snapshots(&self) -> RuntimeResult<Vec<ServiceSnapshot>> {
191        let mut snapshots = Vec::with_capacity(self.handles.len());
192        for handle in &self.handles {
193            snapshots.push(handle.snapshot().await?);
194        }
195        Ok(snapshots)
196    }
197
198    /// Returns the managed handles.
199    pub fn handles(&self) -> &[ServiceHandle] {
200        &self.handles
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use std::sync::Arc;
207
208    use async_trait::async_trait;
209    use tokio::time::Duration;
210
211    use mabi_core::Protocol;
212
213    use crate::device::DeviceRegistry;
214    use crate::driver::{
215        ProtocolDescriptor, ProtocolDriver, ProtocolDriverRegistry, ProtocolLaunchSpec,
216    };
217    use crate::service::{
218        ManagedService, RuntimeResult, ServiceContext, ServiceSnapshot, ServiceState, ServiceStatus,
219    };
220    use crate::session::{RuntimeExtensions, RuntimeSession, RuntimeSessionSpec};
221
222    struct NullService {
223        status: parking_lot::RwLock<ServiceStatus>,
224    }
225
226    impl NullService {
227        fn new() -> Self {
228            Self {
229                status: parking_lot::RwLock::new(ServiceStatus::new("null")),
230            }
231        }
232    }
233
234    #[async_trait]
235    impl ManagedService for NullService {
236        async fn start(&self, _context: &ServiceContext) -> RuntimeResult<()> {
237            let mut status = self.status.write();
238            status.state = ServiceState::Starting;
239            Ok(())
240        }
241
242        async fn stop(&self, _context: &ServiceContext) -> RuntimeResult<()> {
243            let mut status = self.status.write();
244            status.state = ServiceState::Stopped;
245            status.ready = false;
246            Ok(())
247        }
248
249        async fn serve(&self, context: ServiceContext) -> RuntimeResult<()> {
250            {
251                let mut status = self.status.write();
252                status.state = ServiceState::Running;
253                status.ready = true;
254            }
255            context.cancellation_token().cancelled().await;
256            let mut status = self.status.write();
257            status.state = ServiceState::Stopped;
258            status.ready = false;
259            Ok(())
260        }
261
262        fn status(&self) -> ServiceStatus {
263            self.status.read().clone()
264        }
265
266        async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
267            let mut snapshot = ServiceSnapshot::new("null");
268            snapshot.status = self.status();
269            Ok(snapshot)
270        }
271
272        fn register_devices(&self, _registry: &DeviceRegistry) -> RuntimeResult<()> {
273            Ok(())
274        }
275    }
276
277    struct NullDriver;
278
279    #[async_trait]
280    impl ProtocolDriver for NullDriver {
281        fn descriptor(&self) -> ProtocolDescriptor {
282            ProtocolDescriptor {
283                key: "null",
284                display_name: "Null",
285                protocol: Protocol::ModbusTcp,
286                default_port: 0,
287                description: "null driver",
288            }
289        }
290
291        async fn build(
292            &self,
293            _spec: ProtocolLaunchSpec,
294            _extensions: RuntimeExtensions,
295        ) -> RuntimeResult<Arc<dyn ManagedService>> {
296            Ok(Arc::new(NullService::new()))
297        }
298    }
299
300    #[tokio::test]
301    async fn session_starts_and_stops_services() {
302        let mut registry = ProtocolDriverRegistry::new();
303        registry.register(NullDriver);
304
305        let spec = RuntimeSessionSpec {
306            services: vec![ProtocolLaunchSpec {
307                protocol: "null".into(),
308                name: Some("test-null".into()),
309                config: serde_json::json!({}),
310            }],
311            readiness_timeout: Some(1_000),
312        };
313
314        let session = RuntimeSession::new(spec, &registry, RuntimeExtensions::default())
315            .await
316            .unwrap();
317        session.start(Duration::from_secs(1)).await.unwrap();
318        assert_eq!(session.snapshots().await.unwrap().len(), 1);
319        session.stop().await.unwrap();
320    }
321}