Skip to main content

awaken_runtime/registry/
lifecycle.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use awaken_contract::contract::executor::LlmExecutor;
5use awaken_contract::registry_spec::{AgentSpec, ModelBindingSpec};
6use serde::Serialize;
7
8#[cfg(feature = "a2a")]
9use super::MapBackendRegistry;
10use super::diagnostics::{RegistryValidationError, validate_registry_set};
11use super::memory::{
12    MapAgentSpecRegistry, MapModelRegistry, MapPluginSource, MapProviderRegistry, MapToolRegistry,
13};
14use super::snapshot::RegistryHandle;
15#[cfg(feature = "a2a")]
16use super::traits::BackendRegistry;
17use super::traits::{ModelBinding, RegistrySet};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ProviderRemovalPolicy {
21    BlockIfReferenced,
22    CascadeUnusedModelBindings,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
26pub struct ProviderRemovalPreview {
27    pub provider_id: String,
28    pub model_ids: Vec<String>,
29    pub agent_ids: Vec<String>,
30    pub block_if_referenced_allowed: bool,
31    pub cascade_unused_model_bindings_allowed: bool,
32}
33
34impl ProviderRemovalPreview {
35    pub fn new(
36        provider_id: impl Into<String>,
37        mut model_ids: Vec<String>,
38        mut agent_ids: Vec<String>,
39    ) -> Self {
40        model_ids.sort();
41        model_ids.dedup();
42        agent_ids.sort();
43        agent_ids.dedup();
44        Self {
45            provider_id: provider_id.into(),
46            block_if_referenced_allowed: model_ids.is_empty(),
47            cascade_unused_model_bindings_allowed: agent_ids.is_empty(),
48            model_ids,
49            agent_ids,
50        }
51    }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
55pub struct ProviderRemovalImpact {
56    pub provider_id: String,
57    pub removed_model_ids: Vec<String>,
58    pub affected_agent_ids: Vec<String>,
59}
60
61#[derive(Debug, thiserror::Error)]
62pub enum RegistryUpdateError {
63    #[error("provider already registered: {0}")]
64    ProviderAlreadyExists(String),
65    #[error("provider not found: {0}")]
66    ProviderNotFound(String),
67    #[error(
68        "provider '{provider_id}' is still referenced by models {model_ids:?} and agents {agent_ids:?}"
69    )]
70    ProviderInUse {
71        provider_id: String,
72        model_ids: Vec<String>,
73        agent_ids: Vec<String>,
74    },
75    #[error("registry build failed: {0}")]
76    Build(String),
77    #[error("{0}")]
78    Validation(#[from] RegistryValidationError),
79}
80
81pub struct RuntimeRegistryUpdate {
82    pub providers: HashMap<String, Arc<dyn LlmExecutor>>,
83    pub models: Vec<ModelBindingSpec>,
84    pub agents: Vec<AgentSpec>,
85}
86
87impl RegistryHandle {
88    pub fn preview_remove_provider(
89        &self,
90        id: &str,
91    ) -> Result<ProviderRemovalPreview, RegistryUpdateError> {
92        let snapshot = self.snapshot();
93        preview_provider_removal(snapshot.registries(), id)
94    }
95
96    pub fn register_provider(
97        &self,
98        id: impl Into<String>,
99        executor: Arc<dyn LlmExecutor>,
100    ) -> Result<u64, RegistryUpdateError> {
101        let id = id.into();
102        self.update(|registries| {
103            let mut draft = RegistrySetDraft::from_set(registries)?;
104            if draft.providers.contains_key(&id) {
105                return Err(RegistryUpdateError::ProviderAlreadyExists(id));
106            }
107            draft
108                .providers
109                .register_provider(id, executor)
110                .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
111            draft.into_validated_set()
112        })
113    }
114
115    pub fn replace_provider(
116        &self,
117        id: impl Into<String>,
118        executor: Arc<dyn LlmExecutor>,
119    ) -> Result<u64, RegistryUpdateError> {
120        let id = id.into();
121        self.update(|registries| {
122            let mut draft = RegistrySetDraft::from_set(registries)?;
123            if !draft.providers.contains_key(&id) {
124                return Err(RegistryUpdateError::ProviderNotFound(id));
125            }
126            draft.providers.replace_provider(id, executor);
127            draft.into_validated_set()
128        })
129    }
130
131    pub fn remove_provider(
132        &self,
133        id: &str,
134        policy: ProviderRemovalPolicy,
135    ) -> Result<ProviderRemovalImpact, RegistryUpdateError> {
136        let mut impact = None;
137        self.update(|registries| {
138            let mut draft = RegistrySetDraft::from_set(registries)?;
139            if !draft.providers.contains_key(id) {
140                return Err(RegistryUpdateError::ProviderNotFound(id.to_string()));
141            }
142
143            let preview = preview_provider_removal_from_draft(&draft, id)?;
144
145            match policy {
146                ProviderRemovalPolicy::BlockIfReferenced if !preview.model_ids.is_empty() => {
147                    return Err(RegistryUpdateError::ProviderInUse {
148                        provider_id: id.to_string(),
149                        model_ids: preview.model_ids,
150                        agent_ids: preview.agent_ids,
151                    });
152                }
153                ProviderRemovalPolicy::CascadeUnusedModelBindings
154                    if !preview.agent_ids.is_empty() =>
155                {
156                    return Err(RegistryUpdateError::ProviderInUse {
157                        provider_id: id.to_string(),
158                        model_ids: preview.model_ids,
159                        agent_ids: preview.agent_ids,
160                    });
161                }
162                _ => {}
163            }
164
165            for model_id in &preview.model_ids {
166                draft.models.remove(model_id);
167            }
168            draft.providers.remove_provider(id);
169
170            impact = Some(ProviderRemovalImpact {
171                provider_id: preview.provider_id,
172                removed_model_ids: preview.model_ids,
173                affected_agent_ids: preview.agent_ids,
174            });
175            draft.into_validated_set()
176        })?;
177        impact.ok_or_else(|| RegistryUpdateError::Build("provider removal did not run".into()))
178    }
179}
180
181pub fn preview_provider_removal(
182    registries: &RegistrySet,
183    id: &str,
184) -> Result<ProviderRemovalPreview, RegistryUpdateError> {
185    if registries.providers.get_provider(id).is_none() {
186        return Err(RegistryUpdateError::ProviderNotFound(id.to_string()));
187    }
188    let model_ids = provider_model_ids_from_set(registries, id);
189    let agent_ids = agents_using_models_from_set(registries, &model_ids);
190    Ok(ProviderRemovalPreview::new(id, model_ids, agent_ids))
191}
192
193pub fn rebuild_agent_model_provider_registries(
194    base: &RegistrySet,
195    update: RuntimeRegistryUpdate,
196) -> Result<RegistrySet, RegistryUpdateError> {
197    let mut draft = RegistrySetDraft::from_set(base)?;
198
199    draft.providers = MapProviderRegistry::new();
200    for (id, executor) in update.providers {
201        draft
202            .providers
203            .register_provider(id, executor)
204            .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
205    }
206
207    draft.models = MapModelRegistry::new();
208    for model in update.models {
209        draft
210            .models
211            .register_model(model.id.clone(), ModelBinding::from(&model))
212            .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
213    }
214
215    draft.agents = MapAgentSpecRegistry::new();
216    for agent in update.agents {
217        draft
218            .agents
219            .register_spec(agent)
220            .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
221    }
222
223    let registries = draft.into_set();
224    validate_registry_set(&registries)?;
225    Ok(registries)
226}
227
228struct RegistrySetDraft {
229    agents: MapAgentSpecRegistry,
230    tools: MapToolRegistry,
231    models: MapModelRegistry,
232    providers: MapProviderRegistry,
233    plugins: MapPluginSource,
234    #[cfg(feature = "a2a")]
235    backends: MapBackendRegistry,
236}
237
238impl RegistrySetDraft {
239    fn from_set(set: &RegistrySet) -> Result<Self, RegistryUpdateError> {
240        let mut agents = MapAgentSpecRegistry::new();
241        for id in set.agents.agent_ids() {
242            if let Some(agent) = set.agents.get_agent(&id) {
243                agents
244                    .register(id, agent, |msg| {
245                        crate::builder::BuildError::AgentRegistryConflict(format!("agent {msg}"))
246                    })
247                    .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
248            }
249        }
250
251        let mut tools = MapToolRegistry::new();
252        for id in set.tools.tool_ids() {
253            if let Some(tool) = set.tools.get_tool(&id) {
254                tools
255                    .register_tool(id, tool)
256                    .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
257            }
258        }
259
260        let mut models = MapModelRegistry::new();
261        for id in set.models.model_ids() {
262            if let Some(model) = set.models.get_model(&id) {
263                models
264                    .register_model(id, model)
265                    .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
266            }
267        }
268
269        let mut providers = MapProviderRegistry::new();
270        for id in set.providers.provider_ids() {
271            if let Some(provider) = set.providers.get_provider(&id) {
272                providers
273                    .register_provider(id, provider)
274                    .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
275            }
276        }
277
278        let mut plugins = MapPluginSource::new();
279        for id in set.plugins.plugin_ids() {
280            if let Some(plugin) = set.plugins.get_plugin(&id) {
281                plugins
282                    .register_plugin(id, plugin)
283                    .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
284            }
285        }
286
287        #[cfg(feature = "a2a")]
288        let mut backends = MapBackendRegistry::new();
289        #[cfg(feature = "a2a")]
290        for id in set.backends.backend_ids() {
291            if let Some(factory) = set.backends.get_backend_factory(&id) {
292                backends
293                    .register_backend_factory(factory)
294                    .map_err(|error| RegistryUpdateError::Build(error.to_string()))?;
295            }
296        }
297
298        Ok(Self {
299            agents,
300            tools,
301            models,
302            providers,
303            plugins,
304            #[cfg(feature = "a2a")]
305            backends,
306        })
307    }
308
309    fn into_set(self) -> RegistrySet {
310        RegistrySet {
311            agents: Arc::new(self.agents),
312            tools: Arc::new(self.tools),
313            models: Arc::new(self.models),
314            providers: Arc::new(self.providers),
315            plugins: Arc::new(self.plugins),
316            #[cfg(feature = "a2a")]
317            backends: Arc::new(self.backends) as Arc<dyn BackendRegistry>,
318        }
319    }
320
321    fn into_validated_set(self) -> Result<RegistrySet, RegistryUpdateError> {
322        let registries = self.into_set();
323        validate_registry_set(&registries)?;
324        Ok(registries)
325    }
326}
327
328fn provider_model_ids(models: &MapModelRegistry, provider_id: &str) -> Vec<String> {
329    models
330        .ids()
331        .into_iter()
332        .filter(|model_id| {
333            models
334                .get(model_id)
335                .is_some_and(|model| model.provider_id == provider_id)
336        })
337        .collect()
338}
339
340fn preview_provider_removal_from_draft(
341    draft: &RegistrySetDraft,
342    provider_id: &str,
343) -> Result<ProviderRemovalPreview, RegistryUpdateError> {
344    if !draft.providers.contains_key(provider_id) {
345        return Err(RegistryUpdateError::ProviderNotFound(
346            provider_id.to_string(),
347        ));
348    }
349    let model_ids = provider_model_ids(&draft.models, provider_id);
350    let agent_ids = agents_using_models(&draft.agents, &model_ids);
351    Ok(ProviderRemovalPreview::new(
352        provider_id,
353        model_ids,
354        agent_ids,
355    ))
356}
357
358fn provider_model_ids_from_set(registries: &RegistrySet, provider_id: &str) -> Vec<String> {
359    registries
360        .models
361        .model_ids()
362        .into_iter()
363        .filter(|model_id| {
364            registries
365                .models
366                .get_model(model_id)
367                .is_some_and(|model| model.provider_id == provider_id)
368        })
369        .collect()
370}
371
372fn agents_using_models(agents: &MapAgentSpecRegistry, model_ids: &[String]) -> Vec<String> {
373    let model_ids: HashSet<_> = model_ids.iter().map(String::as_str).collect();
374    agents
375        .ids()
376        .into_iter()
377        .filter(|agent_id| {
378            agents.get(agent_id).is_some_and(|agent| {
379                agent.endpoint.is_none() && model_ids.contains(agent.model_id.as_str())
380            })
381        })
382        .collect()
383}
384
385fn agents_using_models_from_set(registries: &RegistrySet, model_ids: &[String]) -> Vec<String> {
386    let model_ids: HashSet<_> = model_ids.iter().map(String::as_str).collect();
387    registries
388        .agents
389        .agent_ids()
390        .into_iter()
391        .filter(|agent_id| {
392            registries.agents.get_agent(agent_id).is_some_and(|agent| {
393                agent.endpoint.is_none() && model_ids.contains(agent.model_id.as_str())
394            })
395        })
396        .collect()
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use async_trait::async_trait;
403    use awaken_contract::contract::executor::{InferenceExecutionError, InferenceRequest};
404    use awaken_contract::contract::inference::{StopReason, StreamResult, TokenUsage};
405
406    struct StubExecutor;
407
408    #[async_trait]
409    impl LlmExecutor for StubExecutor {
410        async fn execute(
411            &self,
412            _request: InferenceRequest,
413        ) -> Result<StreamResult, InferenceExecutionError> {
414            Ok(StreamResult {
415                content: vec![],
416                tool_calls: vec![],
417                usage: Some(TokenUsage::default()),
418                stop_reason: Some(StopReason::EndTurn),
419                has_incomplete_tool_calls: false,
420            })
421        }
422
423        fn name(&self) -> &str {
424            "stub"
425        }
426    }
427
428    fn executor() -> Arc<dyn LlmExecutor> {
429        Arc::new(StubExecutor)
430    }
431
432    fn registry_set() -> RegistrySet {
433        let mut agents = MapAgentSpecRegistry::new();
434        agents
435            .register_spec(AgentSpec {
436                id: "a".into(),
437                model_id: "m".into(),
438                system_prompt: "s".into(),
439                ..Default::default()
440            })
441            .unwrap();
442
443        let mut models = MapModelRegistry::new();
444        models
445            .register_model(
446                "m",
447                ModelBinding {
448                    provider_id: "p".into(),
449                    upstream_model: "upstream".into(),
450                },
451            )
452            .unwrap();
453
454        let mut providers = MapProviderRegistry::new();
455        providers.register_provider("p", executor()).unwrap();
456
457        RegistrySet {
458            agents: Arc::new(agents),
459            tools: Arc::new(MapToolRegistry::new()),
460            models: Arc::new(models),
461            providers: Arc::new(providers),
462            plugins: Arc::new(MapPluginSource::new()),
463            #[cfg(feature = "a2a")]
464            backends: Arc::new(MapBackendRegistry::new()),
465        }
466    }
467
468    #[test]
469    fn remove_provider_blocks_when_model_and_agent_depend_on_it() {
470        let handle = RegistryHandle::new(registry_set());
471        let preview = handle
472            .preview_remove_provider("p")
473            .expect("provider exists");
474        assert_eq!(
475            preview,
476            ProviderRemovalPreview {
477                provider_id: "p".into(),
478                model_ids: vec!["m".into()],
479                agent_ids: vec!["a".into()],
480                block_if_referenced_allowed: false,
481                cascade_unused_model_bindings_allowed: false,
482            }
483        );
484
485        let err = handle
486            .remove_provider("p", ProviderRemovalPolicy::CascadeUnusedModelBindings)
487            .expect_err("agent dependency must block removal");
488        assert!(err.to_string().contains("agents [\"a\"]"));
489    }
490
491    #[test]
492    fn remove_provider_cascades_unused_model_bindings() {
493        let mut update = RuntimeRegistryUpdate {
494            providers: HashMap::new(),
495            models: vec![ModelBindingSpec {
496                id: "m".into(),
497                provider_id: "p".into(),
498                upstream_model: "upstream".into(),
499            }],
500            agents: Vec::new(),
501        };
502        update.providers.insert("p".into(), executor());
503        let base = registry_set();
504        let registries = rebuild_agent_model_provider_registries(&base, update).unwrap();
505        let handle = RegistryHandle::new(registries);
506
507        let impact = handle
508            .remove_provider("p", ProviderRemovalPolicy::CascadeUnusedModelBindings)
509            .expect("unused model can be removed with provider");
510
511        assert_eq!(impact.removed_model_ids, vec!["m"]);
512        let snapshot = handle.snapshot();
513        assert!(snapshot.registries().providers.get_provider("p").is_none());
514        assert!(snapshot.registries().models.get_model("m").is_none());
515    }
516
517    #[test]
518    fn replace_provider_keeps_model_binding_and_agent() {
519        let handle = RegistryHandle::new(registry_set());
520        let version = handle
521            .replace_provider("p", executor())
522            .expect("provider exists");
523        assert_eq!(version, 2);
524        let snapshot = handle.snapshot();
525        assert!(snapshot.registries().providers.get_provider("p").is_some());
526        assert!(snapshot.registries().models.get_model("m").is_some());
527        assert!(snapshot.registries().agents.get_agent("a").is_some());
528    }
529
530    #[test]
531    fn concurrent_provider_registration_preserves_all_updates() {
532        let handle = Arc::new(RegistryHandle::new(registry_set()));
533        let mut threads = Vec::new();
534
535        for index in 0..16 {
536            let handle = Arc::clone(&handle);
537            threads.push(std::thread::spawn(move || {
538                handle
539                    .register_provider(format!("p-{index}"), executor())
540                    .expect("provider registration must succeed");
541            }));
542        }
543
544        for thread in threads {
545            thread.join().expect("thread must not panic");
546        }
547
548        let snapshot = handle.snapshot();
549        for index in 0..16 {
550            let provider_id = format!("p-{index}");
551            assert!(
552                snapshot
553                    .registries()
554                    .providers
555                    .get_provider(&provider_id)
556                    .is_some(),
557                "provider {provider_id} must survive concurrent updates"
558            );
559        }
560    }
561}