Skip to main content

shaperail_runtime/registry/
mod.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use deadpool_redis::Pool;
5use redis::AsyncCommands;
6use shaperail_core::{ServiceRegistryEntry, ServiceStatus, ShaperailError};
7
8/// Redis key prefix for service registry entries.
9const REGISTRY_PREFIX: &str = "shaperail:services:";
10
11/// Default heartbeat interval in seconds.
12const HEARTBEAT_INTERVAL_SECS: u64 = 10;
13
14/// Default TTL for registry entries in seconds.
15/// If a service misses 3 heartbeats, it's considered unhealthy.
16const REGISTRY_TTL_SECS: u64 = 35;
17
18/// Redis-backed service registry for multi-service workspace discovery.
19///
20/// Services register on startup and send periodic heartbeats.
21/// Other services discover peers by querying the registry.
22#[derive(Clone)]
23pub struct ServiceRegistry {
24    pool: Arc<Pool>,
25}
26
27impl ServiceRegistry {
28    /// Create a new service registry backed by the given Redis pool.
29    pub fn new(pool: Arc<Pool>) -> Self {
30        Self { pool }
31    }
32
33    /// Register a service in the registry. Sets a TTL so stale entries expire.
34    pub async fn register(&self, entry: &ServiceRegistryEntry) -> Result<(), ShaperailError> {
35        let key = format!("{REGISTRY_PREFIX}{}", entry.name);
36        let value = serde_json::to_string(entry).map_err(|e| {
37            ShaperailError::Internal(format!("Failed to serialize registry entry: {e}"))
38        })?;
39
40        let mut conn = self
41            .pool
42            .get()
43            .await
44            .map_err(|e| ShaperailError::Internal(format!("Redis connection error: {e}")))?;
45
46        redis::cmd("SET")
47            .arg(&key)
48            .arg(&value)
49            .arg("EX")
50            .arg(REGISTRY_TTL_SECS)
51            .query_async::<()>(&mut *conn)
52            .await
53            .map_err(|e| ShaperailError::Internal(format!("Failed to register service: {e}")))?;
54
55        Ok(())
56    }
57
58    /// Update heartbeat for a service (refreshes TTL and updates timestamp + status).
59    pub async fn heartbeat(&self, name: &str) -> Result<(), ShaperailError> {
60        let key = format!("{REGISTRY_PREFIX}{name}");
61
62        let mut conn = self
63            .pool
64            .get()
65            .await
66            .map_err(|e| ShaperailError::Internal(format!("Redis connection error: {e}")))?;
67
68        let value: Option<String> = conn
69            .get(&key)
70            .await
71            .map_err(|e| ShaperailError::Internal(format!("Failed to read registry: {e}")))?;
72
73        let Some(value) = value else {
74            return Err(ShaperailError::NotFound);
75        };
76
77        let mut entry: ServiceRegistryEntry = serde_json::from_str(&value).map_err(|e| {
78            ShaperailError::Internal(format!("Failed to parse registry entry: {e}"))
79        })?;
80
81        entry.status = ServiceStatus::Healthy;
82        entry.last_heartbeat = chrono::Utc::now().to_rfc3339();
83
84        let updated = serde_json::to_string(&entry).map_err(|e| {
85            ShaperailError::Internal(format!("Failed to serialize registry entry: {e}"))
86        })?;
87
88        redis::cmd("SET")
89            .arg(&key)
90            .arg(&updated)
91            .arg("EX")
92            .arg(REGISTRY_TTL_SECS)
93            .query_async::<()>(&mut *conn)
94            .await
95            .map_err(|e| ShaperailError::Internal(format!("Failed to update heartbeat: {e}")))?;
96
97        Ok(())
98    }
99
100    /// Deregister a service (mark as stopped and remove).
101    pub async fn deregister(&self, name: &str) -> Result<(), ShaperailError> {
102        let key = format!("{REGISTRY_PREFIX}{name}");
103
104        let mut conn = self
105            .pool
106            .get()
107            .await
108            .map_err(|e| ShaperailError::Internal(format!("Redis connection error: {e}")))?;
109
110        conn.del::<_, ()>(&key)
111            .await
112            .map_err(|e| ShaperailError::Internal(format!("Failed to deregister service: {e}")))?;
113
114        Ok(())
115    }
116
117    /// Look up a service by name.
118    pub async fn lookup(&self, name: &str) -> Result<Option<ServiceRegistryEntry>, ShaperailError> {
119        let key = format!("{REGISTRY_PREFIX}{name}");
120
121        let mut conn = self
122            .pool
123            .get()
124            .await
125            .map_err(|e| ShaperailError::Internal(format!("Redis connection error: {e}")))?;
126
127        let value: Option<String> = conn
128            .get(&key)
129            .await
130            .map_err(|e| ShaperailError::Internal(format!("Failed to read registry: {e}")))?;
131
132        match value {
133            Some(v) => {
134                let entry: ServiceRegistryEntry = serde_json::from_str(&v).map_err(|e| {
135                    ShaperailError::Internal(format!("Failed to parse registry entry: {e}"))
136                })?;
137                Ok(Some(entry))
138            }
139            None => Ok(None),
140        }
141    }
142
143    /// List all registered services.
144    pub async fn list_services(&self) -> Result<Vec<ServiceRegistryEntry>, ShaperailError> {
145        let pattern = format!("{REGISTRY_PREFIX}*");
146
147        let mut conn = self
148            .pool
149            .get()
150            .await
151            .map_err(|e| ShaperailError::Internal(format!("Redis connection error: {e}")))?;
152
153        let keys: Vec<String> = redis::cmd("KEYS")
154            .arg(&pattern)
155            .query_async(&mut *conn)
156            .await
157            .map_err(|e| ShaperailError::Internal(format!("Failed to list services: {e}")))?;
158
159        let mut services = Vec::new();
160        for key in &keys {
161            let value: Option<String> = conn
162                .get(key)
163                .await
164                .map_err(|e| ShaperailError::Internal(format!("Failed to read registry: {e}")))?;
165            if let Some(v) = value {
166                if let Ok(entry) = serde_json::from_str::<ServiceRegistryEntry>(&v) {
167                    services.push(entry);
168                }
169            }
170        }
171
172        Ok(services)
173    }
174
175    /// Discover a service that exposes a specific resource.
176    pub async fn discover_resource(
177        &self,
178        resource_name: &str,
179    ) -> Result<Option<ServiceRegistryEntry>, ShaperailError> {
180        let services = self.list_services().await?;
181        Ok(services.into_iter().find(|s| {
182            s.status == ServiceStatus::Healthy && s.resources.iter().any(|r| r == resource_name)
183        }))
184    }
185
186    /// Start a background heartbeat task for the given service name.
187    /// Returns a `tokio::task::JoinHandle` that can be aborted on shutdown.
188    pub fn start_heartbeat(&self, name: String) -> tokio::task::JoinHandle<()> {
189        let registry = self.clone();
190        tokio::spawn(async move {
191            let mut interval = tokio::time::interval(Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
192            loop {
193                interval.tick().await;
194                if let Err(e) = registry.heartbeat(&name).await {
195                    tracing::warn!("Service registry heartbeat failed for '{name}': {e}");
196                }
197            }
198        })
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn registry_key_format() {
208        let key = format!("{REGISTRY_PREFIX}users-api");
209        assert_eq!(key, "shaperail:services:users-api");
210    }
211
212    #[test]
213    fn registry_constants() {
214        assert_eq!(HEARTBEAT_INTERVAL_SECS, 10);
215        assert_eq!(REGISTRY_TTL_SECS, 35);
216    }
217}