1use std::pin::Pin;
5use std::sync::Arc;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use futures::{Stream, StreamExt};
10use tokio_util::sync::CancellationToken;
11
12use super::{
13 Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
14 DiscoverySpec, DiscoveryStream, EndpointInstanceId, EventChannelInstanceId,
15 ModelCardInstanceId,
16};
17use crate::storage::kv;
18
19const INSTANCES_BUCKET: &str = "v1/instances";
20const MODELS_BUCKET: &str = "v1/mdc";
21const EVENT_CHANNELS_BUCKET: &str = "v1/event_channels";
22
23pub struct KVStoreDiscovery {
25 store: Arc<kv::Manager>,
26 cancel_token: CancellationToken,
27}
28
29impl KVStoreDiscovery {
30 pub fn new(store: kv::Manager, cancel_token: CancellationToken) -> Self {
31 Self {
32 store: Arc::new(store),
33 cancel_token,
34 }
35 }
36
37 fn endpoint_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
39 format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
40 }
41
42 fn model_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
44 format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
45 }
46
47 fn event_channel_key(
49 namespace: &str,
50 component: &str,
51 topic: &str,
52 instance_id: u64,
53 ) -> String {
54 format!("{}/{}/{}/{:x}", namespace, component, topic, instance_id)
55 }
56
57 fn query_prefix(query: &DiscoveryQuery) -> String {
59 match query {
60 DiscoveryQuery::AllEndpoints => INSTANCES_BUCKET.to_string(),
61 DiscoveryQuery::NamespacedEndpoints { namespace } => {
62 format!("{}/{}", INSTANCES_BUCKET, namespace)
63 }
64 DiscoveryQuery::ComponentEndpoints {
65 namespace,
66 component,
67 } => {
68 format!("{}/{}/{}", INSTANCES_BUCKET, namespace, component)
69 }
70 DiscoveryQuery::Endpoint {
71 namespace,
72 component,
73 endpoint,
74 } => {
75 format!(
76 "{}/{}/{}/{}",
77 INSTANCES_BUCKET, namespace, component, endpoint
78 )
79 }
80 DiscoveryQuery::AllModels => MODELS_BUCKET.to_string(),
81 DiscoveryQuery::NamespacedModels { namespace } => {
82 format!("{}/{}", MODELS_BUCKET, namespace)
83 }
84 DiscoveryQuery::ComponentModels {
85 namespace,
86 component,
87 } => {
88 format!("{}/{}/{}", MODELS_BUCKET, namespace, component)
89 }
90 DiscoveryQuery::EndpointModels {
91 namespace,
92 component,
93 endpoint,
94 } => {
95 format!("{}/{}/{}/{}", MODELS_BUCKET, namespace, component, endpoint)
96 }
97 DiscoveryQuery::EventChannels(query) => {
98 let mut path = EVENT_CHANNELS_BUCKET.to_string();
99 if let Some(ns) = &query.namespace {
100 path.push('/');
101 path.push_str(ns);
102 if let Some(comp) = &query.component {
103 path.push('/');
104 path.push_str(comp);
105 if let Some(topic) = &query.topic {
106 path.push('/');
107 path.push_str(topic);
108 }
109 }
110 }
111 path
112 }
113 }
114 }
115
116 fn strip_bucket_prefix<'a>(key: &'a str, bucket_name: &str) -> &'a str {
120 if let Some(stripped) = key.strip_prefix(bucket_name) {
122 stripped.strip_prefix('/').unwrap_or(stripped)
124 } else {
125 key
127 }
128 }
129
130 fn matches_prefix(key_str: &str, prefix: &str, bucket_name: &str) -> bool {
133 let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
135 let relative_prefix = Self::strip_bucket_prefix(prefix, bucket_name);
136
137 if relative_prefix.is_empty() {
139 return true;
140 }
141
142 relative_key.starts_with(relative_prefix)
144 }
145
146 fn parse_instance(value: &[u8]) -> Result<DiscoveryInstance> {
148 let instance: DiscoveryInstance = serde_json::from_slice(value)?;
149 Ok(instance)
150 }
151}
152
153#[async_trait]
154impl Discovery for KVStoreDiscovery {
155 fn instance_id(&self) -> u64 {
156 self.store.connection_id()
157 }
158
159 async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
160 let instance_id = self.instance_id();
161 let instance = spec.with_instance_id(instance_id);
162
163 let (bucket_name, key_path) = match &instance {
164 DiscoveryInstance::Endpoint(inst) => {
165 let key = Self::endpoint_key(
166 &inst.namespace,
167 &inst.component,
168 &inst.endpoint,
169 inst.instance_id,
170 );
171 tracing::debug!(
172 "KVStoreDiscovery::register: Registering endpoint instance_id={}, namespace={}, component={}, endpoint={}, key={}",
173 inst.instance_id,
174 inst.namespace,
175 inst.component,
176 inst.endpoint,
177 key
178 );
179 (INSTANCES_BUCKET, key)
180 }
181 DiscoveryInstance::Model {
182 namespace,
183 component,
184 endpoint,
185 instance_id,
186 model_suffix,
187 ..
188 } => {
189 let mut key = Self::model_key(namespace, component, endpoint, *instance_id);
190
191 if let Some(suffix) = model_suffix
194 && !suffix.is_empty()
195 {
196 key = format!("{}/{}", key, suffix);
197 tracing::debug!(
198 "KVStoreDiscovery::register: Registering LoRA model with suffix={}, instance_id={}, namespace={}, component={}, endpoint={}, key={}",
199 suffix,
200 instance_id,
201 namespace,
202 component,
203 endpoint,
204 key
205 );
206 }
207
208 if model_suffix.as_ref().is_none_or(|s| s.is_empty()) {
210 tracing::debug!(
211 "KVStoreDiscovery::register: Registering base model instance_id={}, namespace={}, component={}, endpoint={}, key={}",
212 instance_id,
213 namespace,
214 component,
215 endpoint,
216 key
217 );
218 }
219 (MODELS_BUCKET, key)
220 }
221 DiscoveryInstance::EventChannel {
222 namespace,
223 component,
224 topic,
225 instance_id,
226 ..
227 } => {
228 let key = Self::event_channel_key(namespace, component, topic, *instance_id);
229 tracing::info!(
231 "KVStoreDiscovery::register: EventChannel bucket={}, key={}",
232 EVENT_CHANNELS_BUCKET,
233 key
234 );
235 tracing::debug!(
236 "KVStoreDiscovery::register: Registering event channel instance_id={}, namespace={}, component={}, topic={}, key={}",
237 instance_id,
238 namespace,
239 component,
240 topic,
241 key
242 );
243 (EVENT_CHANNELS_BUCKET, key)
244 }
245 };
246
247 let instance_json = serde_json::to_vec(&instance)?;
249 tracing::debug!(
250 "KVStoreDiscovery::register: Serialized instance to {} bytes for key={}",
251 instance_json.len(),
252 key_path
253 );
254
255 tracing::debug!(
257 "KVStoreDiscovery::register: Getting/creating bucket={} for key={}",
258 bucket_name,
259 key_path
260 );
261 let bucket = self.store.get_or_create_bucket(bucket_name, None).await?;
262 let key = kv::Key::new(key_path.clone());
263
264 tracing::debug!(
265 "KVStoreDiscovery::register: Inserting into bucket={}, key={}",
266 bucket_name,
267 key_path
268 );
269 let outcome = bucket.insert(&key, instance_json.into(), 0).await?;
271 tracing::debug!(
272 "KVStoreDiscovery::register: Successfully registered instance_id={}, key={}, outcome={:?}",
273 instance_id,
274 key_path,
275 outcome
276 );
277
278 Ok(instance)
279 }
280
281 async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
282 let (bucket_name, key_path) = match &instance {
283 DiscoveryInstance::Endpoint(inst) => {
284 let key = Self::endpoint_key(
285 &inst.namespace,
286 &inst.component,
287 &inst.endpoint,
288 inst.instance_id,
289 );
290 tracing::debug!(
291 "Unregistering endpoint instance_id={}, namespace={}, component={}, endpoint={}, key={}",
292 inst.instance_id,
293 inst.namespace,
294 inst.component,
295 inst.endpoint,
296 key
297 );
298 (INSTANCES_BUCKET, key)
299 }
300 DiscoveryInstance::Model {
301 namespace,
302 component,
303 endpoint,
304 instance_id,
305 model_suffix,
306 ..
307 } => {
308 let mut key = Self::model_key(namespace, component, endpoint, *instance_id);
309
310 if let Some(suffix) = model_suffix
312 && !suffix.is_empty()
313 {
314 key = format!("{}/{}", key, suffix);
315 tracing::debug!(
316 "KVStoreDiscovery::unregister: Unregistering LoRA model with suffix={}, instance_id={}, namespace={}, component={}, endpoint={}, key={}",
317 suffix,
318 instance_id,
319 namespace,
320 component,
321 endpoint,
322 key
323 );
324 }
325
326 if model_suffix.as_ref().is_none_or(|s| s.is_empty()) {
328 tracing::debug!(
329 "Unregistering base model instance_id={}, namespace={}, component={}, endpoint={}, key={}",
330 instance_id,
331 namespace,
332 component,
333 endpoint,
334 key
335 );
336 }
337 (MODELS_BUCKET, key)
338 }
339 DiscoveryInstance::EventChannel {
340 namespace,
341 component,
342 topic,
343 instance_id,
344 ..
345 } => {
346 let key = Self::event_channel_key(namespace, component, topic, *instance_id);
347 tracing::debug!(
348 "KVStoreDiscovery::unregister: Unregistering event channel instance_id={}, namespace={}, component={}, topic={}, key={}",
349 instance_id,
350 namespace,
351 component,
352 topic,
353 key
354 );
355 (EVENT_CHANNELS_BUCKET, key)
356 }
357 };
358
359 let Some(bucket) = self.store.get_bucket(bucket_name).await? else {
361 tracing::warn!(
362 "Bucket {} does not exist, instance already removed",
363 bucket_name
364 );
365 return Ok(());
366 };
367
368 let key = kv::Key::new(key_path.clone());
369
370 bucket.delete(&key).await?;
372
373 Ok(())
374 }
375
376 async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
377 let prefix = Self::query_prefix(&query);
378 let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
379 INSTANCES_BUCKET
380 } else if prefix.starts_with(EVENT_CHANNELS_BUCKET) {
381 EVENT_CHANNELS_BUCKET
382 } else {
383 MODELS_BUCKET
384 };
385
386 let Some(bucket) = self.store.get_bucket(bucket_name).await? else {
388 tracing::info!(
389 "KVStoreDiscovery::list: bucket missing for query={:?}, prefix={}, bucket={}",
390 query,
391 prefix,
392 bucket_name
393 );
394 return Ok(Vec::new());
395 };
396
397 let entries = bucket.entries().await?;
399 tracing::info!(
400 "KVStoreDiscovery::list: query={:?}, prefix={}, bucket={}, entries={}",
401 query,
402 prefix,
403 bucket_name,
404 entries.len()
405 );
406
407 let mut instances = Vec::new();
409 for (key, value) in entries {
410 if Self::matches_prefix(key.as_ref(), &prefix, bucket_name) {
411 match Self::parse_instance(&value) {
412 Ok(instance) => instances.push(instance),
413 Err(e) => {
414 tracing::warn!(%key, error = %e, "Failed to parse discovery instance");
415 }
416 }
417 }
418 }
419
420 Ok(instances)
421 }
422
423 async fn list_and_watch(
424 &self,
425 query: DiscoveryQuery,
426 cancel_token: Option<CancellationToken>,
427 ) -> Result<DiscoveryStream> {
428 let prefix = Self::query_prefix(&query);
429 let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
430 INSTANCES_BUCKET
431 } else if prefix.starts_with(EVENT_CHANNELS_BUCKET) {
432 EVENT_CHANNELS_BUCKET
433 } else {
434 MODELS_BUCKET
435 };
436
437 tracing::trace!(
438 "KVStoreDiscovery::list_and_watch: Starting watch for query={:?}, prefix={}, bucket={}",
439 query,
440 prefix,
441 bucket_name
442 );
443
444 let cancel_token = cancel_token.unwrap_or_else(|| self.cancel_token.clone());
446
447 let (_, mut rx) = self.store.clone().watch(
449 bucket_name,
450 None, cancel_token,
452 );
453
454 let stream = async_stream::stream! {
456 while let Some(event) = rx.recv().await {
457 let discovery_event = match event {
458 kv::WatchEvent::Put(kv) => {
459 if !Self::matches_prefix(kv.key_str(), &prefix, bucket_name) {
461 continue;
462 }
463
464 match Self::parse_instance(kv.value()) {
465 Ok(instance) => {
466 Some(DiscoveryEvent::Added(instance))
467 },
468 Err(e) => {
469 tracing::warn!(
470 key = %kv.key_str(),
471 error = %e,
472 "Failed to parse discovery instance from watch event"
473 );
474 None
475 }
476 }
477 }
478 kv::WatchEvent::Delete(kv) => {
479 let key_str = kv.as_ref();
480 if !Self::matches_prefix(key_str, &prefix, bucket_name) {
482 continue;
483 }
484
485 let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
496 let key_parts: Vec<&str> = relative_key.split('/').collect();
497
498 let min_parts = 4;
501 if key_parts.len() < min_parts {
502 tracing::warn!(
503 key = %key_str,
504 relative_key = %relative_key,
505 actual_parts = key_parts.len(),
506 expected_min = min_parts,
507 bucket = bucket_name,
508 "Delete event key doesn't have enough parts"
509 );
510 continue;
511 }
512
513 let namespace = key_parts[0].to_string();
514 let component = key_parts[1].to_string();
515
516 let id = if bucket_name == EVENT_CHANNELS_BUCKET {
518 let topic = key_parts[2].to_string();
520 let instance_id_hex = key_parts[3];
521 match u64::from_str_radix(instance_id_hex, 16) {
522 Ok(instance_id) => {
523 DiscoveryInstanceId::EventChannel(EventChannelInstanceId {
524 namespace,
525 component,
526 topic,
527 instance_id,
528 })
529 }
530 Err(e) => {
531 tracing::warn!(
532 key = %key_str,
533 error = %e,
534 instance_id_hex = %instance_id_hex,
535 "Failed to parse event channel instance_id hex"
536 );
537 continue;
538 }
539 }
540 } else {
541 let endpoint = key_parts[2].to_string();
542 let instance_id_hex = key_parts[3];
543
544 match u64::from_str_radix(instance_id_hex, 16) {
545 Ok(instance_id) => {
546 if bucket_name == INSTANCES_BUCKET {
548 DiscoveryInstanceId::Endpoint(EndpointInstanceId {
549 namespace,
550 component,
551 endpoint,
552 instance_id,
553 })
554 } else {
555 let model_suffix = key_parts.get(4).map(|s| s.to_string());
557 DiscoveryInstanceId::Model(ModelCardInstanceId {
558 namespace,
559 component,
560 endpoint,
561 instance_id,
562 model_suffix,
563 })
564 }
565 }
566 Err(e) => {
567 tracing::warn!(
568 key = %key_str,
569 error = %e,
570 instance_id_hex = %instance_id_hex,
571 "Failed to parse instance_id hex from deleted key"
572 );
573 continue;
574 }
575 }
576 };
577
578 tracing::debug!(
579 "KVStoreDiscovery::list_and_watch: Emitting Removed event for {:?}, key={}",
580 id,
581 key_str
582 );
583 Some(DiscoveryEvent::Removed(id))
584 }
585 };
586
587 if let Some(event) = discovery_event {
588 yield Ok(event);
589 }
590 }
591 };
592 Ok(Box::pin(stream))
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599 use crate::component::TransportType;
600
601 #[tokio::test]
602 async fn test_kv_store_discovery_register_endpoint() {
603 let store = kv::Manager::memory();
604 let cancel_token = CancellationToken::new();
605 let client = KVStoreDiscovery::new(store, cancel_token);
606
607 let spec = DiscoverySpec::Endpoint {
608 namespace: "test".to_string(),
609 component: "comp1".to_string(),
610 endpoint: "ep1".to_string(),
611 transport: TransportType::Nats("nats://localhost:4222".to_string()),
612 };
613
614 let instance = client.register(spec).await.unwrap();
615
616 match instance {
617 DiscoveryInstance::Endpoint(inst) => {
618 assert_eq!(inst.namespace, "test");
619 assert_eq!(inst.component, "comp1");
620 assert_eq!(inst.endpoint, "ep1");
621 }
622 _ => panic!("Expected Endpoint instance"),
623 }
624 }
625
626 #[tokio::test]
627 async fn test_kv_store_discovery_list() {
628 let store = kv::Manager::memory();
629 let cancel_token = CancellationToken::new();
630 let client = KVStoreDiscovery::new(store, cancel_token);
631
632 let spec1 = DiscoverySpec::Endpoint {
634 namespace: "ns1".to_string(),
635 component: "comp1".to_string(),
636 endpoint: "ep1".to_string(),
637 transport: TransportType::Nats("nats://localhost:4222".to_string()),
638 };
639 client.register(spec1).await.unwrap();
640
641 let spec2 = DiscoverySpec::Endpoint {
642 namespace: "ns1".to_string(),
643 component: "comp1".to_string(),
644 endpoint: "ep2".to_string(),
645 transport: TransportType::Nats("nats://localhost:4222".to_string()),
646 };
647 client.register(spec2).await.unwrap();
648
649 let spec3 = DiscoverySpec::Endpoint {
650 namespace: "ns2".to_string(),
651 component: "comp2".to_string(),
652 endpoint: "ep1".to_string(),
653 transport: TransportType::Nats("nats://localhost:4222".to_string()),
654 };
655 client.register(spec3).await.unwrap();
656
657 let all = client.list(DiscoveryQuery::AllEndpoints).await.unwrap();
659 assert_eq!(all.len(), 3);
660
661 let ns1 = client
663 .list(DiscoveryQuery::NamespacedEndpoints {
664 namespace: "ns1".to_string(),
665 })
666 .await
667 .unwrap();
668 assert_eq!(ns1.len(), 2);
669
670 let comp1 = client
672 .list(DiscoveryQuery::ComponentEndpoints {
673 namespace: "ns1".to_string(),
674 component: "comp1".to_string(),
675 })
676 .await
677 .unwrap();
678 assert_eq!(comp1.len(), 2);
679 }
680
681 #[tokio::test]
682 async fn test_kv_store_discovery_watch() {
683 let store = kv::Manager::memory();
684 let cancel_token = CancellationToken::new();
685 let client = Arc::new(KVStoreDiscovery::new(store, cancel_token.clone()));
686
687 let mut stream = client
689 .list_and_watch(DiscoveryQuery::AllEndpoints, None)
690 .await
691 .unwrap();
692
693 let client_clone = client.clone();
694 let register_task = tokio::spawn(async move {
695 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
696
697 let spec = DiscoverySpec::Endpoint {
698 namespace: "test".to_string(),
699 component: "comp1".to_string(),
700 endpoint: "ep1".to_string(),
701 transport: TransportType::Nats("nats://localhost:4222".to_string()),
702 };
703 client_clone.register(spec).await.unwrap();
704 });
705
706 let event = stream.next().await.unwrap().unwrap();
708 match event {
709 DiscoveryEvent::Added(instance) => match instance {
710 DiscoveryInstance::Endpoint(inst) => {
711 assert_eq!(inst.namespace, "test");
712 assert_eq!(inst.component, "comp1");
713 assert_eq!(inst.endpoint, "ep1");
714 }
715 _ => panic!("Expected Endpoint instance"),
716 },
717 _ => panic!("Expected Added event"),
718 }
719
720 register_task.await.unwrap();
721 cancel_token.cancel();
722 }
723}