Skip to main content

systemprompt_extension/registry/
queries.rs

1use super::ExtensionRegistry;
2use crate::Extension;
3use crate::asset::{AssetDefinition, AssetPaths};
4use std::sync::Arc;
5use systemprompt_provider_contracts::Job;
6use tracing::warn;
7
8impl ExtensionRegistry {
9    #[must_use]
10    pub fn get(&self, id: &str) -> Option<&Arc<dyn Extension>> {
11        self.extensions.get(id)
12    }
13
14    #[must_use]
15    pub fn has(&self, id: &str) -> bool {
16        self.extensions.contains_key(id)
17    }
18
19    #[must_use]
20    pub fn ids(&self) -> Vec<&str> {
21        self.extensions.keys().map(String::as_str).collect()
22    }
23
24    #[must_use]
25    pub fn extensions(&self) -> &[Arc<dyn Extension>] {
26        &self.sorted_extensions
27    }
28
29    #[must_use]
30    pub fn schema_extensions(&self) -> Vec<Arc<dyn Extension>> {
31        self.sorted_extensions
32            .iter()
33            .filter(|e| e.has_schemas())
34            .cloned()
35            .collect()
36    }
37
38    #[must_use]
39    pub fn enabled_extensions(&self, disabled_ids: &[String]) -> Vec<Arc<dyn Extension>> {
40        self.sorted_extensions
41            .iter()
42            .filter(|ext| {
43                let id = ext.id();
44                if ext.is_required() {
45                    if disabled_ids.iter().any(|d| d == id) {
46                        warn!(
47                            extension = %id,
48                            "Cannot disable required extension - ignoring disabled flag"
49                        );
50                    }
51                    return true;
52                }
53                !disabled_ids.iter().any(|d| d == id)
54            })
55            .cloned()
56            .collect()
57    }
58
59    /// Schema-bearing extensions in dependency (topological) order — the
60    /// single ordering authority for schema installation. `enabled_extensions`
61    /// already preserves `sorted_extensions` topo order, so this only filters.
62    #[must_use]
63    pub fn enabled_schema_extensions(&self, disabled_ids: &[String]) -> Vec<Arc<dyn Extension>> {
64        self.enabled_extensions(disabled_ids)
65            .into_iter()
66            .filter(|e| e.has_schemas() || e.has_migrations())
67            .collect()
68    }
69
70    #[must_use]
71    pub fn enabled_api_extensions(
72        &self,
73        ctx: &dyn crate::ExtensionContext,
74        disabled_ids: &[String],
75    ) -> Vec<Arc<dyn Extension>> {
76        self.enabled_extensions(disabled_ids)
77            .into_iter()
78            .filter(|e| e.has_router(ctx))
79            .collect()
80    }
81
82    #[must_use]
83    pub fn enabled_job_extensions(&self, disabled_ids: &[String]) -> Vec<Arc<dyn Extension>> {
84        self.enabled_extensions(disabled_ids)
85            .into_iter()
86            .filter(|e| e.has_jobs())
87            .collect()
88    }
89
90    #[must_use]
91    pub fn api_extensions(&self, ctx: &dyn crate::ExtensionContext) -> Vec<Arc<dyn Extension>> {
92        self.sorted_extensions
93            .iter()
94            .filter(|e| e.has_router(ctx))
95            .cloned()
96            .collect()
97    }
98
99    #[must_use]
100    pub fn job_extensions(&self) -> Vec<Arc<dyn Extension>> {
101        self.sorted_extensions
102            .iter()
103            .filter(|e| e.has_jobs())
104            .cloned()
105            .collect()
106    }
107
108    #[must_use]
109    pub fn config_extensions(&self) -> Vec<Arc<dyn Extension>> {
110        self.sorted_extensions
111            .iter()
112            .filter(|e| e.has_config())
113            .cloned()
114            .collect()
115    }
116
117    #[must_use]
118    pub fn llm_provider_extensions(&self) -> Vec<Arc<dyn Extension>> {
119        self.sorted_extensions
120            .iter()
121            .filter(|e| e.has_llm_providers())
122            .cloned()
123            .collect()
124    }
125
126    #[must_use]
127    pub fn tool_provider_extensions(&self) -> Vec<Arc<dyn Extension>> {
128        self.sorted_extensions
129            .iter()
130            .filter(|e| e.has_tool_providers())
131            .cloned()
132            .collect()
133    }
134
135    #[must_use]
136    pub fn storage_extensions(&self) -> Vec<Arc<dyn Extension>> {
137        self.sorted_extensions
138            .iter()
139            .filter(|e| e.has_storage_paths())
140            .cloned()
141            .collect()
142    }
143
144    pub fn all_required_storage_paths(&self) -> Vec<&'static str> {
145        self.sorted_extensions
146            .iter()
147            .flat_map(|e| e.required_storage_paths())
148            .collect()
149    }
150
151    #[must_use]
152    pub fn asset_extensions(&self) -> Vec<Arc<dyn Extension>> {
153        self.sorted_extensions
154            .iter()
155            .filter(|e| e.declares_assets())
156            .cloned()
157            .collect()
158    }
159
160    pub fn all_required_assets(
161        &self,
162        paths: &dyn AssetPaths,
163    ) -> Vec<(&'static str, AssetDefinition)> {
164        self.sorted_extensions
165            .iter()
166            .flat_map(|e| {
167                let id = e.id();
168                e.required_assets(paths)
169                    .into_iter()
170                    .map(move |asset| (id, asset))
171            })
172            .collect()
173    }
174
175    #[must_use]
176    pub fn all_jobs(&self) -> Vec<Arc<dyn Job>> {
177        self.sorted_extensions
178            .iter()
179            .flat_map(|ext| ext.jobs())
180            .collect()
181    }
182
183    #[must_use]
184    pub fn job_by_name(&self, name: &str) -> Option<Arc<dyn Job>> {
185        self.sorted_extensions
186            .iter()
187            .flat_map(|ext| ext.jobs())
188            .find(|job| job.name() == name)
189    }
190
191    #[must_use]
192    pub fn jobs_by_tag(&self, tag: &str) -> Vec<Arc<dyn Job>> {
193        self.sorted_extensions
194            .iter()
195            .flat_map(|ext| ext.jobs())
196            .filter(|job| job.tags().contains(&tag))
197            .collect()
198    }
199}