use crate::types::SonaConfig;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum AgentType {
CodeAgent,
ChatAgent,
RagAgent,
TaskPlanner,
DomainExpert,
CodebaseHelper,
DataAnalyst,
CreativeWriter,
ReasoningAgent,
MultiModal,
Custom(String),
}
impl std::fmt::Display for AgentType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AgentType::CodeAgent => write!(f, "code-agent"),
AgentType::ChatAgent => write!(f, "chat-agent"),
AgentType::RagAgent => write!(f, "rag-agent"),
AgentType::TaskPlanner => write!(f, "task-planner"),
AgentType::DomainExpert => write!(f, "domain-expert"),
AgentType::CodebaseHelper => write!(f, "codebase-helper"),
AgentType::DataAnalyst => write!(f, "data-analyst"),
AgentType::CreativeWriter => write!(f, "creative-writer"),
AgentType::ReasoningAgent => write!(f, "reasoning-agent"),
AgentType::MultiModal => write!(f, "multi-modal"),
AgentType::Custom(name) => write!(f, "custom-{}", name),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum TaskDomain {
SoftwareDevelopment,
CustomerSupport,
Healthcare,
Finance,
Legal,
Education,
Research,
Marketing,
General,
Custom(String),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum TrainingMethod {
Supervised {
batch_size: usize,
epochs: usize,
},
RLHF {
reward_weight: f32,
kl_penalty: f32,
},
DPO {
beta: f32,
ref_weight: f32,
},
Online {
lr_decay: f32,
window_size: usize,
},
FewShot {
k_shot: usize,
meta_lr: f32,
},
}
impl Default for TrainingMethod {
fn default() -> Self {
TrainingMethod::Online {
lr_decay: 0.999,
window_size: 1000,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VerticalConfig {
pub domain: TaskDomain,
pub vocab_boost: usize,
pub quality_metrics: Vec<String>,
pub compliance_level: ComplianceLevel,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub enum ComplianceLevel {
#[default]
None,
Basic,
Hipaa,
Soc2,
Gdpr,
Custom(String),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum TemplatePreset {
Minimal,
Balanced,
Production,
MaxQuality,
Edge,
Research,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TrainingTemplate {
pub name: String,
pub agent_type: AgentType,
pub sona_config: SonaConfig,
pub training_method: TrainingMethod,
pub vertical: Option<VerticalConfig>,
pub expected_data_size: DataSizeHint,
pub memory_budget_mb: usize,
pub target_latency_us: u64,
pub continuous_learning: bool,
pub auto_export: bool,
pub tags: Vec<String>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub enum DataSizeHint {
Tiny,
Small,
#[default]
Medium,
Large,
Massive,
}
impl TrainingTemplate {
pub fn new(name: impl Into<String>, agent_type: AgentType) -> Self {
Self {
name: name.into(),
agent_type,
sona_config: SonaConfig::default(),
training_method: TrainingMethod::default(),
vertical: None,
expected_data_size: DataSizeHint::default(),
memory_budget_mb: 100,
target_latency_us: 1000,
continuous_learning: true,
auto_export: false,
tags: Vec::new(),
}
}
pub fn from_preset(preset: TemplatePreset, agent_type: AgentType) -> Self {
let mut template = Self::new(format!("{:?}-{}", preset, agent_type), agent_type.clone());
match preset {
TemplatePreset::Minimal => {
template.sona_config = SonaConfig::edge_deployment();
template.memory_budget_mb = 10;
template.expected_data_size = DataSizeHint::Tiny;
}
TemplatePreset::Balanced => {
template.sona_config = SonaConfig::default();
template.memory_budget_mb = 100;
}
TemplatePreset::Production => {
template.sona_config = SonaConfig::max_throughput();
template.memory_budget_mb = 200;
template.auto_export = true;
}
TemplatePreset::MaxQuality => {
template.sona_config = SonaConfig::max_quality();
template.memory_budget_mb = 500;
template.expected_data_size = DataSizeHint::Large;
}
TemplatePreset::Edge => {
template.sona_config = SonaConfig::edge_deployment();
template.memory_budget_mb = 5;
template.target_latency_us = 500;
}
TemplatePreset::Research => {
template.sona_config = SonaConfig::max_quality();
template.sona_config.trajectory_capacity = 50000;
template.memory_budget_mb = 1000;
template.expected_data_size = DataSizeHint::Massive;
}
}
template.apply_agent_optimizations();
template
}
pub fn code_agent() -> Self {
let mut template = Self::new("code-agent", AgentType::CodeAgent);
template.sona_config.base_lora_rank = 16; template.sona_config.pattern_clusters = 200; template.sona_config.trajectory_capacity = 10000;
template.sona_config.quality_threshold = 0.2; template.training_method = TrainingMethod::Online {
lr_decay: 0.9995,
window_size: 5000,
};
template.tags = vec!["code".into(), "development".into(), "completion".into()];
template
}
pub fn chat_agent() -> Self {
let mut template = Self::new("chat-agent", AgentType::ChatAgent);
template.sona_config.base_lora_rank = 8;
template.sona_config.pattern_clusters = 50;
template.sona_config.quality_threshold = 0.4;
template.target_latency_us = 500; template.training_method = TrainingMethod::RLHF {
reward_weight: 0.5,
kl_penalty: 0.1,
};
template.tags = vec!["chat".into(), "conversation".into(), "support".into()];
template
}
pub fn rag_agent() -> Self {
let mut template = Self::new("rag-agent", AgentType::RagAgent);
template.sona_config.pattern_clusters = 200; template.sona_config.trajectory_capacity = 10000;
template.sona_config.embedding_dim = 512; template.sona_config.hidden_dim = 512;
template.training_method = TrainingMethod::Supervised {
batch_size: 32,
epochs: 10,
};
template.tags = vec!["rag".into(), "retrieval".into(), "documents".into()];
template
}
pub fn task_planner() -> Self {
let mut template = Self::new("task-planner", AgentType::TaskPlanner);
template.sona_config.base_lora_rank = 16;
template.sona_config.ewc_lambda = 2000.0; template.sona_config.pattern_clusters = 100;
template.training_method = TrainingMethod::DPO {
beta: 0.1,
ref_weight: 0.5,
};
template.tags = vec!["planning".into(), "tasks".into(), "decomposition".into()];
template
}
pub fn domain_expert(domain: TaskDomain) -> Self {
let domain_name = format!("{:?}", domain).to_lowercase();
let mut template = Self::new(
format!("domain-expert-{}", domain_name),
AgentType::DomainExpert,
);
template.sona_config.quality_threshold = 0.1; template.sona_config.trajectory_capacity = 20000;
template.sona_config.base_lora_rank = 16;
template.vertical = Some(VerticalConfig {
domain: domain.clone(),
vocab_boost: 10000,
quality_metrics: vec!["accuracy".into(), "relevance".into(), "compliance".into()],
compliance_level: match domain {
TaskDomain::Healthcare => ComplianceLevel::Hipaa,
TaskDomain::Finance => ComplianceLevel::Soc2,
TaskDomain::Legal => ComplianceLevel::Basic,
_ => ComplianceLevel::None,
},
});
template.tags = vec!["domain".into(), "expert".into(), domain_name];
template
}
pub fn codebase_helper() -> Self {
let mut template = Self::new("codebase-helper", AgentType::CodebaseHelper);
template.sona_config.pattern_clusters = 200;
template.sona_config.trajectory_capacity = 10000;
template.sona_config.quality_threshold = 0.2;
template.sona_config.base_lora_rank = 16;
template.expected_data_size = DataSizeHint::Large;
template.training_method = TrainingMethod::Online {
lr_decay: 0.999,
window_size: 10000,
};
template.tags = vec!["codebase".into(), "repository".into(), "navigation".into()];
template
}
pub fn data_analyst() -> Self {
let mut template = Self::new("data-analyst", AgentType::DataAnalyst);
template.sona_config.base_lora_rank = 8;
template.sona_config.pattern_clusters = 100;
template.vertical = Some(VerticalConfig {
domain: TaskDomain::Research,
vocab_boost: 5000,
quality_metrics: vec!["accuracy".into(), "insight_quality".into()],
compliance_level: ComplianceLevel::None,
});
template.tags = vec!["data".into(), "analysis".into(), "insights".into()];
template
}
pub fn creative_writer() -> Self {
let mut template = Self::new("creative-writer", AgentType::CreativeWriter);
template.sona_config.base_lora_rank = 8;
template.sona_config.pattern_clusters = 50; template.sona_config.quality_threshold = 0.5; template.training_method = TrainingMethod::RLHF {
reward_weight: 0.7,
kl_penalty: 0.05, };
template.vertical = Some(VerticalConfig {
domain: TaskDomain::Marketing,
vocab_boost: 0,
quality_metrics: vec!["creativity".into(), "engagement".into(), "clarity".into()],
compliance_level: ComplianceLevel::None,
});
template.tags = vec!["creative".into(), "writing".into(), "content".into()];
template
}
pub fn reasoning_agent() -> Self {
let mut template = Self::new("reasoning-agent", AgentType::ReasoningAgent);
template.sona_config.base_lora_rank = 16;
template.sona_config.ewc_lambda = 3000.0; template.sona_config.pattern_clusters = 150;
template.sona_config.quality_threshold = 0.3;
template.training_method = TrainingMethod::DPO {
beta: 0.15,
ref_weight: 0.4,
};
template.tags = vec!["reasoning".into(), "logic".into(), "math".into()];
template
}
pub fn with_sona_config(mut self, config: SonaConfig) -> Self {
self.sona_config = config;
self
}
pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
self.training_method = method;
self
}
pub fn with_vertical(mut self, vertical: VerticalConfig) -> Self {
self.vertical = Some(vertical);
self
}
pub fn with_memory_budget(mut self, mb: usize) -> Self {
self.memory_budget_mb = mb;
self
}
pub fn with_target_latency(mut self, us: u64) -> Self {
self.target_latency_us = us;
self
}
pub fn with_continuous_learning(mut self, enabled: bool) -> Self {
self.continuous_learning = enabled;
self
}
pub fn with_auto_export(mut self, enabled: bool) -> Self {
self.auto_export = enabled;
self
}
pub fn with_tags(mut self, tags: Vec<String>) -> Self {
self.tags = tags;
self
}
pub fn with_hidden_dim(mut self, dim: usize) -> Self {
self.sona_config.hidden_dim = dim;
self.sona_config.embedding_dim = dim;
self
}
pub fn with_lora_ranks(mut self, micro: usize, base: usize) -> Self {
self.sona_config.micro_lora_rank = micro.min(2); self.sona_config.base_lora_rank = base;
self
}
fn apply_agent_optimizations(&mut self) {
match &self.agent_type {
AgentType::CodeAgent | AgentType::CodebaseHelper => {
self.sona_config.pattern_clusters = 200;
self.sona_config.base_lora_rank = 16;
}
AgentType::ChatAgent => {
self.sona_config.pattern_clusters = 50;
self.target_latency_us = 500;
}
AgentType::RagAgent => {
self.sona_config.pattern_clusters = 200;
self.sona_config.trajectory_capacity = 10000;
}
AgentType::ReasoningAgent => {
self.sona_config.ewc_lambda = 3000.0;
self.sona_config.base_lora_rank = 16;
}
AgentType::DomainExpert => {
self.sona_config.quality_threshold = 0.1;
}
_ => {}
}
}
pub fn validate(&self) -> Result<(), String> {
if self.sona_config.micro_lora_rank > 2 {
return Err("MicroLoRA rank must be 1 or 2".into());
}
if self.sona_config.hidden_dim == 0 {
return Err("Hidden dimension must be > 0".into());
}
if self.memory_budget_mb < 1 {
return Err("Memory budget must be >= 1 MB".into());
}
Ok(())
}
pub fn estimated_memory_mb(&self) -> usize {
let config = &self.sona_config;
let engine_mb = 5;
let lora_bytes =
config.hidden_dim * (config.micro_lora_rank + config.base_lora_rank) * 2 * 4 * 2;
let lora_mb = lora_bytes / (1024 * 1024);
let traj_mb = (config.trajectory_capacity * 800) / (1024 * 1024);
let pattern_mb = (config.pattern_clusters * config.embedding_dim * 4) / (1024 * 1024);
engine_mb + lora_mb + traj_mb + pattern_mb + 1
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_template_creation() {
let template = TrainingTemplate::code_agent();
assert_eq!(template.agent_type, AgentType::CodeAgent);
assert_eq!(template.sona_config.base_lora_rank, 16);
assert_eq!(template.sona_config.pattern_clusters, 200);
}
#[test]
fn test_preset_templates() {
let production =
TrainingTemplate::from_preset(TemplatePreset::Production, AgentType::ChatAgent);
assert!(production.auto_export);
let edge = TrainingTemplate::from_preset(TemplatePreset::Edge, AgentType::ChatAgent);
assert_eq!(edge.memory_budget_mb, 5);
}
#[test]
fn test_domain_expert() {
let medical = TrainingTemplate::domain_expert(TaskDomain::Healthcare);
assert!(medical.vertical.is_some());
if let Some(v) = &medical.vertical {
assert!(matches!(v.compliance_level, ComplianceLevel::Hipaa));
}
}
#[test]
fn test_builder_pattern() {
let template = TrainingTemplate::new("custom", AgentType::Custom("test".into()))
.with_hidden_dim(512)
.with_lora_ranks(2, 16)
.with_memory_budget(200)
.with_continuous_learning(true);
assert_eq!(template.sona_config.hidden_dim, 512);
assert_eq!(template.sona_config.micro_lora_rank, 2);
assert_eq!(template.sona_config.base_lora_rank, 16);
}
#[test]
fn test_validation() {
let mut template = TrainingTemplate::code_agent();
assert!(template.validate().is_ok());
template.sona_config.micro_lora_rank = 5;
assert!(template.validate().is_err());
}
#[test]
fn test_memory_estimation() {
let template = TrainingTemplate::code_agent();
let mem = template.estimated_memory_mb();
assert!(mem > 0);
assert!(mem < template.memory_budget_mb * 2);
}
}