Skip to main content

systemprompt_templates/
registry.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use handlebars::Handlebars;
5use serde_json::Value;
6use systemprompt_template_provider::{
7    DynComponentRenderer, DynPageDataProvider, DynPagePrerenderer, DynTemplateDataExtender,
8    DynTemplateLoader, DynTemplateProvider, PartialSource, TemplateDefinition,
9};
10use tracing::{debug, info, warn};
11
12use crate::TemplateError;
13
14pub struct TemplateRegistry {
15    providers: Vec<DynTemplateProvider>,
16    loaders: Vec<DynTemplateLoader>,
17    extenders: Vec<DynTemplateDataExtender>,
18    components: Vec<DynComponentRenderer>,
19    page_providers: Vec<DynPageDataProvider>,
20    page_prerenderers: Vec<DynPagePrerenderer>,
21    resolved_templates: HashMap<String, TemplateDefinition>,
22    handlebars: Handlebars<'static>,
23    template_sources: HashMap<String, String>,
24}
25
26impl Default for TemplateRegistry {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl TemplateRegistry {
33    #[must_use]
34    pub fn new() -> Self {
35        Self {
36            providers: Vec::new(),
37            loaders: Vec::new(),
38            extenders: Vec::new(),
39            components: Vec::new(),
40            page_providers: Vec::new(),
41            page_prerenderers: Vec::new(),
42            resolved_templates: HashMap::new(),
43            handlebars: Handlebars::new(),
44            template_sources: HashMap::new(),
45        }
46    }
47
48    pub fn register_provider(&mut self, provider: DynTemplateProvider) {
49        debug!(
50            provider_id = %provider.provider_id(),
51            priority = provider.priority(),
52            "Registering template provider"
53        );
54        self.providers.push(provider);
55        self.providers.sort_by_key(|p| p.priority());
56    }
57
58    pub fn register_loader(&mut self, loader: DynTemplateLoader) {
59        self.loaders.push(loader);
60    }
61
62    pub fn register_extender(&mut self, extender: DynTemplateDataExtender) {
63        debug!(
64            extender_id = %extender.extender_id(),
65            priority = extender.priority(),
66            "Registering template data extender"
67        );
68        self.extenders.push(extender);
69        self.extenders.sort_by_key(|e| e.priority());
70    }
71
72    pub fn register_component(&mut self, component: DynComponentRenderer) {
73        debug!(
74            component_id = %component.component_id(),
75            variable_name = %component.variable_name(),
76            priority = component.priority(),
77            "Registering component renderer"
78        );
79        self.components.push(component);
80        self.components.sort_by_key(|c| c.priority());
81    }
82
83    pub fn register_page_provider(&mut self, provider: DynPageDataProvider) {
84        debug!(
85            provider_id = %provider.provider_id(),
86            pages = ?provider.applies_to_pages(),
87            "Registering page data provider"
88        );
89        self.page_providers.push(provider);
90        self.page_providers.sort_by_key(|p| p.priority());
91    }
92
93    pub fn register_page_prerenderer(&mut self, prerenderer: DynPagePrerenderer) {
94        debug!(
95            page_type = %prerenderer.page_type(),
96            priority = prerenderer.priority(),
97            "Registering page prerenderer"
98        );
99        self.page_prerenderers.push(prerenderer);
100        self.page_prerenderers.sort_by_key(|p| p.priority());
101    }
102
103    pub async fn initialize(&mut self) -> Result<(), TemplateError> {
104        info!(
105            providers = self.providers.len(),
106            loaders = self.loaders.len(),
107            "Initializing template registry"
108        );
109
110        if self.loaders.is_empty() {
111            return Err(TemplateError::NotInitialized);
112        }
113
114        let mut all_templates: Vec<(TemplateDefinition, &str)> = Vec::new();
115
116        for provider in &self.providers {
117            for template in provider.templates() {
118                all_templates.push((template, provider.provider_id()));
119            }
120        }
121
122        all_templates.sort_by(|a, b| a.0.priority.cmp(&b.0.priority));
123
124        for (template, provider_id) in all_templates {
125            if self.resolved_templates.contains_key(&template.name) {
126                debug!(
127                    template = %template.name,
128                    provider = %provider_id,
129                    "Template already registered, skipping"
130                );
131                continue;
132            }
133
134            debug!(
135                template = %template.name,
136                provider = %provider_id,
137                priority = template.priority,
138                "Registering template"
139            );
140
141            match self.load_template(&template).await {
142                Ok(content) => {
143                    if let Err(e) = self
144                        .handlebars
145                        .register_template_string(&template.name, content)
146                    {
147                        warn!(
148                            template = %template.name,
149                            error = %e,
150                            "Failed to compile template"
151                        );
152                        continue;
153                    }
154                    self.template_sources
155                        .insert(template.name.clone(), provider_id.to_string());
156                    self.resolved_templates
157                        .insert(template.name.clone(), template);
158                },
159                Err(e) => {
160                    warn!(
161                        template = %template.name,
162                        error = %e,
163                        "Failed to load template"
164                    );
165                },
166            }
167        }
168
169        self.register_partial_templates().await;
170
171        info!(
172            templates = self.resolved_templates.len(),
173            "Template registry initialized"
174        );
175
176        Ok(())
177    }
178
179    async fn register_partial_templates(&mut self) {
180        for component in &self.components {
181            let Some(partial) = component.partial_template() else {
182                continue;
183            };
184
185            let content = match &partial.source {
186                PartialSource::Embedded(s) => (*s).to_string(),
187                PartialSource::File(path) => match self.load_partial_file(path).await {
188                    Ok(c) => c,
189                    Err(e) => {
190                        warn!(
191                            component_id = %component.component_id(),
192                            path = %path.display(),
193                            error = %e,
194                            "Failed to load partial template file"
195                        );
196                        continue;
197                    },
198                },
199            };
200
201            debug!(
202                component_id = %component.component_id(),
203                partial_name = %partial.name,
204                "Registering partial template"
205            );
206
207            if let Err(e) = self
208                .handlebars
209                .register_template_string(&partial.name, content)
210            {
211                warn!(
212                    component_id = %component.component_id(),
213                    partial_name = %partial.name,
214                    error = %e,
215                    "Failed to compile partial template"
216                );
217            }
218        }
219    }
220
221    async fn load_partial_file(&self, path: &Path) -> Result<String, TemplateError> {
222        tokio::fs::read_to_string(path)
223            .await
224            .map_err(|e| TemplateError::LoadError {
225                name: path.display().to_string(),
226                source: e.into(),
227            })
228    }
229
230    async fn load_template(
231        &self,
232        definition: &TemplateDefinition,
233    ) -> Result<String, TemplateError> {
234        for loader in &self.loaders {
235            if loader.can_load(&definition.source) {
236                return loader.load(&definition.source).await.map_err(|e| {
237                    TemplateError::LoadError {
238                        name: definition.name.clone(),
239                        source: e.into(),
240                    }
241                });
242            }
243        }
244        Err(TemplateError::NoLoader(definition.name.clone()))
245    }
246
247    pub fn render(&self, template_name: &str, data: &Value) -> Result<String, TemplateError> {
248        self.handlebars
249            .render(template_name, data)
250            .map_err(|e| TemplateError::RenderError {
251                name: template_name.to_string(),
252                source: e.into(),
253            })
254    }
255
256    pub fn render_partial(
257        &self,
258        partial_name: &str,
259        data: &Value,
260    ) -> Result<String, TemplateError> {
261        self.handlebars
262            .render(partial_name, data)
263            .map_err(|e| TemplateError::RenderError {
264                name: partial_name.to_string(),
265                source: e.into(),
266            })
267    }
268
269    #[must_use]
270    pub fn has_partial(&self, partial_name: &str) -> bool {
271        self.handlebars.has_template(partial_name)
272    }
273
274    #[must_use]
275    pub fn has_template(&self, name: &str) -> bool {
276        self.resolved_templates.contains_key(name)
277    }
278
279    #[must_use]
280    pub fn find_template(&self, name: &str) -> Option<&TemplateDefinition> {
281        self.resolved_templates.get(name)
282    }
283
284    #[must_use]
285    pub fn find_template_for_content_type(&self, content_type: &str) -> Option<&str> {
286        let content_type_owned = content_type.to_string();
287        self.resolved_templates
288            .iter()
289            .find(|(_, def)| def.content_types.contains(&content_type_owned))
290            .map(|(name, _)| name.as_str())
291    }
292
293    #[must_use]
294    pub fn extenders_for(&self, content_type: &str) -> Vec<&DynTemplateDataExtender> {
295        let content_type_owned = content_type.to_string();
296        self.extenders
297            .iter()
298            .filter(|e| {
299                let types = e.applies_to();
300                types.is_empty() || types.contains(&content_type_owned)
301            })
302            .collect()
303    }
304
305    #[must_use]
306    pub fn components_for(&self, content_type: &str) -> Vec<&DynComponentRenderer> {
307        let content_type_owned = content_type.to_string();
308        self.components
309            .iter()
310            .filter(|c| {
311                let types = c.applies_to();
312                types.is_empty() || types.contains(&content_type_owned)
313            })
314            .collect()
315    }
316
317    #[must_use]
318    pub fn page_providers_for(&self, page_type: &str) -> Vec<&DynPageDataProvider> {
319        let page_type_owned = page_type.to_string();
320        self.page_providers
321            .iter()
322            .filter(|p| {
323                let pages = p.applies_to_pages();
324                pages.is_empty() || pages.contains(&page_type_owned)
325            })
326            .collect()
327    }
328
329    #[must_use]
330    pub fn page_prerenderers(&self) -> &[DynPagePrerenderer] {
331        &self.page_prerenderers
332    }
333
334    #[must_use]
335    pub fn find_template_provider(&self, name: &str) -> Option<&str> {
336        self.template_sources.get(name).map(String::as_str)
337    }
338
339    #[must_use]
340    pub fn template_names(&self) -> Vec<&str> {
341        self.resolved_templates.keys().map(String::as_str).collect()
342    }
343
344    #[must_use]
345    pub fn available_content_types(&self) -> Vec<String> {
346        self.resolved_templates
347            .values()
348            .flat_map(|def| def.content_types.iter().cloned())
349            .collect()
350    }
351
352    #[must_use]
353    pub fn stats(&self) -> RegistryStats {
354        RegistryStats {
355            providers: self.providers.len(),
356            templates: self.resolved_templates.len(),
357            loaders: self.loaders.len(),
358            extenders: self.extenders.len(),
359            components: self.components.len(),
360            page_providers: self.page_providers.len(),
361            page_prerenderers: self.page_prerenderers.len(),
362        }
363    }
364}
365
366#[derive(Debug, Clone, Copy)]
367pub struct RegistryStats {
368    pub providers: usize,
369    pub templates: usize,
370    pub loaders: usize,
371    pub extenders: usize,
372    pub components: usize,
373    pub page_providers: usize,
374    pub page_prerenderers: usize,
375}
376
377impl std::fmt::Debug for TemplateRegistry {
378    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379        f.debug_struct("TemplateRegistry")
380            .field("providers", &self.providers.len())
381            .field(
382                "templates",
383                &self.resolved_templates.keys().collect::<Vec<_>>(),
384            )
385            .field("loaders", &self.loaders.len())
386            .field("extenders", &self.extenders.len())
387            .field("components", &self.components.len())
388            .field("page_providers", &self.page_providers.len())
389            .field("page_prerenderers", &self.page_prerenderers.len())
390            .finish_non_exhaustive()
391    }
392}