1use std::fmt;
33
34use crate::{
35 config::HealthStatus,
36 distributed::RequestPlaneMode,
37 metrics::{MetricsHierarchy, MetricsRegistry, prometheus_names},
38 service::ServiceClient,
39 service::ServiceSet,
40};
41
42use super::{DistributedRuntime, Runtime, traits::*, transports::nats::Slug, utils::Duration};
43
44use crate::pipeline::network::{PushWorkHandler, ingress::push_endpoint::PushEndpoint};
45use crate::protocols::EndpointId;
46use async_nats::{
47 rustls::quic,
48 service::{Service, ServiceExt},
49};
50use dashmap::DashMap;
51use derive_builder::Builder;
52use derive_getters::Getters;
53use educe::Educe;
54use serde::{Deserialize, Serialize};
55use std::{collections::HashMap, hash::Hash, sync::Arc};
56use validator::{Validate, ValidationError};
57
58mod client;
59#[allow(clippy::module_inception)]
60mod component;
61mod endpoint;
62mod namespace;
63mod registry;
64pub mod service;
65
66pub(crate) use client::EndpointDiscoverySource;
67pub(crate) use client::RoutingInstances;
68pub(crate) use client::RoutingOccupancyState;
69pub(crate) use client::get_or_create_routing_occupancy_state;
70pub use client::{Client, RoutingInstanceCounts};
71pub use endpoint::build_transport_type;
72
73#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
74#[serde(rename_all = "snake_case")]
75pub enum TransportType {
76 #[serde(rename = "nats_tcp")]
77 Nats(String),
78 Tcp(String),
79}
80
81impl TransportType {
82 pub fn address(&self) -> &str {
83 match self {
84 TransportType::Nats(address) | TransportType::Tcp(address) => address,
85 }
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
90#[serde(rename_all = "snake_case")]
91pub enum DeviceType {
92 Cpu,
93 Cuda,
94}
95
96#[derive(Default)]
97pub struct RegistryInner {
98 pub(crate) services: HashMap<String, Service>,
99}
100
101#[derive(Clone)]
102pub struct Registry {
103 pub(crate) inner: Arc<tokio::sync::Mutex<RegistryInner>>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
107pub struct Instance {
108 pub component: String,
109 pub endpoint: String,
110 pub namespace: String,
111 pub instance_id: u64,
112 pub transport: TransportType,
113 #[serde(default, skip_serializing_if = "Option::is_none")]
114 pub device_type: Option<DeviceType>,
115}
116
117impl Instance {
118 pub fn id(&self) -> u64 {
119 self.instance_id
120 }
121
122 pub fn endpoint_id(&self) -> EndpointId {
123 EndpointId {
124 namespace: self.namespace.clone(),
125 component: self.component.clone(),
126 name: self.endpoint.clone(),
127 }
128 }
129
130 pub fn endpoint_instance_id(&self) -> crate::discovery::EndpointInstanceId {
131 crate::discovery::EndpointInstanceId {
132 namespace: self.namespace.clone(),
133 component: self.component.clone(),
134 endpoint: self.endpoint.clone(),
135 instance_id: self.instance_id,
136 }
137 }
138}
139
140impl fmt::Display for Instance {
141 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142 write!(
143 f,
144 "{}/{}/{}/{}",
145 self.namespace, self.component, self.endpoint, self.instance_id
146 )
147 }
148}
149
150impl std::cmp::Ord for Instance {
152 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
153 self.to_string().cmp(&other.to_string())
154 }
155}
156
157impl PartialOrd for Instance {
158 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
159 Some(self.cmp(other))
161 }
162}
163
164#[derive(Educe, Builder, Clone, Validate)]
170#[educe(Debug)]
171#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
172pub struct Component {
173 #[builder(private)]
174 #[educe(Debug(ignore))]
175 drt: Arc<DistributedRuntime>,
176
177 #[builder(setter(into))]
179 #[validate(custom(function = "validate_allowed_chars"))]
180 name: String,
181
182 #[builder(default = "Vec::new()")]
184 labels: Vec<(String, String)>,
185
186 #[builder(setter(into))]
189 namespace: Namespace,
190
191 #[builder(default = "crate::MetricsRegistry::new()")]
193 metrics_registry: crate::MetricsRegistry,
194}
195
196impl Hash for Component {
197 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
198 self.namespace.name().hash(state);
199 self.name.hash(state);
200 }
201}
202
203impl PartialEq for Component {
204 fn eq(&self, other: &Self) -> bool {
205 self.namespace.name() == other.namespace.name() && self.name == other.name
206 }
207}
208
209impl Eq for Component {}
210
211impl std::fmt::Display for Component {
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 write!(f, "{}.{}", self.namespace.name(), self.name)
214 }
215}
216
217impl DistributedRuntimeProvider for Component {
218 fn drt(&self) -> &DistributedRuntime {
219 &self.drt
220 }
221}
222
223impl RuntimeProvider for Component {
224 fn rt(&self) -> &Runtime {
225 self.drt.rt()
226 }
227}
228
229impl MetricsHierarchy for Component {
230 fn basename(&self) -> String {
231 self.name.clone()
232 }
233
234 fn parent_hierarchies(&self) -> Vec<&dyn MetricsHierarchy> {
235 let mut parents = vec![];
236
237 parents.extend(self.namespace.parent_hierarchies());
239
240 parents.push(&self.namespace as &dyn MetricsHierarchy);
242
243 parents
244 }
245
246 fn get_metrics_registry(&self) -> &MetricsRegistry {
247 &self.metrics_registry
248 }
249
250 fn connection_id(&self) -> Option<u64> {
251 Some(self.drt.connection_id())
252 }
253}
254
255impl Component {
256 pub fn service_name(&self) -> String {
257 let service_name = format!("{}_{}", self.namespace.name(), self.name);
258 Slug::slugify(&service_name).to_string()
259 }
260
261 pub fn namespace(&self) -> &Namespace {
262 &self.namespace
263 }
264
265 pub fn name(&self) -> &str {
266 &self.name
267 }
268
269 pub fn labels(&self) -> &[(String, String)] {
270 &self.labels
271 }
272
273 pub fn endpoint(&self, endpoint: impl Into<String>) -> Endpoint {
274 let endpoint = Endpoint {
275 component: self.clone(),
276 name: endpoint.into(),
277 labels: Vec::new(),
278 metrics_registry: crate::MetricsRegistry::new(),
279 };
280 self.get_metrics_registry()
282 .add_child_registry(endpoint.get_metrics_registry());
283 endpoint
284 }
285
286 pub async fn list_instances(&self) -> anyhow::Result<Vec<Instance>> {
287 let discovery = self.drt.discovery();
288
289 let discovery_query = crate::discovery::DiscoveryQuery::ComponentEndpoints {
290 namespace: self.namespace.name(),
291 component: self.name.clone(),
292 };
293
294 let discovery_instances = discovery.list(discovery_query).await?;
295
296 let mut instances: Vec<Instance> = discovery_instances
298 .into_iter()
299 .filter_map(|di| match di {
300 crate::discovery::DiscoveryInstance::Endpoint(instance) => Some(instance),
301 _ => None, })
303 .collect();
304
305 instances.sort();
306 Ok(instances)
307 }
308}
309
310impl ComponentBuilder {
311 pub fn from_runtime(drt: Arc<DistributedRuntime>) -> Self {
312 Self::default().drt(drt)
313 }
314
315 pub fn build(self) -> Result<Component, anyhow::Error> {
316 let component = self.build_internal()?;
317 let drt = component.drt();
321 if drt.request_plane().is_nats() {
322 let mut rx = drt.register_nats_service(component.clone());
323 let result = tokio::task::block_in_place(|| rx.blocking_recv());
328 match result {
329 Some(Ok(())) => {
330 tracing::debug!(
331 component = component.service_name(),
332 "NATS service registration completed"
333 );
334 }
335 Some(Err(e)) => {
336 return Err(anyhow::anyhow!(
337 "NATS service registration failed for component '{}': {}",
338 component.service_name(),
339 e
340 ));
341 }
342 None => {
343 return Err(anyhow::anyhow!(
344 "NATS service registration channel closed unexpectedly for component '{}'",
345 component.service_name()
346 ));
347 }
348 }
349 }
350 Ok(component)
351 }
352}
353
354#[derive(Debug, Clone)]
355pub struct Endpoint {
356 component: Component,
357
358 name: String,
361
362 labels: Vec<(String, String)>,
364
365 metrics_registry: crate::MetricsRegistry,
367}
368
369impl Hash for Endpoint {
370 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
371 self.component.hash(state);
372 self.name.hash(state);
373 }
374}
375
376impl PartialEq for Endpoint {
377 fn eq(&self, other: &Self) -> bool {
378 self.component == other.component && self.name == other.name
379 }
380}
381
382impl Eq for Endpoint {}
383
384impl DistributedRuntimeProvider for Endpoint {
385 fn drt(&self) -> &DistributedRuntime {
386 self.component.drt()
387 }
388}
389
390impl RuntimeProvider for Endpoint {
391 fn rt(&self) -> &Runtime {
392 self.component.rt()
393 }
394}
395
396impl MetricsHierarchy for Endpoint {
397 fn basename(&self) -> String {
398 self.name.clone()
399 }
400
401 fn parent_hierarchies(&self) -> Vec<&dyn MetricsHierarchy> {
402 let mut parents = vec![];
403
404 parents.extend(self.component.parent_hierarchies());
406
407 parents.push(&self.component as &dyn MetricsHierarchy);
409
410 parents
411 }
412
413 fn get_metrics_registry(&self) -> &MetricsRegistry {
414 &self.metrics_registry
415 }
416
417 fn connection_id(&self) -> Option<u64> {
418 Some(self.component.drt().connection_id())
419 }
420}
421
422impl Endpoint {
423 pub fn id(&self) -> EndpointId {
424 EndpointId {
425 namespace: self.component.namespace().name().to_string(),
426 component: self.component.name().to_string(),
427 name: self.name().to_string(),
428 }
429 }
430
431 pub fn name(&self) -> &str {
432 &self.name
433 }
434
435 pub fn component(&self) -> &Component {
436 &self.component
437 }
438
439 pub async fn client(&self) -> anyhow::Result<client::Client> {
440 client::Client::new(self.clone()).await
441 }
442
443 pub fn endpoint_builder(&self) -> endpoint::EndpointConfigBuilder {
444 endpoint::EndpointConfigBuilder::from_endpoint(self.clone())
445 }
446}
447
448#[derive(Builder, Clone, Validate)]
449#[builder(pattern = "owned")]
450pub struct Namespace {
451 #[builder(private)]
452 runtime: Arc<DistributedRuntime>,
453
454 #[validate(custom(function = "validate_allowed_chars"))]
455 name: String,
456
457 #[builder(default = "None")]
458 parent: Option<Arc<Namespace>>,
459
460 #[builder(default = "Vec::new()")]
462 labels: Vec<(String, String)>,
463
464 #[builder(default = "crate::MetricsRegistry::new()")]
466 metrics_registry: crate::MetricsRegistry,
467
468 #[builder(default = "Arc::new(DashMap::new())")]
473 component_cache: Arc<DashMap<String, Component>>,
474}
475
476impl DistributedRuntimeProvider for Namespace {
477 fn drt(&self) -> &DistributedRuntime {
478 &self.runtime
479 }
480}
481
482impl std::fmt::Debug for Namespace {
483 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
484 write!(
485 f,
486 "Namespace {{ name: {}; parent: {:?} }}",
487 self.name, self.parent
488 )
489 }
490}
491
492impl RuntimeProvider for Namespace {
493 fn rt(&self) -> &Runtime {
494 self.runtime.rt()
495 }
496}
497
498impl std::fmt::Display for Namespace {
499 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500 write!(f, "{}", self.name)
501 }
502}
503
504impl Namespace {
505 pub(crate) fn new(runtime: DistributedRuntime, name: String) -> anyhow::Result<Self> {
506 let ns = NamespaceBuilder::default()
507 .runtime(Arc::new(runtime))
508 .name(name)
509 .build()?;
510 ns.drt()
512 .get_metrics_registry()
513 .add_child_registry(ns.get_metrics_registry());
514 Ok(ns)
515 }
516
517 pub fn component(&self, name: impl Into<String>) -> anyhow::Result<Component> {
523 let name = name.into();
524
525 if let Some(cached) = self.component_cache.get(&name) {
528 return Ok(cached.value().clone());
529 }
530
531 let component = ComponentBuilder::from_runtime(self.runtime.clone())
533 .name(&name)
534 .namespace(self.clone())
535 .build()?;
536
537 self.get_metrics_registry()
539 .add_child_registry(component.get_metrics_registry());
540
541 self.component_cache.insert(name, component.clone());
545
546 Ok(component)
547 }
548
549 pub fn namespace(&self, name: impl Into<String>) -> anyhow::Result<Namespace> {
551 let child = NamespaceBuilder::default()
552 .runtime(self.runtime.clone())
553 .name(name.into())
554 .parent(Some(Arc::new(self.clone())))
555 .build()?;
556 self.get_metrics_registry()
558 .add_child_registry(child.get_metrics_registry());
559 Ok(child)
560 }
561
562 pub fn name(&self) -> String {
563 match &self.parent {
564 Some(parent) => format!("{}.{}", parent.name(), self.name),
565 None => self.name.clone(),
566 }
567 }
568}
569
570fn validate_allowed_chars(input: &str) -> Result<(), ValidationError> {
572 let regex = regex::Regex::new(r"^[a-z0-9-_]+$").unwrap();
574
575 if regex.is_match(input) {
576 Ok(())
577 } else {
578 Err(ValidationError::new("invalid_characters"))
579 }
580}