use std::collections::BTreeMap;
use std::sync::Arc;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use thiserror::Error;
use tokio::task::{AbortHandle, JoinError, JoinHandle};
use tokio::time::{timeout, Duration};
use tokio_util::sync::CancellationToken;
use mabi_core::Protocol;
use crate::device::DeviceRegistry;
pub type RuntimeResult<T> = Result<T, RuntimeError>;
#[derive(Debug, Error)]
pub enum RuntimeError {
#[error("service error: {message}")]
Service { message: String },
#[error("service task failed: {message}")]
TaskJoin { message: String },
#[error("service readiness timed out after {seconds}s")]
ReadinessTimeout { seconds: u64 },
}
impl RuntimeError {
pub fn service(message: impl Into<String>) -> Self {
Self::Service {
message: message.into(),
}
}
}
impl From<JoinError> for RuntimeError {
fn from(error: JoinError) -> Self {
Self::TaskJoin {
message: error.to_string(),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum ServiceState {
#[default]
Idle,
Starting,
Running,
Stopping,
Stopped,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceStatus {
pub name: String,
pub protocol: Option<Protocol>,
pub state: ServiceState,
pub ready: bool,
pub started_at: Option<DateTime<Utc>>,
pub last_error: Option<String>,
}
impl ServiceStatus {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
protocol: None,
state: ServiceState::Idle,
ready: false,
started_at: None,
last_error: None,
}
}
pub fn is_terminal(&self) -> bool {
matches!(self.state, ServiceState::Stopped | ServiceState::Error)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceSnapshot {
pub name: String,
pub protocol: Option<Protocol>,
pub status: ServiceStatus,
#[serde(default)]
pub metadata: BTreeMap<String, JsonValue>,
}
impl ServiceSnapshot {
pub fn new(name: impl Into<String>) -> Self {
let name = name.into();
Self {
status: ServiceStatus::new(name.clone()),
name,
protocol: None,
metadata: BTreeMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: JsonValue) -> Self {
self.metadata.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServiceEvent {
StateChanged { state: ServiceState },
Cancelled,
Message { message: String },
}
#[derive(Debug, Clone)]
struct TrackedTask {
label: String,
abort: AbortHandle,
}
#[derive(Debug)]
struct ServiceContextInner {
name: String,
protocol: Option<Protocol>,
started_at: DateTime<Utc>,
cancellation: CancellationToken,
event_tx: tokio::sync::broadcast::Sender<ServiceEvent>,
tracked_tasks: Mutex<Vec<TrackedTask>>,
}
#[derive(Clone, Debug)]
pub struct ServiceContext {
inner: Arc<ServiceContextInner>,
}
impl ServiceContext {
pub fn new(name: impl Into<String>, protocol: Option<Protocol>) -> Self {
let (event_tx, _) = tokio::sync::broadcast::channel(64);
Self {
inner: Arc::new(ServiceContextInner {
name: name.into(),
protocol,
started_at: Utc::now(),
cancellation: CancellationToken::new(),
event_tx,
tracked_tasks: Mutex::new(Vec::new()),
}),
}
}
pub fn name(&self) -> &str {
&self.inner.name
}
pub fn protocol(&self) -> Option<Protocol> {
self.inner.protocol
}
pub fn started_at(&self) -> DateTime<Utc> {
self.inner.started_at
}
pub fn cancellation_token(&self) -> CancellationToken {
self.inner.cancellation.clone()
}
pub fn child_token(&self) -> CancellationToken {
self.inner.cancellation.child_token()
}
pub fn cancel(&self) {
self.inner.cancellation.cancel();
let _ = self.emit(ServiceEvent::Cancelled);
}
pub fn is_cancelled(&self) -> bool {
self.inner.cancellation.is_cancelled()
}
pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ServiceEvent> {
self.inner.event_tx.subscribe()
}
pub fn emit(
&self,
event: ServiceEvent,
) -> Result<usize, tokio::sync::broadcast::error::SendError<ServiceEvent>> {
self.inner.event_tx.send(event)
}
pub fn track_task(&self, label: impl Into<String>, handle: &JoinHandle<()>) {
self.inner.tracked_tasks.lock().push(TrackedTask {
label: label.into(),
abort: handle.abort_handle(),
});
}
pub fn spawn_task<F>(&self, label: impl Into<String>, future: F) -> JoinHandle<()>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let label = label.into();
let handle = tokio::spawn(future);
self.inner.tracked_tasks.lock().push(TrackedTask {
label,
abort: handle.abort_handle(),
});
handle
}
pub fn tracked_tasks(&self) -> Vec<String> {
self.inner
.tracked_tasks
.lock()
.iter()
.map(|task| task.label.clone())
.collect()
}
pub fn abort_tracked_tasks(&self) {
for task in self.inner.tracked_tasks.lock().iter() {
task.abort.abort();
}
}
}
#[async_trait]
pub trait ManagedService: Send + Sync {
async fn start(&self, context: &ServiceContext) -> RuntimeResult<()>;
async fn stop(&self, context: &ServiceContext) -> RuntimeResult<()>;
async fn serve(&self, context: ServiceContext) -> RuntimeResult<()>;
fn status(&self) -> ServiceStatus;
async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot>;
fn register_devices(&self, _registry: &DeviceRegistry) -> RuntimeResult<()> {
Ok(())
}
}
pub struct ServiceHandle {
service: Arc<dyn ManagedService>,
context: ServiceContext,
task: Arc<tokio::sync::Mutex<Option<JoinHandle<RuntimeResult<()>>>>>,
}
impl ServiceHandle {
pub fn new(service: Arc<dyn ManagedService>, context: ServiceContext) -> Self {
Self {
service,
context,
task: Arc::new(tokio::sync::Mutex::new(None)),
}
}
pub fn named(
name: impl Into<String>,
protocol: Option<Protocol>,
service: Arc<dyn ManagedService>,
) -> Self {
Self::new(service, ServiceContext::new(name, protocol))
}
pub fn context(&self) -> ServiceContext {
self.context.clone()
}
pub async fn spawn(&self) -> RuntimeResult<()> {
let mut guard = self.task.lock().await;
if guard.is_some() {
return Ok(());
}
self.service.start(&self.context).await?;
let service = self.service.clone();
let context = self.context.clone();
*guard = Some(tokio::spawn(async move { service.serve(context).await }));
Ok(())
}
pub async fn stop(&self) -> RuntimeResult<()> {
self.context.cancel();
self.service.stop(&self.context).await?;
self.context.abort_tracked_tasks();
if let Some(handle) = self.task.lock().await.take() {
handle.await??;
}
Ok(())
}
pub async fn wait(&self) -> RuntimeResult<()> {
if let Some(handle) = self.task.lock().await.take() {
handle.await??;
}
Ok(())
}
pub async fn readiness(&self, max_wait: Duration) -> RuntimeResult<ServiceStatus> {
let service = self.service.clone();
timeout(max_wait, async move {
loop {
let status = service.status();
if status.ready || status.is_terminal() {
return status;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
})
.await
.map_err(|_| RuntimeError::ReadinessTimeout {
seconds: max_wait.as_secs(),
})
}
pub fn status(&self) -> ServiceStatus {
self.service.status()
}
pub async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
self.service.snapshot().await
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use async_trait::async_trait;
use tokio::time::Duration;
use crate::service::{
ManagedService, RuntimeResult, ServiceContext, ServiceHandle, ServiceSnapshot,
ServiceState, ServiceStatus,
};
struct TestService {
status: parking_lot::RwLock<ServiceStatus>,
}
impl TestService {
fn new() -> Self {
Self {
status: parking_lot::RwLock::new(ServiceStatus::new("test")),
}
}
}
#[async_trait]
impl ManagedService for TestService {
async fn start(&self, context: &ServiceContext) -> RuntimeResult<()> {
let mut status = self.status.write();
status.state = ServiceState::Starting;
status.started_at = Some(context.started_at());
Ok(())
}
async fn stop(&self, _context: &ServiceContext) -> RuntimeResult<()> {
let mut status = self.status.write();
status.state = ServiceState::Stopped;
status.ready = false;
Ok(())
}
async fn serve(&self, context: ServiceContext) -> RuntimeResult<()> {
{
let mut status = self.status.write();
status.state = ServiceState::Running;
status.ready = true;
}
context.cancellation_token().cancelled().await;
let mut status = self.status.write();
status.state = ServiceState::Stopped;
status.ready = false;
Ok(())
}
fn status(&self) -> ServiceStatus {
self.status.read().clone()
}
async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
let mut snapshot = ServiceSnapshot::new("test");
snapshot.status = self.status();
Ok(snapshot)
}
}
#[tokio::test]
async fn handle_spawns_and_stops_service() {
let service = Arc::new(TestService::new());
let handle = ServiceHandle::named("test", None, service);
handle.spawn().await.unwrap();
let status = handle.readiness(Duration::from_secs(1)).await.unwrap();
assert!(status.ready);
handle.stop().await.unwrap();
assert_eq!(handle.status().state, ServiceState::Stopped);
}
}