Skip to main content

mabi_runtime/
session.rs

1use std::sync::Arc;
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use serde_json::Value as JsonValue;
6use tokio::time::Duration;
7
8use mabi_core::Protocol;
9
10use crate::device::{DeviceRegistry, DynDevicePort};
11use crate::driver::{ProtocolDriverRegistry, ProtocolLaunchSpec};
12use crate::service::{
13    RuntimeError, RuntimeResult, ServiceHandle, ServiceSnapshot, RUNTIME_CONTRACT_VERSION,
14    SNAPSHOT_METADATA_VERSION,
15};
16
17/// Decorates controller-visible device ports at runtime.
18pub trait DevicePortLayer: Send + Sync {
19    fn decorate(&self, protocol: Option<Protocol>, port: DynDevicePort) -> DynDevicePort;
20}
21
22/// Shared runtime extensions consumed by sessions and protocol drivers.
23#[derive(Clone, Default)]
24pub struct RuntimeExtensions {
25    device_layers: Vec<Arc<dyn DevicePortLayer>>,
26    protocol_configs: std::collections::BTreeMap<String, JsonValue>,
27}
28
29impl RuntimeExtensions {
30    /// Creates an empty extension set.
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Appends a device-layer decorator.
36    pub fn add_device_layer(&mut self, layer: Arc<dyn DevicePortLayer>) {
37        self.device_layers.push(layer);
38    }
39
40    /// Inserts a protocol-scoped configuration payload.
41    pub fn insert_protocol_config(&mut self, protocol: impl Into<String>, value: JsonValue) {
42        self.protocol_configs.insert(protocol.into(), value);
43    }
44
45    /// Returns a protocol-scoped configuration payload.
46    pub fn protocol_config(&self, protocol: &str) -> Option<&JsonValue> {
47        self.protocol_configs.get(protocol)
48    }
49
50    /// Applies all registered device layers in insertion order.
51    pub fn decorate_device_port(
52        &self,
53        protocol: Option<Protocol>,
54        mut port: DynDevicePort,
55    ) -> DynDevicePort {
56        for layer in &self.device_layers {
57            port = layer.decorate(protocol, port);
58        }
59        port
60    }
61}
62
63/// Runtime session configuration shared by CLI controllers.
64#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
65pub struct RuntimeSessionSpec {
66    /// Protocol services that should be launched in the session.
67    #[serde(default)]
68    pub services: Vec<ProtocolLaunchSpec>,
69    /// Optional readiness timeout override in milliseconds.
70    #[serde(default)]
71    pub readiness_timeout: Option<u64>,
72}
73
74/// Session-level runtime snapshot envelope for runner-facing consumers.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct RuntimeSessionSnapshot {
77    pub contract_version: String,
78    pub snapshot_metadata_version: String,
79    pub captured_at: DateTime<Utc>,
80    pub services: Vec<ServiceSnapshot>,
81}
82
83impl RuntimeSessionSnapshot {
84    /// Builds a session snapshot envelope around normalized service snapshots.
85    pub fn new(services: Vec<ServiceSnapshot>) -> Self {
86        Self {
87            contract_version: RUNTIME_CONTRACT_VERSION.to_string(),
88            snapshot_metadata_version: SNAPSHOT_METADATA_VERSION.to_string(),
89            captured_at: Utc::now(),
90            services,
91        }
92    }
93}
94
95impl RuntimeSessionSpec {
96    /// Returns the configured readiness timeout or a fallback duration.
97    pub fn readiness_duration(&self, fallback: Duration) -> Duration {
98        self.readiness_timeout
99            .map(Duration::from_millis)
100            .unwrap_or(fallback)
101    }
102}
103
104/// Same-process runtime session containing multiple managed services.
105pub struct RuntimeSession {
106    spec: RuntimeSessionSpec,
107    devices: DeviceRegistry,
108    handles: Vec<ServiceHandle>,
109}
110
111impl RuntimeSession {
112    /// Builds a runtime session from the provided spec and registry.
113    pub async fn new(
114        spec: RuntimeSessionSpec,
115        registry: &ProtocolDriverRegistry,
116        extensions: RuntimeExtensions,
117    ) -> RuntimeResult<Self> {
118        if spec.services.is_empty() {
119            return Err(RuntimeError::config(
120                "runtime session requires at least one service",
121            ));
122        }
123
124        let devices = DeviceRegistry::new();
125        let mut handles = Vec::with_capacity(spec.services.len());
126
127        for launch in &spec.services {
128            let driver = registry.get(launch.key()).ok_or_else(|| {
129                RuntimeError::config(format!("unknown protocol driver: {}", launch.key()))
130            })?;
131            let descriptor = driver.descriptor();
132            let service = driver.build(launch.clone(), extensions.clone()).await?;
133            let service_protocol = service.status().protocol.or(Some(descriptor.protocol));
134
135            let service_devices = DeviceRegistry::new();
136            service.register_devices(&service_devices)?;
137            for (device_id, port) in service_devices.entries() {
138                devices.register(
139                    device_id,
140                    extensions.decorate_device_port(service_protocol, port),
141                );
142            }
143
144            handles.push(ServiceHandle::named(
145                launch.service_name(&descriptor),
146                service_protocol,
147                service,
148            ));
149        }
150
151        Ok(Self {
152            spec,
153            devices,
154            handles,
155        })
156    }
157
158    /// Starts all managed services and waits for readiness.
159    pub async fn start(&self, fallback_readiness_timeout: Duration) -> RuntimeResult<()> {
160        let readiness_timeout = self.spec.readiness_duration(fallback_readiness_timeout);
161        let mut started = Vec::new();
162
163        for handle in &self.handles {
164            if let Err(error) = handle.spawn().await {
165                self.stop_started(&started).await;
166                return Err(error);
167            }
168
169            match handle.readiness(readiness_timeout).await {
170                Ok(status) if status.ready && !status.is_terminal() => started.push(handle),
171                Ok(status) => {
172                    self.stop_started(&started).await;
173                    return Err(RuntimeError::protocol(format!(
174                        "service failed to become ready: {} ({:?})",
175                        status.name, status.state
176                    )));
177                }
178                Err(error) => {
179                    self.stop_started(&started).await;
180                    return Err(error);
181                }
182            }
183        }
184
185        Ok(())
186    }
187
188    async fn stop_started(&self, started: &[&ServiceHandle]) {
189        for handle in started.iter().rev() {
190            let _ = handle.stop().await;
191        }
192    }
193
194    /// Stops all managed services in reverse order.
195    pub async fn stop(&self) -> RuntimeResult<()> {
196        let mut errors = Vec::new();
197        for handle in self.handles.iter().rev() {
198            if let Err(error) = handle.stop().await {
199                errors.push(error.to_string());
200            }
201        }
202        if errors.is_empty() {
203            Ok(())
204        } else {
205            Err(RuntimeError::protocol(errors.join("; ")))
206        }
207    }
208
209    /// Returns the shared controller-visible device registry.
210    pub fn devices(&self) -> DeviceRegistry {
211        self.devices.clone()
212    }
213
214    /// Returns the current service snapshots.
215    pub async fn snapshots(&self) -> RuntimeResult<Vec<ServiceSnapshot>> {
216        let mut snapshots = Vec::with_capacity(self.handles.len());
217        for handle in &self.handles {
218            snapshots.push(handle.snapshot().await?);
219        }
220        Ok(snapshots)
221    }
222
223    /// Returns a session-level snapshot envelope with normalized service snapshots.
224    pub async fn session_snapshot(&self) -> RuntimeResult<RuntimeSessionSnapshot> {
225        Ok(RuntimeSessionSnapshot::new(self.snapshots().await?))
226    }
227
228    /// Returns the managed handles.
229    pub fn handles(&self) -> &[ServiceHandle] {
230        &self.handles
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use std::sync::Arc;
237
238    use async_trait::async_trait;
239    use tokio::time::Duration;
240
241    use mabi_core::Protocol;
242
243    use crate::device::DeviceRegistry;
244    use crate::driver::{
245        ProtocolDescriptor, ProtocolDriver, ProtocolDriverRegistry, ProtocolLaunchSpec,
246    };
247    use crate::service::{
248        ManagedService, RuntimeResult, ServiceContext, ServiceSnapshot, ServiceState, ServiceStatus,
249    };
250    use crate::session::{RuntimeExtensions, RuntimeSession, RuntimeSessionSpec};
251
252    struct NullService {
253        status: parking_lot::RwLock<ServiceStatus>,
254    }
255
256    impl NullService {
257        fn new() -> Self {
258            Self {
259                status: parking_lot::RwLock::new(ServiceStatus::new("null")),
260            }
261        }
262    }
263
264    #[async_trait]
265    impl ManagedService for NullService {
266        async fn start(&self, _context: &ServiceContext) -> RuntimeResult<()> {
267            let mut status = self.status.write();
268            status.state = ServiceState::Starting;
269            Ok(())
270        }
271
272        async fn stop(&self, _context: &ServiceContext) -> RuntimeResult<()> {
273            let mut status = self.status.write();
274            status.state = ServiceState::Stopped;
275            status.ready = false;
276            Ok(())
277        }
278
279        async fn serve(&self, context: ServiceContext) -> RuntimeResult<()> {
280            {
281                let mut status = self.status.write();
282                status.state = ServiceState::Running;
283                status.ready = true;
284            }
285            context.cancellation_token().cancelled().await;
286            let mut status = self.status.write();
287            status.state = ServiceState::Stopped;
288            status.ready = false;
289            Ok(())
290        }
291
292        fn status(&self) -> ServiceStatus {
293            self.status.read().clone()
294        }
295
296        async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
297            let mut snapshot = ServiceSnapshot::new("null");
298            snapshot.status = self.status();
299            Ok(snapshot)
300        }
301
302        fn register_devices(&self, _registry: &DeviceRegistry) -> RuntimeResult<()> {
303            Ok(())
304        }
305    }
306
307    struct NullDriver;
308
309    #[async_trait]
310    impl ProtocolDriver for NullDriver {
311        fn descriptor(&self) -> ProtocolDescriptor {
312            ProtocolDescriptor {
313                key: "null",
314                display_name: "Null",
315                protocol: Protocol::ModbusTcp,
316                default_port: 0,
317                description: "null driver",
318            }
319        }
320
321        async fn build(
322            &self,
323            _spec: ProtocolLaunchSpec,
324            _extensions: RuntimeExtensions,
325        ) -> RuntimeResult<Arc<dyn ManagedService>> {
326            Ok(Arc::new(NullService::new()))
327        }
328    }
329
330    #[tokio::test]
331    async fn session_starts_and_stops_services() {
332        let mut registry = ProtocolDriverRegistry::new();
333        registry.register(NullDriver);
334
335        let spec = RuntimeSessionSpec {
336            services: vec![ProtocolLaunchSpec {
337                protocol: "null".into(),
338                name: Some("test-null".into()),
339                config: serde_json::json!({}),
340            }],
341            readiness_timeout: Some(1_000),
342        };
343
344        let session = RuntimeSession::new(spec, &registry, RuntimeExtensions::default())
345            .await
346            .unwrap();
347        session.start(Duration::from_secs(1)).await.unwrap();
348        let snapshots = session.snapshots().await.unwrap();
349        assert_eq!(snapshots.len(), 1);
350        assert!(snapshots[0].runtime_metadata().is_some());
351        let session_snapshot = session.session_snapshot().await.unwrap();
352        assert_eq!(session_snapshot.services.len(), 1);
353        assert_eq!(session_snapshot.contract_version, "runtime-contract-v1");
354        session.stop().await.unwrap();
355    }
356}