dynamo_runtime/discovery/
kv_store.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::pin::Pin;
5use std::sync::Arc;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use futures::{Stream, StreamExt};
10use tokio_util::sync::CancellationToken;
11
12use super::{
13    Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, DiscoveryStream,
14};
15use crate::storage::key_value_store::{KeyValueStoreManager, WatchEvent};
16
17const INSTANCES_BUCKET: &str = "v1/instances";
18const MODELS_BUCKET: &str = "v1/mdc";
19
20/// Discovery implementation backed by a KeyValueStore
21pub struct KVStoreDiscovery {
22    store: Arc<KeyValueStoreManager>,
23    cancel_token: CancellationToken,
24}
25
26impl KVStoreDiscovery {
27    pub fn new(store: KeyValueStoreManager, cancel_token: CancellationToken) -> Self {
28        Self {
29            store: Arc::new(store),
30            cancel_token,
31        }
32    }
33
34    /// Build the key path for an endpoint (relative to bucket, not absolute)
35    fn endpoint_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
36        format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
37    }
38
39    /// Build the key path for a model (relative to bucket, not absolute)
40    fn model_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
41        format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
42    }
43
44    /// Extract prefix for querying based on discovery query
45    fn query_prefix(query: &DiscoveryQuery) -> String {
46        match query {
47            DiscoveryQuery::AllEndpoints => INSTANCES_BUCKET.to_string(),
48            DiscoveryQuery::NamespacedEndpoints { namespace } => {
49                format!("{}/{}", INSTANCES_BUCKET, namespace)
50            }
51            DiscoveryQuery::ComponentEndpoints {
52                namespace,
53                component,
54            } => {
55                format!("{}/{}/{}", INSTANCES_BUCKET, namespace, component)
56            }
57            DiscoveryQuery::Endpoint {
58                namespace,
59                component,
60                endpoint,
61            } => {
62                format!(
63                    "{}/{}/{}/{}",
64                    INSTANCES_BUCKET, namespace, component, endpoint
65                )
66            }
67            DiscoveryQuery::AllModels => MODELS_BUCKET.to_string(),
68            DiscoveryQuery::NamespacedModels { namespace } => {
69                format!("{}/{}", MODELS_BUCKET, namespace)
70            }
71            DiscoveryQuery::ComponentModels {
72                namespace,
73                component,
74            } => {
75                format!("{}/{}/{}", MODELS_BUCKET, namespace, component)
76            }
77            DiscoveryQuery::EndpointModels {
78                namespace,
79                component,
80                endpoint,
81            } => {
82                format!("{}/{}/{}/{}", MODELS_BUCKET, namespace, component, endpoint)
83            }
84        }
85    }
86
87    /// Strip bucket prefix from a key if present, returning the relative path within the bucket
88    /// For example: "v1/instances/ns/comp/ep" -> "ns/comp/ep"
89    /// Or if already relative: "ns/comp/ep" -> "ns/comp/ep"
90    fn strip_bucket_prefix<'a>(key: &'a str, bucket_name: &str) -> &'a str {
91        // Try to strip "bucket_name/" from the beginning
92        if let Some(stripped) = key.strip_prefix(bucket_name) {
93            // Strip the leading slash if present
94            stripped.strip_prefix('/').unwrap_or(stripped)
95        } else {
96            // Key is already relative to bucket
97            key
98        }
99    }
100
101    /// Check if a key matches the given prefix, handling both absolute and relative key formats
102    /// This works regardless of whether keys include the bucket prefix (etcd) or not (memory)
103    fn matches_prefix(key_str: &str, prefix: &str, bucket_name: &str) -> bool {
104        // Normalize both the key and prefix to relative paths (without bucket prefix)
105        let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
106        let relative_prefix = Self::strip_bucket_prefix(prefix, bucket_name);
107
108        // Empty prefix matches everything in the bucket
109        if relative_prefix.is_empty() {
110            return true;
111        }
112
113        // Check if the relative key starts with the relative prefix
114        relative_key.starts_with(relative_prefix)
115    }
116
117    /// Parse and deserialize a discovery instance from KV store entry
118    fn parse_instance(value: &[u8]) -> Result<DiscoveryInstance> {
119        let instance: DiscoveryInstance = serde_json::from_slice(value)?;
120        Ok(instance)
121    }
122}
123
124#[async_trait]
125impl Discovery for KVStoreDiscovery {
126    fn instance_id(&self) -> u64 {
127        self.store.connection_id()
128    }
129
130    async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
131        let instance_id = self.instance_id();
132        let instance = spec.with_instance_id(instance_id);
133
134        let (bucket_name, key_path) = match &instance {
135            DiscoveryInstance::Endpoint(inst) => {
136                let key = Self::endpoint_key(
137                    &inst.namespace,
138                    &inst.component,
139                    &inst.endpoint,
140                    inst.instance_id,
141                );
142                tracing::debug!(
143                    "KVStoreDiscovery::register: Registering endpoint instance_id={}, namespace={}, component={}, endpoint={}, key={}",
144                    inst.instance_id,
145                    inst.namespace,
146                    inst.component,
147                    inst.endpoint,
148                    key
149                );
150                (INSTANCES_BUCKET, key)
151            }
152            DiscoveryInstance::Model {
153                namespace,
154                component,
155                endpoint,
156                instance_id,
157                ..
158            } => {
159                let key = Self::model_key(namespace, component, endpoint, *instance_id);
160                tracing::debug!(
161                    "KVStoreDiscovery::register: Registering model instance_id={}, namespace={}, component={}, endpoint={}, key={}",
162                    instance_id,
163                    namespace,
164                    component,
165                    endpoint,
166                    key
167                );
168                (MODELS_BUCKET, key)
169            }
170        };
171
172        // Serialize the instance
173        let instance_json = serde_json::to_vec(&instance)?;
174        tracing::debug!(
175            "KVStoreDiscovery::register: Serialized instance to {} bytes for key={}",
176            instance_json.len(),
177            key_path
178        );
179
180        // Store in the KV store with no TTL (instances persist until explicitly removed)
181        tracing::debug!(
182            "KVStoreDiscovery::register: Getting/creating bucket={} for key={}",
183            bucket_name,
184            key_path
185        );
186        let bucket = self.store.get_or_create_bucket(bucket_name, None).await?;
187        let key = crate::storage::key_value_store::Key::from_raw(key_path.clone());
188
189        tracing::debug!(
190            "KVStoreDiscovery::register: Inserting into bucket={}, key={}",
191            bucket_name,
192            key_path
193        );
194        // Use revision 0 for initial registration
195        let outcome = bucket.insert(&key, instance_json.into(), 0).await?;
196        tracing::debug!(
197            "KVStoreDiscovery::register: Successfully registered instance_id={}, key={}, outcome={:?}",
198            instance_id,
199            key_path,
200            outcome
201        );
202
203        Ok(instance)
204    }
205
206    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
207        let prefix = Self::query_prefix(&query);
208        let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
209            INSTANCES_BUCKET
210        } else {
211            MODELS_BUCKET
212        };
213
214        // Get bucket - if it doesn't exist, return empty list
215        let Some(bucket) = self.store.get_bucket(bucket_name).await? else {
216            return Ok(Vec::new());
217        };
218
219        // Get all entries from the bucket
220        let entries = bucket.entries().await?;
221
222        // Filter by prefix and deserialize
223        let mut instances = Vec::new();
224        for (key_str, value) in entries {
225            if Self::matches_prefix(&key_str, &prefix, bucket_name) {
226                match Self::parse_instance(&value) {
227                    Ok(instance) => instances.push(instance),
228                    Err(e) => {
229                        tracing::warn!(key = %key_str, error = %e, "Failed to parse discovery instance");
230                    }
231                }
232            }
233        }
234
235        Ok(instances)
236    }
237
238    async fn list_and_watch(
239        &self,
240        query: DiscoveryQuery,
241        cancel_token: Option<CancellationToken>,
242    ) -> Result<DiscoveryStream> {
243        let prefix = Self::query_prefix(&query);
244        let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
245            INSTANCES_BUCKET
246        } else {
247            MODELS_BUCKET
248        };
249
250        tracing::debug!(
251            "KVStoreDiscovery::list_and_watch: Starting watch for query={:?}, prefix={}, bucket={}",
252            query,
253            prefix,
254            bucket_name
255        );
256
257        // Use the provided cancellation token, or fall back to the default token
258        let cancel_token = cancel_token.unwrap_or_else(|| self.cancel_token.clone());
259
260        // Use the KeyValueStoreManager's watch mechanism
261        let (_, mut rx) = self.store.clone().watch(
262            bucket_name,
263            None, // No TTL
264            cancel_token,
265        );
266
267        tracing::debug!(
268            "KVStoreDiscovery::list_and_watch: Got watch receiver for bucket={}",
269            bucket_name
270        );
271
272        // Create a stream that filters and transforms WatchEvents to DiscoveryEvents
273        let stream = async_stream::stream! {
274            let mut event_count = 0;
275            tracing::debug!("KVStoreDiscovery::list_and_watch: Stream started, waiting for events on prefix={}", prefix);
276            while let Some(event) = rx.recv().await {
277                event_count += 1;
278                tracing::debug!(
279                    "KVStoreDiscovery::list_and_watch: Received event #{} for prefix={}",
280                    event_count,
281                    prefix
282                );
283                let discovery_event = match event {
284                    WatchEvent::Put(kv) => {
285                        tracing::debug!(
286                            "KVStoreDiscovery::list_and_watch: Put event, key={}, prefix={}, matches={}",
287                            kv.key_str(),
288                            prefix,
289                            Self::matches_prefix(kv.key_str(), &prefix, bucket_name)
290                        );
291                        // Check if this key matches our prefix
292                        if !Self::matches_prefix(kv.key_str(), &prefix, bucket_name) {
293                            tracing::debug!(
294                                "KVStoreDiscovery::list_and_watch: Skipping key {} (doesn't match prefix {})",
295                                kv.key_str(),
296                                prefix
297                            );
298                            continue;
299                        }
300
301                        match Self::parse_instance(kv.value()) {
302                            Ok(instance) => {
303                                tracing::debug!(
304                                    "KVStoreDiscovery::list_and_watch: Emitting Added event for instance_id={}, key={}",
305                                    instance.instance_id(),
306                                    kv.key_str()
307                                );
308                                Some(DiscoveryEvent::Added(instance))
309                            },
310                            Err(e) => {
311                                tracing::warn!(
312                                    key = %kv.key_str(),
313                                    error = %e,
314                                    "Failed to parse discovery instance from watch event"
315                                );
316                                None
317                            }
318                        }
319                    }
320                    WatchEvent::Delete(kv) => {
321                        let key_str = kv.as_ref();
322                        tracing::debug!(
323                            "KVStoreDiscovery::list_and_watch: Delete event, key={}, prefix={}",
324                            key_str,
325                            prefix
326                        );
327                        // Check if this key matches our prefix
328                        if !Self::matches_prefix(key_str, &prefix, bucket_name) {
329                            tracing::debug!(
330                                "KVStoreDiscovery::list_and_watch: Skipping deleted key {} (doesn't match prefix {})",
331                                key_str,
332                                prefix
333                            );
334                            continue;
335                        }
336
337                        // Extract instance_id from the key path, not the value
338                        // Delete events have empty values in etcd, so we parse the instance_id from the key
339                        // Key format: "v1/instances/namespace/component/endpoint/{instance_id:x}"
340                        let key_parts: Vec<&str> = key_str.split('/').collect();
341                        match key_parts.last() {
342                            Some(instance_id_hex) => {
343                                match u64::from_str_radix(instance_id_hex, 16) {
344                                    Ok(instance_id) => {
345                                        tracing::debug!(
346                                            "KVStoreDiscovery::list_and_watch: Emitting Removed event for instance_id={}, key={}",
347                                            instance_id,
348                                            key_str
349                                        );
350                                        Some(DiscoveryEvent::Removed(instance_id))
351                                    }
352                                    Err(e) => {
353                                        tracing::warn!(
354                                            key = %key_str,
355                                            error = %e,
356                                            "Failed to parse instance_id hex from deleted key"
357                                        );
358                                        None
359                                    }
360                                }
361                            }
362                            None => {
363                                tracing::warn!(
364                                    key = %key_str,
365                                    "Delete event key has no path components"
366                                );
367                                None
368                            }
369                        }
370                    }
371                };
372
373                if let Some(event) = discovery_event {
374                    tracing::debug!("KVStoreDiscovery::list_and_watch: Yielding event: {:?}", event);
375                    yield Ok(event);
376                } else {
377                    tracing::debug!("KVStoreDiscovery::list_and_watch: Event was filtered out (None)");
378                }
379            }
380            tracing::debug!("KVStoreDiscovery::list_and_watch: Stream ended after {} events for prefix={}", event_count, prefix);
381        };
382
383        tracing::debug!(
384            "KVStoreDiscovery::list_and_watch: Returning stream for query={:?}",
385            query
386        );
387        Ok(Box::pin(stream))
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::component::TransportType;
395
396    #[tokio::test]
397    async fn test_kv_store_discovery_register_endpoint() {
398        let store = KeyValueStoreManager::memory();
399        let cancel_token = CancellationToken::new();
400        let client = KVStoreDiscovery::new(store, cancel_token);
401
402        let spec = DiscoverySpec::Endpoint {
403            namespace: "test".to_string(),
404            component: "comp1".to_string(),
405            endpoint: "ep1".to_string(),
406            transport: TransportType::Nats("nats://localhost:4222".to_string()),
407        };
408
409        let instance = client.register(spec).await.unwrap();
410
411        match instance {
412            DiscoveryInstance::Endpoint(inst) => {
413                assert_eq!(inst.namespace, "test");
414                assert_eq!(inst.component, "comp1");
415                assert_eq!(inst.endpoint, "ep1");
416            }
417            _ => panic!("Expected Endpoint instance"),
418        }
419    }
420
421    #[tokio::test]
422    async fn test_kv_store_discovery_list() {
423        let store = KeyValueStoreManager::memory();
424        let cancel_token = CancellationToken::new();
425        let client = KVStoreDiscovery::new(store, cancel_token);
426
427        // Register multiple endpoints
428        let spec1 = DiscoverySpec::Endpoint {
429            namespace: "ns1".to_string(),
430            component: "comp1".to_string(),
431            endpoint: "ep1".to_string(),
432            transport: TransportType::Nats("nats://localhost:4222".to_string()),
433        };
434        client.register(spec1).await.unwrap();
435
436        let spec2 = DiscoverySpec::Endpoint {
437            namespace: "ns1".to_string(),
438            component: "comp1".to_string(),
439            endpoint: "ep2".to_string(),
440            transport: TransportType::Nats("nats://localhost:4222".to_string()),
441        };
442        client.register(spec2).await.unwrap();
443
444        let spec3 = DiscoverySpec::Endpoint {
445            namespace: "ns2".to_string(),
446            component: "comp2".to_string(),
447            endpoint: "ep1".to_string(),
448            transport: TransportType::Nats("nats://localhost:4222".to_string()),
449        };
450        client.register(spec3).await.unwrap();
451
452        // List all endpoints
453        let all = client.list(DiscoveryQuery::AllEndpoints).await.unwrap();
454        assert_eq!(all.len(), 3);
455
456        // List namespaced endpoints
457        let ns1 = client
458            .list(DiscoveryQuery::NamespacedEndpoints {
459                namespace: "ns1".to_string(),
460            })
461            .await
462            .unwrap();
463        assert_eq!(ns1.len(), 2);
464
465        // List component endpoints
466        let comp1 = client
467            .list(DiscoveryQuery::ComponentEndpoints {
468                namespace: "ns1".to_string(),
469                component: "comp1".to_string(),
470            })
471            .await
472            .unwrap();
473        assert_eq!(comp1.len(), 2);
474    }
475
476    #[tokio::test]
477    async fn test_kv_store_discovery_watch() {
478        let store = KeyValueStoreManager::memory();
479        let cancel_token = CancellationToken::new();
480        let client = Arc::new(KVStoreDiscovery::new(store, cancel_token.clone()));
481
482        // Start watching before registering
483        let mut stream = client
484            .list_and_watch(DiscoveryQuery::AllEndpoints, None)
485            .await
486            .unwrap();
487
488        let client_clone = client.clone();
489        let register_task = tokio::spawn(async move {
490            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
491
492            let spec = DiscoverySpec::Endpoint {
493                namespace: "test".to_string(),
494                component: "comp1".to_string(),
495                endpoint: "ep1".to_string(),
496                transport: TransportType::Nats("nats://localhost:4222".to_string()),
497            };
498            client_clone.register(spec).await.unwrap();
499        });
500
501        // Wait for the added event
502        let event = stream.next().await.unwrap().unwrap();
503        match event {
504            DiscoveryEvent::Added(instance) => match instance {
505                DiscoveryInstance::Endpoint(inst) => {
506                    assert_eq!(inst.namespace, "test");
507                    assert_eq!(inst.component, "comp1");
508                    assert_eq!(inst.endpoint, "ep1");
509                }
510                _ => panic!("Expected Endpoint instance"),
511            },
512            _ => panic!("Expected Added event"),
513        }
514
515        register_task.await.unwrap();
516        cancel_token.cancel();
517    }
518}