Skip to main content

distri_types/
prompt.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::Arc;
4
5use handlebars::Handlebars;
6use handlebars::handlebars_helper;
7use serde::{Deserialize, Serialize};
8use tokio::sync::RwLock;
9
10use crate::{AgentError, Message, Part};
11
12/// A registry for prompt templates that can be used across the system.
13#[derive(Debug, Clone)]
14pub struct PromptRegistry {
15    templates: Arc<RwLock<HashMap<String, PromptTemplate>>>,
16    partials: Arc<RwLock<HashMap<String, String>>>,
17}
18
19/// A prompt template with metadata.
20#[derive(Debug, Clone)]
21pub struct PromptTemplate {
22    pub name: String,
23    pub content: String,
24    pub description: Option<String>,
25    pub version: Option<String>,
26}
27
28#[derive(Debug, Clone, Default, Serialize)]
29pub struct TemplateData<'a> {
30    pub description: String,
31    pub instructions: String,
32    pub available_tools: String,
33    pub task: String,
34    pub scratchpad: String,
35    pub dynamic_sections: Vec<PromptSection>,
36    pub dynamic_values: std::collections::HashMap<String, serde_json::Value>,
37    /// Session values fetched from the session store - available in templates as {{session.key}}
38    pub session_values: std::collections::HashMap<String, serde_json::Value>,
39    pub reasoning_depth: &'a str,
40    pub execution_mode: &'a str,
41    pub tool_format: &'a str,
42    pub show_examples: bool,
43    pub max_steps: usize,
44    pub current_steps: usize,
45    pub remaining_steps: usize,
46    pub todos: Option<String>,
47    pub json_tools: bool,
48    /// Formatted list of available skills the agent can load on demand
49    #[serde(default)]
50    pub available_skills: Option<String>,
51}
52
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct PromptSection {
55    pub key: String,
56    pub content: String,
57}
58
59impl PromptRegistry {
60    pub fn new() -> Self {
61        Self {
62            templates: Arc::new(RwLock::new(HashMap::new())),
63            partials: Arc::new(RwLock::new(HashMap::new())),
64        }
65    }
66
67    /// Create a registry preloaded with the built-in templates/partials.
68    pub async fn with_defaults() -> Result<Self, AgentError> {
69        let registry = Self::new();
70        registry.register_static_templates().await?;
71        registry.register_static_partials().await?;
72        Ok(registry)
73    }
74
75    async fn register_static_templates(&self) -> Result<(), AgentError> {
76        let templates = vec![
77            PromptTemplate {
78                name: "planning".to_string(),
79                content: include_str!("../prompt_templates/planning.hbs").to_string(),
80                description: Some("Default system message template".to_string()),
81                version: Some("1.0.0".to_string()),
82            },
83            PromptTemplate {
84                name: "user".to_string(),
85                content: include_str!("../prompt_templates/user.hbs").to_string(),
86                description: Some("Default user message template".to_string()),
87                version: Some("1.0.0".to_string()),
88            },
89            PromptTemplate {
90                name: "code".to_string(),
91                content: include_str!("../prompt_templates/code.hbs").to_string(),
92                description: Some("Code generation template".to_string()),
93                version: Some("1.0.0".to_string()),
94            },
95            PromptTemplate {
96                name: "reflection".to_string(),
97                content: include_str!("../prompt_templates/reflection.hbs").to_string(),
98                description: Some("Reflection and improvement template".to_string()),
99                version: Some("1.0.0".to_string()),
100            },
101            PromptTemplate {
102                name: "standard_user_message".to_string(),
103                content: include_str!("../prompt_templates/user.hbs").to_string(),
104                description: Some("Standard user message template".to_string()),
105                version: Some("1.0.0".to_string()),
106            },
107        ];
108
109        let mut templates_lock = self.templates.write().await;
110        for template in templates {
111            templates_lock.insert(template.name.clone(), template);
112        }
113
114        Ok(())
115    }
116
117    async fn register_static_partials(&self) -> Result<(), AgentError> {
118        let partials = vec![
119            (
120                "core_instructions",
121                include_str!("../prompt_templates/partials/core_instructions.hbs"),
122            ),
123            (
124                "communication",
125                include_str!("../prompt_templates/partials/communication.hbs"),
126            ),
127            (
128                "todo_instructions",
129                include_str!("../prompt_templates/partials/todo_instructions.hbs"),
130            ),
131            (
132                "tools_xml",
133                include_str!("../prompt_templates/partials/tools_xml.hbs"),
134            ),
135            (
136                "tools_json",
137                include_str!("../prompt_templates/partials/tools_json.hbs"),
138            ),
139            (
140                "reasoning",
141                include_str!("../prompt_templates/partials/reasoning.hbs"),
142            ),
143            (
144                "skills",
145                include_str!("../prompt_templates/partials/skills.hbs"),
146            ),
147            (
148                "connections",
149                include_str!("../prompt_templates/partials/connections.hbs"),
150            ),
151            (
152                "sub_agents",
153                include_str!("../prompt_templates/partials/sub_agents.hbs"),
154            ),
155        ];
156
157        let mut partials_lock = self.partials.write().await;
158        for (name, content) in partials {
159            partials_lock.insert(name.to_string(), content.to_string());
160        }
161
162        Ok(())
163    }
164
165    pub async fn register_template(&self, template: PromptTemplate) -> Result<(), AgentError> {
166        let mut templates = self.templates.write().await;
167        templates.insert(template.name.clone(), template);
168        Ok(())
169    }
170
171    pub async fn register_template_string(
172        &self,
173        name: String,
174        content: String,
175        description: Option<String>,
176        version: Option<String>,
177    ) -> Result<(), AgentError> {
178        let template = PromptTemplate {
179            name: name.clone(),
180            content,
181            description,
182            version,
183        };
184        self.register_template(template).await
185    }
186
187    pub fn get_default_templates() -> Vec<crate::stores::NewPromptTemplate> {
188        vec![
189            crate::stores::NewPromptTemplate {
190                name: "planning".to_string(),
191                template: include_str!("../prompt_templates/planning.hbs").to_string(),
192                description: Some("Default system message template".to_string()),
193                version: Some("1.0.0".to_string()),
194                is_system: true,
195            },
196            crate::stores::NewPromptTemplate {
197                name: "user".to_string(),
198                template: include_str!("../prompt_templates/user.hbs").to_string(),
199                description: Some("Default user message template".to_string()),
200                version: Some("1.0.0".to_string()),
201                is_system: true,
202            },
203            crate::stores::NewPromptTemplate {
204                name: "code".to_string(),
205                template: include_str!("../prompt_templates/code.hbs").to_string(),
206                description: Some("Code generation template".to_string()),
207                version: Some("1.0.0".to_string()),
208                is_system: true,
209            },
210            crate::stores::NewPromptTemplate {
211                name: "reflection".to_string(),
212                template: include_str!("../prompt_templates/reflection.hbs").to_string(),
213                description: Some("Reflection and improvement template".to_string()),
214                version: Some("1.0.0".to_string()),
215                is_system: true,
216            },
217            crate::stores::NewPromptTemplate {
218                name: "standard_user_message".to_string(),
219                template: include_str!("../prompt_templates/user.hbs").to_string(),
220                description: Some("Standard user message template".to_string()),
221                version: Some("1.0.0".to_string()),
222                is_system: true,
223            },
224        ]
225    }
226
227    pub async fn register_template_file<P: AsRef<Path>>(
228        &self,
229        name: String,
230        file_path: P,
231        description: Option<String>,
232        version: Option<String>,
233    ) -> Result<(), AgentError> {
234        let path = file_path.as_ref();
235        let content = tokio::fs::read_to_string(path).await.map_err(|e| {
236            AgentError::Planning(format!(
237                "Failed to read template file '{}': {}",
238                path.display(),
239                e
240            ))
241        })?;
242
243        let template = PromptTemplate {
244            name: name.clone(),
245            content,
246            description,
247            version,
248        };
249        self.register_template(template).await
250    }
251
252    pub async fn register_partial(&self, name: String, content: String) -> Result<(), AgentError> {
253        let mut partials = self.partials.write().await;
254        partials.insert(name, content);
255        Ok(())
256    }
257
258    pub async fn register_partial_file<P: AsRef<Path>>(
259        &self,
260        name: String,
261        file_path: P,
262    ) -> Result<(), AgentError> {
263        let path = file_path.as_ref();
264        let content = tokio::fs::read_to_string(path).await.map_err(|e| {
265            AgentError::Planning(format!(
266                "Failed to read partial file '{}': {}",
267                path.display(),
268                e
269            ))
270        })?;
271        self.register_partial(name, content).await
272    }
273
274    pub async fn register_templates_from_directory<P: AsRef<Path>>(
275        &self,
276        dir_path: P,
277    ) -> Result<(), AgentError> {
278        let path = dir_path.as_ref();
279        if !path.exists() {
280            return Ok(());
281        }
282
283        let mut entries = tokio::fs::read_dir(path).await.map_err(|e| {
284            AgentError::Planning(format!(
285                "Failed to read directory '{}': {}",
286                path.display(),
287                e
288            ))
289        })?;
290
291        while let Some(entry) = entries
292            .next_entry()
293            .await
294            .map_err(|e| AgentError::Planning(format!("Failed to read directory entry: {}", e)))?
295        {
296            let entry_path = entry.path();
297            if entry_path.is_file()
298                && let Some(extension) = entry_path.extension()
299                && (extension == "hbs" || extension == "handlebars")
300                && let Some(stem) = entry_path.file_stem() {
301                    let name = stem.to_string_lossy().to_string();
302                    tracing::debug!(
303                        "Registering template '{}' from '{}'",
304                        name,
305                        entry_path.display()
306                    );
307                    self.register_template_file(name, &entry_path, None, None)
308                        .await?;
309                }
310        }
311
312        Ok(())
313    }
314
315    pub async fn register_partials_from_directory<P: AsRef<Path>>(
316        &self,
317        dir_path: P,
318    ) -> Result<(), AgentError> {
319        let path = dir_path.as_ref();
320        if !path.exists() {
321            return Ok(());
322        }
323
324        let mut entries = tokio::fs::read_dir(path).await.map_err(|e| {
325            AgentError::Planning(format!(
326                "Failed to read directory '{}': {}",
327                path.display(),
328                e
329            ))
330        })?;
331
332        while let Some(entry) = entries
333            .next_entry()
334            .await
335            .map_err(|e| AgentError::Planning(format!("Failed to read directory entry: {}", e)))?
336        {
337            let entry_path = entry.path();
338            if entry_path.is_file()
339                && let Some(extension) = entry_path.extension()
340                && (extension == "hbs" || extension == "handlebars")
341                && let Some(stem) = entry_path.file_stem() {
342                    let name = stem.to_string_lossy().to_string();
343                    tracing::debug!(
344                        "Registering partial '{}' from '{}'",
345                        name,
346                        entry_path.display()
347                    );
348                    self.register_partial_file(name, &entry_path).await?;
349                }
350        }
351
352        Ok(())
353    }
354
355    pub async fn get_template(&self, name: &str) -> Option<PromptTemplate> {
356        let templates = self.templates.read().await;
357        templates.get(name).cloned()
358    }
359
360    pub async fn get_partial(&self, name: &str) -> Option<String> {
361        let partials = self.partials.read().await;
362        partials.get(name).cloned()
363    }
364
365    pub async fn list_templates(&self) -> Vec<String> {
366        let templates = self.templates.read().await;
367        templates.keys().cloned().collect()
368    }
369
370    pub async fn list_partials(&self) -> Vec<String> {
371        let partials = self.partials.read().await;
372        partials.keys().cloned().collect()
373    }
374
375    pub async fn get_all_templates(&self) -> HashMap<String, PromptTemplate> {
376        let templates = self.templates.read().await;
377        templates.clone()
378    }
379
380    pub async fn get_all_partials(&self) -> HashMap<String, String> {
381        let partials = self.partials.read().await;
382        partials.clone()
383    }
384
385    pub async fn clear(&self) {
386        let mut templates = self.templates.write().await;
387        let mut partials = self.partials.write().await;
388        templates.clear();
389        partials.clear();
390    }
391
392    pub async fn remove_template(&self, name: &str) -> Option<PromptTemplate> {
393        let mut templates = self.templates.write().await;
394        templates.remove(name)
395    }
396
397    pub async fn remove_partial(&self, name: &str) -> Option<String> {
398        let mut partials = self.partials.write().await;
399        partials.remove(name)
400    }
401
402    pub async fn configure_handlebars(
403        &self,
404        handlebars: &mut handlebars::Handlebars<'_>,
405    ) -> Result<(), AgentError> {
406        handlebars_helper!(eq: |x: str, y: str| x == y);
407        handlebars.register_helper("eq", Box::new(eq));
408        let partials = self.partials.read().await;
409        for (name, content) in partials.iter() {
410            handlebars.register_partial(name, content).map_err(|e| {
411                AgentError::Planning(format!("Failed to register partial '{}': {}", name, e))
412            })?;
413        }
414        Ok(())
415    }
416
417    pub async fn render_template<'a>(
418        &self,
419        template: &str,
420        template_data: &TemplateData<'a>,
421    ) -> Result<String, AgentError> {
422        let mut handlebars = Handlebars::new();
423        handlebars.set_strict_mode(true);
424
425        self.configure_handlebars(&mut handlebars).await?;
426        let rendered = handlebars
427            .render_template(template, &template_data)
428            .map_err(|e| AgentError::Planning(format!("Failed to render template: {}", e)))?;
429        Ok(rendered)
430    }
431
432    pub async fn validate_template(&self, template: &str) -> Result<(), AgentError> {
433        let mut handlebars = Handlebars::new();
434        handlebars.set_strict_mode(true);
435        self.configure_handlebars(&mut handlebars).await?;
436        let sample_template_data = TemplateData::default();
437        handlebars
438            .render_template(template, &sample_template_data)
439            .map(|_| ())
440            .map_err(|e| AgentError::Planning(format!("Failed to render template: {}", e)))
441    }
442}
443
444impl Default for PromptRegistry {
445    fn default() -> Self {
446        Self::new()
447    }
448}
449
450/// Render a system/user prompt pair into model-ready messages.
451pub async fn build_prompt_messages<'a>(
452    registry: &PromptRegistry,
453    system_template: &str,
454    user_template: &str,
455    template_data: &TemplateData<'a>,
456    user_message: &Message,
457) -> Result<Vec<Message>, AgentError> {
458    let rendered_system = registry
459        .render_template(system_template, template_data)
460        .await?;
461    let rendered_user = registry
462        .render_template(user_template, template_data)
463        .await?;
464
465    let system_msg = Message::system(rendered_system, None);
466
467    let mut user_msg = user_message.clone();
468    if user_msg.parts.is_empty()
469        && let Some(text) = user_message.as_text() {
470            user_msg.parts.push(Part::Text(text));
471        }
472    if !rendered_user.is_empty() {
473        user_msg.parts.push(Part::Text(rendered_user));
474    }
475
476    Ok(vec![system_msg, user_msg])
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[tokio::test]
484    async fn renders_templates_and_messages() {
485        let registry = PromptRegistry::with_defaults().await.unwrap();
486        let data = TemplateData {
487            description: "desc".into(),
488            instructions: "be nice".into(),
489            available_tools: "none".into(),
490            task: "task".into(),
491            scratchpad: String::new(),
492            dynamic_sections: vec![],
493            dynamic_values: HashMap::new(),
494            session_values: HashMap::new(),
495            reasoning_depth: "standard",
496            execution_mode: "tools",
497            tool_format: "json",
498            show_examples: false,
499            max_steps: 5,
500            current_steps: 0,
501            remaining_steps: 5,
502            todos: None,
503            json_tools: true,
504            available_skills: None,
505        };
506        let msgs = build_prompt_messages(
507            &registry,
508            "{{instructions}}",
509            "task: {{task}}",
510            &data,
511            &Message::user("hello".into(), None),
512        )
513        .await
514        .unwrap();
515        assert_eq!(msgs.len(), 2);
516        assert!(msgs[0].as_text().unwrap().contains("be nice"));
517        assert!(msgs[1].as_text().unwrap().contains("task"));
518    }
519}