dynamo_runtime/discovery/
kv_store.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::pin::Pin;
5use std::sync::Arc;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use futures::{Stream, StreamExt};
10use tokio_util::sync::CancellationToken;
11
12use super::{
13    Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
14    DiscoverySpec, DiscoveryStream, EndpointInstanceId, ModelCardInstanceId,
15};
16use crate::storage::kv;
17
18const INSTANCES_BUCKET: &str = "v1/instances";
19const MODELS_BUCKET: &str = "v1/mdc";
20
21/// Discovery implementation backed by a kv::Store
22pub struct KVStoreDiscovery {
23    store: Arc<kv::Manager>,
24    cancel_token: CancellationToken,
25}
26
27impl KVStoreDiscovery {
28    pub fn new(store: kv::Manager, cancel_token: CancellationToken) -> Self {
29        Self {
30            store: Arc::new(store),
31            cancel_token,
32        }
33    }
34
35    /// Build the key path for an endpoint (relative to bucket, not absolute)
36    fn endpoint_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
37        format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
38    }
39
40    /// Build the key path for a model (relative to bucket, not absolute)
41    fn model_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
42        format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
43    }
44
45    /// Extract prefix for querying based on discovery query
46    fn query_prefix(query: &DiscoveryQuery) -> String {
47        match query {
48            DiscoveryQuery::AllEndpoints => INSTANCES_BUCKET.to_string(),
49            DiscoveryQuery::NamespacedEndpoints { namespace } => {
50                format!("{}/{}", INSTANCES_BUCKET, namespace)
51            }
52            DiscoveryQuery::ComponentEndpoints {
53                namespace,
54                component,
55            } => {
56                format!("{}/{}/{}", INSTANCES_BUCKET, namespace, component)
57            }
58            DiscoveryQuery::Endpoint {
59                namespace,
60                component,
61                endpoint,
62            } => {
63                format!(
64                    "{}/{}/{}/{}",
65                    INSTANCES_BUCKET, namespace, component, endpoint
66                )
67            }
68            DiscoveryQuery::AllModels => MODELS_BUCKET.to_string(),
69            DiscoveryQuery::NamespacedModels { namespace } => {
70                format!("{}/{}", MODELS_BUCKET, namespace)
71            }
72            DiscoveryQuery::ComponentModels {
73                namespace,
74                component,
75            } => {
76                format!("{}/{}/{}", MODELS_BUCKET, namespace, component)
77            }
78            DiscoveryQuery::EndpointModels {
79                namespace,
80                component,
81                endpoint,
82            } => {
83                format!("{}/{}/{}/{}", MODELS_BUCKET, namespace, component, endpoint)
84            }
85        }
86    }
87
88    /// Strip bucket prefix from a key if present, returning the relative path within the bucket
89    /// For example: "v1/instances/ns/comp/ep" -> "ns/comp/ep"
90    /// Or if already relative: "ns/comp/ep" -> "ns/comp/ep"
91    fn strip_bucket_prefix<'a>(key: &'a str, bucket_name: &str) -> &'a str {
92        // Try to strip "bucket_name/" from the beginning
93        if let Some(stripped) = key.strip_prefix(bucket_name) {
94            // Strip the leading slash if present
95            stripped.strip_prefix('/').unwrap_or(stripped)
96        } else {
97            // Key is already relative to bucket
98            key
99        }
100    }
101
102    /// Check if a key matches the given prefix, handling both absolute and relative key formats
103    /// This works regardless of whether keys include the bucket prefix (etcd) or not (memory)
104    fn matches_prefix(key_str: &str, prefix: &str, bucket_name: &str) -> bool {
105        // Normalize both the key and prefix to relative paths (without bucket prefix)
106        let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
107        let relative_prefix = Self::strip_bucket_prefix(prefix, bucket_name);
108
109        // Empty prefix matches everything in the bucket
110        if relative_prefix.is_empty() {
111            return true;
112        }
113
114        // Check if the relative key starts with the relative prefix
115        relative_key.starts_with(relative_prefix)
116    }
117
118    /// Parse and deserialize a discovery instance from KV store entry
119    fn parse_instance(value: &[u8]) -> Result<DiscoveryInstance> {
120        let instance: DiscoveryInstance = serde_json::from_slice(value)?;
121        Ok(instance)
122    }
123}
124
125#[async_trait]
126impl Discovery for KVStoreDiscovery {
127    fn instance_id(&self) -> u64 {
128        self.store.connection_id()
129    }
130
131    async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
132        let instance_id = self.instance_id();
133        let instance = spec.with_instance_id(instance_id);
134
135        let (bucket_name, key_path) = match &instance {
136            DiscoveryInstance::Endpoint(inst) => {
137                let key = Self::endpoint_key(
138                    &inst.namespace,
139                    &inst.component,
140                    &inst.endpoint,
141                    inst.instance_id,
142                );
143                tracing::debug!(
144                    "KVStoreDiscovery::register: Registering endpoint instance_id={}, namespace={}, component={}, endpoint={}, key={}",
145                    inst.instance_id,
146                    inst.namespace,
147                    inst.component,
148                    inst.endpoint,
149                    key
150                );
151                (INSTANCES_BUCKET, key)
152            }
153            DiscoveryInstance::Model {
154                namespace,
155                component,
156                endpoint,
157                instance_id,
158                model_suffix,
159                ..
160            } => {
161                let mut key = Self::model_key(namespace, component, endpoint, *instance_id);
162
163                // If there's a model_suffix (e.g., for LoRA adapters), append it after the instance_id
164                // Key format: {namespace}/{component}/{endpoint}/{instance_id:x}/{model_suffix}
165                if let Some(suffix) = model_suffix
166                    && !suffix.is_empty()
167                {
168                    key = format!("{}/{}", key, suffix);
169                    tracing::debug!(
170                        "KVStoreDiscovery::register: Registering LoRA model with suffix={}, instance_id={}, namespace={}, component={}, endpoint={}, key={}",
171                        suffix,
172                        instance_id,
173                        namespace,
174                        component,
175                        endpoint,
176                        key
177                    );
178                }
179
180                // Log for base models (no suffix or empty suffix)
181                if model_suffix.as_ref().is_none_or(|s| s.is_empty()) {
182                    tracing::debug!(
183                        "KVStoreDiscovery::register: Registering base model instance_id={}, namespace={}, component={}, endpoint={}, key={}",
184                        instance_id,
185                        namespace,
186                        component,
187                        endpoint,
188                        key
189                    );
190                }
191                (MODELS_BUCKET, key)
192            }
193        };
194
195        // Serialize the instance
196        let instance_json = serde_json::to_vec(&instance)?;
197        tracing::debug!(
198            "KVStoreDiscovery::register: Serialized instance to {} bytes for key={}",
199            instance_json.len(),
200            key_path
201        );
202
203        // Store in the KV store with no TTL (instances persist until explicitly removed)
204        tracing::debug!(
205            "KVStoreDiscovery::register: Getting/creating bucket={} for key={}",
206            bucket_name,
207            key_path
208        );
209        let bucket = self.store.get_or_create_bucket(bucket_name, None).await?;
210        let key = kv::Key::new(key_path.clone());
211
212        tracing::debug!(
213            "KVStoreDiscovery::register: Inserting into bucket={}, key={}",
214            bucket_name,
215            key_path
216        );
217        // Use revision 0 for initial registration
218        let outcome = bucket.insert(&key, instance_json.into(), 0).await?;
219        tracing::debug!(
220            "KVStoreDiscovery::register: Successfully registered instance_id={}, key={}, outcome={:?}",
221            instance_id,
222            key_path,
223            outcome
224        );
225
226        Ok(instance)
227    }
228
229    async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
230        let (bucket_name, key_path) = match &instance {
231            DiscoveryInstance::Endpoint(inst) => {
232                let key = Self::endpoint_key(
233                    &inst.namespace,
234                    &inst.component,
235                    &inst.endpoint,
236                    inst.instance_id,
237                );
238                tracing::debug!(
239                    "Unregistering endpoint instance_id={}, namespace={}, component={}, endpoint={}, key={}",
240                    inst.instance_id,
241                    inst.namespace,
242                    inst.component,
243                    inst.endpoint,
244                    key
245                );
246                (INSTANCES_BUCKET, key)
247            }
248            DiscoveryInstance::Model {
249                namespace,
250                component,
251                endpoint,
252                instance_id,
253                model_suffix,
254                ..
255            } => {
256                let mut key = Self::model_key(namespace, component, endpoint, *instance_id);
257
258                // If there's a model_suffix (e.g., for LoRA adapters), append it after the instance_id
259                if let Some(suffix) = model_suffix
260                    && !suffix.is_empty()
261                {
262                    key = format!("{}/{}", key, suffix);
263                    tracing::debug!(
264                        "KVStoreDiscovery::unregister: Unregistering LoRA model with suffix={}, instance_id={}, namespace={}, component={}, endpoint={}, key={}",
265                        suffix,
266                        instance_id,
267                        namespace,
268                        component,
269                        endpoint,
270                        key
271                    );
272                }
273
274                // Log for base models (no suffix or empty suffix)
275                if model_suffix.as_ref().is_none_or(|s| s.is_empty()) {
276                    tracing::debug!(
277                        "Unregistering base model instance_id={}, namespace={}, component={}, endpoint={}, key={}",
278                        instance_id,
279                        namespace,
280                        component,
281                        endpoint,
282                        key
283                    );
284                }
285                (MODELS_BUCKET, key)
286            }
287        };
288
289        // Get the bucket - if it doesn't exist, the instance is already removed from the KV store
290        let Some(bucket) = self.store.get_bucket(bucket_name).await? else {
291            tracing::warn!(
292                "Bucket {} does not exist, instance already removed",
293                bucket_name
294            );
295            return Ok(());
296        };
297
298        let key = kv::Key::new(key_path.clone());
299
300        // Delete the entry from the bucket
301        bucket.delete(&key).await?;
302
303        Ok(())
304    }
305
306    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
307        let prefix = Self::query_prefix(&query);
308        let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
309            INSTANCES_BUCKET
310        } else {
311            MODELS_BUCKET
312        };
313
314        // Get bucket - if it doesn't exist, return empty list
315        let Some(bucket) = self.store.get_bucket(bucket_name).await? else {
316            return Ok(Vec::new());
317        };
318
319        // Get all entries from the bucket
320        let entries = bucket.entries().await?;
321
322        // Filter by prefix and deserialize
323        let mut instances = Vec::new();
324        for (key, value) in entries {
325            if Self::matches_prefix(key.as_ref(), &prefix, bucket_name) {
326                match Self::parse_instance(&value) {
327                    Ok(instance) => instances.push(instance),
328                    Err(e) => {
329                        tracing::warn!(%key, error = %e, "Failed to parse discovery instance");
330                    }
331                }
332            }
333        }
334
335        Ok(instances)
336    }
337
338    async fn list_and_watch(
339        &self,
340        query: DiscoveryQuery,
341        cancel_token: Option<CancellationToken>,
342    ) -> Result<DiscoveryStream> {
343        let prefix = Self::query_prefix(&query);
344        let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
345            INSTANCES_BUCKET
346        } else {
347            MODELS_BUCKET
348        };
349
350        tracing::trace!(
351            "KVStoreDiscovery::list_and_watch: Starting watch for query={:?}, prefix={}, bucket={}",
352            query,
353            prefix,
354            bucket_name
355        );
356
357        // Use the provided cancellation token, or fall back to the default token
358        let cancel_token = cancel_token.unwrap_or_else(|| self.cancel_token.clone());
359
360        // Use the kv::Manager's watch mechanism
361        let (_, mut rx) = self.store.clone().watch(
362            bucket_name,
363            None, // No TTL
364            cancel_token,
365        );
366
367        // Create a stream that filters and transforms WatchEvents to DiscoveryEvents
368        let stream = async_stream::stream! {
369            while let Some(event) = rx.recv().await {
370                let discovery_event = match event {
371                    kv::WatchEvent::Put(kv) => {
372                        // Check if this key matches our prefix
373                        if !Self::matches_prefix(kv.key_str(), &prefix, bucket_name) {
374                            continue;
375                        }
376
377                        match Self::parse_instance(kv.value()) {
378                            Ok(instance) => {
379                                Some(DiscoveryEvent::Added(instance))
380                            },
381                            Err(e) => {
382                                tracing::warn!(
383                                    key = %kv.key_str(),
384                                    error = %e,
385                                    "Failed to parse discovery instance from watch event"
386                                );
387                                None
388                            }
389                        }
390                    }
391                    kv::WatchEvent::Delete(kv) => {
392                        let key_str = kv.as_ref();
393                        // Check if this key matches our prefix
394                        if !Self::matches_prefix(key_str, &prefix, bucket_name) {
395                            continue;
396                        }
397
398                        // Extract DiscoveryInstanceId from the key path
399                        // Delete events have empty values in etcd, so we reconstruct the ID from the key
400                        //
401                        // Key format (relative to bucket, after stripping bucket prefix):
402                        // - Endpoints: "namespace/component/endpoint/{instance_id:x}"
403                        // - Models: "namespace/component/endpoint/{instance_id:x}"
404                        // - LoRA models: "namespace/component/endpoint/{instance_id:x}/{lora_slug}"
405                        //
406                        // Use strip_bucket_prefix for consistency with matches_prefix().
407                        let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
408                        let key_parts: Vec<&str> = relative_key.split('/').collect();
409
410                        // In relative key: namespace/component/endpoint/{instance_id}[/{lora_slug}]
411                        // We need at least 4 parts: namespace, component, endpoint, instance_id
412                        if key_parts.len() < 4 {
413                            tracing::warn!(
414                                key = %key_str,
415                                relative_key = %relative_key,
416                                actual_parts = key_parts.len(),
417                                "Delete event key doesn't have enough parts"
418                            );
419                            continue;
420                        }
421
422                        let namespace = key_parts[0].to_string();
423                        let component = key_parts[1].to_string();
424                        let endpoint = key_parts[2].to_string();
425                        let instance_id_hex = key_parts[3];
426
427                        match u64::from_str_radix(instance_id_hex, 16) {
428                            Ok(instance_id) => {
429                                // Construct the appropriate DiscoveryInstanceId based on bucket type
430                                let id = if bucket_name == INSTANCES_BUCKET {
431                                    DiscoveryInstanceId::Endpoint(EndpointInstanceId {
432                                        namespace,
433                                        component,
434                                        endpoint,
435                                        instance_id,
436                                    })
437                                } else {
438                                    // Model - check for LoRA suffix (5th part if present)
439                                    let model_suffix = key_parts.get(4).map(|s| s.to_string());
440                                    DiscoveryInstanceId::Model(ModelCardInstanceId {
441                                        namespace,
442                                        component,
443                                        endpoint,
444                                        instance_id,
445                                        model_suffix,
446                                    })
447                                };
448
449                                tracing::debug!(
450                                    "KVStoreDiscovery::list_and_watch: Emitting Removed event for {:?}, key={}",
451                                    id,
452                                    key_str
453                                );
454                                Some(DiscoveryEvent::Removed(id))
455                            }
456                            Err(e) => {
457                                tracing::warn!(
458                                    key = %key_str,
459                                    relative_key = %relative_key,
460                                    error = %e,
461                                    instance_id_hex = %instance_id_hex,
462                                    "Failed to parse instance_id hex from deleted key"
463                                );
464                                None
465                            }
466                        }
467                    }
468                };
469
470                if let Some(event) = discovery_event {
471                    yield Ok(event);
472                }
473            }
474        };
475        Ok(Box::pin(stream))
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use crate::component::TransportType;
483
484    #[tokio::test]
485    async fn test_kv_store_discovery_register_endpoint() {
486        let store = kv::Manager::memory();
487        let cancel_token = CancellationToken::new();
488        let client = KVStoreDiscovery::new(store, cancel_token);
489
490        let spec = DiscoverySpec::Endpoint {
491            namespace: "test".to_string(),
492            component: "comp1".to_string(),
493            endpoint: "ep1".to_string(),
494            transport: TransportType::Nats("nats://localhost:4222".to_string()),
495        };
496
497        let instance = client.register(spec).await.unwrap();
498
499        match instance {
500            DiscoveryInstance::Endpoint(inst) => {
501                assert_eq!(inst.namespace, "test");
502                assert_eq!(inst.component, "comp1");
503                assert_eq!(inst.endpoint, "ep1");
504            }
505            _ => panic!("Expected Endpoint instance"),
506        }
507    }
508
509    #[tokio::test]
510    async fn test_kv_store_discovery_list() {
511        let store = kv::Manager::memory();
512        let cancel_token = CancellationToken::new();
513        let client = KVStoreDiscovery::new(store, cancel_token);
514
515        // Register multiple endpoints
516        let spec1 = DiscoverySpec::Endpoint {
517            namespace: "ns1".to_string(),
518            component: "comp1".to_string(),
519            endpoint: "ep1".to_string(),
520            transport: TransportType::Nats("nats://localhost:4222".to_string()),
521        };
522        client.register(spec1).await.unwrap();
523
524        let spec2 = DiscoverySpec::Endpoint {
525            namespace: "ns1".to_string(),
526            component: "comp1".to_string(),
527            endpoint: "ep2".to_string(),
528            transport: TransportType::Nats("nats://localhost:4222".to_string()),
529        };
530        client.register(spec2).await.unwrap();
531
532        let spec3 = DiscoverySpec::Endpoint {
533            namespace: "ns2".to_string(),
534            component: "comp2".to_string(),
535            endpoint: "ep1".to_string(),
536            transport: TransportType::Nats("nats://localhost:4222".to_string()),
537        };
538        client.register(spec3).await.unwrap();
539
540        // List all endpoints
541        let all = client.list(DiscoveryQuery::AllEndpoints).await.unwrap();
542        assert_eq!(all.len(), 3);
543
544        // List namespaced endpoints
545        let ns1 = client
546            .list(DiscoveryQuery::NamespacedEndpoints {
547                namespace: "ns1".to_string(),
548            })
549            .await
550            .unwrap();
551        assert_eq!(ns1.len(), 2);
552
553        // List component endpoints
554        let comp1 = client
555            .list(DiscoveryQuery::ComponentEndpoints {
556                namespace: "ns1".to_string(),
557                component: "comp1".to_string(),
558            })
559            .await
560            .unwrap();
561        assert_eq!(comp1.len(), 2);
562    }
563
564    #[tokio::test]
565    async fn test_kv_store_discovery_watch() {
566        let store = kv::Manager::memory();
567        let cancel_token = CancellationToken::new();
568        let client = Arc::new(KVStoreDiscovery::new(store, cancel_token.clone()));
569
570        // Start watching before registering
571        let mut stream = client
572            .list_and_watch(DiscoveryQuery::AllEndpoints, None)
573            .await
574            .unwrap();
575
576        let client_clone = client.clone();
577        let register_task = tokio::spawn(async move {
578            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
579
580            let spec = DiscoverySpec::Endpoint {
581                namespace: "test".to_string(),
582                component: "comp1".to_string(),
583                endpoint: "ep1".to_string(),
584                transport: TransportType::Nats("nats://localhost:4222".to_string()),
585            };
586            client_clone.register(spec).await.unwrap();
587        });
588
589        // Wait for the added event
590        let event = stream.next().await.unwrap().unwrap();
591        match event {
592            DiscoveryEvent::Added(instance) => match instance {
593                DiscoveryInstance::Endpoint(inst) => {
594                    assert_eq!(inst.namespace, "test");
595                    assert_eq!(inst.component, "comp1");
596                    assert_eq!(inst.endpoint, "ep1");
597                }
598                _ => panic!("Expected Endpoint instance"),
599            },
600            _ => panic!("Expected Added event"),
601        }
602
603        register_task.await.unwrap();
604        cancel_token.cancel();
605    }
606}