1use super::{
5 Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
6 DiscoverySpec, DiscoveryStream,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use std::sync::{Arc, Mutex};
11use tokio_util::sync::CancellationToken;
12
13#[derive(Clone, Default)]
15pub struct SharedMockRegistry {
16 instances: Arc<Mutex<Vec<DiscoveryInstance>>>,
17}
18
19impl SharedMockRegistry {
20 pub fn new() -> Self {
21 Self::default()
22 }
23}
24
25pub struct MockDiscovery {
28 instance_id: u64,
29 registry: SharedMockRegistry,
30}
31
32impl MockDiscovery {
33 pub fn new(instance_id: Option<u64>, registry: SharedMockRegistry) -> Self {
34 let instance_id = instance_id.unwrap_or_else(|| {
35 use std::sync::atomic::{AtomicU64, Ordering};
36 static COUNTER: AtomicU64 = AtomicU64::new(1);
37 COUNTER.fetch_add(1, Ordering::SeqCst)
38 });
39
40 Self {
41 instance_id,
42 registry,
43 }
44 }
45}
46
47fn matches_query(instance: &DiscoveryInstance, query: &DiscoveryQuery) -> bool {
49 match (instance, query) {
50 (DiscoveryInstance::Endpoint(_), DiscoveryQuery::AllEndpoints) => true,
52 (DiscoveryInstance::Endpoint(inst), DiscoveryQuery::NamespacedEndpoints { namespace }) => {
53 &inst.namespace == namespace
54 }
55 (
56 DiscoveryInstance::Endpoint(inst),
57 DiscoveryQuery::ComponentEndpoints {
58 namespace,
59 component,
60 },
61 ) => &inst.namespace == namespace && &inst.component == component,
62 (
63 DiscoveryInstance::Endpoint(inst),
64 DiscoveryQuery::Endpoint {
65 namespace,
66 component,
67 endpoint,
68 },
69 ) => {
70 &inst.namespace == namespace
71 && &inst.component == component
72 && &inst.endpoint == endpoint
73 }
74
75 (DiscoveryInstance::Model { .. }, DiscoveryQuery::AllModels) => true,
77 (
78 DiscoveryInstance::Model {
79 namespace: inst_ns, ..
80 },
81 DiscoveryQuery::NamespacedModels { namespace },
82 ) => inst_ns == namespace,
83 (
84 DiscoveryInstance::Model {
85 namespace: inst_ns,
86 component: inst_comp,
87 ..
88 },
89 DiscoveryQuery::ComponentModels {
90 namespace,
91 component,
92 },
93 ) => inst_ns == namespace && inst_comp == component,
94 (
95 DiscoveryInstance::Model {
96 namespace: inst_ns,
97 component: inst_comp,
98 endpoint: inst_ep,
99 ..
100 },
101 DiscoveryQuery::EndpointModels {
102 namespace,
103 component,
104 endpoint,
105 },
106 ) => inst_ns == namespace && inst_comp == component && inst_ep == endpoint,
107
108 (
110 DiscoveryInstance::EventChannel {
111 namespace: inst_ns,
112 component: inst_comp,
113 topic: inst_topic,
114 ..
115 },
116 DiscoveryQuery::EventChannels(query),
117 ) => {
118 query.namespace.as_ref().is_none_or(|ns| ns == inst_ns)
119 && query.component.as_ref().is_none_or(|c| c == inst_comp)
120 && query.topic.as_ref().is_none_or(|t| t == inst_topic)
121 }
122
123 (
125 DiscoveryInstance::Endpoint(_),
126 DiscoveryQuery::AllModels
127 | DiscoveryQuery::NamespacedModels { .. }
128 | DiscoveryQuery::ComponentModels { .. }
129 | DiscoveryQuery::EndpointModels { .. }
130 | DiscoveryQuery::EventChannels(_),
131 ) => false,
132 (
133 DiscoveryInstance::Model { .. },
134 DiscoveryQuery::AllEndpoints
135 | DiscoveryQuery::NamespacedEndpoints { .. }
136 | DiscoveryQuery::ComponentEndpoints { .. }
137 | DiscoveryQuery::Endpoint { .. }
138 | DiscoveryQuery::EventChannels(_),
139 ) => false,
140 (
141 DiscoveryInstance::EventChannel { .. },
142 DiscoveryQuery::AllEndpoints
143 | DiscoveryQuery::NamespacedEndpoints { .. }
144 | DiscoveryQuery::ComponentEndpoints { .. }
145 | DiscoveryQuery::Endpoint { .. }
146 | DiscoveryQuery::AllModels
147 | DiscoveryQuery::NamespacedModels { .. }
148 | DiscoveryQuery::ComponentModels { .. }
149 | DiscoveryQuery::EndpointModels { .. },
150 ) => false,
151 }
152}
153
154#[async_trait]
155impl Discovery for MockDiscovery {
156 fn instance_id(&self) -> u64 {
157 self.instance_id
158 }
159
160 async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
161 let instance = spec.with_instance_id(self.instance_id);
162
163 self.registry
164 .instances
165 .lock()
166 .unwrap()
167 .push(instance.clone());
168
169 Ok(instance)
170 }
171
172 async fn unregister(&self, instance: DiscoveryInstance) -> Result<()> {
173 let target_id = instance.id();
174
175 self.registry
176 .instances
177 .lock()
178 .unwrap()
179 .retain(|i| i.id() != target_id);
180
181 Ok(())
182 }
183
184 async fn list(&self, query: DiscoveryQuery) -> Result<Vec<DiscoveryInstance>> {
185 let instances = self.registry.instances.lock().unwrap();
186 Ok(instances
187 .iter()
188 .filter(|instance| matches_query(instance, &query))
189 .cloned()
190 .collect())
191 }
192
193 async fn list_and_watch(
194 &self,
195 query: DiscoveryQuery,
196 _cancel_token: Option<CancellationToken>,
197 ) -> Result<DiscoveryStream> {
198 use std::collections::HashSet;
199
200 let registry = self.registry.clone();
201
202 let stream = async_stream::stream! {
203 let mut known_instances: HashSet<DiscoveryInstanceId> = HashSet::new();
204
205 loop {
206 let current: Vec<_> = {
207 let instances = registry.instances.lock().unwrap();
208 instances
209 .iter()
210 .filter(|instance| matches_query(instance, &query))
211 .cloned()
212 .collect()
213 };
214
215 let current_ids: HashSet<DiscoveryInstanceId> = current.iter().map(|i| i.id()).collect();
216
217 for instance in current {
219 let id = instance.id();
220 if known_instances.insert(id) {
221 yield Ok(DiscoveryEvent::Added(instance));
222 }
223 }
224
225 for id in known_instances.difference(¤t_ids).cloned().collect::<Vec<_>>() {
227 known_instances.remove(&id);
228 yield Ok(DiscoveryEvent::Removed(id));
229 }
230
231 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
232 }
233 };
234
235 Ok(Box::pin(stream))
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use futures::StreamExt;
243
244 fn model_spec(
245 namespace: &str,
246 component: &str,
247 endpoint: &str,
248 model_name: &str,
249 ) -> DiscoverySpec {
250 DiscoverySpec::Model {
251 namespace: namespace.to_string(),
252 component: component.to_string(),
253 endpoint: endpoint.to_string(),
254 card_json: serde_json::json!({
255 "display_name": model_name,
256 }),
257 model_suffix: None,
258 }
259 }
260
261 fn lora_model_spec(
262 namespace: &str,
263 component: &str,
264 endpoint: &str,
265 model_name: &str,
266 source_path: &str,
267 lora_name: &str,
268 ) -> DiscoverySpec {
269 DiscoverySpec::Model {
270 namespace: namespace.to_string(),
271 component: component.to_string(),
272 endpoint: endpoint.to_string(),
273 card_json: serde_json::json!({
274 "display_name": model_name,
275 "source_path": source_path,
276 "lora": {
277 "name": lora_name,
278 },
279 }),
280 model_suffix: Some(lora_name.to_string()),
281 }
282 }
283
284 #[tokio::test]
285 async fn test_mock_discovery_add_and_remove() {
286 let registry = SharedMockRegistry::new();
287 let client1 = MockDiscovery::new(Some(1), registry.clone());
288 let client2 = MockDiscovery::new(Some(2), registry.clone());
289
290 let spec = DiscoverySpec::Endpoint {
291 namespace: "test-ns".to_string(),
292 component: "test-comp".to_string(),
293 endpoint: "test-ep".to_string(),
294 transport: crate::component::TransportType::Nats("test-subject".to_string()),
295 device_type: None,
296 };
297
298 let query = DiscoveryQuery::Endpoint {
299 namespace: "test-ns".to_string(),
300 component: "test-comp".to_string(),
301 endpoint: "test-ep".to_string(),
302 };
303
304 let mut stream = client1.list_and_watch(query.clone(), None).await.unwrap();
306
307 let instance1 = client1.register(spec.clone()).await.unwrap();
309
310 let event = stream.next().await.unwrap().unwrap();
311 match event {
312 DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
313 assert_eq!(inst.instance_id, 1);
314 }
315 _ => panic!("Expected Added event for instance-1"),
316 }
317
318 client2.register(spec.clone()).await.unwrap();
320
321 let event = stream.next().await.unwrap().unwrap();
322 match event {
323 DiscoveryEvent::Added(DiscoveryInstance::Endpoint(inst)) => {
324 assert_eq!(inst.instance_id, 2);
325 }
326 _ => panic!("Expected Added event for instance-2"),
327 }
328
329 client1.unregister(instance1).await.unwrap();
331
332 let event = stream.next().await.unwrap().unwrap();
333 match event {
334 DiscoveryEvent::Removed(id) => {
335 let endpoint_id = id.extract_endpoint_id().expect("Expected endpoint removal");
336 assert_eq!(endpoint_id.instance_id, 1);
337 }
338 _ => panic!("Expected Removed event for instance-1"),
339 }
340 }
341
342 #[tokio::test]
343 async fn register_allows_same_model_name_on_same_endpoint() {
344 let registry = SharedMockRegistry::new();
345 let discovery1 = MockDiscovery::new(Some(1), registry.clone());
346 let discovery2 = MockDiscovery::new(Some(2), registry);
347 let spec = model_spec("ns", "comp", "generate", "model-a");
348
349 discovery1.register(spec.clone()).await.unwrap();
350 discovery2.register(spec).await.unwrap();
351
352 let instances = discovery1
353 .list(DiscoveryQuery::EndpointModels {
354 namespace: "ns".to_string(),
355 component: "comp".to_string(),
356 endpoint: "generate".to_string(),
357 })
358 .await
359 .unwrap();
360 assert_eq!(instances.len(), 2);
361 }
362
363 #[tokio::test]
364 async fn register_rejects_different_model_name_on_same_endpoint() {
365 let registry = SharedMockRegistry::new();
366 let discovery1 = MockDiscovery::new(Some(1), registry.clone());
367 let discovery2 = MockDiscovery::new(Some(2), registry);
368
369 discovery1
370 .register(model_spec("ns", "comp", "generate", "model-a"))
371 .await
372 .unwrap();
373
374 let err = discovery2
375 .register(model_spec("ns", "comp", "generate", "model-b"))
376 .await
377 .unwrap_err();
378
379 assert!(err.to_string().contains(
380 "Cannot register model 'model-b' on endpoint 'ns/comp/generate': a different model 'model-a' is already registered there"
381 ));
382
383 let instances = discovery1
384 .list(DiscoveryQuery::EndpointModels {
385 namespace: "ns".to_string(),
386 component: "comp".to_string(),
387 endpoint: "generate".to_string(),
388 })
389 .await
390 .unwrap();
391 assert_eq!(instances.len(), 1);
392 }
393
394 #[tokio::test]
395 async fn register_allows_different_model_names_on_different_endpoints() {
396 let registry = SharedMockRegistry::new();
397 let discovery1 = MockDiscovery::new(Some(1), registry.clone());
398 let discovery2 = MockDiscovery::new(Some(2), registry);
399
400 discovery1
401 .register(model_spec("ns", "comp", "generate-a", "model-a"))
402 .await
403 .unwrap();
404 discovery2
405 .register(model_spec("ns", "comp", "generate-b", "model-b"))
406 .await
407 .unwrap();
408 }
409
410 #[tokio::test]
411 async fn register_allows_lora_adapter_on_same_endpoint() {
412 let registry = SharedMockRegistry::new();
413 let discovery1 = MockDiscovery::new(Some(1), registry.clone());
414 let discovery2 = MockDiscovery::new(Some(2), registry);
415
416 discovery1
417 .register(DiscoverySpec::Model {
418 namespace: "ns".to_string(),
419 component: "comp".to_string(),
420 endpoint: "generate".to_string(),
421 card_json: serde_json::json!({
422 "display_name": "base-model",
423 "source_path": "base-repo",
424 }),
425 model_suffix: None,
426 })
427 .await
428 .unwrap();
429
430 discovery2
431 .register(lora_model_spec(
432 "ns",
433 "comp",
434 "generate",
435 "adapter-a",
436 "base-repo",
437 "adapter-a",
438 ))
439 .await
440 .unwrap();
441 }
442
443 #[tokio::test]
444 async fn register_rejects_lora_adapter_for_different_base_model() {
445 let registry = SharedMockRegistry::new();
446 let discovery1 = MockDiscovery::new(Some(1), registry.clone());
447 let discovery2 = MockDiscovery::new(Some(2), registry);
448
449 discovery1
450 .register(DiscoverySpec::Model {
451 namespace: "ns".to_string(),
452 component: "comp".to_string(),
453 endpoint: "generate".to_string(),
454 card_json: serde_json::json!({
455 "display_name": "base-model",
456 "source_path": "base-repo",
457 }),
458 model_suffix: None,
459 })
460 .await
461 .unwrap();
462
463 let err = discovery2
464 .register(lora_model_spec(
465 "ns",
466 "comp",
467 "generate",
468 "adapter-a",
469 "other-base-repo",
470 "adapter-a",
471 ))
472 .await
473 .unwrap_err();
474
475 assert!(err.to_string().contains(
476 "Cannot register model 'adapter-a' on endpoint 'ns/comp/generate': a different model 'base-model' is already registered there"
477 ));
478 }
479}