systemprompt_extension/registry/
queries.rs1use super::ExtensionRegistry;
2use crate::Extension;
3use crate::asset::{AssetDefinition, AssetPaths};
4use std::sync::Arc;
5use systemprompt_provider_contracts::Job;
6use tracing::warn;
7
8impl ExtensionRegistry {
9 #[must_use]
10 pub fn get(&self, id: &str) -> Option<&Arc<dyn Extension>> {
11 self.extensions.get(id)
12 }
13
14 #[must_use]
15 pub fn has(&self, id: &str) -> bool {
16 self.extensions.contains_key(id)
17 }
18
19 #[must_use]
20 pub fn ids(&self) -> Vec<&str> {
21 self.extensions.keys().map(String::as_str).collect()
22 }
23
24 #[must_use]
25 pub fn extensions(&self) -> &[Arc<dyn Extension>] {
26 &self.sorted_extensions
27 }
28
29 #[must_use]
30 pub fn schema_extensions(&self) -> Vec<Arc<dyn Extension>> {
31 let mut exts: Vec<_> = self
32 .sorted_extensions
33 .iter()
34 .filter(|e| e.has_schemas())
35 .cloned()
36 .collect();
37 exts.sort_by_key(|e| e.migration_weight());
38 exts
39 }
40
41 #[must_use]
42 pub fn enabled_extensions(&self, disabled_ids: &[String]) -> Vec<Arc<dyn Extension>> {
43 self.sorted_extensions
44 .iter()
45 .filter(|ext| {
46 let id = ext.id();
47 if ext.is_required() {
48 if disabled_ids.iter().any(|d| d == id) {
49 warn!(
50 extension = %id,
51 "Cannot disable required extension - ignoring disabled flag"
52 );
53 }
54 return true;
55 }
56 !disabled_ids.iter().any(|d| d == id)
57 })
58 .cloned()
59 .collect()
60 }
61
62 #[must_use]
63 pub fn enabled_schema_extensions(&self, disabled_ids: &[String]) -> Vec<Arc<dyn Extension>> {
64 let mut exts: Vec<_> = self
65 .enabled_extensions(disabled_ids)
66 .into_iter()
67 .filter(|e| e.has_schemas() || e.has_migrations())
68 .collect();
69 exts.sort_by_key(|e| e.migration_weight());
70 exts
71 }
72
73 #[must_use]
74 pub fn enabled_api_extensions(
75 &self,
76 ctx: &dyn crate::ExtensionContext,
77 disabled_ids: &[String],
78 ) -> Vec<Arc<dyn Extension>> {
79 self.enabled_extensions(disabled_ids)
80 .into_iter()
81 .filter(|e| e.has_router(ctx))
82 .collect()
83 }
84
85 #[must_use]
86 pub fn enabled_job_extensions(&self, disabled_ids: &[String]) -> Vec<Arc<dyn Extension>> {
87 self.enabled_extensions(disabled_ids)
88 .into_iter()
89 .filter(|e| e.has_jobs())
90 .collect()
91 }
92
93 #[must_use]
94 pub fn api_extensions(&self, ctx: &dyn crate::ExtensionContext) -> Vec<Arc<dyn Extension>> {
95 self.sorted_extensions
96 .iter()
97 .filter(|e| e.has_router(ctx))
98 .cloned()
99 .collect()
100 }
101
102 #[must_use]
103 pub fn job_extensions(&self) -> Vec<Arc<dyn Extension>> {
104 self.sorted_extensions
105 .iter()
106 .filter(|e| e.has_jobs())
107 .cloned()
108 .collect()
109 }
110
111 #[must_use]
112 pub fn config_extensions(&self) -> Vec<Arc<dyn Extension>> {
113 self.sorted_extensions
114 .iter()
115 .filter(|e| e.has_config())
116 .cloned()
117 .collect()
118 }
119
120 #[must_use]
121 pub fn llm_provider_extensions(&self) -> Vec<Arc<dyn Extension>> {
122 self.sorted_extensions
123 .iter()
124 .filter(|e| e.has_llm_providers())
125 .cloned()
126 .collect()
127 }
128
129 #[must_use]
130 pub fn tool_provider_extensions(&self) -> Vec<Arc<dyn Extension>> {
131 self.sorted_extensions
132 .iter()
133 .filter(|e| e.has_tool_providers())
134 .cloned()
135 .collect()
136 }
137
138 #[must_use]
139 pub fn storage_extensions(&self) -> Vec<Arc<dyn Extension>> {
140 self.sorted_extensions
141 .iter()
142 .filter(|e| e.has_storage_paths())
143 .cloned()
144 .collect()
145 }
146
147 pub fn all_required_storage_paths(&self) -> Vec<&'static str> {
148 self.sorted_extensions
149 .iter()
150 .flat_map(|e| e.required_storage_paths())
151 .collect()
152 }
153
154 #[must_use]
155 pub fn asset_extensions(&self) -> Vec<Arc<dyn Extension>> {
156 self.sorted_extensions
157 .iter()
158 .filter(|e| e.declares_assets())
159 .cloned()
160 .collect()
161 }
162
163 pub fn all_required_assets(
164 &self,
165 paths: &dyn AssetPaths,
166 ) -> Vec<(&'static str, AssetDefinition)> {
167 self.sorted_extensions
168 .iter()
169 .flat_map(|e| {
170 let id = e.id();
171 e.required_assets(paths)
172 .into_iter()
173 .map(move |asset| (id, asset))
174 })
175 .collect()
176 }
177
178 #[must_use]
179 pub fn all_jobs(&self) -> Vec<Arc<dyn Job>> {
180 self.sorted_extensions
181 .iter()
182 .flat_map(|ext| ext.jobs())
183 .collect()
184 }
185
186 #[must_use]
187 pub fn job_by_name(&self, name: &str) -> Option<Arc<dyn Job>> {
188 self.sorted_extensions
189 .iter()
190 .flat_map(|ext| ext.jobs())
191 .find(|job| job.name() == name)
192 }
193
194 #[must_use]
195 pub fn jobs_by_tag(&self, tag: &str) -> Vec<Arc<dyn Job>> {
196 self.sorted_extensions
197 .iter()
198 .flat_map(|ext| ext.jobs())
199 .filter(|job| job.tags().contains(&tag))
200 .collect()
201 }
202}