use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::types::{Action, ActionParams, WorkerId};
use std::sync::Arc;
use super::batch::DecisionResponse;
use super::manager::{
BatchDecisionRequest, ManagementDecision, ManagerAgent, ManagerId, WorkerDecisionRequest,
};
use super::worker::{FixedScopeStrategy, Guidance, ScopeStrategy, WorkerScope};
use crate::context::{
ContextResolver, ContextStore, GlobalContext, ManagerContext, TaskContext,
WorkerContext as WorkerCtx,
};
#[derive(Clone)]
pub struct DefaultManagerConfig {
pub process_interval_ticks: u64,
pub immediate_on_escalation: bool,
pub confidence_threshold: f64,
pub scope_strategy: Arc<dyn ScopeStrategy>,
}
impl std::fmt::Debug for DefaultManagerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DefaultManagerConfig")
.field("process_interval_ticks", &self.process_interval_ticks)
.field("immediate_on_escalation", &self.immediate_on_escalation)
.field("confidence_threshold", &self.confidence_threshold)
.field("scope_strategy", &"<dyn ScopeStrategy>")
.finish()
}
}
impl Default for DefaultManagerConfig {
fn default() -> Self {
Self {
process_interval_ticks: 5,
immediate_on_escalation: true,
confidence_threshold: 0.3,
scope_strategy: Arc::new(FixedScopeStrategy::minimal()),
}
}
}
pub struct DefaultBatchManagerAgent {
id: ManagerId,
name: String,
config: DefaultManagerConfig,
last_process_tick: AtomicU64,
}
impl DefaultBatchManagerAgent {
pub fn new(id: ManagerId) -> Self {
Self {
id,
name: format!("DefaultManager_{}", id.0),
config: DefaultManagerConfig::default(),
last_process_tick: AtomicU64::new(0),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
pub fn with_config(mut self, config: DefaultManagerConfig) -> Self {
self.config = config;
self
}
pub fn with_interval(mut self, ticks: u64) -> Self {
self.config.process_interval_ticks = ticks;
self
}
fn build_context_store(&self, context: &TaskContext) -> ContextStore {
let candidates = self.build_candidates(context);
let mut global = GlobalContext::new(context.tick)
.with_max_ticks(context.get_i64("max_ticks").unwrap_or(100) as u64)
.with_progress(context.progress)
.with_success_rate(context.success_rate);
if let Some(task) = context.get_str("task") {
global = global.with_task(task);
}
if let Some(hint) = context.get_str("hint") {
global = global.with_hint(hint);
}
let mut store = ContextStore::new(context.tick);
store.global = global;
for (&worker_id, summary) in &context.workers {
let mut worker_ctx = WorkerCtx::new(worker_id)
.with_failures(summary.consecutive_failures)
.with_history_len(summary.history_len)
.with_escalation(summary.has_escalation)
.with_candidates(candidates.clone());
if let Some(ref action) = summary.last_action {
worker_ctx =
worker_ctx.with_last_action(action, summary.last_success.unwrap_or(false));
}
if let Some(ref output) = summary.last_output {
worker_ctx.metadata.insert(
"last_output".to_string(),
serde_json::Value::String(output.clone()),
);
}
store.workers.insert(worker_id, worker_ctx);
}
store.managers.insert(
self.id,
ManagerContext::new(self.id)
.with_name(&self.name)
.with_last_tick(self.last_process_tick.load(Ordering::Relaxed)),
);
for (worker_id, escalation) in &context.escalations {
store.escalations.push((*worker_id, escalation.clone()));
}
if let Some(ref actions) = context.available_actions {
store.actions = Some(actions.clone());
}
if let Some(task) = context.get_str("task") {
store = store.insert("task", task);
}
if let Some(hint) = context.get_str("hint") {
store = store.insert("hint", hint);
}
store
}
fn should_process(&self, context: &TaskContext) -> bool {
let tick = context.tick;
let last_tick = self.last_process_tick.load(Ordering::Relaxed);
if self.config.immediate_on_escalation && context.has_escalations() {
return true;
}
tick >= last_tick + self.config.process_interval_ticks
}
fn build_candidates(&self, context: &TaskContext) -> Vec<String> {
let all_actions = context
.available_actions
.as_ref()
.map(|cfg| cfg.all_action_names())
.unwrap_or_else(|| vec!["Continue".to_string()]);
let excluded_names: std::collections::HashSet<String> = context
.excluded_actions
.iter()
.filter_map(|s| s.split('(').next().map(|n| n.to_string()))
.collect();
let filtered: Vec<String> = all_actions
.into_iter()
.filter(|name| !excluded_names.contains(name))
.collect();
if filtered.is_empty() {
vec!["Continue".to_string()]
} else {
filtered
}
}
fn response_to_guidance(&self, response: &DecisionResponse) -> Guidance {
let action_name = if response.confidence < self.config.confidence_threshold {
"Continue"
} else {
&response.tool
};
let action = Action {
name: action_name.to_string(),
params: ActionParams {
target: if response.target.is_empty() {
None
} else {
Some(response.target.clone())
},
args: response.args.clone(),
data: Vec::new(),
},
};
Guidance {
actions: vec![action],
content: response.reasoning.clone(),
props: HashMap::new(),
exploration_target: None,
scope: WorkerScope::default(),
}
}
fn default_guidance(&self) -> Guidance {
Guidance {
actions: vec![Action {
name: "Continue".to_string(),
params: ActionParams::default(),
}],
content: None,
props: HashMap::new(),
exploration_target: None,
scope: WorkerScope::default(),
}
}
}
impl ManagerAgent for DefaultBatchManagerAgent {
fn prepare(&self, context: &TaskContext) -> BatchDecisionRequest {
if context.v2_guidances.is_some() {
return BatchDecisionRequest {
manager_id: self.id,
requests: vec![],
};
}
if !self.should_process(context) {
return BatchDecisionRequest {
manager_id: self.id,
requests: vec![],
};
}
let store = self.build_context_store(context);
let worker_ids: Vec<WorkerId> = context
.worker_ids()
.into_iter()
.filter(|id| !context.done_workers.contains(id))
.collect();
if worker_ids.is_empty() {
return BatchDecisionRequest {
manager_id: self.id,
requests: vec![],
};
}
let task_goal = context
.get_str("task")
.unwrap_or("Continue current work")
.to_string();
let requests: Vec<WorkerDecisionRequest> = worker_ids
.iter()
.map(|&worker_id| {
let scope = self
.config
.scope_strategy
.determine_scope(context, worker_id);
let mut resolved = ContextResolver::resolve_with_scope(&store, worker_id, &scope);
let mut instruction = super::worker::ManagerInstruction::new();
if let Some(prev_guidance) = context.previous_guidances.get(&worker_id) {
instruction = super::worker::ManagerInstruction::from_guidance(prev_guidance);
}
if instruction.has_content() {
resolved.manager_instruction = Some(instruction);
}
WorkerDecisionRequest {
worker_id,
query: task_goal.clone(),
context: resolved,
lora: None,
}
})
.collect();
BatchDecisionRequest {
manager_id: self.id,
requests,
}
}
fn finalize(
&self,
context: &TaskContext,
responses: Vec<(WorkerId, DecisionResponse)>,
) -> ManagementDecision {
let tick = context.tick;
if let Some(ref v2_guidances) = context.v2_guidances {
let worker_ids = context.worker_ids();
let mut guidances = HashMap::new();
for (i, worker_id) in worker_ids.iter().enumerate() {
if context.done_workers.contains(worker_id) {
continue;
}
let mut guidance = v2_guidances
.get(i)
.cloned()
.unwrap_or_else(|| self.default_guidance());
guidance.scope = self
.config
.scope_strategy
.determine_scope(context, *worker_id);
guidances.insert(*worker_id, guidance);
}
return ManagementDecision {
guidances,
strategy_update: None,
async_tasks: vec![],
};
}
if responses.is_empty() {
let mut guidances = HashMap::new();
for worker_id in context.worker_ids().iter() {
let mut guidance = self.default_guidance();
guidance.scope = self
.config
.scope_strategy
.determine_scope(context, *worker_id);
guidances.insert(*worker_id, guidance);
}
return ManagementDecision {
guidances,
strategy_update: None,
async_tasks: vec![],
};
}
self.last_process_tick.store(tick, Ordering::Relaxed);
let mut guidances = HashMap::new();
for (worker_id, response) in responses.iter() {
let mut guidance = self.response_to_guidance(response);
guidance.scope = self
.config
.scope_strategy
.determine_scope(context, *worker_id);
guidances.insert(*worker_id, guidance);
}
ManagementDecision {
guidances,
strategy_update: None,
async_tasks: vec![],
}
}
fn id(&self) -> ManagerId {
self.id
}
fn name(&self) -> &str {
&self.name
}
}
pub struct DefaultBatchManagerAgentBuilder {
id: ManagerId,
name: Option<String>,
config: DefaultManagerConfig,
}
impl DefaultBatchManagerAgentBuilder {
pub fn new(id: ManagerId) -> Self {
Self {
id,
name: None,
config: DefaultManagerConfig::default(),
}
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn config(mut self, config: DefaultManagerConfig) -> Self {
self.config = config;
self
}
pub fn interval(mut self, ticks: u64) -> Self {
self.config.process_interval_ticks = ticks;
self
}
pub fn immediate_on_escalation(mut self, enabled: bool) -> Self {
self.config.immediate_on_escalation = enabled;
self
}
pub fn confidence_threshold(mut self, threshold: f64) -> Self {
self.config.confidence_threshold = threshold;
self
}
pub fn build(self) -> DefaultBatchManagerAgent {
let mut agent = DefaultBatchManagerAgent::new(self.id).with_config(self.config);
if let Some(name) = self.name {
agent = agent.with_name(name);
}
agent
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::WorkerSummary;
fn sample_context() -> TaskContext {
TaskContext::new(10)
.with_worker(WorkerSummary::new(WorkerId(0)))
.with_worker(WorkerSummary::new(WorkerId(1)).with_escalation(true))
.with_success_rate(0.8)
.with_progress(0.5)
}
#[test]
fn test_default_manager_new() {
let manager = DefaultBatchManagerAgent::new(ManagerId(0));
assert_eq!(manager.id(), ManagerId(0));
assert_eq!(manager.name(), "DefaultManager_0");
}
#[test]
fn test_default_manager_with_name() {
let manager = DefaultBatchManagerAgent::new(ManagerId(1)).with_name("TestManager");
assert_eq!(manager.name(), "TestManager");
}
#[test]
fn test_prepare_with_context() {
let manager = DefaultBatchManagerAgent::new(ManagerId(0));
let context = sample_context();
let request = manager.prepare(&context);
assert_eq!(request.manager_id, ManagerId(0));
assert_eq!(request.requests.len(), 2); }
#[test]
fn test_finalize_empty_responses() {
let manager = DefaultBatchManagerAgent::new(ManagerId(0));
let context = sample_context();
let decision = manager.finalize(&context, vec![]);
assert_eq!(decision.guidances.len(), 2);
for guidance in decision.guidances.values() {
assert_eq!(guidance.actions.len(), 1);
assert_eq!(guidance.actions[0].name, "Continue");
}
}
#[test]
fn test_response_to_guidance() {
let manager = DefaultBatchManagerAgent::new(ManagerId(0));
let response = DecisionResponse {
tool: "Read".to_string(),
target: "/path/to/file".to_string(),
args: HashMap::new(),
reasoning: Some("Need to read file".to_string()),
confidence: 0.8,
prompt: None,
raw_response: None,
};
let guidance = manager.response_to_guidance(&response);
assert_eq!(guidance.actions.len(), 1);
assert_eq!(guidance.actions[0].name, "Read");
assert_eq!(
guidance.actions[0].params.target,
Some("/path/to/file".to_string())
);
}
#[test]
fn test_low_confidence_falls_back_to_continue() {
let manager = DefaultBatchManagerAgent::new(ManagerId(0));
let response = DecisionResponse {
tool: "Read".to_string(),
target: "/path".to_string(),
args: HashMap::new(),
reasoning: None,
confidence: 0.1, prompt: None,
raw_response: None,
};
let guidance = manager.response_to_guidance(&response);
assert_eq!(guidance.actions[0].name, "Continue");
}
#[test]
fn test_builder() {
let manager = DefaultBatchManagerAgentBuilder::new(ManagerId(2))
.name("CustomManager")
.interval(10)
.confidence_threshold(0.5)
.build();
assert_eq!(manager.id(), ManagerId(2));
assert_eq!(manager.name(), "CustomManager");
assert_eq!(manager.config.process_interval_ticks, 10);
assert!((manager.config.confidence_threshold - 0.5).abs() < 0.001);
}
}