1use std::pin::Pin;
5use std::sync::Arc;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use futures::{Stream, StreamExt};
10use tokio_util::sync::CancellationToken;
11
12use super::{
13 Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
14 DiscoverySpec, DiscoveryStream, EndpointInstanceId, ModelCardInstanceId,
15};
16use crate::storage::kv;
17
18const INSTANCES_BUCKET: &str = "v1/instances";
19const MODELS_BUCKET: &str = "v1/mdc";
20
21pub struct KVStoreDiscovery {
23 store: Arc<kv::Manager>,
24 cancel_token: CancellationToken,
25}
26
27impl KVStoreDiscovery {
28 pub fn new(store: kv::Manager, cancel_token: CancellationToken) -> Self {
29 Self {
30 store: Arc::new(store),
31 cancel_token,
32 }
33 }
34
35 fn endpoint_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
37 format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
38 }
39
40 fn model_key(namespace: &str, component: &str, endpoint: &str, instance_id: u64) -> String {
42 format!("{}/{}/{}/{:x}", namespace, component, endpoint, instance_id)
43 }
44
45 fn query_prefix(query: &DiscoveryQuery) -> String {
47 match query {
48 DiscoveryQuery::AllEndpoints => INSTANCES_BUCKET.to_string(),
49 DiscoveryQuery::NamespacedEndpoints { namespace } => {
50 format!("{}/{}", INSTANCES_BUCKET, namespace)
51 }
52 DiscoveryQuery::ComponentEndpoints {
53 namespace,
54 component,
55 } => {
56 format!("{}/{}/{}", INSTANCES_BUCKET, namespace, component)
57 }
58 DiscoveryQuery::Endpoint {
59 namespace,
60 component,
61 endpoint,
62 } => {
63 format!(
64 "{}/{}/{}/{}",
65 INSTANCES_BUCKET, namespace, component, endpoint
66 )
67 }
68 DiscoveryQuery::AllModels => MODELS_BUCKET.to_string(),
69 DiscoveryQuery::NamespacedModels { namespace } => {
70 format!("{}/{}", MODELS_BUCKET, namespace)
71 }
72 DiscoveryQuery::ComponentModels {
73 namespace,
74 component,
75 } => {
76 format!("{}/{}/{}", MODELS_BUCKET, namespace, component)
77 }
78 DiscoveryQuery::EndpointModels {
79 namespace,
80 component,
81 endpoint,
82 } => {
83 format!("{}/{}/{}/{}", MODELS_BUCKET, namespace, component, endpoint)
84 }
85 }
86 }
87
88 fn strip_bucket_prefix<'a>(key: &'a str, bucket_name: &str) -> &'a str {
92 if let Some(stripped) = key.strip_prefix(bucket_name) {
94 stripped.strip_prefix('/').unwrap_or(stripped)
96 } else {
97 key
99 }
100 }
101
102 fn matches_prefix(key_str: &str, prefix: &str, bucket_name: &str) -> bool {
105 let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
107 let relative_prefix = Self::strip_bucket_prefix(prefix, bucket_name);
108
109 if relative_prefix.is_empty() {
111 return true;
112 }
113
114 relative_key.starts_with(relative_prefix)
116 }
117
118 fn parse_instance(value: &[u8]) -> Result<DiscoveryInstance> {
120 let instance: DiscoveryInstance = serde_json::from_slice(value)?;
121 Ok(instance)
122 }
123}
124
125#[async_trait]
126impl Discovery for KVStoreDiscovery {
127 fn instance_id(&self) -> u64 {
128 self.store.connection_id()
129 }
130
131 async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
132 let instance_id = self.instance_id();
133 let instance = spec.with_instance_id(instance_id);
134
135 let (bucket_name, key_path) = match &instance {
136 DiscoveryInstance::Endpoint(inst) => {
137 let key = Self::endpoint_key(
138 &inst.namespace,
139 &inst.component,
140 &inst.endpoint,
141 inst.instance_id,
142 );
143 tracing::debug!(
144 "KVStoreDiscovery::register: Registering endpoint instance_id={}, namespace={}, component={}, endpoint={}, key={}",
145 inst.instance_id,
146 inst.namespace,
147 inst.component,
148 inst.endpoint,
149 key
150 );
151 (INSTANCES_BUCKET, key)
152 }
153 DiscoveryInstance::Model {
154 namespace,
155 component,
156 endpoint,
157 instance_id,
158 model_suffix,
159 ..
160 } => {
161 let mut key = Self::model_key(namespace, component, endpoint, *instance_id);
162
163 if let Some(suffix) = model_suffix
166 && !suffix.is_empty()
167 {
168 key = format!("{}/{}", key, suffix);
169 tracing::debug!(
170 "KVStoreDiscovery::register: Registering LoRA model with suffix={}, instance_id={}, namespace={}, component={}, endpoint={}, key={}",
171 suffix,
172 instance_id,
173 namespace,
174 component,
175 endpoint,
176 key
177 );
178 }
179
180 if model_suffix.as_ref().is_none_or(|s| s.is_empty()) {
182 tracing::debug!(
183 "KVStoreDiscovery::register: Registering base model instance_id={}, namespace={}, component={}, endpoint={}, key={}",
184 instance_id,
185 namespace,
186 component,
187 endpoint,
188 key
189 );
190 }
191 (MODELS_BUCKET, key)
192 }
193 };
194
195 let instance_json = serde_json::to_vec(&instance)?;
197 tracing::debug!(
198 "KVStoreDiscovery::register: Serialized instance to {} bytes for key={}",
199 instance_json.len(),
200 key_path
201 );
202
203 tracing::debug!(
205 "KVStoreDiscovery::register: Getting/creating bucket={} for key={}",
206 bucket_name,
207 key_path
208 );
209 let bucket = self.store.get_or_create_bucket(bucket_name, None).await?;
210 let key = kv::Key::new(key_path.clone());
211
212 tracing::debug!(
213 "KVStoreDiscovery::register: Inserting into bucket={}, key={}",
214 bucket_name,
215 key_path
216 );
217 let outcome = bucket.insert(&key, instance_json.into(), 0).await?;
219 tracing::debug!(
220 "KVStoreDiscovery::register: Successfully registered instance_id={}, key={}, outcome={:?}",
221 instance_id,
222 key_path,
223 outcome
224 );
225
226 Ok(instance)
227 }
228
229 async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
230 let (bucket_name, key_path) = match &instance {
231 DiscoveryInstance::Endpoint(inst) => {
232 let key = Self::endpoint_key(
233 &inst.namespace,
234 &inst.component,
235 &inst.endpoint,
236 inst.instance_id,
237 );
238 tracing::debug!(
239 "Unregistering endpoint instance_id={}, namespace={}, component={}, endpoint={}, key={}",
240 inst.instance_id,
241 inst.namespace,
242 inst.component,
243 inst.endpoint,
244 key
245 );
246 (INSTANCES_BUCKET, key)
247 }
248 DiscoveryInstance::Model {
249 namespace,
250 component,
251 endpoint,
252 instance_id,
253 model_suffix,
254 ..
255 } => {
256 let mut key = Self::model_key(namespace, component, endpoint, *instance_id);
257
258 if let Some(suffix) = model_suffix
260 && !suffix.is_empty()
261 {
262 key = format!("{}/{}", key, suffix);
263 tracing::debug!(
264 "KVStoreDiscovery::unregister: Unregistering LoRA model with suffix={}, instance_id={}, namespace={}, component={}, endpoint={}, key={}",
265 suffix,
266 instance_id,
267 namespace,
268 component,
269 endpoint,
270 key
271 );
272 }
273
274 if model_suffix.as_ref().is_none_or(|s| s.is_empty()) {
276 tracing::debug!(
277 "Unregistering base model instance_id={}, namespace={}, component={}, endpoint={}, key={}",
278 instance_id,
279 namespace,
280 component,
281 endpoint,
282 key
283 );
284 }
285 (MODELS_BUCKET, key)
286 }
287 };
288
289 let Some(bucket) = self.store.get_bucket(bucket_name).await? else {
291 tracing::warn!(
292 "Bucket {} does not exist, instance already removed",
293 bucket_name
294 );
295 return Ok(());
296 };
297
298 let key = kv::Key::new(key_path.clone());
299
300 bucket.delete(&key).await?;
302
303 Ok(())
304 }
305
306 async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
307 let prefix = Self::query_prefix(&query);
308 let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
309 INSTANCES_BUCKET
310 } else {
311 MODELS_BUCKET
312 };
313
314 let Some(bucket) = self.store.get_bucket(bucket_name).await? else {
316 return Ok(Vec::new());
317 };
318
319 let entries = bucket.entries().await?;
321
322 let mut instances = Vec::new();
324 for (key, value) in entries {
325 if Self::matches_prefix(key.as_ref(), &prefix, bucket_name) {
326 match Self::parse_instance(&value) {
327 Ok(instance) => instances.push(instance),
328 Err(e) => {
329 tracing::warn!(%key, error = %e, "Failed to parse discovery instance");
330 }
331 }
332 }
333 }
334
335 Ok(instances)
336 }
337
338 async fn list_and_watch(
339 &self,
340 query: DiscoveryQuery,
341 cancel_token: Option<CancellationToken>,
342 ) -> Result<DiscoveryStream> {
343 let prefix = Self::query_prefix(&query);
344 let bucket_name = if prefix.starts_with(INSTANCES_BUCKET) {
345 INSTANCES_BUCKET
346 } else {
347 MODELS_BUCKET
348 };
349
350 tracing::trace!(
351 "KVStoreDiscovery::list_and_watch: Starting watch for query={:?}, prefix={}, bucket={}",
352 query,
353 prefix,
354 bucket_name
355 );
356
357 let cancel_token = cancel_token.unwrap_or_else(|| self.cancel_token.clone());
359
360 let (_, mut rx) = self.store.clone().watch(
362 bucket_name,
363 None, cancel_token,
365 );
366
367 let stream = async_stream::stream! {
369 while let Some(event) = rx.recv().await {
370 let discovery_event = match event {
371 kv::WatchEvent::Put(kv) => {
372 if !Self::matches_prefix(kv.key_str(), &prefix, bucket_name) {
374 continue;
375 }
376
377 match Self::parse_instance(kv.value()) {
378 Ok(instance) => {
379 Some(DiscoveryEvent::Added(instance))
380 },
381 Err(e) => {
382 tracing::warn!(
383 key = %kv.key_str(),
384 error = %e,
385 "Failed to parse discovery instance from watch event"
386 );
387 None
388 }
389 }
390 }
391 kv::WatchEvent::Delete(kv) => {
392 let key_str = kv.as_ref();
393 if !Self::matches_prefix(key_str, &prefix, bucket_name) {
395 continue;
396 }
397
398 let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
408 let key_parts: Vec<&str> = relative_key.split('/').collect();
409
410 if key_parts.len() < 4 {
413 tracing::warn!(
414 key = %key_str,
415 relative_key = %relative_key,
416 actual_parts = key_parts.len(),
417 "Delete event key doesn't have enough parts"
418 );
419 continue;
420 }
421
422 let namespace = key_parts[0].to_string();
423 let component = key_parts[1].to_string();
424 let endpoint = key_parts[2].to_string();
425 let instance_id_hex = key_parts[3];
426
427 match u64::from_str_radix(instance_id_hex, 16) {
428 Ok(instance_id) => {
429 let id = if bucket_name == INSTANCES_BUCKET {
431 DiscoveryInstanceId::Endpoint(EndpointInstanceId {
432 namespace,
433 component,
434 endpoint,
435 instance_id,
436 })
437 } else {
438 let model_suffix = key_parts.get(4).map(|s| s.to_string());
440 DiscoveryInstanceId::Model(ModelCardInstanceId {
441 namespace,
442 component,
443 endpoint,
444 instance_id,
445 model_suffix,
446 })
447 };
448
449 tracing::debug!(
450 "KVStoreDiscovery::list_and_watch: Emitting Removed event for {:?}, key={}",
451 id,
452 key_str
453 );
454 Some(DiscoveryEvent::Removed(id))
455 }
456 Err(e) => {
457 tracing::warn!(
458 key = %key_str,
459 relative_key = %relative_key,
460 error = %e,
461 instance_id_hex = %instance_id_hex,
462 "Failed to parse instance_id hex from deleted key"
463 );
464 None
465 }
466 }
467 }
468 };
469
470 if let Some(event) = discovery_event {
471 yield Ok(event);
472 }
473 }
474 };
475 Ok(Box::pin(stream))
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482 use crate::component::TransportType;
483
484 #[tokio::test]
485 async fn test_kv_store_discovery_register_endpoint() {
486 let store = kv::Manager::memory();
487 let cancel_token = CancellationToken::new();
488 let client = KVStoreDiscovery::new(store, cancel_token);
489
490 let spec = DiscoverySpec::Endpoint {
491 namespace: "test".to_string(),
492 component: "comp1".to_string(),
493 endpoint: "ep1".to_string(),
494 transport: TransportType::Nats("nats://localhost:4222".to_string()),
495 };
496
497 let instance = client.register(spec).await.unwrap();
498
499 match instance {
500 DiscoveryInstance::Endpoint(inst) => {
501 assert_eq!(inst.namespace, "test");
502 assert_eq!(inst.component, "comp1");
503 assert_eq!(inst.endpoint, "ep1");
504 }
505 _ => panic!("Expected Endpoint instance"),
506 }
507 }
508
509 #[tokio::test]
510 async fn test_kv_store_discovery_list() {
511 let store = kv::Manager::memory();
512 let cancel_token = CancellationToken::new();
513 let client = KVStoreDiscovery::new(store, cancel_token);
514
515 let spec1 = DiscoverySpec::Endpoint {
517 namespace: "ns1".to_string(),
518 component: "comp1".to_string(),
519 endpoint: "ep1".to_string(),
520 transport: TransportType::Nats("nats://localhost:4222".to_string()),
521 };
522 client.register(spec1).await.unwrap();
523
524 let spec2 = DiscoverySpec::Endpoint {
525 namespace: "ns1".to_string(),
526 component: "comp1".to_string(),
527 endpoint: "ep2".to_string(),
528 transport: TransportType::Nats("nats://localhost:4222".to_string()),
529 };
530 client.register(spec2).await.unwrap();
531
532 let spec3 = DiscoverySpec::Endpoint {
533 namespace: "ns2".to_string(),
534 component: "comp2".to_string(),
535 endpoint: "ep1".to_string(),
536 transport: TransportType::Nats("nats://localhost:4222".to_string()),
537 };
538 client.register(spec3).await.unwrap();
539
540 let all = client.list(DiscoveryQuery::AllEndpoints).await.unwrap();
542 assert_eq!(all.len(), 3);
543
544 let ns1 = client
546 .list(DiscoveryQuery::NamespacedEndpoints {
547 namespace: "ns1".to_string(),
548 })
549 .await
550 .unwrap();
551 assert_eq!(ns1.len(), 2);
552
553 let comp1 = client
555 .list(DiscoveryQuery::ComponentEndpoints {
556 namespace: "ns1".to_string(),
557 component: "comp1".to_string(),
558 })
559 .await
560 .unwrap();
561 assert_eq!(comp1.len(), 2);
562 }
563
564 #[tokio::test]
565 async fn test_kv_store_discovery_watch() {
566 let store = kv::Manager::memory();
567 let cancel_token = CancellationToken::new();
568 let client = Arc::new(KVStoreDiscovery::new(store, cancel_token.clone()));
569
570 let mut stream = client
572 .list_and_watch(DiscoveryQuery::AllEndpoints, None)
573 .await
574 .unwrap();
575
576 let client_clone = client.clone();
577 let register_task = tokio::spawn(async move {
578 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
579
580 let spec = DiscoverySpec::Endpoint {
581 namespace: "test".to_string(),
582 component: "comp1".to_string(),
583 endpoint: "ep1".to_string(),
584 transport: TransportType::Nats("nats://localhost:4222".to_string()),
585 };
586 client_clone.register(spec).await.unwrap();
587 });
588
589 let event = stream.next().await.unwrap().unwrap();
591 match event {
592 DiscoveryEvent::Added(instance) => match instance {
593 DiscoveryInstance::Endpoint(inst) => {
594 assert_eq!(inst.namespace, "test");
595 assert_eq!(inst.component, "comp1");
596 assert_eq!(inst.endpoint, "ep1");
597 }
598 _ => panic!("Expected Endpoint instance"),
599 },
600 _ => panic!("Expected Added event"),
601 }
602
603 register_task.await.unwrap();
604 cancel_token.cancel();
605 }
606}