1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use awaken_contract::contract::executor::LlmExecutor;
5use awaken_contract::registry_spec::{AgentSpec, ModelBindingSpec};
6use serde::Serialize;
7
8#[cfg(feature = "a2a")]
9use super::MapBackendRegistry;
10use super::diagnostics::{RegistryValidationError, validate_registry_set};
11use super::memory::{
12 MapAgentSpecRegistry, MapModelRegistry, MapPluginSource, MapProviderRegistry, MapToolRegistry,
13};
14use super::snapshot::RegistryHandle;
15#[cfg(feature = "a2a")]
16use super::traits::BackendRegistry;
17use super::traits::{ModelBinding, RegistrySet};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ProviderRemovalPolicy {
21 BlockIfReferenced,
22 CascadeUnusedModelBindings,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
26pub struct ProviderRemovalPreview {
27 pub provider_id: String,
28 pub model_ids: Vec<String>,
29 pub agent_ids: Vec<String>,
30 pub block_if_referenced_allowed: bool,
31 pub cascade_unused_model_bindings_allowed: bool,
32}
33
34impl ProviderRemovalPreview {
35 pub fn new(
36 provider_id: impl Into<String>,
37 mut model_ids: Vec<String>,
38 mut agent_ids: Vec<String>,
39 ) -> Self {
40 model_ids.sort();
41 model_ids.dedup();
42 agent_ids.sort();
43 agent_ids.dedup();
44 Self {
45 provider_id: provider_id.into(),
46 block_if_referenced_allowed: model_ids.is_empty(),
47 cascade_unused_model_bindings_allowed: agent_ids.is_empty(),
48 model_ids,
49 agent_ids,
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
55pub struct ProviderRemovalImpact {
56 pub provider_id: String,
57 pub removed_model_ids: Vec<String>,
58 pub affected_agent_ids: Vec<String>,
59}
60
61#[derive(Debug, thiserror::Error)]
62pub enum RegistryUpdateError {
63 #[error("provider already registered: {0}")]
64 ProviderAlreadyExists(String),
65 #[error("provider not found: {0}")]
66 ProviderNotFound(String),
67 #[error(
68 "provider '{provider_id}' is still referenced by models {model_ids:?} and agents {agent_ids:?}"
69 )]
70 ProviderInUse {
71 provider_id: String,
72 model_ids: Vec<String>,
73 agent_ids: Vec<String>,
74 },
75 #[error("registry build failed: {0}")]
76 Build(String),
77 #[error("{0}")]
78 Validation(#[from] RegistryValidationError),
79}
80
81pub struct RuntimeRegistryUpdate {
82 pub providers: HashMap<String, Arc<dyn LlmExecutor>>,
83 pub models: Vec<ModelBindingSpec>,
84 pub agents: Vec<AgentSpec>,
85}
86
87impl RegistryHandle {
88 pub fn preview_remove_provider(
89 &self,
90 id: &str,
91 ) -> Result<ProviderRemovalPreview, RegistryUpdateError> {
92 let snapshot = self.snapshot();
93 preview_provider_removal(snapshot.registries(), id)
94 }
95
96 pub fn register_provider(
97 &self,
98 id: impl Into<String>,
99 executor: Arc<dyn LlmExecutor>,
100 ) -> Result<u64, RegistryUpdateError> {
101 let id = id.into();
102 self.update(|registries| {
103 let mut draft = RegistrySetDraft::from_set(registries)?;
104 if draft.providers.contains_key(&id) {
105 return Err(RegistryUpdateError::ProviderAlreadyExists(id));
106 }
107 draft
108 .providers
109 .register_provider(id, executor)
110 .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
111 draft.into_validated_set()
112 })
113 }
114
115 pub fn replace_provider(
116 &self,
117 id: impl Into<String>,
118 executor: Arc<dyn LlmExecutor>,
119 ) -> Result<u64, RegistryUpdateError> {
120 let id = id.into();
121 self.update(|registries| {
122 let mut draft = RegistrySetDraft::from_set(registries)?;
123 if !draft.providers.contains_key(&id) {
124 return Err(RegistryUpdateError::ProviderNotFound(id));
125 }
126 draft.providers.replace_provider(id, executor);
127 draft.into_validated_set()
128 })
129 }
130
131 pub fn remove_provider(
132 &self,
133 id: &str,
134 policy: ProviderRemovalPolicy,
135 ) -> Result<ProviderRemovalImpact, RegistryUpdateError> {
136 let mut impact = None;
137 self.update(|registries| {
138 let mut draft = RegistrySetDraft::from_set(registries)?;
139 if !draft.providers.contains_key(id) {
140 return Err(RegistryUpdateError::ProviderNotFound(id.to_string()));
141 }
142
143 let preview = preview_provider_removal_from_draft(&draft, id)?;
144
145 match policy {
146 ProviderRemovalPolicy::BlockIfReferenced if !preview.model_ids.is_empty() => {
147 return Err(RegistryUpdateError::ProviderInUse {
148 provider_id: id.to_string(),
149 model_ids: preview.model_ids,
150 agent_ids: preview.agent_ids,
151 });
152 }
153 ProviderRemovalPolicy::CascadeUnusedModelBindings
154 if !preview.agent_ids.is_empty() =>
155 {
156 return Err(RegistryUpdateError::ProviderInUse {
157 provider_id: id.to_string(),
158 model_ids: preview.model_ids,
159 agent_ids: preview.agent_ids,
160 });
161 }
162 _ => {}
163 }
164
165 for model_id in &preview.model_ids {
166 draft.models.remove(model_id);
167 }
168 draft.providers.remove_provider(id);
169
170 impact = Some(ProviderRemovalImpact {
171 provider_id: preview.provider_id,
172 removed_model_ids: preview.model_ids,
173 affected_agent_ids: preview.agent_ids,
174 });
175 draft.into_validated_set()
176 })?;
177 impact.ok_or_else(|| RegistryUpdateError::Build("provider removal did not run".into()))
178 }
179}
180
181pub fn preview_provider_removal(
182 registries: &RegistrySet,
183 id: &str,
184) -> Result<ProviderRemovalPreview, RegistryUpdateError> {
185 if registries.providers.get_provider(id).is_none() {
186 return Err(RegistryUpdateError::ProviderNotFound(id.to_string()));
187 }
188 let model_ids = provider_model_ids_from_set(registries, id);
189 let agent_ids = agents_using_models_from_set(registries, &model_ids);
190 Ok(ProviderRemovalPreview::new(id, model_ids, agent_ids))
191}
192
193pub fn rebuild_agent_model_provider_registries(
194 base: &RegistrySet,
195 update: RuntimeRegistryUpdate,
196) -> Result<RegistrySet, RegistryUpdateError> {
197 let mut draft = RegistrySetDraft::from_set(base)?;
198
199 draft.providers = MapProviderRegistry::new();
200 for (id, executor) in update.providers {
201 draft
202 .providers
203 .register_provider(id, executor)
204 .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
205 }
206
207 draft.models = MapModelRegistry::new();
208 for model in update.models {
209 draft
210 .models
211 .register_model(model.id.clone(), ModelBinding::from(&model))
212 .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
213 }
214
215 draft.agents = MapAgentSpecRegistry::new();
216 for agent in update.agents {
217 draft
218 .agents
219 .register_spec(agent)
220 .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
221 }
222
223 let registries = draft.into_set();
224 validate_registry_set(®istries)?;
225 Ok(registries)
226}
227
228struct RegistrySetDraft {
229 agents: MapAgentSpecRegistry,
230 tools: MapToolRegistry,
231 models: MapModelRegistry,
232 providers: MapProviderRegistry,
233 plugins: MapPluginSource,
234 #[cfg(feature = "a2a")]
235 backends: MapBackendRegistry,
236}
237
238impl RegistrySetDraft {
239 fn from_set(set: &RegistrySet) -> Result<Self, RegistryUpdateError> {
240 let mut agents = MapAgentSpecRegistry::new();
241 for id in set.agents.agent_ids() {
242 if let Some(agent) = set.agents.get_agent(&id) {
243 agents
244 .register(id, agent, |msg| {
245 crate::builder::BuildError::AgentRegistryConflict(format!("agent {msg}"))
246 })
247 .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
248 }
249 }
250
251 let mut tools = MapToolRegistry::new();
252 for id in set.tools.tool_ids() {
253 if let Some(tool) = set.tools.get_tool(&id) {
254 tools
255 .register_tool(id, tool)
256 .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
257 }
258 }
259
260 let mut models = MapModelRegistry::new();
261 for id in set.models.model_ids() {
262 if let Some(model) = set.models.get_model(&id) {
263 models
264 .register_model(id, model)
265 .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
266 }
267 }
268
269 let mut providers = MapProviderRegistry::new();
270 for id in set.providers.provider_ids() {
271 if let Some(provider) = set.providers.get_provider(&id) {
272 providers
273 .register_provider(id, provider)
274 .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
275 }
276 }
277
278 let mut plugins = MapPluginSource::new();
279 for id in set.plugins.plugin_ids() {
280 if let Some(plugin) = set.plugins.get_plugin(&id) {
281 plugins
282 .register_plugin(id, plugin)
283 .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
284 }
285 }
286
287 #[cfg(feature = "a2a")]
288 let mut backends = MapBackendRegistry::new();
289 #[cfg(feature = "a2a")]
290 for id in set.backends.backend_ids() {
291 if let Some(factory) = set.backends.get_backend_factory(&id) {
292 backends
293 .register_backend_factory(factory)
294 .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
295 }
296 }
297
298 Ok(Self {
299 agents,
300 tools,
301 models,
302 providers,
303 plugins,
304 #[cfg(feature = "a2a")]
305 backends,
306 })
307 }
308
309 fn into_set(self) -> RegistrySet {
310 RegistrySet {
311 agents: Arc::new(self.agents),
312 tools: Arc::new(self.tools),
313 models: Arc::new(self.models),
314 providers: Arc::new(self.providers),
315 plugins: Arc::new(self.plugins),
316 #[cfg(feature = "a2a")]
317 backends: Arc::new(self.backends) as Arc<dyn BackendRegistry>,
318 }
319 }
320
321 fn into_validated_set(self) -> Result<RegistrySet, RegistryUpdateError> {
322 let registries = self.into_set();
323 validate_registry_set(®istries)?;
324 Ok(registries)
325 }
326}
327
328fn provider_model_ids(models: &MapModelRegistry, provider_id: &str) -> Vec<String> {
329 models
330 .ids()
331 .into_iter()
332 .filter(|model_id| {
333 models
334 .get(model_id)
335 .is_some_and(|model| model.provider_id == provider_id)
336 })
337 .collect()
338}
339
340fn preview_provider_removal_from_draft(
341 draft: &RegistrySetDraft,
342 provider_id: &str,
343) -> Result<ProviderRemovalPreview, RegistryUpdateError> {
344 if !draft.providers.contains_key(provider_id) {
345 return Err(RegistryUpdateError::ProviderNotFound(
346 provider_id.to_string(),
347 ));
348 }
349 let model_ids = provider_model_ids(&draft.models, provider_id);
350 let agent_ids = agents_using_models(&draft.agents, &model_ids);
351 Ok(ProviderRemovalPreview::new(
352 provider_id,
353 model_ids,
354 agent_ids,
355 ))
356}
357
358fn provider_model_ids_from_set(registries: &RegistrySet, provider_id: &str) -> Vec<String> {
359 registries
360 .models
361 .model_ids()
362 .into_iter()
363 .filter(|model_id| {
364 registries
365 .models
366 .get_model(model_id)
367 .is_some_and(|model| model.provider_id == provider_id)
368 })
369 .collect()
370}
371
372fn agents_using_models(agents: &MapAgentSpecRegistry, model_ids: &[String]) -> Vec<String> {
373 let model_ids: HashSet<_> = model_ids.iter().map(String::as_str).collect();
374 agents
375 .ids()
376 .into_iter()
377 .filter(|agent_id| {
378 agents.get(agent_id).is_some_and(|agent| {
379 agent.endpoint.is_none() && model_ids.contains(agent.model_id.as_str())
380 })
381 })
382 .collect()
383}
384
385fn agents_using_models_from_set(registries: &RegistrySet, model_ids: &[String]) -> Vec<String> {
386 let model_ids: HashSet<_> = model_ids.iter().map(String::as_str).collect();
387 registries
388 .agents
389 .agent_ids()
390 .into_iter()
391 .filter(|agent_id| {
392 registries.agents.get_agent(agent_id).is_some_and(|agent| {
393 agent.endpoint.is_none() && model_ids.contains(agent.model_id.as_str())
394 })
395 })
396 .collect()
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use async_trait::async_trait;
403 use awaken_contract::contract::executor::{InferenceExecutionError, InferenceRequest};
404 use awaken_contract::contract::inference::{StopReason, StreamResult, TokenUsage};
405
406 struct StubExecutor;
407
408 #[async_trait]
409 impl LlmExecutor for StubExecutor {
410 async fn execute(
411 &self,
412 _request: InferenceRequest,
413 ) -> Result<StreamResult, InferenceExecutionError> {
414 Ok(StreamResult {
415 content: vec![],
416 tool_calls: vec![],
417 usage: Some(TokenUsage::default()),
418 stop_reason: Some(StopReason::EndTurn),
419 has_incomplete_tool_calls: false,
420 })
421 }
422
423 fn name(&self) -> &str {
424 "stub"
425 }
426 }
427
428 fn executor() -> Arc<dyn LlmExecutor> {
429 Arc::new(StubExecutor)
430 }
431
432 fn registry_set() -> RegistrySet {
433 let mut agents = MapAgentSpecRegistry::new();
434 agents
435 .register_spec(AgentSpec {
436 id: "a".into(),
437 model_id: "m".into(),
438 system_prompt: "s".into(),
439 ..Default::default()
440 })
441 .unwrap();
442
443 let mut models = MapModelRegistry::new();
444 models
445 .register_model(
446 "m",
447 ModelBinding {
448 provider_id: "p".into(),
449 upstream_model: "upstream".into(),
450 },
451 )
452 .unwrap();
453
454 let mut providers = MapProviderRegistry::new();
455 providers.register_provider("p", executor()).unwrap();
456
457 RegistrySet {
458 agents: Arc::new(agents),
459 tools: Arc::new(MapToolRegistry::new()),
460 models: Arc::new(models),
461 providers: Arc::new(providers),
462 plugins: Arc::new(MapPluginSource::new()),
463 #[cfg(feature = "a2a")]
464 backends: Arc::new(MapBackendRegistry::new()),
465 }
466 }
467
468 #[test]
469 fn remove_provider_blocks_when_model_and_agent_depend_on_it() {
470 let handle = RegistryHandle::new(registry_set());
471 let preview = handle
472 .preview_remove_provider("p")
473 .expect("provider exists");
474 assert_eq!(
475 preview,
476 ProviderRemovalPreview {
477 provider_id: "p".into(),
478 model_ids: vec!["m".into()],
479 agent_ids: vec!["a".into()],
480 block_if_referenced_allowed: false,
481 cascade_unused_model_bindings_allowed: false,
482 }
483 );
484
485 let err = handle
486 .remove_provider("p", ProviderRemovalPolicy::CascadeUnusedModelBindings)
487 .expect_err("agent dependency must block removal");
488 assert!(err.to_string().contains("agents [\"a\"]"));
489 }
490
491 #[test]
492 fn remove_provider_cascades_unused_model_bindings() {
493 let mut update = RuntimeRegistryUpdate {
494 providers: HashMap::new(),
495 models: vec![ModelBindingSpec {
496 id: "m".into(),
497 provider_id: "p".into(),
498 upstream_model: "upstream".into(),
499 }],
500 agents: Vec::new(),
501 };
502 update.providers.insert("p".into(), executor());
503 let base = registry_set();
504 let registries = rebuild_agent_model_provider_registries(&base, update).unwrap();
505 let handle = RegistryHandle::new(registries);
506
507 let impact = handle
508 .remove_provider("p", ProviderRemovalPolicy::CascadeUnusedModelBindings)
509 .expect("unused model can be removed with provider");
510
511 assert_eq!(impact.removed_model_ids, vec!["m"]);
512 let snapshot = handle.snapshot();
513 assert!(snapshot.registries().providers.get_provider("p").is_none());
514 assert!(snapshot.registries().models.get_model("m").is_none());
515 }
516
517 #[test]
518 fn replace_provider_keeps_model_binding_and_agent() {
519 let handle = RegistryHandle::new(registry_set());
520 let version = handle
521 .replace_provider("p", executor())
522 .expect("provider exists");
523 assert_eq!(version, 2);
524 let snapshot = handle.snapshot();
525 assert!(snapshot.registries().providers.get_provider("p").is_some());
526 assert!(snapshot.registries().models.get_model("m").is_some());
527 assert!(snapshot.registries().agents.get_agent("a").is_some());
528 }
529
530 #[test]
531 fn concurrent_provider_registration_preserves_all_updates() {
532 let handle = Arc::new(RegistryHandle::new(registry_set()));
533 let mut threads = Vec::new();
534
535 for index in 0..16 {
536 let handle = Arc::clone(&handle);
537 threads.push(std::thread::spawn(move || {
538 handle
539 .register_provider(format!("p-{index}"), executor())
540 .expect("provider registration must succeed");
541 }));
542 }
543
544 for thread in threads {
545 thread.join().expect("thread must not panic");
546 }
547
548 let snapshot = handle.snapshot();
549 for index in 0..16 {
550 let provider_id = format!("p-{index}");
551 assert!(
552 snapshot
553 .registries()
554 .providers
555 .get_provider(&provider_id)
556 .is_some(),
557 "provider {provider_id} must survive concurrent updates"
558 );
559 }
560 }
561}