Skip to main content

mabi_runtime/
service.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use parking_lot::Mutex;
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value as JsonValue};
9use thiserror::Error;
10use tokio::task::{AbortHandle, JoinError, JoinHandle};
11use tokio::time::{timeout, Duration};
12use tokio_util::sync::CancellationToken;
13
14use mabi_core::Protocol;
15
16use crate::device::DeviceRegistry;
17
18/// Runtime-level result type.
19pub type RuntimeResult<T> = Result<T, RuntimeError>;
20
21/// Stable runtime contract version consumed by Forge and Trials.
22pub const RUNTIME_CONTRACT_VERSION: &str = "runtime-contract-v1";
23
24/// Stable service snapshot metadata contract version.
25pub const SNAPSHOT_METADATA_VERSION: &str = "snapshot-metadata-v1";
26
27/// Reserved metadata key for runtime-owned service snapshot fields.
28pub const RUNTIME_METADATA_KEY: &str = "_runtime";
29
30/// Machine-readable runtime error classification.
31#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
32#[serde(rename_all = "snake_case")]
33pub enum RuntimeErrorKind {
34    ProtocolError,
35    ConfigError,
36    BindError,
37    Timeout,
38    InternalError,
39}
40
41impl std::fmt::Display for RuntimeErrorKind {
42    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        formatter.write_str(match self {
44            Self::ProtocolError => "protocol_error",
45            Self::ConfigError => "config_error",
46            Self::BindError => "bind_error",
47            Self::Timeout => "timeout",
48            Self::InternalError => "internal_error",
49        })
50    }
51}
52
53/// Structured runtime error payload for machine consumers.
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
55pub struct RuntimeErrorInfo {
56    pub kind: RuntimeErrorKind,
57    pub message: String,
58}
59
60/// Runtime-level errors.
61#[derive(Debug, Error)]
62#[non_exhaustive]
63pub enum RuntimeError {
64    #[error("service error: {message}")]
65    Service { message: String },
66
67    #[error("service task failed: {message}")]
68    TaskJoin { message: String },
69
70    #[error("service readiness timed out after {seconds}s")]
71    ReadinessTimeout { seconds: u64 },
72
73    #[error("{kind}: {message}")]
74    Classified {
75        kind: RuntimeErrorKind,
76        message: String,
77    },
78}
79
80impl RuntimeError {
81    /// Convenience constructor for message-based errors.
82    pub fn service(message: impl Into<String>) -> Self {
83        Self::Service {
84            message: message.into(),
85        }
86    }
87
88    /// Creates a protocol-level runtime error.
89    pub fn protocol(message: impl Into<String>) -> Self {
90        Self::classified(RuntimeErrorKind::ProtocolError, message)
91    }
92
93    /// Creates a configuration-level runtime error.
94    pub fn config(message: impl Into<String>) -> Self {
95        Self::classified(RuntimeErrorKind::ConfigError, message)
96    }
97
98    /// Creates a bind/listen/address allocation runtime error.
99    pub fn bind(message: impl Into<String>) -> Self {
100        Self::classified(RuntimeErrorKind::BindError, message)
101    }
102
103    /// Creates a timeout runtime error.
104    pub fn timeout(message: impl Into<String>) -> Self {
105        Self::classified(RuntimeErrorKind::Timeout, message)
106    }
107
108    /// Creates an internal runtime error.
109    pub fn internal(message: impl Into<String>) -> Self {
110        Self::classified(RuntimeErrorKind::InternalError, message)
111    }
112
113    fn classified(kind: RuntimeErrorKind, message: impl Into<String>) -> Self {
114        Self::Classified {
115            kind,
116            message: message.into(),
117        }
118    }
119
120    /// Returns the stable machine-readable error kind.
121    pub fn kind(&self) -> RuntimeErrorKind {
122        match self {
123            Self::Service { .. } | Self::TaskJoin { .. } => RuntimeErrorKind::InternalError,
124            Self::ReadinessTimeout { .. } => RuntimeErrorKind::Timeout,
125            Self::Classified { kind, .. } => *kind,
126        }
127    }
128
129    /// Returns the human-readable error message without losing the stable kind.
130    pub fn message(&self) -> String {
131        match self {
132            Self::Service { message }
133            | Self::TaskJoin { message }
134            | Self::Classified { message, .. } => message.clone(),
135            Self::ReadinessTimeout { seconds } => {
136                format!("service readiness timed out after {seconds}s")
137            }
138        }
139    }
140
141    /// Returns the structured runtime error payload.
142    pub fn info(&self) -> RuntimeErrorInfo {
143        RuntimeErrorInfo {
144            kind: self.kind(),
145            message: self.message(),
146        }
147    }
148}
149
150impl From<JoinError> for RuntimeError {
151    fn from(error: JoinError) -> Self {
152        Self::internal(format!("service task failed: {error}"))
153    }
154}
155
156/// Shared service lifecycle states.
157#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
158#[serde(rename_all = "snake_case")]
159pub enum ServiceState {
160    #[default]
161    Idle,
162    Starting,
163    Running,
164    Stopping,
165    Stopped,
166    Error,
167}
168
169/// Current service status snapshot.
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct ServiceStatus {
172    pub name: String,
173    pub protocol: Option<Protocol>,
174    pub state: ServiceState,
175    pub ready: bool,
176    pub started_at: Option<DateTime<Utc>>,
177    pub last_error: Option<String>,
178}
179
180impl ServiceStatus {
181    /// Creates a fresh idle status.
182    pub fn new(name: impl Into<String>) -> Self {
183        Self {
184            name: name.into(),
185            protocol: None,
186            state: ServiceState::Idle,
187            ready: false,
188            started_at: None,
189            last_error: None,
190        }
191    }
192
193    /// Returns true when the service is terminal.
194    pub fn is_terminal(&self) -> bool {
195        matches!(self.state, ServiceState::Stopped | ServiceState::Error)
196    }
197}
198
199/// Structured snapshot used by the CLI and tests.
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct ServiceSnapshot {
202    pub name: String,
203    pub protocol: Option<Protocol>,
204    pub status: ServiceStatus,
205    #[serde(default)]
206    pub metadata: BTreeMap<String, JsonValue>,
207}
208
209impl ServiceSnapshot {
210    /// Creates an empty snapshot.
211    pub fn new(name: impl Into<String>) -> Self {
212        let name = name.into();
213        Self {
214            status: ServiceStatus::new(name.clone()),
215            name,
216            protocol: None,
217            metadata: BTreeMap::new(),
218        }
219    }
220
221    /// Adds metadata to the snapshot.
222    pub fn with_metadata(mut self, key: impl Into<String>, value: JsonValue) -> Self {
223        self.metadata.insert(key.into(), value);
224        self
225    }
226
227    /// Adds or refreshes runtime-owned metadata under the reserved `_runtime` key.
228    pub fn with_runtime_metadata(mut self) -> Self {
229        self.ensure_runtime_metadata();
230        self
231    }
232
233    /// Ensures runtime-owned metadata exists under the reserved `_runtime` key.
234    pub fn ensure_runtime_metadata(&mut self) {
235        let metadata = ServiceRuntimeMetadata::from_snapshot(self);
236        self.metadata
237            .insert(RUNTIME_METADATA_KEY.to_string(), json!(metadata));
238    }
239
240    /// Returns parsed runtime metadata when present.
241    pub fn runtime_metadata(&self) -> Option<ServiceRuntimeMetadata> {
242        self.metadata
243            .get(RUNTIME_METADATA_KEY)
244            .and_then(|value| serde_json::from_value(value.clone()).ok())
245    }
246}
247
248/// Runtime-owned stable service snapshot metadata.
249#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
250pub struct ServiceRuntimeMetadata {
251    pub contract_version: String,
252    pub snapshot_metadata_version: String,
253    pub captured_at: DateTime<Utc>,
254    pub service_name: String,
255    pub protocol: Option<String>,
256    pub state: ServiceState,
257    pub ready: bool,
258    pub started_at: Option<DateTime<Utc>>,
259    pub last_error: Option<String>,
260}
261
262impl ServiceRuntimeMetadata {
263    /// Builds runtime metadata from the current service snapshot.
264    pub fn from_snapshot(snapshot: &ServiceSnapshot) -> Self {
265        let protocol = snapshot
266            .status
267            .protocol
268            .or(snapshot.protocol)
269            .map(|protocol| protocol.to_string());
270        Self {
271            contract_version: RUNTIME_CONTRACT_VERSION.to_string(),
272            snapshot_metadata_version: SNAPSHOT_METADATA_VERSION.to_string(),
273            captured_at: Utc::now(),
274            service_name: snapshot.status.name.clone(),
275            protocol,
276            state: snapshot.status.state,
277            ready: snapshot.status.ready,
278            started_at: snapshot.status.started_at,
279            last_error: snapshot.status.last_error.clone(),
280        }
281    }
282}
283
284/// Structured readiness report for runner-facing health checks.
285#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
286pub struct ServiceReadinessReport {
287    pub contract_version: String,
288    pub checked_at: DateTime<Utc>,
289    pub service_name: String,
290    pub protocol: Option<String>,
291    pub state: ServiceState,
292    pub ready: bool,
293    pub timeout_ms: u64,
294    pub error: Option<RuntimeErrorInfo>,
295}
296
297impl ServiceReadinessReport {
298    /// Builds a readiness report from a status and optional error.
299    pub fn from_status(
300        status: ServiceStatus,
301        timeout: Duration,
302        error: Option<RuntimeErrorInfo>,
303    ) -> Self {
304        Self {
305            contract_version: RUNTIME_CONTRACT_VERSION.to_string(),
306            checked_at: Utc::now(),
307            service_name: status.name,
308            protocol: status.protocol.map(|protocol| protocol.to_string()),
309            state: status.state,
310            ready: status.ready,
311            timeout_ms: timeout.as_millis() as u64,
312            error,
313        }
314    }
315}
316
317/// Events emitted by the shared runtime context.
318#[derive(Debug, Clone, Serialize, Deserialize)]
319#[serde(tag = "type", rename_all = "snake_case")]
320pub enum ServiceEvent {
321    StateChanged { state: ServiceState },
322    Cancelled,
323    Message { message: String },
324}
325
326#[derive(Debug, Clone)]
327struct TrackedTask {
328    label: String,
329    abort: AbortHandle,
330}
331
332#[derive(Debug)]
333struct ServiceContextInner {
334    name: String,
335    protocol: Option<Protocol>,
336    started_at: DateTime<Utc>,
337    cancellation: CancellationToken,
338    event_tx: tokio::sync::broadcast::Sender<ServiceEvent>,
339    tracked_tasks: Mutex<Vec<TrackedTask>>,
340}
341
342/// Shared runtime context provided to all managed services.
343#[derive(Clone, Debug)]
344pub struct ServiceContext {
345    inner: Arc<ServiceContextInner>,
346}
347
348impl ServiceContext {
349    /// Creates a new service context.
350    pub fn new(name: impl Into<String>, protocol: Option<Protocol>) -> Self {
351        let (event_tx, _) = tokio::sync::broadcast::channel(64);
352        Self {
353            inner: Arc::new(ServiceContextInner {
354                name: name.into(),
355                protocol,
356                started_at: Utc::now(),
357                cancellation: CancellationToken::new(),
358                event_tx,
359                tracked_tasks: Mutex::new(Vec::new()),
360            }),
361        }
362    }
363
364    /// Returns the service name.
365    pub fn name(&self) -> &str {
366        &self.inner.name
367    }
368
369    /// Returns the service protocol, if one exists.
370    pub fn protocol(&self) -> Option<Protocol> {
371        self.inner.protocol
372    }
373
374    /// Returns when the context was created.
375    pub fn started_at(&self) -> DateTime<Utc> {
376        self.inner.started_at
377    }
378
379    /// Returns the shared cancellation token.
380    pub fn cancellation_token(&self) -> CancellationToken {
381        self.inner.cancellation.clone()
382    }
383
384    /// Returns a child token for scoped tasks.
385    pub fn child_token(&self) -> CancellationToken {
386        self.inner.cancellation.child_token()
387    }
388
389    /// Cancels the context and all child scopes.
390    pub fn cancel(&self) {
391        self.inner.cancellation.cancel();
392        let _ = self.emit(ServiceEvent::Cancelled);
393    }
394
395    /// Returns whether cancellation has been requested.
396    pub fn is_cancelled(&self) -> bool {
397        self.inner.cancellation.is_cancelled()
398    }
399
400    /// Subscribes to service events.
401    pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ServiceEvent> {
402        self.inner.event_tx.subscribe()
403    }
404
405    /// Emits a service event.
406    pub fn emit(
407        &self,
408        event: ServiceEvent,
409    ) -> Result<usize, tokio::sync::broadcast::error::SendError<ServiceEvent>> {
410        self.inner.event_tx.send(event)
411    }
412
413    /// Tracks an externally-spawned task under the context.
414    pub fn track_task(&self, label: impl Into<String>, handle: &JoinHandle<()>) {
415        self.inner.tracked_tasks.lock().push(TrackedTask {
416            label: label.into(),
417            abort: handle.abort_handle(),
418        });
419    }
420
421    /// Spawns and tracks a unit-returning background task.
422    pub fn spawn_task<F>(&self, label: impl Into<String>, future: F) -> JoinHandle<()>
423    where
424        F: std::future::Future<Output = ()> + Send + 'static,
425    {
426        let label = label.into();
427        let handle = tokio::spawn(future);
428        self.inner.tracked_tasks.lock().push(TrackedTask {
429            label,
430            abort: handle.abort_handle(),
431        });
432        handle
433    }
434
435    /// Returns the tracked task labels.
436    pub fn tracked_tasks(&self) -> Vec<String> {
437        self.inner
438            .tracked_tasks
439            .lock()
440            .iter()
441            .map(|task| task.label.clone())
442            .collect()
443    }
444
445    /// Aborts all tracked tasks.
446    pub fn abort_tracked_tasks(&self) {
447        for task in self.inner.tracked_tasks.lock().iter() {
448            task.abort.abort();
449        }
450    }
451}
452
453/// Shared lifecycle contract for protocol services.
454#[async_trait]
455pub trait ManagedService: Send + Sync {
456    /// Performs any non-blocking startup work.
457    async fn start(&self, context: &ServiceContext) -> RuntimeResult<()>;
458
459    /// Requests a graceful stop.
460    async fn stop(&self, context: &ServiceContext) -> RuntimeResult<()>;
461
462    /// Runs the service until completion or cancellation.
463    async fn serve(&self, context: ServiceContext) -> RuntimeResult<()>;
464
465    /// Returns the current status.
466    fn status(&self) -> ServiceStatus;
467
468    /// Returns a structured snapshot.
469    async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot>;
470
471    /// Publishes any controller-visible device ports exposed by this service.
472    fn register_devices(&self, _registry: &DeviceRegistry) -> RuntimeResult<()> {
473        Ok(())
474    }
475}
476
477/// Shared handle for spawning, stopping, and inspecting managed services.
478pub struct ServiceHandle {
479    service: Arc<dyn ManagedService>,
480    context: ServiceContext,
481    task: Arc<tokio::sync::Mutex<Option<JoinHandle<RuntimeResult<()>>>>>,
482}
483
484impl ServiceHandle {
485    /// Creates a new handle around a service and context.
486    pub fn new(service: Arc<dyn ManagedService>, context: ServiceContext) -> Self {
487        Self {
488            service,
489            context,
490            task: Arc::new(tokio::sync::Mutex::new(None)),
491        }
492    }
493
494    /// Creates a handle for a named service.
495    pub fn named(
496        name: impl Into<String>,
497        protocol: Option<Protocol>,
498        service: Arc<dyn ManagedService>,
499    ) -> Self {
500        Self::new(service, ServiceContext::new(name, protocol))
501    }
502
503    /// Returns the shared service context.
504    pub fn context(&self) -> ServiceContext {
505        self.context.clone()
506    }
507
508    /// Spawns the service task if it is not already running.
509    pub async fn spawn(&self) -> RuntimeResult<()> {
510        let mut guard = self.task.lock().await;
511        if guard.is_some() {
512            return Ok(());
513        }
514
515        self.service.start(&self.context).await?;
516
517        let service = self.service.clone();
518        let context = self.context.clone();
519        *guard = Some(tokio::spawn(async move { service.serve(context).await }));
520        Ok(())
521    }
522
523    /// Requests service shutdown and waits for the service task.
524    pub async fn stop(&self) -> RuntimeResult<()> {
525        self.context.cancel();
526        self.service.stop(&self.context).await?;
527        self.context.abort_tracked_tasks();
528
529        if let Some(handle) = self.task.lock().await.take() {
530            handle.await??;
531        }
532
533        Ok(())
534    }
535
536    /// Waits for the service task to finish if it was spawned.
537    pub async fn wait(&self) -> RuntimeResult<()> {
538        if let Some(handle) = self.task.lock().await.take() {
539            handle.await??;
540        }
541        Ok(())
542    }
543
544    /// Waits until the service reports readiness or the timeout elapses.
545    pub async fn readiness(&self, max_wait: Duration) -> RuntimeResult<ServiceStatus> {
546        let service = self.service.clone();
547        timeout(max_wait, async move {
548            loop {
549                let status = service.status();
550                if status.ready || status.is_terminal() {
551                    return status;
552                }
553                tokio::time::sleep(Duration::from_millis(25)).await;
554            }
555        })
556        .await
557        .map_err(|_| {
558            RuntimeError::timeout(format!(
559                "service readiness timed out after {}ms",
560                max_wait.as_millis()
561            ))
562        })
563    }
564
565    /// Returns a structured readiness report without discarding status context.
566    pub async fn readiness_report(&self, max_wait: Duration) -> ServiceReadinessReport {
567        match self.readiness(max_wait).await {
568            Ok(status) => ServiceReadinessReport::from_status(status, max_wait, None),
569            Err(error) => {
570                let status = self.status();
571                ServiceReadinessReport::from_status(status, max_wait, Some(error.info()))
572            }
573        }
574    }
575
576    /// Returns the latest status.
577    pub fn status(&self) -> ServiceStatus {
578        self.service.status()
579    }
580
581    /// Returns the latest snapshot.
582    pub async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
583        Ok(self.service.snapshot().await?.with_runtime_metadata())
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use std::sync::Arc;
590
591    use async_trait::async_trait;
592    use tokio::time::Duration;
593
594    use crate::service::{
595        ManagedService, RuntimeError, RuntimeErrorKind, RuntimeResult, ServiceContext,
596        ServiceHandle, ServiceSnapshot, ServiceState, ServiceStatus, RUNTIME_CONTRACT_VERSION,
597        RUNTIME_METADATA_KEY, SNAPSHOT_METADATA_VERSION,
598    };
599
600    struct TestService {
601        status: parking_lot::RwLock<ServiceStatus>,
602    }
603
604    impl TestService {
605        fn new() -> Self {
606            Self {
607                status: parking_lot::RwLock::new(ServiceStatus::new("test")),
608            }
609        }
610    }
611
612    #[async_trait]
613    impl ManagedService for TestService {
614        async fn start(&self, context: &ServiceContext) -> RuntimeResult<()> {
615            let mut status = self.status.write();
616            status.state = ServiceState::Starting;
617            status.started_at = Some(context.started_at());
618            Ok(())
619        }
620
621        async fn stop(&self, _context: &ServiceContext) -> RuntimeResult<()> {
622            let mut status = self.status.write();
623            status.state = ServiceState::Stopped;
624            status.ready = false;
625            Ok(())
626        }
627
628        async fn serve(&self, context: ServiceContext) -> RuntimeResult<()> {
629            {
630                let mut status = self.status.write();
631                status.state = ServiceState::Running;
632                status.ready = true;
633            }
634            context.cancellation_token().cancelled().await;
635            let mut status = self.status.write();
636            status.state = ServiceState::Stopped;
637            status.ready = false;
638            Ok(())
639        }
640
641        fn status(&self) -> ServiceStatus {
642            self.status.read().clone()
643        }
644
645        async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
646            let mut snapshot = ServiceSnapshot::new("test");
647            snapshot.status = self.status();
648            Ok(snapshot)
649        }
650    }
651
652    #[tokio::test]
653    async fn handle_spawns_and_stops_service() {
654        let service = Arc::new(TestService::new());
655        let handle = ServiceHandle::named("test", None, service);
656        handle.spawn().await.unwrap();
657        let status = handle.readiness(Duration::from_secs(1)).await.unwrap();
658        assert!(status.ready);
659        let report = handle.readiness_report(Duration::from_secs(1)).await;
660        assert!(report.ready);
661        assert_eq!(report.contract_version, RUNTIME_CONTRACT_VERSION);
662        assert!(serde_json::to_value(&report).unwrap()["checked_at"].is_string());
663
664        let snapshot = handle.snapshot().await.unwrap();
665        assert!(snapshot.metadata.contains_key(RUNTIME_METADATA_KEY));
666        let runtime = snapshot.runtime_metadata().expect("runtime metadata");
667        assert_eq!(runtime.contract_version, RUNTIME_CONTRACT_VERSION);
668        assert_eq!(runtime.snapshot_metadata_version, SNAPSHOT_METADATA_VERSION);
669        assert_eq!(runtime.service_name, "test");
670        assert!(runtime.ready);
671
672        handle.stop().await.unwrap();
673        assert_eq!(handle.status().state, ServiceState::Stopped);
674    }
675
676    #[test]
677    fn runtime_error_info_uses_stable_kinds() {
678        let error = RuntimeError::config("invalid launch config");
679        assert_eq!(error.kind(), RuntimeErrorKind::ConfigError);
680        assert_eq!(error.info().message, "invalid launch config");
681
682        let value = serde_json::to_value(error.info()).unwrap();
683        assert_eq!(value["kind"], "config_error");
684        assert_eq!(value["message"], "invalid launch config");
685    }
686}