1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use parking_lot::Mutex;
7use serde::{Deserialize, Serialize};
8use serde_json::Value as JsonValue;
9use thiserror::Error;
10use tokio::task::{AbortHandle, JoinError, JoinHandle};
11use tokio::time::{timeout, Duration};
12use tokio_util::sync::CancellationToken;
13
14use mabi_core::Protocol;
15
16use crate::device::DeviceRegistry;
17
18pub type RuntimeResult<T> = Result<T, RuntimeError>;
20
21#[derive(Debug, Error)]
23pub enum RuntimeError {
24 #[error("service error: {message}")]
25 Service { message: String },
26
27 #[error("service task failed: {message}")]
28 TaskJoin { message: String },
29
30 #[error("service readiness timed out after {seconds}s")]
31 ReadinessTimeout { seconds: u64 },
32}
33
34impl RuntimeError {
35 pub fn service(message: impl Into<String>) -> Self {
37 Self::Service {
38 message: message.into(),
39 }
40 }
41}
42
43impl From<JoinError> for RuntimeError {
44 fn from(error: JoinError) -> Self {
45 Self::TaskJoin {
46 message: error.to_string(),
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
53#[serde(rename_all = "snake_case")]
54pub enum ServiceState {
55 #[default]
56 Idle,
57 Starting,
58 Running,
59 Stopping,
60 Stopped,
61 Error,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ServiceStatus {
67 pub name: String,
68 pub protocol: Option<Protocol>,
69 pub state: ServiceState,
70 pub ready: bool,
71 pub started_at: Option<DateTime<Utc>>,
72 pub last_error: Option<String>,
73}
74
75impl ServiceStatus {
76 pub fn new(name: impl Into<String>) -> Self {
78 Self {
79 name: name.into(),
80 protocol: None,
81 state: ServiceState::Idle,
82 ready: false,
83 started_at: None,
84 last_error: None,
85 }
86 }
87
88 pub fn is_terminal(&self) -> bool {
90 matches!(self.state, ServiceState::Stopped | ServiceState::Error)
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ServiceSnapshot {
97 pub name: String,
98 pub protocol: Option<Protocol>,
99 pub status: ServiceStatus,
100 #[serde(default)]
101 pub metadata: BTreeMap<String, JsonValue>,
102}
103
104impl ServiceSnapshot {
105 pub fn new(name: impl Into<String>) -> Self {
107 let name = name.into();
108 Self {
109 status: ServiceStatus::new(name.clone()),
110 name,
111 protocol: None,
112 metadata: BTreeMap::new(),
113 }
114 }
115
116 pub fn with_metadata(mut self, key: impl Into<String>, value: JsonValue) -> Self {
118 self.metadata.insert(key.into(), value);
119 self
120 }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125#[serde(tag = "type", rename_all = "snake_case")]
126pub enum ServiceEvent {
127 StateChanged { state: ServiceState },
128 Cancelled,
129 Message { message: String },
130}
131
132#[derive(Debug, Clone)]
133struct TrackedTask {
134 label: String,
135 abort: AbortHandle,
136}
137
138#[derive(Debug)]
139struct ServiceContextInner {
140 name: String,
141 protocol: Option<Protocol>,
142 started_at: DateTime<Utc>,
143 cancellation: CancellationToken,
144 event_tx: tokio::sync::broadcast::Sender<ServiceEvent>,
145 tracked_tasks: Mutex<Vec<TrackedTask>>,
146}
147
148#[derive(Clone, Debug)]
150pub struct ServiceContext {
151 inner: Arc<ServiceContextInner>,
152}
153
154impl ServiceContext {
155 pub fn new(name: impl Into<String>, protocol: Option<Protocol>) -> Self {
157 let (event_tx, _) = tokio::sync::broadcast::channel(64);
158 Self {
159 inner: Arc::new(ServiceContextInner {
160 name: name.into(),
161 protocol,
162 started_at: Utc::now(),
163 cancellation: CancellationToken::new(),
164 event_tx,
165 tracked_tasks: Mutex::new(Vec::new()),
166 }),
167 }
168 }
169
170 pub fn name(&self) -> &str {
172 &self.inner.name
173 }
174
175 pub fn protocol(&self) -> Option<Protocol> {
177 self.inner.protocol
178 }
179
180 pub fn started_at(&self) -> DateTime<Utc> {
182 self.inner.started_at
183 }
184
185 pub fn cancellation_token(&self) -> CancellationToken {
187 self.inner.cancellation.clone()
188 }
189
190 pub fn child_token(&self) -> CancellationToken {
192 self.inner.cancellation.child_token()
193 }
194
195 pub fn cancel(&self) {
197 self.inner.cancellation.cancel();
198 let _ = self.emit(ServiceEvent::Cancelled);
199 }
200
201 pub fn is_cancelled(&self) -> bool {
203 self.inner.cancellation.is_cancelled()
204 }
205
206 pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ServiceEvent> {
208 self.inner.event_tx.subscribe()
209 }
210
211 pub fn emit(
213 &self,
214 event: ServiceEvent,
215 ) -> Result<usize, tokio::sync::broadcast::error::SendError<ServiceEvent>> {
216 self.inner.event_tx.send(event)
217 }
218
219 pub fn track_task(&self, label: impl Into<String>, handle: &JoinHandle<()>) {
221 self.inner.tracked_tasks.lock().push(TrackedTask {
222 label: label.into(),
223 abort: handle.abort_handle(),
224 });
225 }
226
227 pub fn spawn_task<F>(&self, label: impl Into<String>, future: F) -> JoinHandle<()>
229 where
230 F: std::future::Future<Output = ()> + Send + 'static,
231 {
232 let label = label.into();
233 let handle = tokio::spawn(future);
234 self.inner.tracked_tasks.lock().push(TrackedTask {
235 label,
236 abort: handle.abort_handle(),
237 });
238 handle
239 }
240
241 pub fn tracked_tasks(&self) -> Vec<String> {
243 self.inner
244 .tracked_tasks
245 .lock()
246 .iter()
247 .map(|task| task.label.clone())
248 .collect()
249 }
250
251 pub fn abort_tracked_tasks(&self) {
253 for task in self.inner.tracked_tasks.lock().iter() {
254 task.abort.abort();
255 }
256 }
257}
258
259#[async_trait]
261pub trait ManagedService: Send + Sync {
262 async fn start(&self, context: &ServiceContext) -> RuntimeResult<()>;
264
265 async fn stop(&self, context: &ServiceContext) -> RuntimeResult<()>;
267
268 async fn serve(&self, context: ServiceContext) -> RuntimeResult<()>;
270
271 fn status(&self) -> ServiceStatus;
273
274 async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot>;
276
277 fn register_devices(&self, _registry: &DeviceRegistry) -> RuntimeResult<()> {
279 Ok(())
280 }
281}
282
283pub struct ServiceHandle {
285 service: Arc<dyn ManagedService>,
286 context: ServiceContext,
287 task: Arc<tokio::sync::Mutex<Option<JoinHandle<RuntimeResult<()>>>>>,
288}
289
290impl ServiceHandle {
291 pub fn new(service: Arc<dyn ManagedService>, context: ServiceContext) -> Self {
293 Self {
294 service,
295 context,
296 task: Arc::new(tokio::sync::Mutex::new(None)),
297 }
298 }
299
300 pub fn named(
302 name: impl Into<String>,
303 protocol: Option<Protocol>,
304 service: Arc<dyn ManagedService>,
305 ) -> Self {
306 Self::new(service, ServiceContext::new(name, protocol))
307 }
308
309 pub fn context(&self) -> ServiceContext {
311 self.context.clone()
312 }
313
314 pub async fn spawn(&self) -> RuntimeResult<()> {
316 let mut guard = self.task.lock().await;
317 if guard.is_some() {
318 return Ok(());
319 }
320
321 self.service.start(&self.context).await?;
322
323 let service = self.service.clone();
324 let context = self.context.clone();
325 *guard = Some(tokio::spawn(async move { service.serve(context).await }));
326 Ok(())
327 }
328
329 pub async fn stop(&self) -> RuntimeResult<()> {
331 self.context.cancel();
332 self.service.stop(&self.context).await?;
333 self.context.abort_tracked_tasks();
334
335 if let Some(handle) = self.task.lock().await.take() {
336 handle.await??;
337 }
338
339 Ok(())
340 }
341
342 pub async fn wait(&self) -> RuntimeResult<()> {
344 if let Some(handle) = self.task.lock().await.take() {
345 handle.await??;
346 }
347 Ok(())
348 }
349
350 pub async fn readiness(&self, max_wait: Duration) -> RuntimeResult<ServiceStatus> {
352 let service = self.service.clone();
353 timeout(max_wait, async move {
354 loop {
355 let status = service.status();
356 if status.ready || status.is_terminal() {
357 return status;
358 }
359 tokio::time::sleep(Duration::from_millis(25)).await;
360 }
361 })
362 .await
363 .map_err(|_| RuntimeError::ReadinessTimeout {
364 seconds: max_wait.as_secs(),
365 })
366 }
367
368 pub fn status(&self) -> ServiceStatus {
370 self.service.status()
371 }
372
373 pub async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
375 self.service.snapshot().await
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use std::sync::Arc;
382
383 use async_trait::async_trait;
384 use tokio::time::Duration;
385
386 use crate::service::{
387 ManagedService, RuntimeResult, ServiceContext, ServiceHandle, ServiceSnapshot,
388 ServiceState, ServiceStatus,
389 };
390
391 struct TestService {
392 status: parking_lot::RwLock<ServiceStatus>,
393 }
394
395 impl TestService {
396 fn new() -> Self {
397 Self {
398 status: parking_lot::RwLock::new(ServiceStatus::new("test")),
399 }
400 }
401 }
402
403 #[async_trait]
404 impl ManagedService for TestService {
405 async fn start(&self, context: &ServiceContext) -> RuntimeResult<()> {
406 let mut status = self.status.write();
407 status.state = ServiceState::Starting;
408 status.started_at = Some(context.started_at());
409 Ok(())
410 }
411
412 async fn stop(&self, _context: &ServiceContext) -> RuntimeResult<()> {
413 let mut status = self.status.write();
414 status.state = ServiceState::Stopped;
415 status.ready = false;
416 Ok(())
417 }
418
419 async fn serve(&self, context: ServiceContext) -> RuntimeResult<()> {
420 {
421 let mut status = self.status.write();
422 status.state = ServiceState::Running;
423 status.ready = true;
424 }
425 context.cancellation_token().cancelled().await;
426 let mut status = self.status.write();
427 status.state = ServiceState::Stopped;
428 status.ready = false;
429 Ok(())
430 }
431
432 fn status(&self) -> ServiceStatus {
433 self.status.read().clone()
434 }
435
436 async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
437 let mut snapshot = ServiceSnapshot::new("test");
438 snapshot.status = self.status();
439 Ok(snapshot)
440 }
441 }
442
443 #[tokio::test]
444 async fn handle_spawns_and_stops_service() {
445 let service = Arc::new(TestService::new());
446 let handle = ServiceHandle::named("test", None, service);
447 handle.spawn().await.unwrap();
448 let status = handle.readiness(Duration::from_secs(1)).await.unwrap();
449 assert!(status.ready);
450 handle.stop().await.unwrap();
451 assert_eq!(handle.status().state, ServiceState::Stopped);
452 }
453}