Skip to main content

dynamo_runtime/discovery/
mod.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use futures::Stream;
7use serde::{Deserialize, Serialize};
8use std::pin::Pin;
9use tokio_util::sync::CancellationToken;
10
11mod metadata;
12pub use metadata::{DiscoveryMetadata, MetadataSnapshot};
13
14mod mock;
15pub use mock::{MockDiscovery, SharedMockRegistry};
16mod kv_store;
17pub use kv_store::KVStoreDiscovery;
18
19mod kube;
20pub use kube::{KubeDiscoveryClient, hash_pod_name};
21
22pub mod utils;
23use crate::component::{DeviceType, TransportType};
24pub use utils::watch_and_extract_field;
25
26/// Transport kind for event plane - used for configuration and env var selection.
27///
28/// This enum represents the *type* of transport without connection details.
29/// Use `EventTransport` when you need the full transport configuration.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
31#[serde(rename_all = "snake_case")]
32pub enum EventTransportKind {
33    /// NATS Core pub/sub
34    #[default]
35    Nats,
36    /// ZMQ pub/sub
37    Zmq,
38}
39
40impl EventTransportKind {
41    /// Parse from environment variable `DYN_EVENT_PLANE`.
42    ///
43    /// Returns `Nats` if the variable is not set or is empty, which is the correct
44    /// default for distributed deployments (etcd/kubernetes backends). For local-only
45    /// workflows (`--discovery-backend file` or `mem`) this context-unaware default
46    /// may be incorrect — prefer [`DistributedRuntime::default_event_transport_kind`]
47    /// when you have access to a runtime, as it derives the correct default from the
48    /// configured discovery backend.
49    ///
50    /// Returns an error for unrecognised values.
51    pub fn from_env() -> Result<Self> {
52        match std::env::var(crate::config::environment_names::event_plane::DYN_EVENT_PLANE)
53            .as_deref()
54        {
55            Ok("nats") | Ok("") | Err(_) => Ok(Self::Nats),
56            Ok("zmq") => Ok(Self::Zmq),
57            Ok(other) => anyhow::bail!(
58                "Invalid DYN_EVENT_PLANE value '{}'. Valid values: 'nats', 'zmq'",
59                other
60            ),
61        }
62    }
63
64    /// Parse from environment variable, defaulting to NATS when the variable is unset.
65    ///
66    /// This default is suitable for distributed deployments. For local-only workflows
67    /// prefer [`DistributedRuntime::default_event_transport_kind`], which automatically
68    /// selects ZMQ when running with a `file` or `mem` discovery backend.
69    ///
70    /// Logs a warning if an invalid value is encountered.
71    pub fn from_env_or_default() -> Self {
72        Self::from_env().unwrap_or_else(|e| {
73            tracing::warn!("{e}, defaulting to NATS");
74            Self::Nats
75        })
76    }
77
78    /// Get the default codec for this transport kind.
79    /// NATS defaults to JSON, ZMQ defaults to MsgPack.
80    pub fn default_codec(&self) -> EventCodecKind {
81        match self {
82            Self::Nats => EventCodecKind::Json,
83            Self::Zmq => EventCodecKind::Msgpack,
84        }
85    }
86}
87
88/// Codec kind for event plane serialization.
89///
90/// This enum represents the serialization format for event envelopes and payloads.
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
92#[serde(rename_all = "snake_case")]
93pub enum EventCodecKind {
94    /// JSON codec - human-readable, good for debugging
95    Json,
96    /// MessagePack codec - compact binary format
97    Msgpack,
98}
99
100impl EventCodecKind {
101    /// Parse from environment variable `DYN_EVENT_PLANE_CODEC`.
102    /// Returns None if not set, allowing transport to select default.
103    /// Returns error for invalid values.
104    pub fn from_env() -> Result<Option<Self>> {
105        match std::env::var(crate::config::environment_names::event_plane::DYN_EVENT_PLANE_CODEC)
106            .as_deref()
107        {
108            Err(_) => Ok(None), // Not set
109            Ok("") => Ok(None), // Empty
110            Ok("json") => Ok(Some(Self::Json)),
111            Ok("msgpack") => Ok(Some(Self::Msgpack)),
112            Ok(other) => anyhow::bail!(
113                "Invalid DYN_EVENT_PLANE_CODEC value '{}'. Valid values: 'json', 'msgpack'",
114                other
115            ),
116        }
117    }
118
119    /// Parse from environment variable with transport-specific default.
120    /// Logs a warning if an invalid value is encountered.
121    pub fn from_env_or_transport_default(transport: EventTransportKind) -> Self {
122        Self::from_env()
123            .unwrap_or_else(|e| {
124                tracing::warn!(
125                    "{}, defaulting to {:?} for {:?}",
126                    e,
127                    transport.default_codec(),
128                    transport
129                );
130                None
131            })
132            .unwrap_or_else(|| transport.default_codec())
133    }
134}
135
136/// Transport configuration for event plane channels.
137///
138/// This enum carries both the transport kind and its connection configuration.
139/// Kept separate from `TransportType` (request plane) to distinguish event semantics.
140#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
141#[serde(tag = "kind", content = "config")]
142pub enum EventTransport {
143    /// NATS Core pub/sub - subject prefix for the channel
144    Nats {
145        /// Subject prefix (e.g., "namespace.dynamo.component.backend")
146        subject_prefix: String,
147    },
148    /// ZMQ pub/sub - endpoint address (direct mode)
149    Zmq {
150        /// ZMQ endpoint (e.g., "tcp://host:port")
151        endpoint: String,
152    },
153    /// ZMQ broker endpoints (broker mode) - for discovery of brokers
154    ZmqBroker {
155        /// XSUB endpoints (publishers connect here)
156        xsub_endpoints: Vec<String>,
157        /// XPUB endpoints (subscribers connect here)
158        xpub_endpoints: Vec<String>,
159    },
160}
161
162impl EventTransport {
163    /// Get the transport kind
164    pub fn kind(&self) -> EventTransportKind {
165        match self {
166            Self::Nats { .. } => EventTransportKind::Nats,
167            Self::Zmq { .. } | Self::ZmqBroker { .. } => EventTransportKind::Zmq,
168        }
169    }
170
171    /// Create a NATS transport with the given subject prefix
172    pub fn nats(subject_prefix: impl Into<String>) -> Self {
173        Self::Nats {
174            subject_prefix: subject_prefix.into(),
175        }
176    }
177
178    /// Create a ZMQ transport with the given endpoint
179    pub fn zmq(endpoint: impl Into<String>) -> Self {
180        Self::Zmq {
181            endpoint: endpoint.into(),
182        }
183    }
184
185    /// Get the subject prefix (NATS) or endpoint (ZMQ)
186    /// For ZmqBroker, returns the first XSUB endpoint
187    pub fn address(&self) -> &str {
188        match self {
189            Self::Nats { subject_prefix } => subject_prefix,
190            Self::Zmq { endpoint } => endpoint,
191            Self::ZmqBroker { xsub_endpoints, .. } => {
192                xsub_endpoints.first().map(|s| s.as_str()).unwrap_or("")
193            }
194        }
195    }
196}
197
198/// Query key for prefix-based discovery queries
199/// Supports hierarchical queries from all endpoints down to specific endpoints
200#[derive(Debug, Clone, PartialEq, Eq, Hash)]
201pub enum DiscoveryQuery {
202    /// Query all endpoints in the system
203    AllEndpoints,
204    /// Query all endpoints in a specific namespace
205    NamespacedEndpoints {
206        namespace: String,
207    },
208    /// Query all endpoints in a namespace/component
209    ComponentEndpoints {
210        namespace: String,
211        component: String,
212    },
213    /// Query a specific endpoint
214    Endpoint {
215        namespace: String,
216        component: String,
217        endpoint: String,
218    },
219    AllModels,
220    NamespacedModels {
221        namespace: String,
222    },
223    ComponentModels {
224        namespace: String,
225        component: String,
226    },
227    EndpointModels {
228        namespace: String,
229        component: String,
230        endpoint: String,
231    },
232    /// Unified event channel query with optional scope filters
233    EventChannels(EventChannelQuery),
234}
235
236/// Unified query for event channels with optional scope filters
237#[derive(Debug, Clone, PartialEq, Eq, Hash)]
238pub struct EventChannelQuery {
239    /// Optional namespace filter
240    pub namespace: Option<String>,
241    /// Optional component filter (requires namespace to be meaningful)
242    pub component: Option<String>,
243    /// Optional topic filter (requires namespace and component to be meaningful)
244    pub topic: Option<String>,
245}
246
247impl EventChannelQuery {
248    /// Query all event channels (no filters)
249    pub fn all() -> Self {
250        Self {
251            namespace: None,
252            component: None,
253            topic: None,
254        }
255    }
256
257    /// Query event channels in a specific namespace
258    pub fn namespace(namespace: impl Into<String>) -> Self {
259        Self {
260            namespace: Some(namespace.into()),
261            component: None,
262            topic: None,
263        }
264    }
265
266    /// Query event channels for a specific component
267    pub fn component(namespace: impl Into<String>, component: impl Into<String>) -> Self {
268        Self {
269            namespace: Some(namespace.into()),
270            component: Some(component.into()),
271            topic: None,
272        }
273    }
274
275    /// Query event channels for a specific topic
276    pub fn topic(
277        namespace: impl Into<String>,
278        component: impl Into<String>,
279        topic: impl Into<String>,
280    ) -> Self {
281        Self {
282            namespace: Some(namespace.into()),
283            component: Some(component.into()),
284            topic: Some(topic.into()),
285        }
286    }
287
288    /// Get the scope level (0=all, 1=namespace, 2=component, 3=topic)
289    pub fn scope_level(&self) -> u8 {
290        if self.topic.is_some() {
291            3
292        } else if self.component.is_some() {
293            2
294        } else if self.namespace.is_some() {
295            1
296        } else {
297            0
298        }
299    }
300}
301
302/// Specification for registering objects in the discovery plane
303/// Represents the input to the register() operation
304#[derive(Debug, Clone, PartialEq, Eq)]
305pub enum DiscoverySpec {
306    /// Endpoint specification for registration
307    Endpoint {
308        namespace: String,
309        component: String,
310        endpoint: String,
311        /// Transport type and routing information
312        transport: TransportType,
313        /// Optional execution device for this endpoint instance.
314        /// Used by hetero routing to distinguish CPU and CUDA workers.
315        device_type: Option<DeviceType>,
316    },
317    Model {
318        namespace: String,
319        component: String,
320        endpoint: String,
321        /// ModelDeploymentCard serialized as JSON
322        /// This allows lib/runtime to remain independent of lib/llm types
323        /// DiscoverySpec.from_model() and DiscoveryInstance.deserialize_model() are ergonomic helpers to create and deserialize the model card.
324        card_json: serde_json::Value,
325        /// Optional suffix appended after instance_id in the key path (e.g., for LoRA adapters)
326        /// Key format: {namespace}/{component}/{endpoint}/{instance_id}[/{model_suffix}]
327        model_suffix: Option<String>,
328    },
329    /// Event plane channel specification
330    /// Used for registering event publishers/subscribers for discovery
331    EventChannel {
332        namespace: String,
333        component: String,
334        /// Topic name for this channel (e.g., "kv-events", "kv-metrics")
335        topic: String,
336        /// Event transport type (NATS subject prefix or ZMQ endpoint)
337        transport: EventTransport,
338    },
339}
340
341impl DiscoverySpec {
342    /// Creates a Model discovery spec from a serializable type
343    /// The card will be serialized to JSON to avoid cross-crate dependencies
344    pub fn from_model<T>(
345        namespace: String,
346        component: String,
347        endpoint: String,
348        card: &T,
349    ) -> Result<Self>
350    where
351        T: Serialize,
352    {
353        Self::from_model_with_suffix(namespace, component, endpoint, card, None)
354    }
355
356    /// Creates a Model discovery spec with an optional suffix (e.g., for LoRA adapters)
357    /// The suffix is appended after the instance_id in the key path
358    pub fn from_model_with_suffix<T>(
359        namespace: String,
360        component: String,
361        endpoint: String,
362        card: &T,
363        model_suffix: Option<String>,
364    ) -> Result<Self>
365    where
366        T: Serialize,
367    {
368        let card_json = serde_json::to_value(card)?;
369        Ok(Self::Model {
370            namespace,
371            component,
372            endpoint,
373            card_json,
374            model_suffix,
375        })
376    }
377
378    /// Attaches an instance ID to create a DiscoveryInstance
379    pub fn with_instance_id(self, instance_id: u64) -> DiscoveryInstance {
380        match self {
381            Self::Endpoint {
382                namespace,
383                component,
384                endpoint,
385                transport,
386                device_type,
387            } => DiscoveryInstance::Endpoint(crate::component::Instance {
388                namespace,
389                component,
390                endpoint,
391                instance_id,
392                transport,
393                device_type,
394            }),
395            Self::Model {
396                namespace,
397                component,
398                endpoint,
399                card_json,
400                model_suffix,
401            } => DiscoveryInstance::Model {
402                namespace,
403                component,
404                endpoint,
405                instance_id,
406                card_json,
407                model_suffix,
408            },
409            Self::EventChannel {
410                namespace,
411                component,
412                topic,
413                transport,
414            } => DiscoveryInstance::EventChannel {
415                namespace,
416                component,
417                topic,
418                instance_id,
419                transport,
420            },
421        }
422    }
423}
424
425/// Registered instances in the discovery plane
426/// Represents objects that have been successfully registered with an instance ID
427#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
428#[serde(tag = "type")]
429pub enum DiscoveryInstance {
430    /// Registered endpoint instance - wraps the component::Instance directly
431    Endpoint(crate::component::Instance),
432    Model {
433        namespace: String,
434        component: String,
435        endpoint: String,
436        instance_id: u64,
437        /// ModelDeploymentCard serialized as JSON
438        /// This allows lib/runtime to remain independent of lib/llm types
439        card_json: serde_json::Value,
440        /// Optional suffix appended after instance_id in the key path (e.g., for LoRA adapters)
441        #[serde(default, skip_serializing_if = "Option::is_none")]
442        model_suffix: Option<String>,
443    },
444    /// Registered event channel instance for event plane pub/sub
445    EventChannel {
446        namespace: String,
447        component: String,
448        /// Topic name for this channel (e.g., "kv-events", "kv-metrics")
449        topic: String,
450        instance_id: u64,
451        /// Event transport type (NATS subject prefix or ZMQ endpoint)
452        transport: EventTransport,
453    },
454}
455
456impl DiscoveryInstance {
457    /// Returns the instance ID for this discovery instance
458    pub fn instance_id(&self) -> u64 {
459        match self {
460            Self::Endpoint(inst) => inst.instance_id,
461            Self::Model { instance_id, .. } => *instance_id,
462            Self::EventChannel { instance_id, .. } => *instance_id,
463        }
464    }
465
466    /// Deserializes the model JSON into the specified type T
467    /// Returns an error if this is not a Model instance or if deserialization fails
468    pub fn deserialize_model<T>(&self) -> Result<T>
469    where
470        T: for<'de> Deserialize<'de>,
471    {
472        match self {
473            Self::Model { card_json, .. } => Ok(serde_json::from_value(card_json.clone())?),
474            Self::Endpoint(_) => {
475                anyhow::bail!("Cannot deserialize model from Endpoint instance")
476            }
477            Self::EventChannel { .. } => {
478                anyhow::bail!("Cannot deserialize model from EventChannel instance")
479            }
480        }
481    }
482
483    /// Extracts the unique identifier for this discovery instance
484    /// Used for tracking, diffing, and removal events
485    pub fn id(&self) -> DiscoveryInstanceId {
486        match self {
487            Self::Endpoint(inst) => DiscoveryInstanceId::Endpoint(EndpointInstanceId {
488                namespace: inst.namespace.clone(),
489                component: inst.component.clone(),
490                endpoint: inst.endpoint.clone(),
491                instance_id: inst.instance_id,
492            }),
493            Self::Model {
494                namespace,
495                component,
496                endpoint,
497                instance_id,
498                model_suffix,
499                ..
500            } => DiscoveryInstanceId::Model(ModelCardInstanceId {
501                namespace: namespace.clone(),
502                component: component.clone(),
503                endpoint: endpoint.clone(),
504                instance_id: *instance_id,
505                model_suffix: model_suffix.clone(),
506            }),
507            Self::EventChannel {
508                namespace,
509                component,
510                topic,
511                instance_id,
512                ..
513            } => DiscoveryInstanceId::EventChannel(EventChannelInstanceId {
514                namespace: namespace.clone(),
515                component: component.clone(),
516                topic: topic.clone(),
517                instance_id: *instance_id,
518            }),
519        }
520    }
521}
522
523/// Unique identifier for an endpoint instance
524#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
525pub struct EndpointInstanceId {
526    pub namespace: String,
527    pub component: String,
528    pub endpoint: String,
529    pub instance_id: u64,
530}
531
532impl EndpointInstanceId {
533    /// Converts to a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}`
534    pub fn to_path(&self) -> String {
535        format!(
536            "{}/{}/{}/{:x}",
537            self.namespace, self.component, self.endpoint, self.instance_id
538        )
539    }
540
541    /// Parses from a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}`
542    pub fn from_path(path: &str) -> Result<Self> {
543        let parts: Vec<&str> = path.split('/').collect();
544        if parts.len() != 4 {
545            anyhow::bail!(
546                "Invalid EndpointInstanceId path: expected 4 parts, got {}",
547                parts.len()
548            );
549        }
550        Ok(Self {
551            namespace: parts[0].to_string(),
552            component: parts[1].to_string(),
553            endpoint: parts[2].to_string(),
554            instance_id: u64::from_str_radix(parts[3], 16)
555                .map_err(|e| anyhow::anyhow!("Invalid instance_id hex: {}", e))?,
556        })
557    }
558}
559
560/// Unique identifier for a model card instance
561/// The combination of (namespace, component, endpoint, instance_id, model_suffix) uniquely identifies a model card
562#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
563pub struct ModelCardInstanceId {
564    pub namespace: String,
565    pub component: String,
566    pub endpoint: String,
567    pub instance_id: u64,
568    /// None for base models, Some(slug) for LoRA adapters
569    pub model_suffix: Option<String>,
570}
571
572/// Unique identifier for an event channel instance
573#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
574pub struct EventChannelInstanceId {
575    pub namespace: String,
576    pub component: String,
577    /// Topic name for this channel (e.g., "kv-events", "kv-metrics")
578    pub topic: String,
579    pub instance_id: u64,
580}
581
582impl EventChannelInstanceId {
583    /// Converts to a path string: `{namespace}/{component}/{topic}/{instance_id:x}`
584    pub fn to_path(&self) -> String {
585        format!(
586            "{}/{}/{}/{:x}",
587            self.namespace, self.component, self.topic, self.instance_id
588        )
589    }
590
591    /// Parses from a path string: `{namespace}/{component}/{topic}/{instance_id:x}`
592    pub fn from_path(path: &str) -> Result<Self> {
593        let parts: Vec<&str> = path.split('/').collect();
594        if parts.len() != 4 {
595            anyhow::bail!(
596                "Invalid EventChannelInstanceId path: expected 4 parts, got {}",
597                parts.len()
598            );
599        }
600        Ok(Self {
601            namespace: parts[0].to_string(),
602            component: parts[1].to_string(),
603            topic: parts[2].to_string(),
604            instance_id: u64::from_str_radix(parts[3], 16)
605                .map_err(|e| anyhow::anyhow!("Invalid instance_id hex: {}", e))?,
606        })
607    }
608}
609
610impl ModelCardInstanceId {
611    /// Converts to a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}[/{model_suffix}]`
612    pub fn to_path(&self) -> String {
613        match &self.model_suffix {
614            Some(suffix) => format!(
615                "{}/{}/{}/{:x}/{}",
616                self.namespace, self.component, self.endpoint, self.instance_id, suffix
617            ),
618            None => format!(
619                "{}/{}/{}/{:x}",
620                self.namespace, self.component, self.endpoint, self.instance_id
621            ),
622        }
623    }
624
625    /// Parses from a path string: `{namespace}/{component}/{endpoint}/{instance_id:x}[/{model_suffix}]`
626    pub fn from_path(path: &str) -> Result<Self> {
627        let parts: Vec<&str> = path.split('/').collect();
628        if parts.len() < 4 || parts.len() > 5 {
629            anyhow::bail!(
630                "Invalid ModelCardInstanceId path: expected 4 or 5 parts, got {}",
631                parts.len()
632            );
633        }
634        Ok(Self {
635            namespace: parts[0].to_string(),
636            component: parts[1].to_string(),
637            endpoint: parts[2].to_string(),
638            instance_id: u64::from_str_radix(parts[3], 16)
639                .map_err(|e| anyhow::anyhow!("Invalid instance_id hex: {}", e))?,
640            model_suffix: parts.get(4).map(|s| s.to_string()),
641        })
642    }
643}
644
645/// Union of instance identifiers for different discovery object types
646#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
647pub enum DiscoveryInstanceId {
648    Endpoint(EndpointInstanceId),
649    Model(ModelCardInstanceId),
650    EventChannel(EventChannelInstanceId),
651}
652
653impl DiscoveryInstanceId {
654    /// Returns the raw instance_id regardless of variant type
655    pub fn instance_id(&self) -> u64 {
656        match self {
657            Self::Endpoint(eid) => eid.instance_id,
658            Self::Model(mid) => mid.instance_id,
659            Self::EventChannel(ecid) => ecid.instance_id,
660        }
661    }
662
663    /// Extracts the EndpointInstanceId, returning an error if this is a Model or EventChannel variant
664    pub fn extract_endpoint_id(&self) -> Result<&EndpointInstanceId> {
665        match self {
666            Self::Endpoint(eid) => Ok(eid),
667            Self::Model(_) => anyhow::bail!("Expected Endpoint variant, got Model"),
668            Self::EventChannel(_) => anyhow::bail!("Expected Endpoint variant, got EventChannel"),
669        }
670    }
671
672    /// Extracts the ModelCardInstanceId, returning an error if this is an Endpoint or EventChannel variant
673    pub fn extract_model_id(&self) -> Result<&ModelCardInstanceId> {
674        match self {
675            Self::Model(mid) => Ok(mid),
676            Self::Endpoint(_) => anyhow::bail!("Expected Model variant, got Endpoint"),
677            Self::EventChannel(_) => anyhow::bail!("Expected Model variant, got EventChannel"),
678        }
679    }
680
681    /// Extracts the EventChannelInstanceId, returning an error if this is an Endpoint or Model variant
682    pub fn extract_event_channel_id(&self) -> Result<&EventChannelInstanceId> {
683        match self {
684            Self::EventChannel(ecid) => Ok(ecid),
685            Self::Endpoint(_) => anyhow::bail!("Expected EventChannel variant, got Endpoint"),
686            Self::Model(_) => anyhow::bail!("Expected EventChannel variant, got Model"),
687        }
688    }
689}
690
691/// Events emitted by the discovery watch stream
692#[derive(Debug, Clone, PartialEq, Eq)]
693pub enum DiscoveryEvent {
694    /// A new instance was added
695    Added(DiscoveryInstance),
696    /// An instance was removed (identified by its unique ID)
697    Removed(DiscoveryInstanceId),
698}
699
700/// Stream type for discovery events
701pub type DiscoveryStream = Pin<Box<dyn Stream<Item = Result<DiscoveryEvent>> + Send>>;
702
703#[derive(Clone, Debug, PartialEq, Eq)]
704struct ModelRegistrationIdentity {
705    display_name: String,
706    source_path: Option<String>,
707    is_lora: bool,
708}
709
710impl ModelRegistrationIdentity {
711    fn base_identity(&self) -> &str {
712        self.source_path.as_deref().unwrap_or(&self.display_name)
713    }
714
715    fn is_compatible_with(&self, other: &Self) -> bool {
716        if self.is_lora || other.is_lora {
717            self.base_identity() == other.base_identity()
718        } else {
719            self.display_name == other.display_name
720        }
721    }
722}
723
724fn extract_model_registration_identity(
725    card_json: &serde_json::Value,
726    model_suffix: Option<&str>,
727) -> Result<ModelRegistrationIdentity> {
728    let display_name = card_json
729        .get("display_name")
730        .and_then(serde_json::Value::as_str)
731        .map(str::to_owned)
732        .ok_or_else(|| {
733            anyhow::anyhow!("failed to deserialize model display_name from card_json")
734        })?;
735    let source_path = card_json
736        .get("source_path")
737        .and_then(serde_json::Value::as_str)
738        .map(str::to_owned);
739    let is_lora =
740        model_suffix.is_some() || card_json.get("lora").is_some_and(|value| !value.is_null());
741
742    Ok(ModelRegistrationIdentity {
743        display_name,
744        source_path,
745        is_lora,
746    })
747}
748
749fn find_conflicting_model_name(
750    instances: &[DiscoveryInstance],
751    requested_identity: &ModelRegistrationIdentity,
752) -> Result<Option<String>> {
753    for instance in instances {
754        if let DiscoveryInstance::Model {
755            card_json,
756            model_suffix,
757            ..
758        } = instance
759        {
760            let existing_identity =
761                extract_model_registration_identity(card_json, model_suffix.as_deref())?;
762            if !requested_identity.is_compatible_with(&existing_identity) {
763                return Ok(Some(existing_identity.display_name));
764            }
765        }
766    }
767
768    Ok(None)
769}
770
771/// Discovery trait for service discovery across different backends
772#[async_trait]
773pub trait Discovery: Send + Sync {
774    /// Returns a unique identifier for this worker (e.g lease id if using etcd or generated id for memory store)
775    /// Discovery objects created by this worker will be associated with this id.
776    fn instance_id(&self) -> u64;
777
778    /// Registers an object in the discovery plane with the instance id
779    async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
780        let (namespace, component, endpoint, requested_identity) = match &spec {
781            DiscoverySpec::Model {
782                namespace,
783                component,
784                endpoint,
785                card_json,
786                model_suffix,
787                ..
788            } => (
789                namespace.clone(),
790                component.clone(),
791                endpoint.clone(),
792                extract_model_registration_identity(card_json, model_suffix.as_deref())?,
793            ),
794            _ => return self.register_internal(spec).await,
795        };
796
797        let query = DiscoveryQuery::EndpointModels {
798            namespace: namespace.clone(),
799            component: component.clone(),
800            endpoint: endpoint.clone(),
801        };
802
803        if let Some(conflicting_name) =
804            find_conflicting_model_name(&self.list(query.clone()).await?, &requested_identity)?
805        {
806            let requested_name = &requested_identity.display_name;
807            anyhow::bail!(
808                "Cannot register model '{requested_name}' on endpoint '{namespace}/{component}/{endpoint}': a different model '{conflicting_name}' is already registered there"
809            );
810        }
811
812        let instance = self.register_internal(spec).await?;
813
814        if let Some(conflicting_name) =
815            find_conflicting_model_name(&self.list(query).await?, &requested_identity)?
816        {
817            let requested_name = &requested_identity.display_name;
818            if let Err(unregister_err) = self.unregister(instance.clone()).await {
819                return Err(anyhow::anyhow!(
820                    "Cannot register model '{requested_name}' on endpoint '{namespace}/{component}/{endpoint}': a different model '{conflicting_name}' is already registered there"
821                ))
822                .context(format!(
823                    "failed to roll back conflicting model registration for instance {instance_id}: {unregister_err}",
824                    instance_id = instance.instance_id()
825                ));
826            }
827
828            anyhow::bail!(
829                "Cannot register model '{requested_name}' on endpoint '{namespace}/{component}/{endpoint}': a different model '{conflicting_name}' is already registered there"
830            );
831        }
832
833        Ok(instance)
834    }
835
836    /// Backend-specific raw registration implementation.
837    async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>;
838
839    /// Unregisters an instance from the discovery plane
840    async fn unregister(&self, instance: DiscoveryInstance) -> Result<()>;
841
842    /// Returns a list of currently registered instances for the given discovery query
843    /// This is a one-time snapshot without watching for changes
844    async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>>;
845
846    /// Returns a stream of discovery events (Added/Removed) for the given discovery query
847    /// The optional cancellation token can be used to stop the watch stream
848    async fn list_and_watch(
849        &self,
850        query: DiscoveryQuery,
851        cancel_token: Option<CancellationToken>,
852    ) -> Result<DiscoveryStream>;
853
854    /// Clean up resources held by this discovery backend.
855    /// For KV store backends, this deletes owned registrations immediately rather than
856    /// waiting for TTL expiry. Default is a no-op for backends that don't need cleanup.
857    fn shutdown(&self) {}
858}