use crate::application::services::paladin::error::PaladinError;
use crate::core::base::entity::node::Node;
use crate::core::platform::container::prompt::{
PromptData, PromptItem, PromptParameters, PromptType, UserPrompt,
};
use log::{debug, info};
use paladin_ports::output::llm_port::{LlmPort, LlmRequest};
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TaskType {
Creative,
Analytical,
Standard,
}
#[derive(Debug, Clone)]
pub struct TemperatureConfig {
pub creative_temp: f32,
pub analytical_temp: f32,
pub standard_temp: f32,
pub enable_llm_detection: bool,
}
impl Default for TemperatureConfig {
fn default() -> Self {
Self {
creative_temp: 0.85,
analytical_temp: 0.2,
standard_temp: 0.6,
enable_llm_detection: true,
}
}
}
pub struct TemperatureService {
llm_port: Arc<dyn LlmPort>,
config: TemperatureConfig,
}
impl TemperatureService {
pub fn new(llm_port: Arc<dyn LlmPort>) -> Self {
Self {
llm_port,
config: TemperatureConfig::default(),
}
}
pub fn with_config(llm_port: Arc<dyn LlmPort>, config: TemperatureConfig) -> Self {
Self { llm_port, config }
}
pub async fn calculate_optimal_temperature(
&self,
agent_description: &str,
task_context: Option<&str>,
) -> Result<f32, PaladinError> {
if agent_description.trim().is_empty() {
return Err(PaladinError::ConfigurationError(
"Agent description cannot be empty".into(),
));
}
info!(
"Calculating optimal temperature: agent_description={}, has_task_context={}",
agent_description,
task_context.is_some()
);
let task_type = if self.config.enable_llm_detection {
self.detect_task_type_with_llm(agent_description, task_context)
.await?
} else {
self.detect_task_type_heuristic(agent_description, task_context)
};
let temperature = match task_type {
TaskType::Creative => self.config.creative_temp,
TaskType::Analytical => self.config.analytical_temp,
TaskType::Standard => self.config.standard_temp,
};
debug!(
"Calculated optimal temperature: task_type={:?}, temperature={}",
task_type, temperature
);
Ok(temperature)
}
async fn detect_task_type_with_llm(
&self,
agent_description: &str,
task_context: Option<&str>,
) -> Result<TaskType, PaladinError> {
let prompt = self.build_detection_prompt(agent_description, task_context);
let request = LlmRequest {
id: Uuid::new_v4(),
model: "gpt-4".to_string(), prompt: PromptItem {
node: Node::new(
PromptData {
prompt_type: PromptType::User(UserPrompt {
query: prompt.clone(),
context: None,
}),
content_attachments: vec![],
parameters: PromptParameters {
max_tokens: Some(500),
temperature: Some(0.0),
top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop_sequences: None,
},
context: None,
expected_output: None,
tags: None,
category: None,
author: None,
metadata: BTreeMap::new(),
},
Some("temperature_detection".to_string()),
),
},
attachments: vec![],
stream: false,
metadata: HashMap::new(),
};
let response = self
.llm_port
.generate(request)
.await
.map_err(|e| PaladinError::LlmError(e.to_string()))?;
let task_type = self.parse_task_type(&response.content)?;
debug!(
"LLM detected task type: response={}, detected_type={:?}",
response.content, task_type
);
Ok(task_type)
}
fn detect_task_type_heuristic(
&self,
agent_description: &str,
task_context: Option<&str>,
) -> TaskType {
let combined =
format!("{} {}", agent_description, task_context.unwrap_or_default()).to_lowercase();
let creative_keywords = [
"creative",
"writing",
"story",
"brainstorm",
"idea",
"imaginative",
"novel",
"poetry",
"artistic",
"design",
];
let analytical_keywords = [
"analytical",
"analyze",
"math",
"calculation",
"logic",
"code",
"debug",
"precise",
"fact",
"data",
"research",
"technical",
];
let creative_score = creative_keywords
.iter()
.filter(|kw| combined.contains(*kw))
.count();
let analytical_score = analytical_keywords
.iter()
.filter(|kw| combined.contains(*kw))
.count();
if analytical_score > creative_score {
TaskType::Analytical
} else if creative_score > 0 {
TaskType::Creative
} else {
TaskType::Standard
}
}
fn build_detection_prompt(
&self,
agent_description: &str,
task_context: Option<&str>,
) -> String {
let context_part = task_context
.map(|c| format!("\n\nTask Context:\n{}", c))
.unwrap_or_default();
format!(
r#"You are a task classifier. Analyze the following agent description and determine the task type.
Agent Description:
{}{}
Task Types:
- CREATIVE: Writing, brainstorming, ideation, storytelling, artistic work
- ANALYTICAL: Math, logic, code analysis, fact extraction, technical work, debugging
- STANDARD: General conversation, Q&A, summarization, information retrieval
Respond with ONLY one word: CREATIVE, ANALYTICAL, or STANDARD"#,
agent_description, context_part
)
}
fn parse_task_type(&self, response: &str) -> Result<TaskType, PaladinError> {
let normalized = response.trim().to_uppercase();
if normalized.contains("CREATIVE") {
Ok(TaskType::Creative)
} else if normalized.contains("ANALYTICAL") {
Ok(TaskType::Analytical)
} else if normalized.contains("STANDARD") {
Ok(TaskType::Standard)
} else {
debug!(
"Could not parse task type from '{}', defaulting to Standard",
response
);
Ok(TaskType::Standard)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use chrono::Utc;
use futures::stream;
use paladin_ports::output::llm_port::{
FinishReason, LlmError, LlmPort, LlmResponse, ProviderCapabilities, StreamingResponse,
TokenUsage,
};
use std::sync::Mutex;
struct MockLlmPort {
response: Mutex<String>,
}
impl MockLlmPort {
fn new(response: &str) -> Arc<Self> {
Arc::new(Self {
response: Mutex::new(response.to_string()),
})
}
}
#[async_trait]
impl LlmPort for MockLlmPort {
async fn generate(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
Ok(LlmResponse {
id: Uuid::new_v4(),
request_id: Uuid::new_v4(),
model: "gpt-4".to_string(),
content: self.response.lock().unwrap().clone(),
finish_reason: FinishReason::Stop,
usage: TokenUsage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
created_at: Utc::now(),
metadata: HashMap::new(),
function_call: None,
})
}
async fn generate_stream(
&self,
_request: LlmRequest,
) -> Result<
Box<dyn futures::Stream<Item = Result<StreamingResponse, LlmError>> + Send>,
LlmError,
> {
Ok(Box::new(stream::empty()))
}
async fn validate_model(&self, _model: &str) -> Result<bool, LlmError> {
Ok(true)
}
async fn get_available_models(&self) -> Result<Vec<String>, LlmError> {
Ok(vec!["gpt-4".to_string()])
}
fn get_provider_name(&self) -> &'static str {
"mock"
}
fn get_capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
supports_streaming: false,
supports_tool_calling: false,
supports_function_calling: false,
supports_vision: false,
supports_embeddings: false,
max_context_tokens: Some(8192),
supports_system_messages: true,
}
}
}
#[tokio::test]
async fn test_temperature_service_creation() {
let llm_port = MockLlmPort::new("STANDARD");
let service = TemperatureService::new(llm_port);
assert_eq!(service.config.creative_temp, 0.85);
assert_eq!(service.config.analytical_temp, 0.2);
assert_eq!(service.config.standard_temp, 0.6);
assert!(service.config.enable_llm_detection);
}
#[tokio::test]
async fn test_temperature_service_with_custom_config() {
let llm_port = MockLlmPort::new("STANDARD");
let config = TemperatureConfig {
creative_temp: 0.9,
analytical_temp: 0.1,
standard_temp: 0.5,
enable_llm_detection: false,
};
let service = TemperatureService::with_config(llm_port, config.clone());
assert_eq!(service.config.creative_temp, 0.9);
assert_eq!(service.config.analytical_temp, 0.1);
assert_eq!(service.config.standard_temp, 0.5);
assert!(!service.config.enable_llm_detection);
}
#[tokio::test]
async fn test_calculate_temperature_creative_task() {
let llm_port = MockLlmPort::new("CREATIVE");
let service = TemperatureService::new(llm_port);
let result: Result<f32, PaladinError> = service
.calculate_optimal_temperature("A creative writing assistant", None)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0.85);
}
#[tokio::test]
async fn test_calculate_temperature_analytical_task() {
let llm_port = MockLlmPort::new("ANALYTICAL");
let service = TemperatureService::new(llm_port);
let result: Result<f32, PaladinError> = service
.calculate_optimal_temperature("A code analysis and debugging assistant", None)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0.2);
}
#[tokio::test]
async fn test_calculate_temperature_standard_task() {
let llm_port = MockLlmPort::new("STANDARD");
let service = TemperatureService::new(llm_port);
let result: Result<f32, PaladinError> = service
.calculate_optimal_temperature("A general Q&A assistant", None)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0.6);
}
#[tokio::test]
async fn test_calculate_temperature_with_task_context() {
let llm_port = MockLlmPort::new("CREATIVE");
let service = TemperatureService::new(llm_port);
let result: Result<f32, PaladinError> = service
.calculate_optimal_temperature(
"A writing assistant",
Some("Write a short story about a robot"),
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0.85);
}
#[tokio::test]
async fn test_calculate_temperature_empty_description() {
let llm_port = MockLlmPort::new("STANDARD");
let service = TemperatureService::new(llm_port);
let result: Result<f32, PaladinError> =
service.calculate_optimal_temperature("", None).await;
assert!(result.is_err());
match result.unwrap_err() {
PaladinError::ConfigurationError(msg) => {
assert!(msg.contains("cannot be empty"));
}
_ => panic!("Expected ConfigurationError"),
}
}
#[tokio::test]
async fn test_heuristic_detection_creative() {
let llm_port = MockLlmPort::new("ignored");
let config = TemperatureConfig {
enable_llm_detection: false,
..Default::default()
};
let service = TemperatureService::with_config(llm_port, config);
let result: Result<f32, PaladinError> = service
.calculate_optimal_temperature("A creative writing and brainstorming assistant", None)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0.85);
}
#[tokio::test]
async fn test_heuristic_detection_analytical() {
let llm_port = MockLlmPort::new("ignored");
let config = TemperatureConfig {
enable_llm_detection: false,
..Default::default()
};
let service = TemperatureService::with_config(llm_port, config);
let result: Result<f32, PaladinError> = service
.calculate_optimal_temperature("A code analysis and math problem solver", None)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0.2);
}
#[tokio::test]
async fn test_heuristic_detection_standard() {
let llm_port = MockLlmPort::new("ignored");
let config = TemperatureConfig {
enable_llm_detection: false,
..Default::default()
};
let service = TemperatureService::with_config(llm_port, config);
let result: Result<f32, PaladinError> = service
.calculate_optimal_temperature("A helpful assistant", None)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0.6);
}
#[tokio::test]
async fn test_parse_task_type_creative() {
let llm_port = MockLlmPort::new("CREATIVE");
let service = TemperatureService::new(llm_port);
let result = service.parse_task_type("CREATIVE");
assert!(result.is_ok());
assert_eq!(result.unwrap(), TaskType::Creative);
}
#[tokio::test]
async fn test_parse_task_type_analytical() {
let llm_port = MockLlmPort::new("ANALYTICAL");
let service = TemperatureService::new(llm_port);
let result = service.parse_task_type("ANALYTICAL");
assert!(result.is_ok());
assert_eq!(result.unwrap(), TaskType::Analytical);
}
#[tokio::test]
async fn test_parse_task_type_standard() {
let llm_port = MockLlmPort::new("STANDARD");
let service = TemperatureService::new(llm_port);
let result = service.parse_task_type("STANDARD");
assert!(result.is_ok());
assert_eq!(result.unwrap(), TaskType::Standard);
}
#[tokio::test]
async fn test_parse_task_type_with_explanation() {
let llm_port = MockLlmPort::new("dummy");
let service = TemperatureService::new(llm_port);
let result = service.parse_task_type(
"Based on the description, this is clearly a CREATIVE task involving storytelling.",
);
assert!(result.is_ok());
assert_eq!(result.unwrap(), TaskType::Creative);
}
#[tokio::test]
async fn test_parse_task_type_ambiguous_defaults_to_standard() {
let llm_port = MockLlmPort::new("dummy");
let service = TemperatureService::new(llm_port);
let result = service.parse_task_type("I'm not sure about this one");
assert!(result.is_ok());
assert_eq!(result.unwrap(), TaskType::Standard);
}
#[tokio::test]
async fn test_build_detection_prompt_without_context() {
let llm_port = MockLlmPort::new("STANDARD");
let service = TemperatureService::new(llm_port);
let prompt = service.build_detection_prompt("A helpful assistant", None);
assert!(prompt.contains("A helpful assistant"));
assert!(!prompt.contains("Task Context:"));
assert!(prompt.contains("CREATIVE"));
assert!(prompt.contains("ANALYTICAL"));
assert!(prompt.contains("STANDARD"));
}
#[tokio::test]
async fn test_build_detection_prompt_with_context() {
let llm_port = MockLlmPort::new("STANDARD");
let service = TemperatureService::new(llm_port);
let prompt =
service.build_detection_prompt("A helpful assistant", Some("Solve math problems"));
assert!(prompt.contains("A helpful assistant"));
assert!(prompt.contains("Task Context:"));
assert!(prompt.contains("Solve math problems"));
}
}