use std::collections::BTreeMap;
use std::sync::Arc;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value as JsonValue};
use thiserror::Error;
use tokio::task::{AbortHandle, JoinError, JoinHandle};
use tokio::time::{timeout, Duration};
use tokio_util::sync::CancellationToken;
use mabi_core::Protocol;
use crate::device::DeviceRegistry;
pub type RuntimeResult<T> = Result<T, RuntimeError>;
pub const RUNTIME_CONTRACT_VERSION: &str = "runtime-contract-v1";
pub const SNAPSHOT_METADATA_VERSION: &str = "snapshot-metadata-v1";
pub const RUNTIME_METADATA_KEY: &str = "_runtime";
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RuntimeErrorKind {
ProtocolError,
ConfigError,
BindError,
Timeout,
InternalError,
}
impl std::fmt::Display for RuntimeErrorKind {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str(match self {
Self::ProtocolError => "protocol_error",
Self::ConfigError => "config_error",
Self::BindError => "bind_error",
Self::Timeout => "timeout",
Self::InternalError => "internal_error",
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RuntimeErrorInfo {
pub kind: RuntimeErrorKind,
pub message: String,
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RuntimeError {
#[error("service error: {message}")]
Service { message: String },
#[error("service task failed: {message}")]
TaskJoin { message: String },
#[error("service readiness timed out after {seconds}s")]
ReadinessTimeout { seconds: u64 },
#[error("{kind}: {message}")]
Classified {
kind: RuntimeErrorKind,
message: String,
},
}
impl RuntimeError {
pub fn service(message: impl Into<String>) -> Self {
Self::Service {
message: message.into(),
}
}
pub fn protocol(message: impl Into<String>) -> Self {
Self::classified(RuntimeErrorKind::ProtocolError, message)
}
pub fn config(message: impl Into<String>) -> Self {
Self::classified(RuntimeErrorKind::ConfigError, message)
}
pub fn bind(message: impl Into<String>) -> Self {
Self::classified(RuntimeErrorKind::BindError, message)
}
pub fn timeout(message: impl Into<String>) -> Self {
Self::classified(RuntimeErrorKind::Timeout, message)
}
pub fn internal(message: impl Into<String>) -> Self {
Self::classified(RuntimeErrorKind::InternalError, message)
}
fn classified(kind: RuntimeErrorKind, message: impl Into<String>) -> Self {
Self::Classified {
kind,
message: message.into(),
}
}
pub fn kind(&self) -> RuntimeErrorKind {
match self {
Self::Service { .. } | Self::TaskJoin { .. } => RuntimeErrorKind::InternalError,
Self::ReadinessTimeout { .. } => RuntimeErrorKind::Timeout,
Self::Classified { kind, .. } => *kind,
}
}
pub fn message(&self) -> String {
match self {
Self::Service { message }
| Self::TaskJoin { message }
| Self::Classified { message, .. } => message.clone(),
Self::ReadinessTimeout { seconds } => {
format!("service readiness timed out after {seconds}s")
}
}
}
pub fn info(&self) -> RuntimeErrorInfo {
RuntimeErrorInfo {
kind: self.kind(),
message: self.message(),
}
}
}
impl From<JoinError> for RuntimeError {
fn from(error: JoinError) -> Self {
Self::internal(format!("service task failed: {error}"))
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum ServiceState {
#[default]
Idle,
Starting,
Running,
Stopping,
Stopped,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceStatus {
pub name: String,
pub protocol: Option<Protocol>,
pub state: ServiceState,
pub ready: bool,
pub started_at: Option<DateTime<Utc>>,
pub last_error: Option<String>,
}
impl ServiceStatus {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
protocol: None,
state: ServiceState::Idle,
ready: false,
started_at: None,
last_error: None,
}
}
pub fn is_terminal(&self) -> bool {
matches!(self.state, ServiceState::Stopped | ServiceState::Error)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceSnapshot {
pub name: String,
pub protocol: Option<Protocol>,
pub status: ServiceStatus,
#[serde(default)]
pub metadata: BTreeMap<String, JsonValue>,
}
impl ServiceSnapshot {
pub fn new(name: impl Into<String>) -> Self {
let name = name.into();
Self {
status: ServiceStatus::new(name.clone()),
name,
protocol: None,
metadata: BTreeMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: JsonValue) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn with_runtime_metadata(mut self) -> Self {
self.ensure_runtime_metadata();
self
}
pub fn ensure_runtime_metadata(&mut self) {
let metadata = ServiceRuntimeMetadata::from_snapshot(self);
self.metadata
.insert(RUNTIME_METADATA_KEY.to_string(), json!(metadata));
}
pub fn runtime_metadata(&self) -> Option<ServiceRuntimeMetadata> {
self.metadata
.get(RUNTIME_METADATA_KEY)
.and_then(|value| serde_json::from_value(value.clone()).ok())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ServiceRuntimeMetadata {
pub contract_version: String,
pub snapshot_metadata_version: String,
pub captured_at: DateTime<Utc>,
pub service_name: String,
pub protocol: Option<String>,
pub state: ServiceState,
pub ready: bool,
pub started_at: Option<DateTime<Utc>>,
pub last_error: Option<String>,
}
impl ServiceRuntimeMetadata {
pub fn from_snapshot(snapshot: &ServiceSnapshot) -> Self {
let protocol = snapshot
.status
.protocol
.or(snapshot.protocol)
.map(|protocol| protocol.to_string());
Self {
contract_version: RUNTIME_CONTRACT_VERSION.to_string(),
snapshot_metadata_version: SNAPSHOT_METADATA_VERSION.to_string(),
captured_at: Utc::now(),
service_name: snapshot.status.name.clone(),
protocol,
state: snapshot.status.state,
ready: snapshot.status.ready,
started_at: snapshot.status.started_at,
last_error: snapshot.status.last_error.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ServiceReadinessReport {
pub contract_version: String,
pub checked_at: DateTime<Utc>,
pub service_name: String,
pub protocol: Option<String>,
pub state: ServiceState,
pub ready: bool,
pub timeout_ms: u64,
pub error: Option<RuntimeErrorInfo>,
}
impl ServiceReadinessReport {
pub fn from_status(
status: ServiceStatus,
timeout: Duration,
error: Option<RuntimeErrorInfo>,
) -> Self {
Self {
contract_version: RUNTIME_CONTRACT_VERSION.to_string(),
checked_at: Utc::now(),
service_name: status.name,
protocol: status.protocol.map(|protocol| protocol.to_string()),
state: status.state,
ready: status.ready,
timeout_ms: timeout.as_millis() as u64,
error,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServiceEvent {
StateChanged { state: ServiceState },
Cancelled,
Message { message: String },
}
#[derive(Debug, Clone)]
struct TrackedTask {
label: String,
abort: AbortHandle,
}
#[derive(Debug)]
struct ServiceContextInner {
name: String,
protocol: Option<Protocol>,
started_at: DateTime<Utc>,
cancellation: CancellationToken,
event_tx: tokio::sync::broadcast::Sender<ServiceEvent>,
tracked_tasks: Mutex<Vec<TrackedTask>>,
}
#[derive(Clone, Debug)]
pub struct ServiceContext {
inner: Arc<ServiceContextInner>,
}
impl ServiceContext {
pub fn new(name: impl Into<String>, protocol: Option<Protocol>) -> Self {
let (event_tx, _) = tokio::sync::broadcast::channel(64);
Self {
inner: Arc::new(ServiceContextInner {
name: name.into(),
protocol,
started_at: Utc::now(),
cancellation: CancellationToken::new(),
event_tx,
tracked_tasks: Mutex::new(Vec::new()),
}),
}
}
pub fn name(&self) -> &str {
&self.inner.name
}
pub fn protocol(&self) -> Option<Protocol> {
self.inner.protocol
}
pub fn started_at(&self) -> DateTime<Utc> {
self.inner.started_at
}
pub fn cancellation_token(&self) -> CancellationToken {
self.inner.cancellation.clone()
}
pub fn child_token(&self) -> CancellationToken {
self.inner.cancellation.child_token()
}
pub fn cancel(&self) {
self.inner.cancellation.cancel();
let _ = self.emit(ServiceEvent::Cancelled);
}
pub fn is_cancelled(&self) -> bool {
self.inner.cancellation.is_cancelled()
}
pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ServiceEvent> {
self.inner.event_tx.subscribe()
}
pub fn emit(
&self,
event: ServiceEvent,
) -> Result<usize, tokio::sync::broadcast::error::SendError<ServiceEvent>> {
self.inner.event_tx.send(event)
}
pub fn track_task(&self, label: impl Into<String>, handle: &JoinHandle<()>) {
self.inner.tracked_tasks.lock().push(TrackedTask {
label: label.into(),
abort: handle.abort_handle(),
});
}
pub fn spawn_task<F>(&self, label: impl Into<String>, future: F) -> JoinHandle<()>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let label = label.into();
let handle = tokio::spawn(future);
self.inner.tracked_tasks.lock().push(TrackedTask {
label,
abort: handle.abort_handle(),
});
handle
}
pub fn tracked_tasks(&self) -> Vec<String> {
self.inner
.tracked_tasks
.lock()
.iter()
.map(|task| task.label.clone())
.collect()
}
pub fn abort_tracked_tasks(&self) {
for task in self.inner.tracked_tasks.lock().iter() {
task.abort.abort();
}
}
}
#[async_trait]
pub trait ManagedService: Send + Sync {
async fn start(&self, context: &ServiceContext) -> RuntimeResult<()>;
async fn stop(&self, context: &ServiceContext) -> RuntimeResult<()>;
async fn serve(&self, context: ServiceContext) -> RuntimeResult<()>;
fn status(&self) -> ServiceStatus;
async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot>;
fn register_devices(&self, _registry: &DeviceRegistry) -> RuntimeResult<()> {
Ok(())
}
}
pub struct ServiceHandle {
service: Arc<dyn ManagedService>,
context: ServiceContext,
task: Arc<tokio::sync::Mutex<Option<JoinHandle<RuntimeResult<()>>>>>,
}
impl ServiceHandle {
pub fn new(service: Arc<dyn ManagedService>, context: ServiceContext) -> Self {
Self {
service,
context,
task: Arc::new(tokio::sync::Mutex::new(None)),
}
}
pub fn named(
name: impl Into<String>,
protocol: Option<Protocol>,
service: Arc<dyn ManagedService>,
) -> Self {
Self::new(service, ServiceContext::new(name, protocol))
}
pub fn context(&self) -> ServiceContext {
self.context.clone()
}
pub async fn spawn(&self) -> RuntimeResult<()> {
let mut guard = self.task.lock().await;
if guard.is_some() {
return Ok(());
}
self.service.start(&self.context).await?;
let service = self.service.clone();
let context = self.context.clone();
*guard = Some(tokio::spawn(async move { service.serve(context).await }));
Ok(())
}
pub async fn stop(&self) -> RuntimeResult<()> {
self.context.cancel();
self.service.stop(&self.context).await?;
self.context.abort_tracked_tasks();
if let Some(handle) = self.task.lock().await.take() {
handle.await??;
}
Ok(())
}
pub async fn wait(&self) -> RuntimeResult<()> {
if let Some(handle) = self.task.lock().await.take() {
handle.await??;
}
Ok(())
}
pub async fn readiness(&self, max_wait: Duration) -> RuntimeResult<ServiceStatus> {
let service = self.service.clone();
timeout(max_wait, async move {
loop {
let status = service.status();
if status.ready || status.is_terminal() {
return status;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
})
.await
.map_err(|_| {
RuntimeError::timeout(format!(
"service readiness timed out after {}ms",
max_wait.as_millis()
))
})
}
pub async fn readiness_report(&self, max_wait: Duration) -> ServiceReadinessReport {
match self.readiness(max_wait).await {
Ok(status) => ServiceReadinessReport::from_status(status, max_wait, None),
Err(error) => {
let status = self.status();
ServiceReadinessReport::from_status(status, max_wait, Some(error.info()))
}
}
}
pub fn status(&self) -> ServiceStatus {
self.service.status()
}
pub async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
Ok(self.service.snapshot().await?.with_runtime_metadata())
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use async_trait::async_trait;
use tokio::time::Duration;
use crate::service::{
ManagedService, RuntimeError, RuntimeErrorKind, RuntimeResult, ServiceContext,
ServiceHandle, ServiceSnapshot, ServiceState, ServiceStatus, RUNTIME_CONTRACT_VERSION,
RUNTIME_METADATA_KEY, SNAPSHOT_METADATA_VERSION,
};
struct TestService {
status: parking_lot::RwLock<ServiceStatus>,
}
impl TestService {
fn new() -> Self {
Self {
status: parking_lot::RwLock::new(ServiceStatus::new("test")),
}
}
}
#[async_trait]
impl ManagedService for TestService {
async fn start(&self, context: &ServiceContext) -> RuntimeResult<()> {
let mut status = self.status.write();
status.state = ServiceState::Starting;
status.started_at = Some(context.started_at());
Ok(())
}
async fn stop(&self, _context: &ServiceContext) -> RuntimeResult<()> {
let mut status = self.status.write();
status.state = ServiceState::Stopped;
status.ready = false;
Ok(())
}
async fn serve(&self, context: ServiceContext) -> RuntimeResult<()> {
{
let mut status = self.status.write();
status.state = ServiceState::Running;
status.ready = true;
}
context.cancellation_token().cancelled().await;
let mut status = self.status.write();
status.state = ServiceState::Stopped;
status.ready = false;
Ok(())
}
fn status(&self) -> ServiceStatus {
self.status.read().clone()
}
async fn snapshot(&self) -> RuntimeResult<ServiceSnapshot> {
let mut snapshot = ServiceSnapshot::new("test");
snapshot.status = self.status();
Ok(snapshot)
}
}
#[tokio::test]
async fn handle_spawns_and_stops_service() {
let service = Arc::new(TestService::new());
let handle = ServiceHandle::named("test", None, service);
handle.spawn().await.unwrap();
let status = handle.readiness(Duration::from_secs(1)).await.unwrap();
assert!(status.ready);
let report = handle.readiness_report(Duration::from_secs(1)).await;
assert!(report.ready);
assert_eq!(report.contract_version, RUNTIME_CONTRACT_VERSION);
assert!(serde_json::to_value(&report).unwrap()["checked_at"].is_string());
let snapshot = handle.snapshot().await.unwrap();
assert!(snapshot.metadata.contains_key(RUNTIME_METADATA_KEY));
let runtime = snapshot.runtime_metadata().expect("runtime metadata");
assert_eq!(runtime.contract_version, RUNTIME_CONTRACT_VERSION);
assert_eq!(runtime.snapshot_metadata_version, SNAPSHOT_METADATA_VERSION);
assert_eq!(runtime.service_name, "test");
assert!(runtime.ready);
handle.stop().await.unwrap();
assert_eq!(handle.status().state, ServiceState::Stopped);
}
#[test]
fn runtime_error_info_uses_stable_kinds() {
let error = RuntimeError::config("invalid launch config");
assert_eq!(error.kind(), RuntimeErrorKind::ConfigError);
assert_eq!(error.info().message, "invalid launch config");
let value = serde_json::to_value(error.info()).unwrap();
assert_eq!(value["kind"], "config_error");
assert_eq!(value["message"], "invalid launch config");
}
}