matrixcode_core/matrixrpc/registry/
service.rs1use std::collections::HashMap;
6use std::sync::Arc;
7
8use tokio::sync::RwLock;
9
10use crate::matrixrpc::service::{ExtensionService, RegistrationInfo, ServiceId, ServiceStatus};
11
12#[derive(Debug, thiserror::Error)]
14pub enum RegistryError {
15 #[error("Service '{0}' already exists in registry")]
17 AlreadyExists(String),
18
19 #[error("Service '{0}' not found in registry")]
21 NotFound(String),
22
23 #[error("Service '{0}' is not running (status: {1:?})")]
25 NotRunning(String, ServiceStatus),
26
27 #[error("Invalid service state: {0}")]
29 InvalidState(String),
30
31 #[error("Internal error: {0}")]
33 Internal(String),
34}
35
36#[derive(Debug, Clone, Default)]
38pub struct ServiceFilter {
39 pub status: Option<ServiceStatus>,
41
42 pub capability: Option<String>,
44
45 pub transport_type: Option<String>,
47}
48
49impl ServiceFilter {
50 pub fn new() -> Self {
52 Self::default()
53 }
54
55 pub fn status(mut self, status: ServiceStatus) -> Self {
57 self.status = Some(status);
58 self
59 }
60
61 pub fn capability(mut self, cap: impl Into<String>) -> Self {
63 self.capability = Some(cap.into());
64 self
65 }
66
67 pub fn transport_type(mut self, transport: impl Into<String>) -> Self {
69 self.transport_type = Some(transport.into());
70 self
71 }
72
73 pub fn matches(&self, service: &ExtensionService) -> bool {
75 if let Some(status) = &self.status {
76 if service.status != *status {
77 return false;
78 }
79 }
80
81 if let Some(cap) = &self.capability {
82 if !service.has_capability(cap) {
83 return false;
84 }
85 }
86
87 true
88 }
89}
90
91#[derive(Debug, Clone, Default)]
93pub struct RegistryStats {
94 pub total: usize,
96
97 pub running: usize,
99
100 pub stopped: usize,
102
103 pub errors: usize,
105
106 pub reconnecting: usize,
108}
109
110#[derive(Debug)]
114pub struct RegistryService {
115 services: Arc<RwLock<HashMap<ServiceId, RegistrationInfo>>>,
117
118 name_index: Arc<RwLock<HashMap<String, ServiceId>>>,
120
121 capability_index: Arc<RwLock<HashMap<String, Vec<ServiceId>>>>,
123
124 heartbeat_timeout_secs: u64,
126}
127
128impl Default for RegistryService {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl RegistryService {
135 pub fn new() -> Self {
137 Self {
138 services: Arc::new(RwLock::new(HashMap::new())),
139 name_index: Arc::new(RwLock::new(HashMap::new())),
140 capability_index: Arc::new(RwLock::new(HashMap::new())),
141 heartbeat_timeout_secs: 60,
142 }
143 }
144
145 pub fn with_heartbeat_timeout(mut self, secs: u64) -> Self {
147 self.heartbeat_timeout_secs = secs;
148 self
149 }
150
151 pub async fn register(&self, service: ExtensionService) -> Result<ServiceId, RegistryError> {
153 let name = service.name.clone();
154 let id = service.id.clone();
155 let capabilities: Vec<_> = service.capabilities.iter().map(|c| c.name.clone()).collect();
156
157 {
159 let name_index = self.name_index.read().await;
160 if name_index.contains_key(&name) {
161 return Err(RegistryError::AlreadyExists(name));
162 }
163 }
164
165 let registration = RegistrationInfo::new(service);
167
168 {
170 let mut services = self.services.write().await;
171 services.insert(id.clone(), registration);
172 }
173
174 {
176 let mut name_index = self.name_index.write().await;
177 name_index.insert(name, id.clone());
178 }
179
180 {
182 let mut cap_index = self.capability_index.write().await;
183 for cap in capabilities {
184 cap_index
185 .entry(cap)
186 .or_insert_with(Vec::new)
187 .push(id.clone());
188 }
189 }
190
191 Ok(id)
192 }
193
194 pub async fn unregister(&self, id: &ServiceId) -> Result<RegistrationInfo, RegistryError> {
196 let registration = {
198 let mut services = self.services.write().await;
199 services
200 .remove(id)
201 .ok_or_else(|| RegistryError::NotFound(id.to_string()))?
202 };
203
204 {
206 let mut name_index = self.name_index.write().await;
207 name_index.remove(®istration.service.name);
208 }
209
210 {
212 let mut cap_index = self.capability_index.write().await;
213 for cap in ®istration.service.capabilities {
214 if let Some(ids) = cap_index.get_mut(&cap.name) {
215 ids.retain(|sid| sid != id);
216 }
217 }
218 }
219
220 Ok(registration)
221 }
222
223 pub async fn get(&self, id: &ServiceId) -> Option<ExtensionService> {
225 let services = self.services.read().await;
226 services.get(id).map(|r| r.service.clone())
227 }
228
229 pub async fn get_by_name(&self, name: &str) -> Option<ExtensionService> {
231 let name_index = self.name_index.read().await;
232 let id = name_index.get(name)?;
233
234 let services = self.services.read().await;
235 services.get(id).map(|r| r.service.clone())
236 }
237
238 pub async fn get_registration(&self, id: &ServiceId) -> Option<RegistrationInfo> {
240 let services = self.services.read().await;
241 services.get(id).cloned()
242 }
243
244 pub async fn update_status(
246 &self,
247 id: &ServiceId,
248 status: ServiceStatus,
249 ) -> Result<(), RegistryError> {
250 let mut services = self.services.write().await;
251 let registration = services
252 .get_mut(id)
253 .ok_or_else(|| RegistryError::NotFound(id.to_string()))?;
254
255 registration.service.set_status(status);
256 registration.touch();
257 Ok(())
258 }
259
260 pub async fn heartbeat(&self, id: &ServiceId) -> Result<(), RegistryError> {
262 let mut services = self.services.write().await;
263 let registration = services
264 .get_mut(id)
265 .ok_or_else(|| RegistryError::NotFound(id.to_string()))?;
266
267 registration.service.heartbeat();
268 registration.touch();
269 Ok(())
270 }
271
272 pub async fn list_all(&self) -> Vec<ExtensionService> {
274 let services = self.services.read().await;
275 services.values().map(|r| r.service.clone()).collect()
276 }
277
278 pub async fn list(&self, filter: &ServiceFilter) -> Vec<ExtensionService> {
280 let services = self.services.read().await;
281 services
282 .values()
283 .filter(|r| filter.matches(&r.service))
284 .map(|r| r.service.clone())
285 .collect()
286 }
287
288 pub async fn find_by_capability(&self, capability: &str) -> Vec<ExtensionService> {
290 let cap_index = self.capability_index.read().await;
291 let ids = cap_index.get(capability).cloned().unwrap_or_default();
292 drop(cap_index);
293
294 let services = self.services.read().await;
295 ids.iter()
296 .filter_map(|id| services.get(id).map(|r| r.service.clone()))
297 .collect()
298 }
299
300 pub async fn stats(&self) -> RegistryStats {
302 let services = self.services.read().await;
303
304 let mut stats = RegistryStats {
305 total: services.len(),
306 ..Default::default()
307 };
308
309 for registration in services.values() {
310 match registration.service.status {
311 ServiceStatus::Running => stats.running += 1,
312 ServiceStatus::Stopped => stats.stopped += 1,
313 ServiceStatus::Error => stats.errors += 1,
314 ServiceStatus::Reconnecting => stats.reconnecting += 1,
315 _ => {}
316 }
317 }
318
319 stats
320 }
321
322 pub async fn health_check(&self) -> Vec<ServiceId> {
324 let timeout = self.heartbeat_timeout_secs;
325 let mut unhealthy = Vec::new();
326
327 let mut services = self.services.write().await;
328 for (id, registration) in services.iter_mut() {
329 if !registration.service.is_healthy(timeout) {
330 if registration.service.status == ServiceStatus::Running {
331 registration.service.set_status(ServiceStatus::Reconnecting);
332 registration.touch();
333 unhealthy.push(id.clone());
334 }
335 }
336 }
337
338 unhealthy
339 }
340
341 pub async fn clear(&self) {
343 let mut services = self.services.write().await;
344 let mut name_index = self.name_index.write().await;
345 let mut cap_index = self.capability_index.write().await;
346
347 services.clear();
348 name_index.clear();
349 cap_index.clear();
350 }
351
352 pub async fn count(&self) -> usize {
354 let services = self.services.read().await;
355 services.len()
356 }
357}
358
359#[derive(Debug)]
361pub struct RegistryBuilder {
362 heartbeat_timeout_secs: u64,
363}
364
365impl Default for RegistryBuilder {
366 fn default() -> Self {
367 Self {
368 heartbeat_timeout_secs: 60,
369 }
370 }
371}
372
373impl RegistryBuilder {
374 pub fn new() -> Self {
376 Self::default()
377 }
378
379 pub fn heartbeat_timeout(mut self, secs: u64) -> Self {
381 self.heartbeat_timeout_secs = secs;
382 self
383 }
384
385 pub fn build(self) -> RegistryService {
387 RegistryService::new().with_heartbeat_timeout(self.heartbeat_timeout_secs)
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::matrixrpc::Capability;
395
396 #[tokio::test]
397 async fn test_register_service() {
398 let registry = RegistryService::new();
399 let service = ExtensionService::new("test-service", "1.0.0");
400
401 let id = registry.register(service).await.unwrap();
402 assert!(registry.get(&id).await.is_some());
403 assert_eq!(registry.count().await, 1);
404 }
405
406 #[tokio::test]
407 async fn test_register_duplicate_name() {
408 let registry = RegistryService::new();
409 let service1 = ExtensionService::new("test", "1.0.0");
410 let service2 = ExtensionService::new("test", "2.0.0");
411
412 registry.register(service1).await.unwrap();
413 let result = registry.register(service2).await;
414 assert!(matches!(result, Err(RegistryError::AlreadyExists(_))));
415 }
416
417 #[tokio::test]
418 async fn test_unregister_service() {
419 let registry = RegistryService::new();
420 let service = ExtensionService::new("test", "1.0.0");
421
422 let id = registry.register(service).await.unwrap();
423 registry.unregister(&id).await.unwrap();
424 assert!(registry.get(&id).await.is_none());
425 }
426
427 #[tokio::test]
428 async fn test_get_by_name() {
429 let registry = RegistryService::new();
430 let service = ExtensionService::new("test", "1.0.0");
431
432 registry.register(service).await.unwrap();
433
434 let found = registry.get_by_name("test").await;
435 assert!(found.is_some());
436 assert_eq!(found.unwrap().version, "1.0.0");
437 }
438
439 #[tokio::test]
440 async fn test_update_status() {
441 let registry = RegistryService::new();
442 let service = ExtensionService::new("test", "1.0.0");
443
444 let id = registry.register(service).await.unwrap();
445 registry
446 .update_status(&id, ServiceStatus::Running)
447 .await
448 .unwrap();
449
450 let service = registry.get(&id).await.unwrap();
451 assert_eq!(service.status, ServiceStatus::Running);
452 }
453
454 #[tokio::test]
455 async fn test_find_by_capability() {
456 let registry = RegistryService::new();
457
458 let service1 = ExtensionService::new("service1", "1.0.0")
459 .capability(Capability::new("tools"));
460
461 let service2 = ExtensionService::new("service2", "1.0.0")
462 .capability(Capability::new("resources"));
463
464 let service3 = ExtensionService::new("service3", "1.0.0")
465 .capability(Capability::new("tools"));
466
467 registry.register(service1).await.unwrap();
468 registry.register(service2).await.unwrap();
469 registry.register(service3).await.unwrap();
470
471 let tools_services = registry.find_by_capability("tools").await;
472 assert_eq!(tools_services.len(), 2);
473
474 let resources_services = registry.find_by_capability("resources").await;
475 assert_eq!(resources_services.len(), 1);
476
477 let prompts_services = registry.find_by_capability("prompts").await;
478 assert!(prompts_services.is_empty());
479 }
480
481 #[tokio::test]
482 async fn test_registry_stats() {
483 let registry = RegistryService::new();
484
485 let mut service1 = ExtensionService::new("s1", "1.0.0");
486 service1.set_status(ServiceStatus::Running);
487
488 let mut service2 = ExtensionService::new("s2", "1.0.0");
489 service2.set_status(ServiceStatus::Stopped);
490
491 let mut service3 = ExtensionService::new("s3", "1.0.0");
492 service3.set_status(ServiceStatus::Error);
493
494 registry.register(service1).await.unwrap();
495 registry.register(service2).await.unwrap();
496 registry.register(service3).await.unwrap();
497
498 let stats = registry.stats().await;
499 assert_eq!(stats.total, 3);
500 assert_eq!(stats.running, 1);
501 assert_eq!(stats.stopped, 1);
502 assert_eq!(stats.errors, 1);
503 }
504
505 #[tokio::test]
506 async fn test_service_filter() {
507 let registry = RegistryService::new();
508
509 let mut service1 = ExtensionService::new("s1", "1.0.0")
510 .capability(Capability::new("tools"));
511 service1.set_status(ServiceStatus::Running);
512
513 let mut service2 = ExtensionService::new("s2", "1.0.0")
514 .capability(Capability::new("tools"));
515 service2.set_status(ServiceStatus::Stopped);
516
517 registry.register(service1).await.unwrap();
518 registry.register(service2).await.unwrap();
519
520 let filter = ServiceFilter::new().status(ServiceStatus::Running);
521 let services = registry.list(&filter).await;
522 assert_eq!(services.len(), 1);
523 assert_eq!(services[0].name, "s1");
524 }
525}