1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::Arc;
4
5use handlebars::Handlebars;
6use handlebars::handlebars_helper;
7use serde::{Deserialize, Serialize};
8use tokio::sync::RwLock;
9
10use crate::{AgentError, Message, Part};
11
12#[derive(Debug, Clone)]
14pub struct PromptRegistry {
15 templates: Arc<RwLock<HashMap<String, PromptTemplate>>>,
16 partials: Arc<RwLock<HashMap<String, String>>>,
17}
18
19#[derive(Debug, Clone)]
21pub struct PromptTemplate {
22 pub name: String,
23 pub content: String,
24 pub description: Option<String>,
25 pub version: Option<String>,
26}
27
28#[derive(Debug, Clone, Default, Serialize)]
29pub struct TemplateData<'a> {
30 pub description: String,
31 pub instructions: String,
32 pub available_tools: String,
33 pub task: String,
34 pub scratchpad: String,
35 pub dynamic_sections: Vec<PromptSection>,
36 pub dynamic_values: std::collections::HashMap<String, serde_json::Value>,
37 pub session_values: std::collections::HashMap<String, serde_json::Value>,
39 pub reasoning_depth: &'a str,
40 pub execution_mode: &'a str,
41 pub tool_format: &'a str,
42 pub show_examples: bool,
43 pub max_steps: usize,
44 pub current_steps: usize,
45 pub remaining_steps: usize,
46 pub todos: Option<String>,
47 pub json_tools: bool,
48 #[serde(default)]
50 pub available_skills: Option<String>,
51}
52
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct PromptSection {
55 pub key: String,
56 pub content: String,
57}
58
59impl PromptRegistry {
60 pub fn new() -> Self {
61 Self {
62 templates: Arc::new(RwLock::new(HashMap::new())),
63 partials: Arc::new(RwLock::new(HashMap::new())),
64 }
65 }
66
67 pub async fn with_defaults() -> Result<Self, AgentError> {
69 let registry = Self::new();
70 registry.register_static_templates().await?;
71 registry.register_static_partials().await?;
72 Ok(registry)
73 }
74
75 async fn register_static_templates(&self) -> Result<(), AgentError> {
76 let templates = vec![
77 PromptTemplate {
78 name: "planning".to_string(),
79 content: include_str!("../prompt_templates/planning.hbs").to_string(),
80 description: Some("Default system message template".to_string()),
81 version: Some("1.0.0".to_string()),
82 },
83 PromptTemplate {
84 name: "user".to_string(),
85 content: include_str!("../prompt_templates/user.hbs").to_string(),
86 description: Some("Default user message template".to_string()),
87 version: Some("1.0.0".to_string()),
88 },
89 PromptTemplate {
90 name: "code".to_string(),
91 content: include_str!("../prompt_templates/code.hbs").to_string(),
92 description: Some("Code generation template".to_string()),
93 version: Some("1.0.0".to_string()),
94 },
95 PromptTemplate {
96 name: "reflection".to_string(),
97 content: include_str!("../prompt_templates/reflection.hbs").to_string(),
98 description: Some("Reflection and improvement template".to_string()),
99 version: Some("1.0.0".to_string()),
100 },
101 PromptTemplate {
102 name: "standard_user_message".to_string(),
103 content: include_str!("../prompt_templates/user.hbs").to_string(),
104 description: Some("Standard user message template".to_string()),
105 version: Some("1.0.0".to_string()),
106 },
107 ];
108
109 let mut templates_lock = self.templates.write().await;
110 for template in templates {
111 templates_lock.insert(template.name.clone(), template);
112 }
113
114 Ok(())
115 }
116
117 async fn register_static_partials(&self) -> Result<(), AgentError> {
118 let partials = vec![
119 (
120 "core_instructions",
121 include_str!("../prompt_templates/partials/core_instructions.hbs"),
122 ),
123 (
124 "communication",
125 include_str!("../prompt_templates/partials/communication.hbs"),
126 ),
127 (
128 "todo_instructions",
129 include_str!("../prompt_templates/partials/todo_instructions.hbs"),
130 ),
131 (
132 "tools_xml",
133 include_str!("../prompt_templates/partials/tools_xml.hbs"),
134 ),
135 (
136 "tools_json",
137 include_str!("../prompt_templates/partials/tools_json.hbs"),
138 ),
139 (
140 "reasoning",
141 include_str!("../prompt_templates/partials/reasoning.hbs"),
142 ),
143 (
144 "skills",
145 include_str!("../prompt_templates/partials/skills.hbs"),
146 ),
147 (
148 "connections",
149 include_str!("../prompt_templates/partials/connections.hbs"),
150 ),
151 (
152 "sub_agents",
153 include_str!("../prompt_templates/partials/sub_agents.hbs"),
154 ),
155 ];
156
157 let mut partials_lock = self.partials.write().await;
158 for (name, content) in partials {
159 partials_lock.insert(name.to_string(), content.to_string());
160 }
161
162 Ok(())
163 }
164
165 pub async fn register_template(&self, template: PromptTemplate) -> Result<(), AgentError> {
166 let mut templates = self.templates.write().await;
167 templates.insert(template.name.clone(), template);
168 Ok(())
169 }
170
171 pub async fn register_template_string(
172 &self,
173 name: String,
174 content: String,
175 description: Option<String>,
176 version: Option<String>,
177 ) -> Result<(), AgentError> {
178 let template = PromptTemplate {
179 name: name.clone(),
180 content,
181 description,
182 version,
183 };
184 self.register_template(template).await
185 }
186
187 pub fn get_default_templates() -> Vec<crate::stores::NewPromptTemplate> {
188 vec![
189 crate::stores::NewPromptTemplate {
190 name: "planning".to_string(),
191 template: include_str!("../prompt_templates/planning.hbs").to_string(),
192 description: Some("Default system message template".to_string()),
193 version: Some("1.0.0".to_string()),
194 is_system: true,
195 },
196 crate::stores::NewPromptTemplate {
197 name: "user".to_string(),
198 template: include_str!("../prompt_templates/user.hbs").to_string(),
199 description: Some("Default user message template".to_string()),
200 version: Some("1.0.0".to_string()),
201 is_system: true,
202 },
203 crate::stores::NewPromptTemplate {
204 name: "code".to_string(),
205 template: include_str!("../prompt_templates/code.hbs").to_string(),
206 description: Some("Code generation template".to_string()),
207 version: Some("1.0.0".to_string()),
208 is_system: true,
209 },
210 crate::stores::NewPromptTemplate {
211 name: "reflection".to_string(),
212 template: include_str!("../prompt_templates/reflection.hbs").to_string(),
213 description: Some("Reflection and improvement template".to_string()),
214 version: Some("1.0.0".to_string()),
215 is_system: true,
216 },
217 crate::stores::NewPromptTemplate {
218 name: "standard_user_message".to_string(),
219 template: include_str!("../prompt_templates/user.hbs").to_string(),
220 description: Some("Standard user message template".to_string()),
221 version: Some("1.0.0".to_string()),
222 is_system: true,
223 },
224 ]
225 }
226
227 pub async fn register_template_file<P: AsRef<Path>>(
228 &self,
229 name: String,
230 file_path: P,
231 description: Option<String>,
232 version: Option<String>,
233 ) -> Result<(), AgentError> {
234 let path = file_path.as_ref();
235 let content = tokio::fs::read_to_string(path).await.map_err(|e| {
236 AgentError::Planning(format!(
237 "Failed to read template file '{}': {}",
238 path.display(),
239 e
240 ))
241 })?;
242
243 let template = PromptTemplate {
244 name: name.clone(),
245 content,
246 description,
247 version,
248 };
249 self.register_template(template).await
250 }
251
252 pub async fn register_partial(&self, name: String, content: String) -> Result<(), AgentError> {
253 let mut partials = self.partials.write().await;
254 partials.insert(name, content);
255 Ok(())
256 }
257
258 pub async fn register_partial_file<P: AsRef<Path>>(
259 &self,
260 name: String,
261 file_path: P,
262 ) -> Result<(), AgentError> {
263 let path = file_path.as_ref();
264 let content = tokio::fs::read_to_string(path).await.map_err(|e| {
265 AgentError::Planning(format!(
266 "Failed to read partial file '{}': {}",
267 path.display(),
268 e
269 ))
270 })?;
271 self.register_partial(name, content).await
272 }
273
274 pub async fn register_templates_from_directory<P: AsRef<Path>>(
275 &self,
276 dir_path: P,
277 ) -> Result<(), AgentError> {
278 let path = dir_path.as_ref();
279 if !path.exists() {
280 return Ok(());
281 }
282
283 let mut entries = tokio::fs::read_dir(path).await.map_err(|e| {
284 AgentError::Planning(format!(
285 "Failed to read directory '{}': {}",
286 path.display(),
287 e
288 ))
289 })?;
290
291 while let Some(entry) = entries
292 .next_entry()
293 .await
294 .map_err(|e| AgentError::Planning(format!("Failed to read directory entry: {}", e)))?
295 {
296 let entry_path = entry.path();
297 if entry_path.is_file()
298 && let Some(extension) = entry_path.extension()
299 && (extension == "hbs" || extension == "handlebars")
300 && let Some(stem) = entry_path.file_stem() {
301 let name = stem.to_string_lossy().to_string();
302 tracing::debug!(
303 "Registering template '{}' from '{}'",
304 name,
305 entry_path.display()
306 );
307 self.register_template_file(name, &entry_path, None, None)
308 .await?;
309 }
310 }
311
312 Ok(())
313 }
314
315 pub async fn register_partials_from_directory<P: AsRef<Path>>(
316 &self,
317 dir_path: P,
318 ) -> Result<(), AgentError> {
319 let path = dir_path.as_ref();
320 if !path.exists() {
321 return Ok(());
322 }
323
324 let mut entries = tokio::fs::read_dir(path).await.map_err(|e| {
325 AgentError::Planning(format!(
326 "Failed to read directory '{}': {}",
327 path.display(),
328 e
329 ))
330 })?;
331
332 while let Some(entry) = entries
333 .next_entry()
334 .await
335 .map_err(|e| AgentError::Planning(format!("Failed to read directory entry: {}", e)))?
336 {
337 let entry_path = entry.path();
338 if entry_path.is_file()
339 && let Some(extension) = entry_path.extension()
340 && (extension == "hbs" || extension == "handlebars")
341 && let Some(stem) = entry_path.file_stem() {
342 let name = stem.to_string_lossy().to_string();
343 tracing::debug!(
344 "Registering partial '{}' from '{}'",
345 name,
346 entry_path.display()
347 );
348 self.register_partial_file(name, &entry_path).await?;
349 }
350 }
351
352 Ok(())
353 }
354
355 pub async fn get_template(&self, name: &str) -> Option<PromptTemplate> {
356 let templates = self.templates.read().await;
357 templates.get(name).cloned()
358 }
359
360 pub async fn get_partial(&self, name: &str) -> Option<String> {
361 let partials = self.partials.read().await;
362 partials.get(name).cloned()
363 }
364
365 pub async fn list_templates(&self) -> Vec<String> {
366 let templates = self.templates.read().await;
367 templates.keys().cloned().collect()
368 }
369
370 pub async fn list_partials(&self) -> Vec<String> {
371 let partials = self.partials.read().await;
372 partials.keys().cloned().collect()
373 }
374
375 pub async fn get_all_templates(&self) -> HashMap<String, PromptTemplate> {
376 let templates = self.templates.read().await;
377 templates.clone()
378 }
379
380 pub async fn get_all_partials(&self) -> HashMap<String, String> {
381 let partials = self.partials.read().await;
382 partials.clone()
383 }
384
385 pub async fn clear(&self) {
386 let mut templates = self.templates.write().await;
387 let mut partials = self.partials.write().await;
388 templates.clear();
389 partials.clear();
390 }
391
392 pub async fn remove_template(&self, name: &str) -> Option<PromptTemplate> {
393 let mut templates = self.templates.write().await;
394 templates.remove(name)
395 }
396
397 pub async fn remove_partial(&self, name: &str) -> Option<String> {
398 let mut partials = self.partials.write().await;
399 partials.remove(name)
400 }
401
402 pub async fn configure_handlebars(
403 &self,
404 handlebars: &mut handlebars::Handlebars<'_>,
405 ) -> Result<(), AgentError> {
406 handlebars_helper!(eq: |x: str, y: str| x == y);
407 handlebars.register_helper("eq", Box::new(eq));
408 let partials = self.partials.read().await;
409 for (name, content) in partials.iter() {
410 handlebars.register_partial(name, content).map_err(|e| {
411 AgentError::Planning(format!("Failed to register partial '{}': {}", name, e))
412 })?;
413 }
414 Ok(())
415 }
416
417 pub async fn render_template<'a>(
418 &self,
419 template: &str,
420 template_data: &TemplateData<'a>,
421 ) -> Result<String, AgentError> {
422 let mut handlebars = Handlebars::new();
423 handlebars.set_strict_mode(true);
424
425 self.configure_handlebars(&mut handlebars).await?;
426 let rendered = handlebars
427 .render_template(template, &template_data)
428 .map_err(|e| AgentError::Planning(format!("Failed to render template: {}", e)))?;
429 Ok(rendered)
430 }
431
432 pub async fn validate_template(&self, template: &str) -> Result<(), AgentError> {
433 let mut handlebars = Handlebars::new();
434 handlebars.set_strict_mode(true);
435 self.configure_handlebars(&mut handlebars).await?;
436 let sample_template_data = TemplateData::default();
437 handlebars
438 .render_template(template, &sample_template_data)
439 .map(|_| ())
440 .map_err(|e| AgentError::Planning(format!("Failed to render template: {}", e)))
441 }
442}
443
444impl Default for PromptRegistry {
445 fn default() -> Self {
446 Self::new()
447 }
448}
449
450pub async fn build_prompt_messages<'a>(
452 registry: &PromptRegistry,
453 system_template: &str,
454 user_template: &str,
455 template_data: &TemplateData<'a>,
456 user_message: &Message,
457) -> Result<Vec<Message>, AgentError> {
458 let rendered_system = registry
459 .render_template(system_template, template_data)
460 .await?;
461 let rendered_user = registry
462 .render_template(user_template, template_data)
463 .await?;
464
465 let system_msg = Message::system(rendered_system, None);
466
467 let mut user_msg = user_message.clone();
468 if user_msg.parts.is_empty()
469 && let Some(text) = user_message.as_text() {
470 user_msg.parts.push(Part::Text(text));
471 }
472 if !rendered_user.is_empty() {
473 user_msg.parts.push(Part::Text(rendered_user));
474 }
475
476 Ok(vec![system_msg, user_msg])
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[tokio::test]
484 async fn renders_templates_and_messages() {
485 let registry = PromptRegistry::with_defaults().await.unwrap();
486 let data = TemplateData {
487 description: "desc".into(),
488 instructions: "be nice".into(),
489 available_tools: "none".into(),
490 task: "task".into(),
491 scratchpad: String::new(),
492 dynamic_sections: vec![],
493 dynamic_values: HashMap::new(),
494 session_values: HashMap::new(),
495 reasoning_depth: "standard",
496 execution_mode: "tools",
497 tool_format: "json",
498 show_examples: false,
499 max_steps: 5,
500 current_steps: 0,
501 remaining_steps: 5,
502 todos: None,
503 json_tools: true,
504 available_skills: None,
505 };
506 let msgs = build_prompt_messages(
507 ®istry,
508 "{{instructions}}",
509 "task: {{task}}",
510 &data,
511 &Message::user("hello".into(), None),
512 )
513 .await
514 .unwrap();
515 assert_eq!(msgs.len(), 2);
516 assert!(msgs[0].as_text().unwrap().contains("be nice"));
517 assert!(msgs[1].as_text().unwrap().contains("task"));
518 }
519}