1use std::collections::HashMap;
2use std::path::Path;
3
4use handlebars::Handlebars;
5use serde_json::Value;
6use systemprompt_template_provider::{
7 DynComponentRenderer, DynPageDataProvider, DynPagePrerenderer, DynTemplateDataExtender,
8 DynTemplateLoader, DynTemplateProvider, PartialSource, TemplateDefinition,
9};
10use tracing::{debug, info, warn};
11
12use crate::TemplateError;
13
14pub struct TemplateRegistry {
15 providers: Vec<DynTemplateProvider>,
16 loaders: Vec<DynTemplateLoader>,
17 extenders: Vec<DynTemplateDataExtender>,
18 components: Vec<DynComponentRenderer>,
19 page_providers: Vec<DynPageDataProvider>,
20 page_prerenderers: Vec<DynPagePrerenderer>,
21 resolved_templates: HashMap<String, TemplateDefinition>,
22 handlebars: Handlebars<'static>,
23 template_sources: HashMap<String, String>,
24}
25
26impl Default for TemplateRegistry {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl TemplateRegistry {
33 #[must_use]
34 pub fn new() -> Self {
35 Self {
36 providers: Vec::new(),
37 loaders: Vec::new(),
38 extenders: Vec::new(),
39 components: Vec::new(),
40 page_providers: Vec::new(),
41 page_prerenderers: Vec::new(),
42 resolved_templates: HashMap::new(),
43 handlebars: Handlebars::new(),
44 template_sources: HashMap::new(),
45 }
46 }
47
48 pub fn register_provider(&mut self, provider: DynTemplateProvider) {
49 debug!(
50 provider_id = %provider.provider_id(),
51 priority = provider.priority(),
52 "Registering template provider"
53 );
54 self.providers.push(provider);
55 self.providers.sort_by_key(|p| p.priority());
56 }
57
58 pub fn register_loader(&mut self, loader: DynTemplateLoader) {
59 self.loaders.push(loader);
60 }
61
62 pub fn register_extender(&mut self, extender: DynTemplateDataExtender) {
63 debug!(
64 extender_id = %extender.extender_id(),
65 priority = extender.priority(),
66 "Registering template data extender"
67 );
68 self.extenders.push(extender);
69 self.extenders.sort_by_key(|e| e.priority());
70 }
71
72 pub fn register_component(&mut self, component: DynComponentRenderer) {
73 debug!(
74 component_id = %component.component_id(),
75 variable_name = %component.variable_name(),
76 priority = component.priority(),
77 "Registering component renderer"
78 );
79 self.components.push(component);
80 self.components.sort_by_key(|c| c.priority());
81 }
82
83 pub fn register_page_provider(&mut self, provider: DynPageDataProvider) {
84 debug!(
85 provider_id = %provider.provider_id(),
86 pages = ?provider.applies_to_pages(),
87 "Registering page data provider"
88 );
89 self.page_providers.push(provider);
90 self.page_providers.sort_by_key(|p| p.priority());
91 }
92
93 pub fn register_page_prerenderer(&mut self, prerenderer: DynPagePrerenderer) {
94 debug!(
95 page_type = %prerenderer.page_type(),
96 priority = prerenderer.priority(),
97 "Registering page prerenderer"
98 );
99 self.page_prerenderers.push(prerenderer);
100 self.page_prerenderers.sort_by_key(|p| p.priority());
101 }
102
103 pub async fn initialize(&mut self) -> Result<(), TemplateError> {
104 info!(
105 providers = self.providers.len(),
106 loaders = self.loaders.len(),
107 "Initializing template registry"
108 );
109
110 if self.loaders.is_empty() {
111 return Err(TemplateError::NotInitialized);
112 }
113
114 let mut all_templates: Vec<(TemplateDefinition, &str)> = Vec::new();
115
116 for provider in &self.providers {
117 for template in provider.templates() {
118 all_templates.push((template, provider.provider_id()));
119 }
120 }
121
122 all_templates.sort_by(|a, b| a.0.priority.cmp(&b.0.priority));
123
124 for (template, provider_id) in all_templates {
125 if self.resolved_templates.contains_key(&template.name) {
126 debug!(
127 template = %template.name,
128 provider = %provider_id,
129 "Template already registered, skipping"
130 );
131 continue;
132 }
133
134 debug!(
135 template = %template.name,
136 provider = %provider_id,
137 priority = template.priority,
138 "Registering template"
139 );
140
141 match self.load_template(&template).await {
142 Ok(content) => {
143 if let Err(e) = self
144 .handlebars
145 .register_template_string(&template.name, content)
146 {
147 warn!(
148 template = %template.name,
149 error = %e,
150 "Failed to compile template"
151 );
152 continue;
153 }
154 self.template_sources
155 .insert(template.name.clone(), provider_id.to_string());
156 self.resolved_templates
157 .insert(template.name.clone(), template);
158 },
159 Err(e) => {
160 warn!(
161 template = %template.name,
162 error = %e,
163 "Failed to load template"
164 );
165 },
166 }
167 }
168
169 self.register_partial_templates().await;
170
171 info!(
172 templates = self.resolved_templates.len(),
173 "Template registry initialized"
174 );
175
176 Ok(())
177 }
178
179 async fn register_partial_templates(&mut self) {
180 for component in &self.components {
181 let Some(partial) = component.partial_template() else {
182 continue;
183 };
184
185 let content = match &partial.source {
186 PartialSource::Embedded(s) => (*s).to_string(),
187 PartialSource::File(path) => match self.load_partial_file(path).await {
188 Ok(c) => c,
189 Err(e) => {
190 warn!(
191 component_id = %component.component_id(),
192 path = %path.display(),
193 error = %e,
194 "Failed to load partial template file"
195 );
196 continue;
197 },
198 },
199 };
200
201 debug!(
202 component_id = %component.component_id(),
203 partial_name = %partial.name,
204 "Registering partial template"
205 );
206
207 if let Err(e) = self
208 .handlebars
209 .register_template_string(&partial.name, content)
210 {
211 warn!(
212 component_id = %component.component_id(),
213 partial_name = %partial.name,
214 error = %e,
215 "Failed to compile partial template"
216 );
217 }
218 }
219 }
220
221 async fn load_partial_file(&self, path: &Path) -> Result<String, TemplateError> {
222 tokio::fs::read_to_string(path)
223 .await
224 .map_err(|e| TemplateError::LoadError {
225 name: path.display().to_string(),
226 source: e.into(),
227 })
228 }
229
230 async fn load_template(
231 &self,
232 definition: &TemplateDefinition,
233 ) -> Result<String, TemplateError> {
234 for loader in &self.loaders {
235 if loader.can_load(&definition.source) {
236 return loader.load(&definition.source).await.map_err(|e| {
237 TemplateError::LoadError {
238 name: definition.name.clone(),
239 source: e.into(),
240 }
241 });
242 }
243 }
244 Err(TemplateError::NoLoader(definition.name.clone()))
245 }
246
247 pub fn render(&self, template_name: &str, data: &Value) -> Result<String, TemplateError> {
248 self.handlebars
249 .render(template_name, data)
250 .map_err(|e| TemplateError::RenderError {
251 name: template_name.to_string(),
252 source: e.into(),
253 })
254 }
255
256 pub fn render_partial(
257 &self,
258 partial_name: &str,
259 data: &Value,
260 ) -> Result<String, TemplateError> {
261 self.handlebars
262 .render(partial_name, data)
263 .map_err(|e| TemplateError::RenderError {
264 name: partial_name.to_string(),
265 source: e.into(),
266 })
267 }
268
269 #[must_use]
270 pub fn has_partial(&self, partial_name: &str) -> bool {
271 self.handlebars.has_template(partial_name)
272 }
273
274 #[must_use]
275 pub fn has_template(&self, name: &str) -> bool {
276 self.resolved_templates.contains_key(name)
277 }
278
279 #[must_use]
280 pub fn find_template(&self, name: &str) -> Option<&TemplateDefinition> {
281 self.resolved_templates.get(name)
282 }
283
284 #[must_use]
285 pub fn find_template_for_content_type(&self, content_type: &str) -> Option<&str> {
286 let content_type_owned = content_type.to_string();
287 self.resolved_templates
288 .iter()
289 .find(|(_, def)| def.content_types.contains(&content_type_owned))
290 .map(|(name, _)| name.as_str())
291 }
292
293 #[must_use]
294 pub fn extenders_for(&self, content_type: &str) -> Vec<&DynTemplateDataExtender> {
295 let content_type_owned = content_type.to_string();
296 self.extenders
297 .iter()
298 .filter(|e| {
299 let types = e.applies_to();
300 types.is_empty() || types.contains(&content_type_owned)
301 })
302 .collect()
303 }
304
305 #[must_use]
306 pub fn components_for(&self, content_type: &str) -> Vec<&DynComponentRenderer> {
307 let content_type_owned = content_type.to_string();
308 self.components
309 .iter()
310 .filter(|c| {
311 let types = c.applies_to();
312 types.is_empty() || types.contains(&content_type_owned)
313 })
314 .collect()
315 }
316
317 #[must_use]
318 pub fn page_providers_for(&self, page_type: &str) -> Vec<&DynPageDataProvider> {
319 let page_type_owned = page_type.to_string();
320 self.page_providers
321 .iter()
322 .filter(|p| {
323 let pages = p.applies_to_pages();
324 pages.is_empty() || pages.contains(&page_type_owned)
325 })
326 .collect()
327 }
328
329 #[must_use]
330 pub fn page_prerenderers(&self) -> &[DynPagePrerenderer] {
331 &self.page_prerenderers
332 }
333
334 #[must_use]
335 pub fn find_template_provider(&self, name: &str) -> Option<&str> {
336 self.template_sources.get(name).map(String::as_str)
337 }
338
339 #[must_use]
340 pub fn template_names(&self) -> Vec<&str> {
341 self.resolved_templates.keys().map(String::as_str).collect()
342 }
343
344 #[must_use]
345 pub fn available_content_types(&self) -> Vec<String> {
346 self.resolved_templates
347 .values()
348 .flat_map(|def| def.content_types.iter().cloned())
349 .collect()
350 }
351
352 #[must_use]
353 pub fn stats(&self) -> RegistryStats {
354 RegistryStats {
355 providers: self.providers.len(),
356 templates: self.resolved_templates.len(),
357 loaders: self.loaders.len(),
358 extenders: self.extenders.len(),
359 components: self.components.len(),
360 page_providers: self.page_providers.len(),
361 page_prerenderers: self.page_prerenderers.len(),
362 }
363 }
364}
365
366#[derive(Debug, Clone, Copy)]
367pub struct RegistryStats {
368 pub providers: usize,
369 pub templates: usize,
370 pub loaders: usize,
371 pub extenders: usize,
372 pub components: usize,
373 pub page_providers: usize,
374 pub page_prerenderers: usize,
375}
376
377impl std::fmt::Debug for TemplateRegistry {
378 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379 f.debug_struct("TemplateRegistry")
380 .field("providers", &self.providers.len())
381 .field(
382 "templates",
383 &self.resolved_templates.keys().collect::<Vec<_>>(),
384 )
385 .field("loaders", &self.loaders.len())
386 .field("extenders", &self.extenders.len())
387 .field("components", &self.components.len())
388 .field("page_providers", &self.page_providers.len())
389 .field("page_prerenderers", &self.page_prerenderers.len())
390 .finish_non_exhaustive()
391 }
392}