Skip to main content

systemprompt_extension/registry/
queries.rs

1use super::ExtensionRegistry;
2use crate::Extension;
3use crate::asset::{AssetDefinition, AssetPaths};
4use std::sync::Arc;
5use systemprompt_provider_contracts::Job;
6use tracing::warn;
7
8impl ExtensionRegistry {
9    #[must_use]
10    pub fn get(&self, id: &str) -> Option<&Arc<dyn Extension>> {
11        self.extensions.get(id)
12    }
13
14    #[must_use]
15    pub fn has(&self, id: &str) -> bool {
16        self.extensions.contains_key(id)
17    }
18
19    #[must_use]
20    pub fn ids(&self) -> Vec<&str> {
21        self.extensions.keys().map(String::as_str).collect()
22    }
23
24    #[must_use]
25    pub fn extensions(&self) -> &[Arc<dyn Extension>] {
26        &self.sorted_extensions
27    }
28
29    #[must_use]
30    pub fn schema_extensions(&self) -> Vec<Arc<dyn Extension>> {
31        let mut exts: Vec<_> = self
32            .sorted_extensions
33            .iter()
34            .filter(|e| e.has_schemas())
35            .cloned()
36            .collect();
37        exts.sort_by_key(|e| e.migration_weight());
38        exts
39    }
40
41    #[must_use]
42    pub fn enabled_extensions(&self, disabled_ids: &[String]) -> Vec<Arc<dyn Extension>> {
43        self.sorted_extensions
44            .iter()
45            .filter(|ext| {
46                let id = ext.id();
47                if ext.is_required() {
48                    if disabled_ids.iter().any(|d| d == id) {
49                        warn!(
50                            extension = %id,
51                            "Cannot disable required extension - ignoring disabled flag"
52                        );
53                    }
54                    return true;
55                }
56                !disabled_ids.iter().any(|d| d == id)
57            })
58            .cloned()
59            .collect()
60    }
61
62    #[must_use]
63    pub fn enabled_schema_extensions(&self, disabled_ids: &[String]) -> Vec<Arc<dyn Extension>> {
64        let mut exts: Vec<_> = self
65            .enabled_extensions(disabled_ids)
66            .into_iter()
67            .filter(|e| e.has_schemas() || e.has_migrations())
68            .collect();
69        exts.sort_by_key(|e| e.migration_weight());
70        exts
71    }
72
73    #[must_use]
74    pub fn enabled_api_extensions(
75        &self,
76        ctx: &dyn crate::ExtensionContext,
77        disabled_ids: &[String],
78    ) -> Vec<Arc<dyn Extension>> {
79        self.enabled_extensions(disabled_ids)
80            .into_iter()
81            .filter(|e| e.has_router(ctx))
82            .collect()
83    }
84
85    #[must_use]
86    pub fn enabled_job_extensions(&self, disabled_ids: &[String]) -> Vec<Arc<dyn Extension>> {
87        self.enabled_extensions(disabled_ids)
88            .into_iter()
89            .filter(|e| e.has_jobs())
90            .collect()
91    }
92
93    #[must_use]
94    pub fn api_extensions(&self, ctx: &dyn crate::ExtensionContext) -> Vec<Arc<dyn Extension>> {
95        self.sorted_extensions
96            .iter()
97            .filter(|e| e.has_router(ctx))
98            .cloned()
99            .collect()
100    }
101
102    #[must_use]
103    pub fn job_extensions(&self) -> Vec<Arc<dyn Extension>> {
104        self.sorted_extensions
105            .iter()
106            .filter(|e| e.has_jobs())
107            .cloned()
108            .collect()
109    }
110
111    #[must_use]
112    pub fn config_extensions(&self) -> Vec<Arc<dyn Extension>> {
113        self.sorted_extensions
114            .iter()
115            .filter(|e| e.has_config())
116            .cloned()
117            .collect()
118    }
119
120    #[must_use]
121    pub fn llm_provider_extensions(&self) -> Vec<Arc<dyn Extension>> {
122        self.sorted_extensions
123            .iter()
124            .filter(|e| e.has_llm_providers())
125            .cloned()
126            .collect()
127    }
128
129    #[must_use]
130    pub fn tool_provider_extensions(&self) -> Vec<Arc<dyn Extension>> {
131        self.sorted_extensions
132            .iter()
133            .filter(|e| e.has_tool_providers())
134            .cloned()
135            .collect()
136    }
137
138    #[must_use]
139    pub fn storage_extensions(&self) -> Vec<Arc<dyn Extension>> {
140        self.sorted_extensions
141            .iter()
142            .filter(|e| e.has_storage_paths())
143            .cloned()
144            .collect()
145    }
146
147    pub fn all_required_storage_paths(&self) -> Vec<&'static str> {
148        self.sorted_extensions
149            .iter()
150            .flat_map(|e| e.required_storage_paths())
151            .collect()
152    }
153
154    #[must_use]
155    pub fn asset_extensions(&self) -> Vec<Arc<dyn Extension>> {
156        self.sorted_extensions
157            .iter()
158            .filter(|e| e.declares_assets())
159            .cloned()
160            .collect()
161    }
162
163    pub fn all_required_assets(
164        &self,
165        paths: &dyn AssetPaths,
166    ) -> Vec<(&'static str, AssetDefinition)> {
167        self.sorted_extensions
168            .iter()
169            .flat_map(|e| {
170                let id = e.id();
171                e.required_assets(paths)
172                    .into_iter()
173                    .map(move |asset| (id, asset))
174            })
175            .collect()
176    }
177
178    #[must_use]
179    pub fn all_jobs(&self) -> Vec<Arc<dyn Job>> {
180        self.sorted_extensions
181            .iter()
182            .flat_map(|ext| ext.jobs())
183            .collect()
184    }
185
186    #[must_use]
187    pub fn job_by_name(&self, name: &str) -> Option<Arc<dyn Job>> {
188        self.sorted_extensions
189            .iter()
190            .flat_map(|ext| ext.jobs())
191            .find(|job| job.name() == name)
192    }
193
194    #[must_use]
195    pub fn jobs_by_tag(&self, tag: &str) -> Vec<Arc<dyn Job>> {
196        self.sorted_extensions
197            .iter()
198            .flat_map(|ext| ext.jobs())
199            .filter(|job| job.tags().contains(&tag))
200            .collect()
201    }
202}