Skip to main content

mabi_runtime/
service.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use parking_lot::Mutex;
7use serde::{Deserialize, Serialize};
8use serde_json::Value as JsonValue;
9use thiserror::Error;
10use tokio::task::{AbortHandle, JoinError, JoinHandle};
11use tokio::time::{timeout, Duration};
12use tokio_util::sync::CancellationToken;
13
14use mabi_core::Protocol;
15
16use crate::device::DeviceRegistry;
17
18/// Runtime-level result type.
19pub type RuntimeResult<T> = Result<T, RuntimeError>;
20
21/// Runtime-level errors.
22#[derive(Debug, Error)]
23pub enum RuntimeError {
24    #[error("service error: {message}")]
25    Service { message: String },
26
27    #[error("service task failed: {message}")]
28    TaskJoin { message: String },
29
30    #[error("service readiness timed out after {seconds}s")]
31    ReadinessTimeout { seconds: u64 },
32}
33
34impl RuntimeError {
35    /// Convenience constructor for message-based errors.
36    pub fn service(message: impl Into<String>) -> Self {
37        Self::Service {
38            message: message.into(),
39        }
40    }
41}
42
43impl From<JoinError> for RuntimeError {
44    fn from(error: JoinError) -> Self {
45        Self::TaskJoin {
46            message: error.to_string(),
47        }
48    }
49}
50
51/// Shared service lifecycle states.
52#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
53#[serde(rename_all = "snake_case")]
54pub enum ServiceState {
55    #[default]
56    Idle,
57    Starting,
58    Running,
59    Stopping,
60    Stopped,
61    Error,
62}
63
64/// Current service status snapshot.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ServiceStatus {
67    pub name: String,
68    pub protocol: Option<Protocol>,
69    pub state: ServiceState,
70    pub ready: bool,
71    pub started_at: Option<DateTime<Utc>>,
72    pub last_error: Option<String>,
73}
74
75impl ServiceStatus {
76    /// Creates a fresh idle status.
77    pub fn new(name: impl Into<String>) -> Self {
78        Self {
79            name: name.into(),
80            protocol: None,
81            state: ServiceState::Idle,
82            ready: false,
83            started_at: None,
84            last_error: None,
85        }
86    }
87
88    /// Returns true when the service is terminal.
89    pub fn is_terminal(&self) -> bool {
90        matches!(self.state, ServiceState::Stopped | ServiceState::Error)
91    }
92}
93
94/// Structured snapshot used by the CLI and tests.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ServiceSnapshot {
97    pub name: String,
98    pub protocol: Option<Protocol>,
99    pub status: ServiceStatus,
100    #[serde(default)]
101    pub metadata: BTreeMap<String, JsonValue>,
102}
103
104impl ServiceSnapshot {
105    /// Creates an empty snapshot.
106    pub fn new(name: impl Into<String>) -> Self {
107        let name = name.into();
108        Self {
109            status: ServiceStatus::new(name.clone()),
110            name,
111            protocol: None,
112            metadata: BTreeMap::new(),
113        }
114    }
115
116    /// Adds metadata to the snapshot.
117    pub fn with_metadata(mut self, key: impl Into<String>, value: JsonValue) -> Self {
118        self.metadata.insert(key.into(), value);
119        self
120    }
121}
122
123/// Events emitted by the shared runtime context.
124#[derive(Debug, Clone, Serialize, Deserialize)]
125#[serde(tag = "type", rename_all = "snake_case")]
126pub enum ServiceEvent {
127    StateChanged { state: ServiceState },
128    Cancelled,
129    Message { message: String },
130}
131
132#[derive(Debug, Clone)]
133struct TrackedTask {
134    label: String,
135    abort: AbortHandle,
136}
137
138#[derive(Debug)]
139struct ServiceContextInner {
140    name: String,
141    protocol: Option<Protocol>,
142    started_at: DateTime<Utc>,
143    cancellation: CancellationToken,
144    event_tx: tokio::sync::broadcast::Sender<ServiceEvent>,
145    tracked_tasks: Mutex<Vec<TrackedTask>>,
146}
147
148/// Shared runtime context provided to all managed services.
149#[derive(Clone, Debug)]
150pub struct ServiceContext {
151    inner: Arc<ServiceContextInner>,
152}
153
154impl ServiceContext {
155    /// Creates a new service context.
156    pub fn new(name: impl Into<String>, protocol: Option<Protocol>) -> Self {
157        let (event_tx, _) = tokio::sync::broadcast::channel(64);
158        Self {
159            inner: Arc::new(ServiceContextInner {
160                name: name.into(),
161                protocol,
162                started_at: Utc::now(),
163                cancellation: CancellationToken::new(),
164                event_tx,
165                tracked_tasks: Mutex::new(Vec::new()),
166            }),
167        }
168    }
169
170    /// Returns the service name.
171    pub fn name(&self) -> &str {
172        &self.inner.name
173    }
174
175    /// Returns the service protocol, if one exists.
176    pub fn protocol(&self) -> Option<Protocol> {
177        self.inner.protocol
178    }
179
180    /// Returns when the context was created.
181    pub fn started_at(&self) -> DateTime<Utc> {
182        self.inner.started_at
183    }
184
185    /// Returns the shared cancellation token.
186    pub fn cancellation_token(&self) -> CancellationToken {
187        self.inner.cancellation.clone()
188    }
189
190    /// Returns a child token for scoped tasks.
191    pub fn child_token(&self) -> CancellationToken {
192        self.inner.cancellation.child_token()
193    }
194
195    /// Cancels the context and all child scopes.
196    pub fn cancel(&self) {
197        self.inner.cancellation.cancel();
198        let _ = self.emit(ServiceEvent::Cancelled);
199    }
200
201    /// Returns whether cancellation has been requested.
202    pub fn is_cancelled(&self) -> bool {
203        self.inner.cancellation.is_cancelled()
204    }
205
206    /// Subscribes to service events.
207    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ServiceEvent> {
208        self.inner.event_tx.subscribe()
209    }
210
211    /// Emits a service event.
212    pub fn emit(
213        &self,
214        event: ServiceEvent,
215    ) -> Result<usize, tokio::sync::broadcast::error::SendError<ServiceEvent>> {
216        self.inner.event_tx.send(event)
217    }
218
219    /// Tracks an externally-spawned task under the context.
220    pub fn track_task(&self, label: impl Into<String>, handle: &JoinHandle<()>) {
221        self.inner.tracked_tasks.lock().push(TrackedTask {
222            label: label.into(),
223            abort: handle.abort_handle(),
224        });
225    }
226
227    /// Spawns and tracks a unit-returning background task.
228    pub fn spawn_task<F>(&self, label: impl Into<String>, future: F) -> JoinHandle<()>
229    where
230        F: std::future::Future<Output = ()> + Send + 'static,
231    {
232        let label = label.into();
233        let handle = tokio::spawn(future);
234        self.inner.tracked_tasks.lock().push(TrackedTask {
235            label,
236            abort: handle.abort_handle(),
237        });
238        handle
239    }
240
241    /// Returns the tracked task labels.
242    pub fn tracked_tasks(&self) -> Vec<String> {
243        self.inner
244            .tracked_tasks
245            .lock()
246            .iter()
247            .map(|task| task.label.clone())
248            .collect()
249    }
250
251    /// Aborts all tracked tasks.
252    pub fn abort_tracked_tasks(&self) {
253        for task in self.inner.tracked_tasks.lock().iter() {
254            task.abort.abort();
255        }
256    }
257}
258
259/// Shared lifecycle contract for protocol services.
260#[async_trait]
261pub trait ManagedService: Send + Sync {
262    /// Performs any non-blocking startup work.
263    async fn start(&self, context: &ServiceContext) -> RuntimeResult<()>;
264
265    /// Requests a graceful stop.
266    async fn stop(&self, context: &ServiceContext) -> RuntimeResult<()>;
267
268    /// Runs the service until completion or cancellation.
269    async fn serve(&self, context: ServiceContext) -> RuntimeResult<()>;
270
271    /// Returns the current status.
272    fn status(&self) -> ServiceStatus;
273
274    /// Returns a structured snapshot.
275    async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot>;
276
277    /// Publishes any controller-visible device ports exposed by this service.
278    fn register_devices(&self, _registry: &DeviceRegistry) -> RuntimeResult<()> {
279        Ok(())
280    }
281}
282
283/// Shared handle for spawning, stopping, and inspecting managed services.
284pub struct ServiceHandle {
285    service: Arc<dyn ManagedService>,
286    context: ServiceContext,
287    task: Arc<tokio::sync::Mutex<Option<JoinHandle<RuntimeResult<()>>>>>,
288}
289
290impl ServiceHandle {
291    /// Creates a new handle around a service and context.
292    pub fn new(service: Arc<dyn ManagedService>, context: ServiceContext) -> Self {
293        Self {
294            service,
295            context,
296            task: Arc::new(tokio::sync::Mutex::new(None)),
297        }
298    }
299
300    /// Creates a handle for a named service.
301    pub fn named(
302        name: impl Into<String>,
303        protocol: Option<Protocol>,
304        service: Arc<dyn ManagedService>,
305    ) -> Self {
306        Self::new(service, ServiceContext::new(name, protocol))
307    }
308
309    /// Returns the shared service context.
310    pub fn context(&self) -> ServiceContext {
311        self.context.clone()
312    }
313
314    /// Spawns the service task if it is not already running.
315    pub async fn spawn(&self) -> RuntimeResult<()> {
316        let mut guard = self.task.lock().await;
317        if guard.is_some() {
318            return Ok(());
319        }
320
321        self.service.start(&self.context).await?;
322
323        let service = self.service.clone();
324        let context = self.context.clone();
325        *guard = Some(tokio::spawn(async move { service.serve(context).await }));
326        Ok(())
327    }
328
329    /// Requests service shutdown and waits for the service task.
330    pub async fn stop(&self) -> RuntimeResult<()> {
331        self.context.cancel();
332        self.service.stop(&self.context).await?;
333        self.context.abort_tracked_tasks();
334
335        if let Some(handle) = self.task.lock().await.take() {
336            handle.await??;
337        }
338
339        Ok(())
340    }
341
342    /// Waits for the service task to finish if it was spawned.
343    pub async fn wait(&self) -> RuntimeResult<()> {
344        if let Some(handle) = self.task.lock().await.take() {
345            handle.await??;
346        }
347        Ok(())
348    }
349
350    /// Waits until the service reports readiness or the timeout elapses.
351    pub async fn readiness(&self, max_wait: Duration) -> RuntimeResult<ServiceStatus> {
352        let service = self.service.clone();
353        timeout(max_wait, async move {
354            loop {
355                let status = service.status();
356                if status.ready || status.is_terminal() {
357                    return status;
358                }
359                tokio::time::sleep(Duration::from_millis(25)).await;
360            }
361        })
362        .await
363        .map_err(|_| RuntimeError::ReadinessTimeout {
364            seconds: max_wait.as_secs(),
365        })
366    }
367
368    /// Returns the latest status.
369    pub fn status(&self) -> ServiceStatus {
370        self.service.status()
371    }
372
373    /// Returns the latest snapshot.
374    pub async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
375        self.service.snapshot().await
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use std::sync::Arc;
382
383    use async_trait::async_trait;
384    use tokio::time::Duration;
385
386    use crate::service::{
387        ManagedService, RuntimeResult, ServiceContext, ServiceHandle, ServiceSnapshot,
388        ServiceState, ServiceStatus,
389    };
390
391    struct TestService {
392        status: parking_lot::RwLock<ServiceStatus>,
393    }
394
395    impl TestService {
396        fn new() -> Self {
397            Self {
398                status: parking_lot::RwLock::new(ServiceStatus::new("test")),
399            }
400        }
401    }
402
403    #[async_trait]
404    impl ManagedService for TestService {
405        async fn start(&self, context: &ServiceContext) -> RuntimeResult<()> {
406            let mut status = self.status.write();
407            status.state = ServiceState::Starting;
408            status.started_at = Some(context.started_at());
409            Ok(())
410        }
411
412        async fn stop(&self, _context: &ServiceContext) -> RuntimeResult<()> {
413            let mut status = self.status.write();
414            status.state = ServiceState::Stopped;
415            status.ready = false;
416            Ok(())
417        }
418
419        async fn serve(&self, context: ServiceContext) -> RuntimeResult<()> {
420            {
421                let mut status = self.status.write();
422                status.state = ServiceState::Running;
423                status.ready = true;
424            }
425            context.cancellation_token().cancelled().await;
426            let mut status = self.status.write();
427            status.state = ServiceState::Stopped;
428            status.ready = false;
429            Ok(())
430        }
431
432        fn status(&self) -> ServiceStatus {
433            self.status.read().clone()
434        }
435
436        async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
437            let mut snapshot = ServiceSnapshot::new("test");
438            snapshot.status = self.status();
439            Ok(snapshot)
440        }
441    }
442
443    #[tokio::test]
444    async fn handle_spawns_and_stops_service() {
445        let service = Arc::new(TestService::new());
446        let handle = ServiceHandle::named("test", None, service);
447        handle.spawn().await.unwrap();
448        let status = handle.readiness(Duration::from_secs(1)).await.unwrap();
449        assert!(status.ready);
450        handle.stop().await.unwrap();
451        assert_eq!(handle.status().state, ServiceState::Stopped);
452    }
453}