1use std::collections::HashMap;
34use std::time::{Duration, Instant};
35
36use parking_lot::RwLock;
37use serde::{Deserialize, Serialize};
38use tracing::{debug, info};
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct AgentTemplate {
43 pub name: String,
45 pub system_prompt: String,
47 pub tools: Vec<String>,
49 pub model: Option<String>,
51 pub temperature: Option<f32>,
53 pub max_tokens: Option<usize>,
55 pub stop_sequences: Vec<String>,
57 pub capabilities: Vec<String>,
59 pub config: HashMap<String, String>,
61 pub description: String,
63}
64
65impl AgentTemplate {
66 pub fn new(name: impl Into<String>) -> Self {
67 Self {
68 name: name.into(),
69 system_prompt: String::new(),
70 tools: Vec::new(),
71 model: None,
72 temperature: None,
73 max_tokens: None,
74 stop_sequences: Vec::new(),
75 capabilities: Vec::new(),
76 config: HashMap::new(),
77 description: String::new(),
78 }
79 }
80
81 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
82 self.system_prompt = prompt.into();
83 self
84 }
85
86 pub fn tool(mut self, tool_name: impl Into<String>) -> Self {
87 self.tools.push(tool_name.into());
88 self
89 }
90
91 pub fn tools<I, S>(mut self, tools: I) -> Self
92 where
93 I: IntoIterator<Item = S>,
94 S: Into<String>,
95 {
96 self.tools.extend(tools.into_iter().map(|s| s.into()));
97 self
98 }
99
100 pub fn model(mut self, model: impl Into<String>) -> Self {
101 self.model = Some(model.into());
102 self
103 }
104
105 pub fn temperature(mut self, temp: f32) -> Self {
106 self.temperature = Some(temp.clamp(0.0, 2.0));
107 self
108 }
109
110 pub fn max_tokens(mut self, max: usize) -> Self {
111 self.max_tokens = Some(max);
112 self
113 }
114
115 pub fn stop_sequence(mut self, seq: impl Into<String>) -> Self {
116 self.stop_sequences.push(seq.into());
117 self
118 }
119
120 pub fn capability(mut self, cap: impl Into<String>) -> Self {
121 self.capabilities.push(cap.into());
122 self
123 }
124
125 pub fn config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
126 self.config.insert(key.into(), value.into());
127 self
128 }
129
130 pub fn description(mut self, desc: impl Into<String>) -> Self {
131 self.description = desc.into();
132 self
133 }
134
135 pub fn extend(mut self, other: &AgentTemplate) -> Self {
137 if self.system_prompt.is_empty() {
138 self.system_prompt = other.system_prompt.clone();
139 }
140 self.tools.extend(other.tools.clone());
141 if self.model.is_none() {
142 self.model = other.model.clone();
143 }
144 if self.temperature.is_none() {
145 self.temperature = other.temperature;
146 }
147 if self.max_tokens.is_none() {
148 self.max_tokens = other.max_tokens;
149 }
150 self.stop_sequences.extend(other.stop_sequences.clone());
151 self.capabilities.extend(other.capabilities.clone());
152 for (k, v) in &other.config {
153 self.config.entry(k.clone()).or_insert_with(|| v.clone());
154 }
155 self
156 }
157}
158
159#[derive(Debug, Clone, Default)]
161pub struct AgentConfig {
162 pub system_prompt_override: Option<String>,
164 pub context_prefix: Option<String>,
166 pub model_override: Option<String>,
168 pub temperature_override: Option<f32>,
170 pub additional_tools: Vec<String>,
172 pub excluded_tools: Vec<String>,
174 pub session_id: Option<String>,
176 pub user_id: Option<String>,
178 pub metadata: HashMap<String, String>,
180}
181
182impl AgentConfig {
183 pub fn new() -> Self {
184 Self::default()
185 }
186
187 pub fn override_system_prompt(mut self, prompt: impl Into<String>) -> Self {
188 self.system_prompt_override = Some(prompt.into());
189 self
190 }
191
192 pub fn context_prefix(mut self, prefix: impl Into<String>) -> Self {
193 self.context_prefix = Some(prefix.into());
194 self
195 }
196
197 pub fn override_model(mut self, model: impl Into<String>) -> Self {
198 self.model_override = Some(model.into());
199 self
200 }
201
202 pub fn override_temperature(mut self, temp: f32) -> Self {
203 self.temperature_override = Some(temp);
204 self
205 }
206
207 pub fn add_tool(mut self, tool: impl Into<String>) -> Self {
208 self.additional_tools.push(tool.into());
209 self
210 }
211
212 pub fn exclude_tool(mut self, tool: impl Into<String>) -> Self {
213 self.excluded_tools.push(tool.into());
214 self
215 }
216
217 pub fn session_id(mut self, id: impl Into<String>) -> Self {
218 self.session_id = Some(id.into());
219 self
220 }
221
222 pub fn user_id(mut self, id: impl Into<String>) -> Self {
223 self.user_id = Some(id.into());
224 self
225 }
226
227 pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
228 self.metadata.insert(key.into(), value.into());
229 self
230 }
231}
232
233#[derive(Debug, Clone)]
235pub struct AgentInstance {
236 pub id: String,
238 pub template_name: String,
240 pub system_prompt: String,
242 pub tools: Vec<String>,
244 pub model: String,
246 pub temperature: f32,
248 pub max_tokens: Option<usize>,
250 pub stop_sequences: Vec<String>,
252 pub capabilities: Vec<String>,
254 pub session_id: Option<String>,
256 pub user_id: Option<String>,
258 pub created_at: Instant,
260 pub metadata: HashMap<String, String>,
262}
263
264impl AgentInstance {
265 pub fn has_capability(&self, capability: &str) -> bool {
267 self.capabilities.contains(&capability.to_string())
268 }
269
270 pub fn has_tool(&self, tool: &str) -> bool {
272 self.tools.contains(&tool.to_string())
273 }
274
275 pub fn age(&self) -> Duration {
277 self.created_at.elapsed()
278 }
279}
280
281#[derive(Debug, thiserror::Error)]
283pub enum FactoryError {
284 #[error("Template not found: {0}")]
285 TemplateNotFound(String),
286
287 #[error("Invalid configuration: {0}")]
288 InvalidConfig(String),
289
290 #[error("Pool exhausted for template: {0}")]
291 PoolExhausted(String),
292
293 #[error("Instance not found: {0}")]
294 InstanceNotFound(String),
295}
296
297#[derive(Debug, Clone, Default, Serialize, Deserialize)]
299pub struct FactoryStats {
300 pub templates_registered: usize,
301 pub total_instances_created: u64,
302 pub active_instances: usize,
303 pub pool_hits: u64,
304 pub pool_misses: u64,
305}
306
307#[derive(Debug, Clone)]
309pub struct PoolConfig {
310 pub max_per_template: usize,
312 pub max_total: usize,
314 pub instance_ttl: Duration,
316 pub enabled: bool,
318}
319
320impl Default for PoolConfig {
321 fn default() -> Self {
322 Self {
323 max_per_template: 10,
324 max_total: 100,
325 instance_ttl: Duration::from_secs(3600),
326 enabled: true,
327 }
328 }
329}
330
331pub struct AgentFactory {
333 templates: RwLock<HashMap<String, AgentTemplate>>,
334 instances: RwLock<HashMap<String, AgentInstance>>,
335 pool: RwLock<HashMap<String, Vec<AgentInstance>>>,
336 pool_config: PoolConfig,
337 stats: RwLock<FactoryStats>,
338 default_model: String,
339 default_temperature: f32,
340}
341
342impl AgentFactory {
343 pub fn new() -> Self {
344 Self {
345 templates: RwLock::new(HashMap::new()),
346 instances: RwLock::new(HashMap::new()),
347 pool: RwLock::new(HashMap::new()),
348 pool_config: PoolConfig::default(),
349 stats: RwLock::new(FactoryStats::default()),
350 default_model: "gpt-4".to_string(),
351 default_temperature: 0.7,
352 }
353 }
354
355 pub fn with_pool_config(mut self, config: PoolConfig) -> Self {
356 self.pool_config = config;
357 self
358 }
359
360 pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
361 self.default_model = model.into();
362 self
363 }
364
365 pub fn with_default_temperature(mut self, temp: f32) -> Self {
366 self.default_temperature = temp.clamp(0.0, 2.0);
367 self
368 }
369
370 pub fn register_template(&self, template: AgentTemplate) -> &Self {
372 let name = template.name.clone();
373 self.templates.write().insert(name.clone(), template);
374 self.stats.write().templates_registered += 1;
375 info!(template = %name, "Template registered");
376 self
377 }
378
379 pub fn register_templates<I>(&self, templates: I) -> &Self
381 where
382 I: IntoIterator<Item = AgentTemplate>,
383 {
384 for template in templates {
385 self.register_template(template);
386 }
387 self
388 }
389
390 pub fn get_template(&self, name: &str) -> Option<AgentTemplate> {
392 self.templates.read().get(name).cloned()
393 }
394
395 pub fn list_templates(&self) -> Vec<String> {
397 self.templates.read().keys().cloned().collect()
398 }
399
400 pub fn create(&self, template_name: &str) -> Result<AgentInstance, FactoryError> {
402 self.create_with_config(template_name, AgentConfig::default())
403 }
404
405 pub fn create_with_config(
407 &self,
408 template_name: &str,
409 config: AgentConfig,
410 ) -> Result<AgentInstance, FactoryError> {
411 if self.pool_config.enabled {
413 if let Some(instance) = self.get_from_pool(template_name) {
414 self.stats.write().pool_hits += 1;
415 debug!(template = %template_name, "Got agent from pool");
416 return Ok(instance);
417 }
418 self.stats.write().pool_misses += 1;
419 }
420
421 let template = self
423 .templates
424 .read()
425 .get(template_name)
426 .cloned()
427 .ok_or_else(|| FactoryError::TemplateNotFound(template_name.to_string()))?;
428
429 let instance = self.build_instance(&template, config);
431
432 self.instances
434 .write()
435 .insert(instance.id.clone(), instance.clone());
436
437 {
439 let mut stats = self.stats.write();
440 stats.total_instances_created += 1;
441 stats.active_instances += 1;
442 }
443
444 info!(
445 instance_id = %instance.id,
446 template = %template_name,
447 "Agent instance created"
448 );
449
450 Ok(instance)
451 }
452
453 fn build_instance(&self, template: &AgentTemplate, config: AgentConfig) -> AgentInstance {
454 let system_prompt =
456 config
457 .system_prompt_override
458 .unwrap_or_else(|| match &config.context_prefix {
459 Some(prefix) => format!("{}\n\n{}", prefix, template.system_prompt),
460 None => template.system_prompt.clone(),
461 });
462
463 let mut tools: Vec<String> = template
465 .tools
466 .iter()
467 .filter(|t| !config.excluded_tools.contains(t))
468 .cloned()
469 .collect();
470 tools.extend(config.additional_tools);
471
472 let model = config
474 .model_override
475 .or_else(|| template.model.clone())
476 .unwrap_or_else(|| self.default_model.clone());
477
478 let temperature = config
480 .temperature_override
481 .or(template.temperature)
482 .unwrap_or(self.default_temperature);
483
484 let mut metadata = template.config.clone();
486 metadata.extend(config.metadata);
487
488 AgentInstance {
489 id: uuid::Uuid::new_v4().to_string(),
490 template_name: template.name.clone(),
491 system_prompt,
492 tools,
493 model,
494 temperature,
495 max_tokens: template.max_tokens,
496 stop_sequences: template.stop_sequences.clone(),
497 capabilities: template.capabilities.clone(),
498 session_id: config.session_id,
499 user_id: config.user_id,
500 created_at: Instant::now(),
501 metadata,
502 }
503 }
504
505 fn get_from_pool(&self, template_name: &str) -> Option<AgentInstance> {
506 let mut pool = self.pool.write();
507 if let Some(instances) = pool.get_mut(template_name) {
508 while let Some(instance) = instances.pop() {
510 if instance.age() < self.pool_config.instance_ttl {
511 return Some(instance);
512 }
513 }
514 }
515 None
516 }
517
518 pub fn release(&self, instance: AgentInstance) {
520 if !self.pool_config.enabled {
521 return;
522 }
523
524 if instance.age() >= self.pool_config.instance_ttl {
526 debug!(instance_id = %instance.id, "Instance expired, not returning to pool");
527 return;
528 }
529
530 let template_name = instance.template_name.clone();
531
532 let mut pool = self.pool.write();
533 let template_pool = pool.entry(template_name).or_default();
534
535 if template_pool.len() >= self.pool_config.max_per_template {
537 debug!(instance_id = %instance.id, "Pool full for template");
538 return;
539 }
540
541 template_pool.push(instance);
542 let mut stats = self.stats.write();
543 stats.active_instances = stats.active_instances.saturating_sub(1);
544 }
545
546 pub fn get_instance(&self, id: &str) -> Option<AgentInstance> {
548 self.instances.read().get(id).cloned()
549 }
550
551 pub fn remove_instance(&self, id: &str) -> bool {
553 let removed = self.instances.write().remove(id).is_some();
554 if removed {
555 let mut stats = self.stats.write();
556 stats.active_instances = stats.active_instances.saturating_sub(1);
557 }
558 removed
559 }
560
561 pub fn stats(&self) -> FactoryStats {
563 self.stats.read().clone()
564 }
565
566 pub fn find_by_capability(&self, capability: &str) -> Vec<AgentTemplate> {
568 self.templates
569 .read()
570 .values()
571 .filter(|t| t.capabilities.contains(&capability.to_string()))
572 .cloned()
573 .collect()
574 }
575
576 pub fn find_by_tool(&self, tool: &str) -> Vec<AgentTemplate> {
578 self.templates
579 .read()
580 .values()
581 .filter(|t| t.tools.contains(&tool.to_string()))
582 .cloned()
583 .collect()
584 }
585
586 pub fn clear_pools(&self) {
588 self.pool.write().clear();
589 debug!("All pools cleared");
590 }
591
592 pub fn cleanup_expired(&self) {
594 let mut pool = self.pool.write();
595 let ttl = self.pool_config.instance_ttl;
596
597 for (_, instances) in pool.iter_mut() {
598 instances.retain(|i| i.age() < ttl);
599 }
600 }
601}
602
603impl Default for AgentFactory {
604 fn default() -> Self {
605 Self::new()
606 }
607}
608
609#[macro_export]
611macro_rules! agent_template {
612 ($name:expr => {
613 $(system_prompt: $prompt:expr,)?
614 $(model: $model:expr,)?
615 $(temperature: $temp:expr,)?
616 $(tools: [$($tool:expr),* $(,)?],)?
617 $(capabilities: [$($cap:expr),* $(,)?],)?
618 $(description: $desc:expr,)?
619 }) => {{
620 let mut template = AgentTemplate::new($name);
621 $(template = template.system_prompt($prompt);)?
622 $(template = template.model($model);)?
623 $(template = template.temperature($temp);)?
624 $($(template = template.tool($tool);)*)?
625 $($(template = template.capability($cap);)*)?
626 $(template = template.description($desc);)?
627 template
628 }};
629}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634
635 #[test]
636 fn test_template_builder() {
637 let template = AgentTemplate::new("test")
638 .system_prompt("You are a test agent")
639 .tool("search")
640 .tool("calculate")
641 .model("gpt-4")
642 .temperature(0.5)
643 .capability("math")
644 .config("key", "value");
645
646 assert_eq!(template.name, "test");
647 assert_eq!(template.tools.len(), 2);
648 assert_eq!(template.model, Some("gpt-4".to_string()));
649 assert_eq!(template.temperature, Some(0.5));
650 }
651
652 #[test]
653 fn test_template_extend() {
654 let base = AgentTemplate::new("base")
655 .system_prompt("Base prompt")
656 .tool("base_tool")
657 .model("gpt-3.5");
658
659 let extended = AgentTemplate::new("extended")
660 .tool("extended_tool")
661 .extend(&base);
662
663 assert_eq!(extended.system_prompt, "Base prompt");
664 assert!(extended.tools.contains(&"base_tool".to_string()));
665 assert!(extended.tools.contains(&"extended_tool".to_string()));
666 assert_eq!(extended.model, Some("gpt-3.5".to_string()));
667 }
668
669 #[test]
670 fn test_factory_register_and_create() {
671 let factory = AgentFactory::new();
672
673 let template = AgentTemplate::new("researcher")
674 .system_prompt("You are a researcher")
675 .tool("web_search");
676
677 factory.register_template(template);
678
679 let instance = factory.create("researcher").unwrap();
680 assert_eq!(instance.template_name, "researcher");
681 assert!(instance.tools.contains(&"web_search".to_string()));
682 }
683
684 #[test]
685 fn test_factory_create_with_config() {
686 let factory = AgentFactory::new();
687
688 let template = AgentTemplate::new("agent")
689 .system_prompt("Original prompt")
690 .tool("tool1")
691 .tool("tool2");
692
693 factory.register_template(template);
694
695 let config = AgentConfig::new()
696 .override_system_prompt("New prompt")
697 .add_tool("tool3")
698 .exclude_tool("tool1")
699 .session_id("session123");
700
701 let instance = factory.create_with_config("agent", config).unwrap();
702
703 assert_eq!(instance.system_prompt, "New prompt");
704 assert!(!instance.tools.contains(&"tool1".to_string()));
705 assert!(instance.tools.contains(&"tool2".to_string()));
706 assert!(instance.tools.contains(&"tool3".to_string()));
707 assert_eq!(instance.session_id, Some("session123".to_string()));
708 }
709
710 #[test]
711 fn test_factory_template_not_found() {
712 let factory = AgentFactory::new();
713 let result = factory.create("nonexistent");
714 assert!(matches!(result, Err(FactoryError::TemplateNotFound(_))));
715 }
716
717 #[test]
718 fn test_factory_list_templates() {
719 let factory = AgentFactory::new();
720
721 factory.register_template(AgentTemplate::new("t1"));
722 factory.register_template(AgentTemplate::new("t2"));
723 factory.register_template(AgentTemplate::new("t3"));
724
725 let templates = factory.list_templates();
726 assert_eq!(templates.len(), 3);
727 }
728
729 #[test]
730 fn test_factory_find_by_capability() {
731 let factory = AgentFactory::new();
732
733 factory.register_template(AgentTemplate::new("math_agent").capability("math"));
734 factory.register_template(AgentTemplate::new("text_agent").capability("text"));
735 factory.register_template(
736 AgentTemplate::new("multi_agent")
737 .capability("math")
738 .capability("text"),
739 );
740
741 let math_agents = factory.find_by_capability("math");
742 assert_eq!(math_agents.len(), 2);
743 }
744
745 #[test]
746 fn test_factory_find_by_tool() {
747 let factory = AgentFactory::new();
748
749 factory.register_template(AgentTemplate::new("a1").tool("search"));
750 factory.register_template(AgentTemplate::new("a2").tool("calculate"));
751 factory.register_template(AgentTemplate::new("a3").tool("search").tool("calculate"));
752
753 let search_agents = factory.find_by_tool("search");
754 assert_eq!(search_agents.len(), 2);
755 }
756
757 #[test]
758 fn test_agent_instance_methods() {
759 let instance = AgentInstance {
760 id: "test".to_string(),
761 template_name: "test".to_string(),
762 system_prompt: "prompt".to_string(),
763 tools: vec!["tool1".to_string(), "tool2".to_string()],
764 model: "gpt-4".to_string(),
765 temperature: 0.7,
766 max_tokens: None,
767 stop_sequences: Vec::new(),
768 capabilities: vec!["cap1".to_string()],
769 session_id: None,
770 user_id: None,
771 created_at: Instant::now(),
772 metadata: HashMap::new(),
773 };
774
775 assert!(instance.has_tool("tool1"));
776 assert!(!instance.has_tool("tool3"));
777 assert!(instance.has_capability("cap1"));
778 assert!(!instance.has_capability("cap2"));
779 }
780
781 #[test]
782 fn test_factory_pooling() {
783 let pool_config = PoolConfig {
784 max_per_template: 2,
785 max_total: 10,
786 instance_ttl: Duration::from_secs(3600),
787 enabled: true,
788 };
789
790 let factory = AgentFactory::new().with_pool_config(pool_config);
791
792 factory.register_template(AgentTemplate::new("pooled"));
793
794 let instance1 = factory.create("pooled").unwrap();
796 let _id1 = instance1.id.clone();
797 factory.release(instance1);
798
799 let _instance2 = factory.create("pooled").unwrap();
801 let stats = factory.stats();
804 assert!(stats.pool_hits > 0 || stats.pool_misses > 0);
805 }
806
807 #[test]
808 fn test_factory_stats() {
809 let factory = AgentFactory::new();
810
811 factory.register_template(AgentTemplate::new("t1"));
812 factory.register_template(AgentTemplate::new("t2"));
813
814 factory.create("t1").unwrap();
815 factory.create("t1").unwrap();
816 factory.create("t2").unwrap();
817
818 let stats = factory.stats();
819 assert_eq!(stats.templates_registered, 2);
820 assert_eq!(stats.total_instances_created, 3);
821 }
822
823 #[test]
824 fn test_context_prefix() {
825 let factory = AgentFactory::new();
826
827 factory.register_template(AgentTemplate::new("agent").system_prompt("Base instructions"));
828
829 let config = AgentConfig::new().context_prefix("User context: VIP customer");
830
831 let instance = factory.create_with_config("agent", config).unwrap();
832 assert!(instance.system_prompt.contains("VIP customer"));
833 assert!(instance.system_prompt.contains("Base instructions"));
834 }
835}