use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AgentError {
#[error("Invalid agent configuration: {0}")]
Invalid(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentProfile {
#[serde(default)]
pub prompt: Option<String>,
#[serde(default)]
pub style: Option<String>,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub model_provider: Option<String>,
#[serde(default)]
pub model_name: Option<String>,
#[serde(default)]
pub allowed_tools: Option<Vec<String>>,
#[serde(default)]
pub denied_tools: Option<Vec<String>>,
#[serde(default = "AgentProfile::default_memory_k")]
pub memory_k: usize,
#[serde(default = "AgentProfile::default_top_p")]
pub top_p: f32,
#[serde(default)]
pub max_context_tokens: Option<usize>,
#[serde(default)]
pub enable_graph: bool,
#[serde(default)]
pub graph_memory: bool,
#[serde(default = "AgentProfile::default_graph_depth")]
pub graph_depth: usize,
#[serde(default = "AgentProfile::default_graph_weight")]
pub graph_weight: f32,
#[serde(default)]
pub auto_graph: bool,
#[serde(default = "AgentProfile::default_graph_threshold")]
pub graph_threshold: f32,
#[serde(default)]
pub graph_steering: bool,
#[serde(default)]
pub fast_reasoning: bool,
#[serde(default)]
pub fast_model_provider: Option<String>,
#[serde(default)]
pub fast_model_name: Option<String>,
#[serde(default = "AgentProfile::default_fast_temperature")]
pub fast_model_temperature: f32,
#[serde(default = "AgentProfile::default_fast_tasks")]
pub fast_model_tasks: Vec<String>,
#[serde(default = "AgentProfile::default_escalation_threshold")]
pub escalation_threshold: f32,
#[serde(default)]
pub show_reasoning: bool,
#[serde(default)]
pub enable_audio_transcription: bool,
#[serde(default = "AgentProfile::default_audio_response_mode")]
pub audio_response_mode: String,
#[serde(default)]
pub audio_scenario: Option<String>,
#[serde(default)]
pub enable_collective: bool,
#[serde(default = "AgentProfile::default_accept_delegations")]
pub accept_delegations: bool,
#[serde(default)]
pub preferred_domains: Vec<String>,
#[serde(default = "AgentProfile::default_max_concurrent_tasks")]
pub max_concurrent_tasks: usize,
#[serde(default = "AgentProfile::default_min_delegation_score")]
pub min_delegation_score: f32,
#[serde(default)]
pub share_learnings: bool,
#[serde(default = "AgentProfile::default_participate_in_voting")]
pub participate_in_voting: bool,
}
impl AgentProfile {
const ALWAYS_ALLOWED_TOOLS: [&'static str; 1] = ["prompt_user"];
fn default_memory_k() -> usize {
10
}
fn default_top_p() -> f32 {
0.9
}
fn default_graph_depth() -> usize {
3
}
fn default_graph_weight() -> f32 {
0.5 }
fn default_graph_threshold() -> f32 {
0.7 }
fn default_fast_temperature() -> f32 {
0.3 }
fn default_fast_tasks() -> Vec<String> {
vec![
"entity_extraction".to_string(),
"graph_analysis".to_string(),
"decision_routing".to_string(),
"tool_selection".to_string(),
"confidence_scoring".to_string(),
]
}
fn default_escalation_threshold() -> f32 {
0.6 }
fn default_audio_response_mode() -> String {
"immediate".to_string()
}
fn default_accept_delegations() -> bool {
true
}
fn default_max_concurrent_tasks() -> usize {
3
}
fn default_min_delegation_score() -> f32 {
0.3
}
fn default_participate_in_voting() -> bool {
true
}
pub fn validate(&self) -> Result<()> {
if let Some(temp) = self.temperature {
if !(0.0..=2.0).contains(&temp) {
return Err(AgentError::Invalid(format!(
"temperature must be between 0.0 and 2.0, got {}",
temp
))
.into());
}
}
if self.top_p < 0.0 || self.top_p > 1.0 {
return Err(AgentError::Invalid(format!(
"top_p must be between 0.0 and 1.0, got {}",
self.top_p
))
.into());
}
if self.graph_weight < 0.0 || self.graph_weight > 1.0 {
return Err(AgentError::Invalid(format!(
"graph_weight must be between 0.0 and 1.0, got {}",
self.graph_weight
))
.into());
}
if self.graph_threshold < 0.0 || self.graph_threshold > 1.0 {
return Err(AgentError::Invalid(format!(
"graph_threshold must be between 0.0 and 1.0, got {}",
self.graph_threshold
))
.into());
}
if let (Some(allowed), Some(denied)) = (&self.allowed_tools, &self.denied_tools) {
let allowed_set: HashSet<_> = allowed.iter().collect();
let denied_set: HashSet<_> = denied.iter().collect();
let overlap: Vec<_> = allowed_set.intersection(&denied_set).collect();
if !overlap.is_empty() {
return Err(AgentError::Invalid(format!(
"tools cannot be both allowed and denied: {:?}",
overlap
))
.into());
}
}
if let Some(provider) = &self.model_provider {
let valid_providers = ["mock", "openai", "anthropic", "ollama", "mlx", "lmstudio"];
if !valid_providers.contains(&provider.as_str()) {
return Err(AgentError::Invalid(format!(
"model_provider must be one of: {}. Got: {}",
valid_providers.join(", "),
provider
))
.into());
}
}
Ok(())
}
pub fn is_tool_allowed(&self, tool_name: &str) -> bool {
if let Some(denied) = &self.denied_tools {
if denied.iter().any(|t| t == tool_name) {
return false;
}
}
if Self::ALWAYS_ALLOWED_TOOLS.contains(&tool_name) {
return true;
}
if let Some(allowed) = &self.allowed_tools {
return allowed.iter().any(|t| t == tool_name);
}
true
}
pub fn effective_temperature(&self, default: f32) -> f32 {
self.temperature.unwrap_or(default)
}
pub fn effective_provider<'a>(&'a self, default: &'a str) -> &'a str {
self.model_provider.as_deref().unwrap_or(default)
}
pub fn effective_model_name<'a>(&'a self, default: Option<&'a str>) -> Option<&'a str> {
self.model_name.as_deref().or(default)
}
}
impl Default for AgentProfile {
fn default() -> Self {
Self {
prompt: None,
style: None,
temperature: None,
model_provider: None,
model_name: None,
allowed_tools: None,
denied_tools: None,
memory_k: Self::default_memory_k(),
top_p: Self::default_top_p(),
max_context_tokens: None,
enable_graph: true, graph_memory: true, graph_depth: Self::default_graph_depth(),
graph_weight: Self::default_graph_weight(),
auto_graph: true, graph_threshold: Self::default_graph_threshold(),
graph_steering: true, fast_reasoning: true, fast_model_provider: Some("lmstudio".to_string()), fast_model_name: Some("lmstudio-community/Llama-3.2-3B-Instruct".to_string()),
fast_model_temperature: Self::default_fast_temperature(),
fast_model_tasks: Self::default_fast_tasks(),
escalation_threshold: Self::default_escalation_threshold(),
show_reasoning: false, enable_audio_transcription: false, audio_response_mode: Self::default_audio_response_mode(),
audio_scenario: None,
enable_collective: false,
accept_delegations: Self::default_accept_delegations(),
preferred_domains: Vec::new(),
max_concurrent_tasks: Self::default_max_concurrent_tasks(),
min_delegation_score: Self::default_min_delegation_score(),
share_learnings: false, participate_in_voting: Self::default_participate_in_voting(),
}
}
}
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;
#[test]
fn test_default_agent_profile() {
let profile = AgentProfile::default();
assert_eq!(profile.memory_k, 10);
assert_eq!(profile.top_p, 0.9);
assert!(profile.fast_reasoning);
assert_eq!(profile.fast_model_provider, Some("lmstudio".to_string()));
assert_eq!(
profile.fast_model_name,
Some("lmstudio-community/Llama-3.2-3B-Instruct".to_string())
);
assert_eq!(profile.fast_model_temperature, 0.3);
assert_eq!(profile.escalation_threshold, 0.6);
assert!(profile.enable_graph);
assert!(profile.graph_memory);
assert!(profile.auto_graph);
assert!(profile.graph_steering);
assert!(profile.validate().is_ok());
}
#[test]
fn test_validate_invalid_temperature() {
let mut profile = AgentProfile::default();
profile.temperature = Some(3.0);
assert!(profile.validate().is_err());
}
#[test]
fn test_validate_invalid_top_p() {
let mut profile = AgentProfile::default();
profile.top_p = 1.5;
assert!(profile.validate().is_err());
}
#[test]
fn test_validate_tool_overlap() {
let mut profile = AgentProfile::default();
profile.allowed_tools = Some(vec!["tool1".to_string(), "tool2".to_string()]);
profile.denied_tools = Some(vec!["tool2".to_string(), "tool3".to_string()]);
assert!(profile.validate().is_err());
}
#[test]
fn test_is_tool_allowed_no_restrictions() {
let profile = AgentProfile::default();
assert!(profile.is_tool_allowed("any_tool"));
assert!(profile.is_tool_allowed("prompt_user"));
}
#[test]
fn test_is_tool_allowed_with_allowlist() {
let mut profile = AgentProfile::default();
profile.allowed_tools = Some(vec!["tool1".to_string(), "tool2".to_string()]);
assert!(profile.is_tool_allowed("tool1"));
assert!(profile.is_tool_allowed("tool2"));
assert!(!profile.is_tool_allowed("tool3"));
assert!(profile.is_tool_allowed("prompt_user"));
}
#[test]
fn test_is_tool_allowed_with_denylist() {
let mut profile = AgentProfile::default();
profile.denied_tools = Some(vec!["tool1".to_string(), "prompt_user".to_string()]);
assert!(!profile.is_tool_allowed("tool1"));
assert!(profile.is_tool_allowed("tool2"));
assert!(!profile.is_tool_allowed("prompt_user"));
}
#[test]
fn test_effective_temperature() {
let mut profile = AgentProfile::default();
assert_eq!(profile.effective_temperature(0.7), 0.7);
profile.temperature = Some(0.5);
assert_eq!(profile.effective_temperature(0.7), 0.5);
}
#[test]
fn test_effective_provider() {
let mut profile = AgentProfile::default();
assert_eq!(profile.effective_provider("mock"), "mock");
profile.model_provider = Some("openai".to_string());
assert_eq!(profile.effective_provider("mock"), "openai");
}
}