use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ServiceId(pub String);
impl ServiceId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn generate() -> Self {
Self(uuid::Uuid::new_v4().to_string())
}
}
impl std::fmt::Display for ServiceId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for ServiceId {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
impl From<String> for ServiceId {
fn from(s: String) -> Self {
Self(s)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ServiceStatus {
Starting,
Running,
Stopping,
Stopped,
Error,
Reconnecting,
}
impl Default for ServiceStatus {
fn default() -> Self {
Self::Starting
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionService {
pub id: ServiceId,
pub name: String,
pub version: String,
#[serde(default)]
pub description: String,
pub capabilities: Vec<Capability>,
pub transport: TransportConfig,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
#[serde(default)]
pub status: ServiceStatus,
#[serde(skip)]
pub last_heartbeat: Option<Instant>,
#[serde(default)]
pub retry_count: u32,
}
impl ExtensionService {
pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
Self {
id: ServiceId::generate(),
name: name.into(),
version: version.into(),
description: String::new(),
capabilities: Vec::new(),
transport: TransportConfig::default(),
metadata: HashMap::new(),
status: ServiceStatus::Starting,
last_heartbeat: None,
retry_count: 0,
}
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = desc.into();
self
}
pub fn capability(mut self, cap: Capability) -> Self {
self.capabilities.push(cap);
self
}
pub fn transport(mut self, transport: TransportConfig) -> Self {
self.transport = transport;
self
}
pub fn metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn has_capability(&self, name: &str) -> bool {
self.capabilities.iter().any(|c| c.name == name)
}
pub fn get_capability(&self, name: &str) -> Option<&Capability> {
self.capabilities.iter().find(|c| c.name == name)
}
pub fn set_status(&mut self, status: ServiceStatus) {
self.status = status;
}
pub fn heartbeat(&mut self) {
self.last_heartbeat = Some(Instant::now());
self.retry_count = 0;
}
pub fn is_healthy(&self, timeout_secs: u64) -> bool {
match self.last_heartbeat {
Some(last) => {
last.elapsed().as_secs() < timeout_secs
&& self.status == ServiceStatus::Running
}
None => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Capability {
pub name: String,
#[serde(default)]
pub version: String,
#[serde(default)]
pub config: HashMap<String, serde_json::Value>,
}
impl Capability {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
version: String::new(),
config: HashMap::new(),
}
}
pub fn version(mut self, version: impl Into<String>) -> Self {
self.version = version.into();
self
}
pub fn config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.config.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransportConfig {
#[serde(rename = "type")]
pub transport_type: TransportType,
#[serde(default)]
pub address: Option<String>,
#[serde(default)]
pub port: Option<u16>,
#[serde(default)]
pub command: Option<String>,
#[serde(default)]
pub args: Vec<String>,
#[serde(default)]
pub env: HashMap<String, String>,
#[serde(default)]
pub cwd: Option<String>,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
#[serde(default = "default_true")]
pub auto_reconnect: bool,
#[serde(default = "default_max_retries")]
pub max_retries: u32,
#[serde(default = "default_heartbeat_interval")]
pub heartbeat_interval_secs: u64,
}
fn default_timeout() -> u64 {
30
}
fn default_true() -> bool {
true
}
fn default_max_retries() -> u32 {
3
}
fn default_heartbeat_interval() -> u64 {
30
}
impl Default for TransportConfig {
fn default() -> Self {
Self {
transport_type: TransportType::Stdio,
address: None,
port: None,
command: None,
args: Vec::new(),
env: HashMap::new(),
cwd: None,
timeout_secs: default_timeout(),
auto_reconnect: true,
max_retries: default_max_retries(),
heartbeat_interval_secs: default_heartbeat_interval(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TransportType {
Stdio,
Tcp,
#[cfg(unix)]
Unix,
WebSocket,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegistrationInfo {
pub service: ExtensionService,
pub registered_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
impl RegistrationInfo {
pub fn new(service: ExtensionService) -> Self {
let now = chrono::Utc::now();
Self {
service,
registered_at: now,
updated_at: now,
}
}
pub fn touch(&mut self) {
self.updated_at = chrono::Utc::now();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_service_id() {
let id = ServiceId::new("test-service");
assert_eq!(id.to_string(), "test-service");
let generated = ServiceId::generate();
assert!(!generated.0.is_empty());
}
#[test]
fn test_extension_service_creation() {
let service = ExtensionService::new("test-service", "1.0.0")
.description("A test service")
.capability(Capability::new("tools"));
assert_eq!(service.name, "test-service");
assert_eq!(service.version, "1.0.0");
assert_eq!(service.description, "A test service");
assert!(service.has_capability("tools"));
assert!(!service.has_capability("resources"));
}
#[test]
fn test_service_status() {
let mut service = ExtensionService::new("test", "1.0.0");
assert_eq!(service.status, ServiceStatus::Starting);
service.set_status(ServiceStatus::Running);
assert_eq!(service.status, ServiceStatus::Running);
}
#[test]
fn test_service_heartbeat() {
let mut service = ExtensionService::new("test", "1.0.0");
assert!(!service.is_healthy(30));
service.set_status(ServiceStatus::Running);
service.heartbeat();
assert!(service.is_healthy(30));
}
#[test]
fn test_capability() {
let cap = Capability::new("tools")
.version("1.0")
.config("max_items".to_string(), serde_json::json!(100));
assert_eq!(cap.name, "tools");
assert_eq!(cap.version, "1.0");
assert_eq!(cap.config.get("max_items"), Some(&serde_json::json!(100)));
}
#[test]
fn test_transport_config_defaults() {
let config = TransportConfig::default();
assert_eq!(config.transport_type, TransportType::Stdio);
assert!(config.auto_reconnect);
assert_eq!(config.max_retries, 3);
assert_eq!(config.heartbeat_interval_secs, 30);
}
#[test]
fn test_registration_info() {
let service = ExtensionService::new("test", "1.0.0");
let reg = RegistrationInfo::new(service);
assert!(reg.registered_at <= chrono::Utc::now());
assert!(reg.updated_at <= chrono::Utc::now());
}
}