use crate::application::errors::planning_error::PlanningError;
use crate::core::platform::container::planning::{Subtask, TaskPlan};
use crate::core::platform::container::prompt::{PromptItem, PromptType, UserPrompt};
use log::info;
use paladin_ports::output::llm_port::{LlmPort, LlmRequest};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
pub struct PlanningService {
llm_port: Arc<dyn LlmPort>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LlmPlanResponse {
task: String,
subtasks: Vec<LlmSubtask>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LlmSubtask {
id: String,
description: String,
dependencies: Vec<String>,
}
impl PlanningService {
pub fn new(llm_port: Arc<dyn LlmPort>) -> Self {
info!("Creating PlanningService");
Self { llm_port }
}
pub async fn create_plan(
&self,
task_description: &str,
max_subtasks: u32,
model: &str,
) -> Result<TaskPlan, PlanningError> {
info!(
"Creating plan for task: '{}' (max {} subtasks)",
task_description, max_subtasks
);
let prompt = self.build_planning_prompt(task_description, max_subtasks);
let user_prompt = UserPrompt {
query: prompt,
context: None,
};
let prompt_item = PromptItem::new(PromptType::User(user_prompt))
.map_err(|e| PlanningError::GenerationFailed(e.to_string()))?;
let request = LlmRequest {
id: Uuid::new_v4(),
model: model.to_string(),
prompt: prompt_item,
attachments: vec![],
stream: false,
metadata: HashMap::new(),
};
let response = self
.llm_port
.generate(request)
.await
.map_err(|e| PlanningError::LlmError(e.to_string()))?;
let plan = self.parse_plan_from_llm(&response.content, max_subtasks)?;
info!("Created plan with {} subtasks", plan.subtask_count());
Ok(plan)
}
pub async fn execute_subtasks(
&self,
plan: &TaskPlan,
original_input: &str,
model: &str,
) -> Result<TaskPlan, PlanningError> {
info!(
"Executing {} subtasks for task: '{}'",
plan.subtasks.len(),
plan.original_task
);
let mut executed_plan = plan.clone();
let mut completed_ids: Vec<String> = Vec::new();
while completed_ids.len() < executed_plan.subtasks.len() {
let mut made_progress = false;
let mut next_subtask_idx = None;
let mut next_dependencies = Vec::new();
for (idx, subtask) in executed_plan.subtasks.iter().enumerate() {
if subtask.completed {
continue;
}
let dependencies = executed_plan
.dependencies
.get(&subtask.id)
.cloned()
.unwrap_or_default();
let can_execute = dependencies
.iter()
.all(|dep_id| completed_ids.contains(dep_id));
if can_execute {
next_subtask_idx = Some(idx);
next_dependencies = dependencies;
break;
}
}
if let Some(idx) = next_subtask_idx {
let subtask_id = executed_plan.subtasks[idx].id.clone();
info!(
"Executing subtask: {} - {}",
subtask_id, executed_plan.subtasks[idx].description
);
let context =
self.build_subtask_context(&executed_plan, &next_dependencies, original_input);
let result = self
.execute_subtask(&executed_plan.subtasks[idx], &context, model)
.await?;
executed_plan.subtasks[idx].complete(result);
completed_ids.push(subtask_id.clone());
made_progress = true;
info!("Completed subtask: {}", subtask_id);
}
if !made_progress && completed_ids.len() < executed_plan.subtasks.len() {
return Err(PlanningError::InvalidPlan(
"Circular dependencies or invalid dependency graph detected".to_string(),
));
}
}
info!("All {} subtasks completed", completed_ids.len());
Ok(executed_plan)
}
pub async fn synthesize_results(
&self,
plan: &TaskPlan,
original_task: &str,
model: &str,
) -> Result<String, PlanningError> {
info!("Synthesizing results for task: '{}'", original_task);
let incomplete: Vec<&Subtask> = plan.subtasks.iter().filter(|st| !st.completed).collect();
if !incomplete.is_empty() {
return Err(PlanningError::InvalidPlan(format!(
"Cannot synthesize results: {} subtasks incomplete",
incomplete.len()
)));
}
let prompt = self.build_synthesis_prompt(plan, original_task);
let user_prompt = UserPrompt {
query: prompt,
context: None,
};
let mut prompt_item = PromptItem::new(PromptType::User(user_prompt))
.map_err(|e| PlanningError::GenerationFailed(e.to_string()))?;
use crate::core::platform::container::prompt::PromptParameters;
prompt_item.set_parameters(PromptParameters {
max_tokens: None,
temperature: Some(0.7), top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop_sequences: None,
});
let request = LlmRequest {
id: Uuid::new_v4(),
model: model.to_string(),
prompt: prompt_item,
attachments: vec![],
stream: false,
metadata: HashMap::new(),
};
let response = self
.llm_port
.generate(request)
.await
.map_err(|e| PlanningError::LlmError(e.to_string()))?;
info!("Synthesis complete");
Ok(response.content)
}
fn build_synthesis_prompt(&self, plan: &TaskPlan, original_task: &str) -> String {
let mut subtask_results = String::new();
for (i, subtask) in plan.subtasks.iter().enumerate() {
if let Some(result) = &subtask.result {
subtask_results.push_str(&format!(
"{}. {}\n Result: {}\n\n",
i + 1,
subtask.description,
result
));
}
}
format!(
r#"You are synthesizing the results of multiple subtasks into a cohesive response.
ORIGINAL TASK: {}
COMPLETED SUBTASKS AND RESULTS:
{}
Synthesize these results into a clear, comprehensive response that directly addresses the original task. Provide a cohesive summary that:
1. Integrates information from all subtasks
2. Presents results in a logical flow
3. Highlights key findings or accomplishments
4. Provides clear next steps or conclusions if applicable
Write the synthesized response now:"#,
original_task, subtask_results
)
}
fn build_planning_prompt(&self, task_description: &str, max_subtasks: u32) -> String {
format!(
r#"You are a task planning assistant. Decompose the following task into subtasks.
TASK: {}
INSTRUCTIONS:
- Break down the task into {} or fewer subtasks
- Each subtask should be concrete and actionable
- Identify dependencies between subtasks
- Return your response as JSON in the following format:
{{
"task": "original task description",
"subtasks": [
{{
"id": "1",
"description": "description of subtask",
"dependencies": ["id1", "id2"]
}}
]
}}
Return ONLY the JSON, no additional text."#,
task_description, max_subtasks
)
}
fn parse_plan_from_llm(
&self,
llm_response: &str,
max_subtasks: u32,
) -> Result<TaskPlan, PlanningError> {
let json_str = self.extract_json(llm_response)?;
let llm_plan: LlmPlanResponse = serde_json::from_str(&json_str)
.map_err(|e| PlanningError::GenerationFailed(format!("JSON parse error: {}", e)))?;
if llm_plan.subtasks.len() as u32 > max_subtasks {
return Err(PlanningError::MaxSubtasksExceeded {
max: max_subtasks,
attempted: llm_plan.subtasks.len() as u32,
});
}
let mut plan = TaskPlan::new(llm_plan.task, max_subtasks);
for llm_subtask in llm_plan.subtasks {
let subtask = Subtask::new(
llm_subtask.id.clone(),
llm_subtask.description,
"Expected output from subtask execution".to_string(), );
plan.add_subtask(subtask)
.map_err(PlanningError::InvalidPlan)?;
if !llm_subtask.dependencies.is_empty() {
plan.dependencies
.insert(llm_subtask.id, llm_subtask.dependencies);
}
}
plan.validate().map_err(PlanningError::InvalidPlan)?;
Ok(plan)
}
fn extract_json(&self, response: &str) -> Result<String, PlanningError> {
let trimmed = response.trim();
if let Some(start) = trimmed.find("```json")
&& let Some(end) = trimmed[start + 7..].find("```")
{
return Ok(trimmed[start + 7..start + 7 + end].trim().to_string());
}
if let Some(start) = trimmed.find("```")
&& let Some(end) = trimmed[start + 3..].find("```")
{
return Ok(trimmed[start + 3..start + 3 + end].trim().to_string());
}
Ok(trimmed.to_string())
}
fn build_subtask_context(
&self,
plan: &TaskPlan,
dependencies: &[String],
original_input: &str,
) -> String {
if dependencies.is_empty() {
return original_input.to_string();
}
let mut context = format!("Original Task: {}\n\n", original_input);
context.push_str("Results from prerequisite subtasks:\n\n");
for dep_id in dependencies {
if let Some(dep_subtask) = plan.subtasks.iter().find(|st| st.id == *dep_id) {
context.push_str(&format!(
"Subtask {}: {}\nResult: {}\n\n",
dep_subtask.id,
dep_subtask.description,
dep_subtask
.result
.as_ref()
.unwrap_or(&"No result".to_string())
));
}
}
context
}
async fn execute_subtask(
&self,
subtask: &Subtask,
context: &str,
model: &str,
) -> Result<String, PlanningError> {
let prompt = format!(
r#"You are executing a subtask as part of a larger plan.
SUBTASK: {}
EXPECTED OUTPUT: {}
CONTEXT:
{}
Execute this subtask and provide the result. Be concise and focused on the expected output."#,
subtask.description, subtask.expected_output, context
);
let user_prompt = UserPrompt {
query: prompt,
context: None,
};
let mut prompt_item = PromptItem::new(PromptType::User(user_prompt))
.map_err(|e| PlanningError::GenerationFailed(e.to_string()))?;
use crate::core::platform::container::prompt::PromptParameters;
prompt_item.set_parameters(PromptParameters {
max_tokens: None,
temperature: Some(0.3), top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop_sequences: None,
});
let request = LlmRequest {
id: Uuid::new_v4(),
model: model.to_string(),
prompt: prompt_item,
attachments: vec![],
stream: false,
metadata: HashMap::new(),
};
let response = self
.llm_port
.generate(request)
.await
.map_err(|e| PlanningError::LlmError(e.to_string()))?;
Ok(response.content)
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use chrono::Utc;
use paladin_ports::output::llm_port::{
FinishReason, LlmError, LlmResponse, ProviderCapabilities, TokenUsage,
};
struct MockLlmPort {
response: String,
}
impl MockLlmPort {
fn new(response: impl Into<String>) -> Self {
Self {
response: response.into(),
}
}
}
#[async_trait]
impl LlmPort for MockLlmPort {
async fn generate(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
Ok(LlmResponse {
id: Uuid::new_v4(),
request_id: Uuid::new_v4(),
model: "test-model".to_string(),
content: self.response.clone(),
finish_reason: FinishReason::Stop,
usage: TokenUsage {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
},
created_at: Utc::now(),
metadata: HashMap::new(),
function_call: None,
})
}
async fn generate_stream(
&self,
_request: LlmRequest,
) -> Result<
Box<
dyn futures::Stream<
Item = Result<paladin_ports::output::llm_port::StreamingResponse, LlmError>,
> + Send,
>,
LlmError,
> {
unimplemented!("Streaming not needed for tests")
}
async fn validate_model(&self, _model: &str) -> Result<bool, LlmError> {
Ok(true)
}
async fn get_available_models(&self) -> Result<Vec<String>, LlmError> {
Ok(vec!["test-model".to_string()])
}
fn get_provider_name(&self) -> &'static str {
"mock"
}
fn get_capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
supports_streaming: false,
supports_function_calling: false,
supports_tool_calling: false,
supports_vision: false,
supports_embeddings: false,
supports_system_messages: true,
max_context_tokens: Some(4096),
}
}
}
#[test]
fn test_planning_service_new() {
let llm_port = Arc::new(MockLlmPort::new("test"));
let _service = PlanningService::new(llm_port.clone());
assert!(Arc::strong_count(&llm_port) >= 2);
}
#[tokio::test]
async fn test_create_plan_basic() {
let plan_json = r#"{
"task": "Analyze security vulnerabilities",
"subtasks": [
{
"id": "1",
"description": "Scan for SQL injection vulnerabilities",
"dependencies": []
},
{
"id": "2",
"description": "Check for XSS vulnerabilities",
"dependencies": []
},
{
"id": "3",
"description": "Generate security report",
"dependencies": ["1", "2"]
}
]
}"#;
let llm_port = Arc::new(MockLlmPort::new(plan_json));
let service = PlanningService::new(llm_port);
let result = service
.create_plan("Analyze security vulnerabilities", 10, "gpt-4")
.await;
assert!(result.is_ok());
let plan = result.unwrap();
assert_eq!(plan.subtask_count(), 3);
}
#[tokio::test]
async fn test_create_plan_enforces_max_subtasks() {
let plan_json = r#"{
"task": "Complex task",
"subtasks": [
{"id": "1", "description": "Task 1", "dependencies": []},
{"id": "2", "description": "Task 2", "dependencies": []},
{"id": "3", "description": "Task 3", "dependencies": []},
{"id": "4", "description": "Task 4", "dependencies": []},
{"id": "5", "description": "Task 5", "dependencies": []},
{"id": "6", "description": "Task 6", "dependencies": []}
]
}"#;
let llm_port = Arc::new(MockLlmPort::new(plan_json));
let service = PlanningService::new(llm_port);
let result = service.create_plan("Complex task", 3, "gpt-4").await;
assert!(result.is_err());
if let Err(e) = result {
match e {
PlanningError::MaxSubtasksExceeded { max, attempted } => {
assert_eq!(max, 3);
assert_eq!(attempted, 6);
}
other => panic!("Expected MaxSubtasksExceeded, got: {:?}", other),
}
}
}
#[tokio::test]
async fn test_execute_subtasks_with_dependencies() {
let plan_json = r#"{
"task": "Build and test application",
"subtasks": [
{
"id": "1",
"description": "Install dependencies",
"dependencies": []
},
{
"id": "2",
"description": "Build application",
"dependencies": ["1"]
},
{
"id": "3",
"description": "Run tests",
"dependencies": ["2"]
}
]
}"#;
let llm_port = Arc::new(MockLlmPort::new(plan_json));
let service = PlanningService::new(llm_port.clone());
let plan = service
.create_plan("Build and test application", 10, "gpt-4")
.await
.expect("Failed to create plan");
let result = service
.execute_subtasks(&plan, "Build and test application", "gpt-4")
.await;
assert!(result.is_ok());
let executed_plan = result.unwrap();
assert_eq!(executed_plan.subtasks.len(), 3);
for subtask in &executed_plan.subtasks {
assert!(
subtask.completed,
"Subtask {} should be completed",
subtask.id
);
assert!(
subtask.result.is_some(),
"Subtask {} should have a result",
subtask.id
);
}
}
#[tokio::test]
async fn test_synthesize_results() {
let mut plan = TaskPlan::new("Build and deploy application".to_string(), 10);
let mut subtask1 = Subtask::new(
"1".to_string(),
"Install dependencies".to_string(),
"Dependencies installed".to_string(),
);
subtask1.complete(
"Successfully installed all dependencies: express, react, typescript".to_string(),
);
let mut subtask2 = Subtask::new(
"2".to_string(),
"Build application".to_string(),
"Build output".to_string(),
);
subtask2
.complete("Build completed successfully. Output: dist/bundle.js (245 KB)".to_string());
let mut subtask3 = Subtask::new(
"3".to_string(),
"Run tests".to_string(),
"Test results".to_string(),
);
subtask3.complete("All tests passed: 42 passed, 0 failed".to_string());
plan.add_subtask(subtask1).unwrap();
plan.add_subtask(subtask2).unwrap();
plan.add_subtask(subtask3).unwrap();
let synthesis_response = r#"Successfully built and tested the application:
1. Installed all required dependencies (express, react, typescript)
2. Built the application successfully (output: dist/bundle.js, 245 KB)
3. Verified functionality with complete test suite (42 tests passed)
The application is ready for deployment."#;
let llm_port = Arc::new(MockLlmPort::new(synthesis_response));
let service = PlanningService::new(llm_port);
let result = service
.synthesize_results(&plan, "Build and deploy application", "gpt-4")
.await;
assert!(result.is_ok());
let synthesized = result.unwrap();
assert!(synthesized.contains("dependencies"));
assert!(synthesized.contains("Built"));
assert!(synthesized.contains("tests passed"));
assert!(synthesized.contains("ready for deployment"));
}
#[tokio::test]
async fn test_planning_failure_invalid_json() {
let invalid_json = "This is not valid JSON at all!";
let llm_port = Arc::new(MockLlmPort::new(invalid_json));
let service = PlanningService::new(llm_port);
let result = service.create_plan("Some task", 10, "gpt-4").await;
assert!(result.is_err());
if let Err(e) = result {
match e {
PlanningError::GenerationFailed(_) => {
}
other => panic!("Expected GenerationFailed, got: {:?}", other),
}
}
}
#[tokio::test]
async fn test_synthesis_with_incomplete_subtasks() {
let mut plan = TaskPlan::new("Test task".to_string(), 10);
let subtask1 = Subtask::new(
"1".to_string(),
"Incomplete task".to_string(),
"Output".to_string(),
);
plan.add_subtask(subtask1).unwrap();
let llm_port = Arc::new(MockLlmPort::new("Some response"));
let service = PlanningService::new(llm_port);
let result = service
.synthesize_results(&plan, "Test task", "gpt-4")
.await;
assert!(result.is_err());
if let Err(e) = result {
match e {
PlanningError::InvalidPlan(msg) => {
assert!(msg.contains("incomplete"));
}
other => panic!("Expected InvalidPlan, got: {:?}", other),
}
}
}
#[tokio::test]
async fn test_planning_logs_progress() {
let plan_json = r#"{
"task": "Simple task",
"subtasks": [
{"id": "1", "description": "Do something", "dependencies": []}
]
}"#;
let llm_port = Arc::new(MockLlmPort::new(plan_json));
let service = PlanningService::new(llm_port);
let plan = service.create_plan("Simple task", 10, "gpt-4").await;
assert!(plan.is_ok());
let plan = plan.unwrap();
let result = service
.execute_subtasks(&plan, "Simple task", "gpt-4")
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_planning_service_uses_configured_model() {
let plan_json = r#"{
"task": "Test task",
"subtasks": [
{"id": "1", "description": "Test subtask", "dependencies": []}
]
}"#;
let llm_port = Arc::new(MockLlmPort::new(plan_json));
let service = PlanningService::new(llm_port.clone());
let result = service.create_plan("Test task", 5, "claude-3").await;
assert!(result.is_ok());
let plan = result.unwrap();
assert_eq!(plan.subtask_count(), 1);
}
#[tokio::test]
async fn test_planning_service_validates_model_compatibility() {
let plan_json = r#"{
"task": "Test task",
"subtasks": [
{"id": "1", "description": "Test subtask", "dependencies": []}
]
}"#;
let llm_port = Arc::new(MockLlmPort::new(plan_json));
let service = PlanningService::new(llm_port);
let gpt4_result = service.create_plan("Task 1", 5, "gpt-4").await;
let claude_result = service.create_plan("Task 2", 5, "claude-3").await;
let custom_result = service.create_plan("Task 3", 5, "custom-model").await;
assert!(gpt4_result.is_ok());
assert!(claude_result.is_ok());
assert!(custom_result.is_ok());
}
#[tokio::test]
async fn test_planning_service_falls_back_on_invalid_model() {
let plan_json = r#"{
"task": "Test task",
"subtasks": [
{"id": "1", "description": "Test subtask", "dependencies": []}
]
}"#;
let llm_port = Arc::new(MockLlmPort::new(plan_json));
let service = PlanningService::new(llm_port);
let result = service.create_plan("Test task", 5, "").await;
assert!(result.is_ok());
}
}