1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use parking_lot::Mutex;
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value as JsonValue};
9use thiserror::Error;
10use tokio::task::{AbortHandle, JoinError, JoinHandle};
11use tokio::time::{timeout, Duration};
12use tokio_util::sync::CancellationToken;
13
14use mabi_core::Protocol;
15
16use crate::device::DeviceRegistry;
17
18pub type RuntimeResult<T> = Result<T, RuntimeError>;
20
21pub const RUNTIME_CONTRACT_VERSION: &str = "runtime-contract-v1";
23
24pub const SNAPSHOT_METADATA_VERSION: &str = "snapshot-metadata-v1";
26
27pub const RUNTIME_METADATA_KEY: &str = "_runtime";
29
30#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
32#[serde(rename_all = "snake_case")]
33pub enum RuntimeErrorKind {
34 ProtocolError,
35 ConfigError,
36 BindError,
37 Timeout,
38 InternalError,
39}
40
41impl std::fmt::Display for RuntimeErrorKind {
42 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 formatter.write_str(match self {
44 Self::ProtocolError => "protocol_error",
45 Self::ConfigError => "config_error",
46 Self::BindError => "bind_error",
47 Self::Timeout => "timeout",
48 Self::InternalError => "internal_error",
49 })
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
55pub struct RuntimeErrorInfo {
56 pub kind: RuntimeErrorKind,
57 pub message: String,
58}
59
60#[derive(Debug, Error)]
62#[non_exhaustive]
63pub enum RuntimeError {
64 #[error("service error: {message}")]
65 Service { message: String },
66
67 #[error("service task failed: {message}")]
68 TaskJoin { message: String },
69
70 #[error("service readiness timed out after {seconds}s")]
71 ReadinessTimeout { seconds: u64 },
72
73 #[error("{kind}: {message}")]
74 Classified {
75 kind: RuntimeErrorKind,
76 message: String,
77 },
78}
79
80impl RuntimeError {
81 pub fn service(message: impl Into<String>) -> Self {
83 Self::Service {
84 message: message.into(),
85 }
86 }
87
88 pub fn protocol(message: impl Into<String>) -> Self {
90 Self::classified(RuntimeErrorKind::ProtocolError, message)
91 }
92
93 pub fn config(message: impl Into<String>) -> Self {
95 Self::classified(RuntimeErrorKind::ConfigError, message)
96 }
97
98 pub fn bind(message: impl Into<String>) -> Self {
100 Self::classified(RuntimeErrorKind::BindError, message)
101 }
102
103 pub fn timeout(message: impl Into<String>) -> Self {
105 Self::classified(RuntimeErrorKind::Timeout, message)
106 }
107
108 pub fn internal(message: impl Into<String>) -> Self {
110 Self::classified(RuntimeErrorKind::InternalError, message)
111 }
112
113 fn classified(kind: RuntimeErrorKind, message: impl Into<String>) -> Self {
114 Self::Classified {
115 kind,
116 message: message.into(),
117 }
118 }
119
120 pub fn kind(&self) -> RuntimeErrorKind {
122 match self {
123 Self::Service { .. } | Self::TaskJoin { .. } => RuntimeErrorKind::InternalError,
124 Self::ReadinessTimeout { .. } => RuntimeErrorKind::Timeout,
125 Self::Classified { kind, .. } => *kind,
126 }
127 }
128
129 pub fn message(&self) -> String {
131 match self {
132 Self::Service { message }
133 | Self::TaskJoin { message }
134 | Self::Classified { message, .. } => message.clone(),
135 Self::ReadinessTimeout { seconds } => {
136 format!("service readiness timed out after {seconds}s")
137 }
138 }
139 }
140
141 pub fn info(&self) -> RuntimeErrorInfo {
143 RuntimeErrorInfo {
144 kind: self.kind(),
145 message: self.message(),
146 }
147 }
148}
149
150impl From<JoinError> for RuntimeError {
151 fn from(error: JoinError) -> Self {
152 Self::internal(format!("service task failed: {error}"))
153 }
154}
155
156#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
158#[serde(rename_all = "snake_case")]
159pub enum ServiceState {
160 #[default]
161 Idle,
162 Starting,
163 Running,
164 Stopping,
165 Stopped,
166 Error,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct ServiceStatus {
172 pub name: String,
173 pub protocol: Option<Protocol>,
174 pub state: ServiceState,
175 pub ready: bool,
176 pub started_at: Option<DateTime<Utc>>,
177 pub last_error: Option<String>,
178}
179
180impl ServiceStatus {
181 pub fn new(name: impl Into<String>) -> Self {
183 Self {
184 name: name.into(),
185 protocol: None,
186 state: ServiceState::Idle,
187 ready: false,
188 started_at: None,
189 last_error: None,
190 }
191 }
192
193 pub fn is_terminal(&self) -> bool {
195 matches!(self.state, ServiceState::Stopped | ServiceState::Error)
196 }
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct ServiceSnapshot {
202 pub name: String,
203 pub protocol: Option<Protocol>,
204 pub status: ServiceStatus,
205 #[serde(default)]
206 pub metadata: BTreeMap<String, JsonValue>,
207}
208
209impl ServiceSnapshot {
210 pub fn new(name: impl Into<String>) -> Self {
212 let name = name.into();
213 Self {
214 status: ServiceStatus::new(name.clone()),
215 name,
216 protocol: None,
217 metadata: BTreeMap::new(),
218 }
219 }
220
221 pub fn with_metadata(mut self, key: impl Into<String>, value: JsonValue) -> Self {
223 self.metadata.insert(key.into(), value);
224 self
225 }
226
227 pub fn with_runtime_metadata(mut self) -> Self {
229 self.ensure_runtime_metadata();
230 self
231 }
232
233 pub fn ensure_runtime_metadata(&mut self) {
235 let metadata = ServiceRuntimeMetadata::from_snapshot(self);
236 self.metadata
237 .insert(RUNTIME_METADATA_KEY.to_string(), json!(metadata));
238 }
239
240 pub fn runtime_metadata(&self) -> Option<ServiceRuntimeMetadata> {
242 self.metadata
243 .get(RUNTIME_METADATA_KEY)
244 .and_then(|value| serde_json::from_value(value.clone()).ok())
245 }
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
250pub struct ServiceRuntimeMetadata {
251 pub contract_version: String,
252 pub snapshot_metadata_version: String,
253 pub captured_at: DateTime<Utc>,
254 pub service_name: String,
255 pub protocol: Option<String>,
256 pub state: ServiceState,
257 pub ready: bool,
258 pub started_at: Option<DateTime<Utc>>,
259 pub last_error: Option<String>,
260}
261
262impl ServiceRuntimeMetadata {
263 pub fn from_snapshot(snapshot: &ServiceSnapshot) -> Self {
265 let protocol = snapshot
266 .status
267 .protocol
268 .or(snapshot.protocol)
269 .map(|protocol| protocol.to_string());
270 Self {
271 contract_version: RUNTIME_CONTRACT_VERSION.to_string(),
272 snapshot_metadata_version: SNAPSHOT_METADATA_VERSION.to_string(),
273 captured_at: Utc::now(),
274 service_name: snapshot.status.name.clone(),
275 protocol,
276 state: snapshot.status.state,
277 ready: snapshot.status.ready,
278 started_at: snapshot.status.started_at,
279 last_error: snapshot.status.last_error.clone(),
280 }
281 }
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
286pub struct ServiceReadinessReport {
287 pub contract_version: String,
288 pub checked_at: DateTime<Utc>,
289 pub service_name: String,
290 pub protocol: Option<String>,
291 pub state: ServiceState,
292 pub ready: bool,
293 pub timeout_ms: u64,
294 pub error: Option<RuntimeErrorInfo>,
295}
296
297impl ServiceReadinessReport {
298 pub fn from_status(
300 status: ServiceStatus,
301 timeout: Duration,
302 error: Option<RuntimeErrorInfo>,
303 ) -> Self {
304 Self {
305 contract_version: RUNTIME_CONTRACT_VERSION.to_string(),
306 checked_at: Utc::now(),
307 service_name: status.name,
308 protocol: status.protocol.map(|protocol| protocol.to_string()),
309 state: status.state,
310 ready: status.ready,
311 timeout_ms: timeout.as_millis() as u64,
312 error,
313 }
314 }
315}
316
317#[derive(Debug, Clone, Serialize, Deserialize)]
319#[serde(tag = "type", rename_all = "snake_case")]
320pub enum ServiceEvent {
321 StateChanged { state: ServiceState },
322 Cancelled,
323 Message { message: String },
324}
325
326#[derive(Debug, Clone)]
327struct TrackedTask {
328 label: String,
329 abort: AbortHandle,
330}
331
332#[derive(Debug)]
333struct ServiceContextInner {
334 name: String,
335 protocol: Option<Protocol>,
336 started_at: DateTime<Utc>,
337 cancellation: CancellationToken,
338 event_tx: tokio::sync::broadcast::Sender<ServiceEvent>,
339 tracked_tasks: Mutex<Vec<TrackedTask>>,
340}
341
342#[derive(Clone, Debug)]
344pub struct ServiceContext {
345 inner: Arc<ServiceContextInner>,
346}
347
348impl ServiceContext {
349 pub fn new(name: impl Into<String>, protocol: Option<Protocol>) -> Self {
351 let (event_tx, _) = tokio::sync::broadcast::channel(64);
352 Self {
353 inner: Arc::new(ServiceContextInner {
354 name: name.into(),
355 protocol,
356 started_at: Utc::now(),
357 cancellation: CancellationToken::new(),
358 event_tx,
359 tracked_tasks: Mutex::new(Vec::new()),
360 }),
361 }
362 }
363
364 pub fn name(&self) -> &str {
366 &self.inner.name
367 }
368
369 pub fn protocol(&self) -> Option<Protocol> {
371 self.inner.protocol
372 }
373
374 pub fn started_at(&self) -> DateTime<Utc> {
376 self.inner.started_at
377 }
378
379 pub fn cancellation_token(&self) -> CancellationToken {
381 self.inner.cancellation.clone()
382 }
383
384 pub fn child_token(&self) -> CancellationToken {
386 self.inner.cancellation.child_token()
387 }
388
389 pub fn cancel(&self) {
391 self.inner.cancellation.cancel();
392 let _ = self.emit(ServiceEvent::Cancelled);
393 }
394
395 pub fn is_cancelled(&self) -> bool {
397 self.inner.cancellation.is_cancelled()
398 }
399
400 pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ServiceEvent> {
402 self.inner.event_tx.subscribe()
403 }
404
405 pub fn emit(
407 &self,
408 event: ServiceEvent,
409 ) -> Result<usize, tokio::sync::broadcast::error::SendError<ServiceEvent>> {
410 self.inner.event_tx.send(event)
411 }
412
413 pub fn track_task(&self, label: impl Into<String>, handle: &JoinHandle<()>) {
415 self.inner.tracked_tasks.lock().push(TrackedTask {
416 label: label.into(),
417 abort: handle.abort_handle(),
418 });
419 }
420
421 pub fn spawn_task<F>(&self, label: impl Into<String>, future: F) -> JoinHandle<()>
423 where
424 F: std::future::Future<Output = ()> + Send + 'static,
425 {
426 let label = label.into();
427 let handle = tokio::spawn(future);
428 self.inner.tracked_tasks.lock().push(TrackedTask {
429 label,
430 abort: handle.abort_handle(),
431 });
432 handle
433 }
434
435 pub fn tracked_tasks(&self) -> Vec<String> {
437 self.inner
438 .tracked_tasks
439 .lock()
440 .iter()
441 .map(|task| task.label.clone())
442 .collect()
443 }
444
445 pub fn abort_tracked_tasks(&self) {
447 for task in self.inner.tracked_tasks.lock().iter() {
448 task.abort.abort();
449 }
450 }
451}
452
453#[async_trait]
455pub trait ManagedService: Send + Sync {
456 async fn start(&self, context: &ServiceContext) -> RuntimeResult<()>;
458
459 async fn stop(&self, context: &ServiceContext) -> RuntimeResult<()>;
461
462 async fn serve(&self, context: ServiceContext) -> RuntimeResult<()>;
464
465 fn status(&self) -> ServiceStatus;
467
468 async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot>;
470
471 fn register_devices(&self, _registry: &DeviceRegistry) -> RuntimeResult<()> {
473 Ok(())
474 }
475}
476
477pub struct ServiceHandle {
479 service: Arc<dyn ManagedService>,
480 context: ServiceContext,
481 task: Arc<tokio::sync::Mutex<Option<JoinHandle<RuntimeResult<()>>>>>,
482}
483
484impl ServiceHandle {
485 pub fn new(service: Arc<dyn ManagedService>, context: ServiceContext) -> Self {
487 Self {
488 service,
489 context,
490 task: Arc::new(tokio::sync::Mutex::new(None)),
491 }
492 }
493
494 pub fn named(
496 name: impl Into<String>,
497 protocol: Option<Protocol>,
498 service: Arc<dyn ManagedService>,
499 ) -> Self {
500 Self::new(service, ServiceContext::new(name, protocol))
501 }
502
503 pub fn context(&self) -> ServiceContext {
505 self.context.clone()
506 }
507
508 pub async fn spawn(&self) -> RuntimeResult<()> {
510 let mut guard = self.task.lock().await;
511 if guard.is_some() {
512 return Ok(());
513 }
514
515 self.service.start(&self.context).await?;
516
517 let service = self.service.clone();
518 let context = self.context.clone();
519 *guard = Some(tokio::spawn(async move { service.serve(context).await }));
520 Ok(())
521 }
522
523 pub async fn stop(&self) -> RuntimeResult<()> {
525 self.context.cancel();
526 self.service.stop(&self.context).await?;
527 self.context.abort_tracked_tasks();
528
529 if let Some(handle) = self.task.lock().await.take() {
530 handle.await??;
531 }
532
533 Ok(())
534 }
535
536 pub async fn wait(&self) -> RuntimeResult<()> {
538 if let Some(handle) = self.task.lock().await.take() {
539 handle.await??;
540 }
541 Ok(())
542 }
543
544 pub async fn readiness(&self, max_wait: Duration) -> RuntimeResult<ServiceStatus> {
546 let service = self.service.clone();
547 timeout(max_wait, async move {
548 loop {
549 let status = service.status();
550 if status.ready || status.is_terminal() {
551 return status;
552 }
553 tokio::time::sleep(Duration::from_millis(25)).await;
554 }
555 })
556 .await
557 .map_err(|_| {
558 RuntimeError::timeout(format!(
559 "service readiness timed out after {}ms",
560 max_wait.as_millis()
561 ))
562 })
563 }
564
565 pub async fn readiness_report(&self, max_wait: Duration) -> ServiceReadinessReport {
567 match self.readiness(max_wait).await {
568 Ok(status) => ServiceReadinessReport::from_status(status, max_wait, None),
569 Err(error) => {
570 let status = self.status();
571 ServiceReadinessReport::from_status(status, max_wait, Some(error.info()))
572 }
573 }
574 }
575
576 pub fn status(&self) -> ServiceStatus {
578 self.service.status()
579 }
580
581 pub async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
583 Ok(self.service.snapshot().await?.with_runtime_metadata())
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use std::sync::Arc;
590
591 use async_trait::async_trait;
592 use tokio::time::Duration;
593
594 use crate::service::{
595 ManagedService, RuntimeError, RuntimeErrorKind, RuntimeResult, ServiceContext,
596 ServiceHandle, ServiceSnapshot, ServiceState, ServiceStatus, RUNTIME_CONTRACT_VERSION,
597 RUNTIME_METADATA_KEY, SNAPSHOT_METADATA_VERSION,
598 };
599
600 struct TestService {
601 status: parking_lot::RwLock<ServiceStatus>,
602 }
603
604 impl TestService {
605 fn new() -> Self {
606 Self {
607 status: parking_lot::RwLock::new(ServiceStatus::new("test")),
608 }
609 }
610 }
611
612 #[async_trait]
613 impl ManagedService for TestService {
614 async fn start(&self, context: &ServiceContext) -> RuntimeResult<()> {
615 let mut status = self.status.write();
616 status.state = ServiceState::Starting;
617 status.started_at = Some(context.started_at());
618 Ok(())
619 }
620
621 async fn stop(&self, _context: &ServiceContext) -> RuntimeResult<()> {
622 let mut status = self.status.write();
623 status.state = ServiceState::Stopped;
624 status.ready = false;
625 Ok(())
626 }
627
628 async fn serve(&self, context: ServiceContext) -> RuntimeResult<()> {
629 {
630 let mut status = self.status.write();
631 status.state = ServiceState::Running;
632 status.ready = true;
633 }
634 context.cancellation_token().cancelled().await;
635 let mut status = self.status.write();
636 status.state = ServiceState::Stopped;
637 status.ready = false;
638 Ok(())
639 }
640
641 fn status(&self) -> ServiceStatus {
642 self.status.read().clone()
643 }
644
645 async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
646 let mut snapshot = ServiceSnapshot::new("test");
647 snapshot.status = self.status();
648 Ok(snapshot)
649 }
650 }
651
652 #[tokio::test]
653 async fn handle_spawns_and_stops_service() {
654 let service = Arc::new(TestService::new());
655 let handle = ServiceHandle::named("test", None, service);
656 handle.spawn().await.unwrap();
657 let status = handle.readiness(Duration::from_secs(1)).await.unwrap();
658 assert!(status.ready);
659 let report = handle.readiness_report(Duration::from_secs(1)).await;
660 assert!(report.ready);
661 assert_eq!(report.contract_version, RUNTIME_CONTRACT_VERSION);
662 assert!(serde_json::to_value(&report).unwrap()["checked_at"].is_string());
663
664 let snapshot = handle.snapshot().await.unwrap();
665 assert!(snapshot.metadata.contains_key(RUNTIME_METADATA_KEY));
666 let runtime = snapshot.runtime_metadata().expect("runtime metadata");
667 assert_eq!(runtime.contract_version, RUNTIME_CONTRACT_VERSION);
668 assert_eq!(runtime.snapshot_metadata_version, SNAPSHOT_METADATA_VERSION);
669 assert_eq!(runtime.service_name, "test");
670 assert!(runtime.ready);
671
672 handle.stop().await.unwrap();
673 assert_eq!(handle.status().state, ServiceState::Stopped);
674 }
675
676 #[test]
677 fn runtime_error_info_uses_stable_kinds() {
678 let error = RuntimeError::config("invalid launch config");
679 assert_eq!(error.kind(), RuntimeErrorKind::ConfigError);
680 assert_eq!(error.info().message, "invalid launch config");
681
682 let value = serde_json::to_value(error.info()).unwrap();
683 assert_eq!(value["kind"], "config_error");
684 assert_eq!(value["message"], "invalid launch config");
685 }
686}