use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::Handle;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use crate::agent::{BatchInvoker, ManagerAgent, WorkerAgent};
use crate::error::SwarmError;
use crate::events::{ActionEventPublisher, LearningEventChannel, LifecycleHook, TraceSubscriber};
use crate::exploration::{DependencyGraph, DependencyGraphProvider, NodeRules, OperatorProvider};
use crate::extensions::Extensions;
use crate::orchestrator::{Orchestrator, OrchestratorBuilder, SwarmConfig, SwarmResult};
use crate::types::SwarmTask;
use super::daemon::{
ActionEventSubscriber, DaemonConfig, DaemonError, EventSubscriberConfig, LearningDaemon,
LearningEventSubscriber,
};
use super::profile_adapter::profile_to_offline_model;
use super::scenario_profile::ScenarioProfile;
use super::snapshot::LearningStore;
use super::trigger::{TrainTrigger, TriggerBuilder};
use super::LearningSnapshot;
use super::OfflineModel;
type DaemonHandle = JoinHandle<Result<(), DaemonError>>;
#[derive(Default)]
struct LearningSetupResult {
daemon_handle: Option<DaemonHandle>,
subscriber_handles: Vec<JoinHandle<()>>,
shutdown_tx: Option<mpsc::Sender<()>>,
}
#[derive(Debug, Clone)]
pub struct LearnableSwarmConfig {
pub scenario: String,
pub data_dir: PathBuf,
pub learning_enabled: bool,
pub subscriber_batch_size: usize,
pub subscriber_flush_interval_ms: u64,
pub daemon_check_interval: Duration,
}
impl Default for LearnableSwarmConfig {
fn default() -> Self {
Self {
scenario: String::new(),
data_dir: default_learning_dir(),
learning_enabled: false,
subscriber_batch_size: 100,
subscriber_flush_interval_ms: 100,
daemon_check_interval: Duration::from_secs(10),
}
}
}
impl LearnableSwarmConfig {
pub fn new(scenario: impl Into<String>) -> Self {
Self {
scenario: scenario.into(),
..Default::default()
}
}
pub fn with_learning(mut self, enabled: bool) -> Self {
self.learning_enabled = enabled;
self
}
pub fn data_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.data_dir = path.into();
self
}
}
fn default_learning_dir() -> PathBuf {
dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("swarm-engine")
.join("learning")
}
pub struct LearnableSwarmBuilder {
runtime: Handle,
config: LearnableSwarmConfig,
swarm_config: Option<SwarmConfig>,
workers: Vec<Box<dyn WorkerAgent>>,
managers: Vec<Box<dyn ManagerAgent>>,
batch_invoker: Option<Box<dyn BatchInvoker>>,
dependency_provider: Option<Box<dyn DependencyGraphProvider>>,
operator_provider: Option<Box<dyn OperatorProvider<NodeRules>>>,
extensions: Option<Extensions>,
dependency_graph: Option<DependencyGraph>,
offline_model: Option<OfflineModel>,
prior_snapshot: Option<LearningSnapshot>,
learning_store: Option<LearningStore>,
train_trigger: Option<Arc<dyn TrainTrigger>>,
lifecycle_hook: Option<Box<dyn LifecycleHook>>,
enable_exploration: bool,
deferred_error: Option<SwarmError>,
trace_subscriber: Option<Arc<dyn TraceSubscriber>>,
}
impl LearnableSwarmBuilder {
pub fn new(runtime: Handle) -> Self {
Self {
runtime,
config: LearnableSwarmConfig::default(),
swarm_config: None,
workers: Vec::new(),
managers: Vec::new(),
batch_invoker: None,
dependency_provider: None,
operator_provider: None,
extensions: None,
dependency_graph: None,
offline_model: None,
prior_snapshot: None,
learning_store: None,
train_trigger: None,
lifecycle_hook: None,
enable_exploration: false,
deferred_error: None,
trace_subscriber: None,
}
}
pub fn scenario(mut self, name: impl Into<String>) -> Self {
self.config.scenario = name.into();
self
}
pub fn with_learning(mut self, enabled: bool) -> Self {
self.config.learning_enabled = enabled;
self
}
pub fn data_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.config.data_dir = path.into();
self
}
pub fn swarm_config(mut self, config: SwarmConfig) -> Self {
self.swarm_config = Some(config);
self
}
pub fn add_worker(mut self, worker: Box<dyn WorkerAgent>) -> Self {
self.workers.push(worker);
self
}
pub fn workers(mut self, workers: Vec<Box<dyn WorkerAgent>>) -> Self {
self.workers = workers;
self
}
pub fn add_manager(mut self, manager: Box<dyn ManagerAgent>) -> Self {
self.managers.push(manager);
self
}
pub fn managers(mut self, managers: Vec<Box<dyn ManagerAgent>>) -> Self {
self.managers = managers;
self
}
pub fn batch_invoker(mut self, invoker: Box<dyn BatchInvoker>) -> Self {
self.batch_invoker = Some(invoker);
self
}
pub fn dependency_provider(mut self, provider: Box<dyn DependencyGraphProvider>) -> Self {
self.dependency_provider = Some(provider);
self
}
pub fn operator_provider(mut self, provider: Box<dyn OperatorProvider<NodeRules>>) -> Self {
self.operator_provider = Some(provider);
self
}
pub fn extensions(mut self, extensions: Extensions) -> Self {
self.extensions = Some(extensions);
self
}
pub fn dependency_graph(mut self, graph: DependencyGraph) -> Self {
self.dependency_graph = Some(graph);
self
}
pub fn offline_model(mut self, model: OfflineModel) -> Self {
self.offline_model = Some(model);
self
}
pub fn with_scenario_profile(mut self, profile: &ScenarioProfile) -> Self {
let model = profile_to_offline_model(profile);
self.offline_model = Some(model);
if self.config.scenario.is_empty() {
self.config.scenario = profile.id.0.clone();
}
self
}
pub fn offline_model_ref(&self) -> Option<&OfflineModel> {
self.offline_model.as_ref()
}
pub fn prior_snapshot(mut self, snapshot: LearningSnapshot) -> Self {
self.prior_snapshot = Some(snapshot);
self
}
pub fn learning_store(mut self, store: LearningStore) -> Self {
self.learning_store = Some(store);
self
}
pub fn with_learning_store(mut self, store: LearningStore) -> Self {
if let Err(e) = self.load_from_store(&store) {
self.deferred_error = Some(e);
}
self.config.data_dir = store.storage().base_dir().to_path_buf();
self.config.learning_enabled = true;
self.learning_store = Some(store);
self
}
pub fn train_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
self.train_trigger = Some(trigger);
self
}
pub fn lifecycle_hook(mut self, hook: Box<dyn LifecycleHook>) -> Self {
self.lifecycle_hook = Some(hook);
self
}
pub fn enable_exploration(mut self, enabled: bool) -> Self {
self.enable_exploration = enabled;
self
}
pub fn with_trace_subscriber(mut self, subscriber: Arc<dyn TraceSubscriber>) -> Self {
self.trace_subscriber = Some(subscriber);
self
}
pub fn with_learning_store_path(mut self, path: impl AsRef<std::path::Path>) -> Self {
match LearningStore::new(&path) {
Ok(store) => {
if let Err(e) = self.load_from_store(&store) {
self.deferred_error = Some(e);
}
self.config.learning_enabled = true;
self.learning_store = Some(store);
}
Err(e) => {
self.deferred_error = Some(SwarmError::Config {
message: format!(
"Failed to create LearningStore at '{}': {}",
path.as_ref().display(),
e
),
});
}
}
self
}
fn load_from_store(&mut self, store: &LearningStore) -> Result<(), SwarmError> {
let scenario_key = &self.config.scenario;
match store.load_scenario(scenario_key) {
Ok(snapshot) => {
self.prior_snapshot = Some(snapshot);
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
tracing::debug!(scenario = %scenario_key, "No prior snapshot found (first run)");
}
Err(e) => {
return Err(SwarmError::Config {
message: format!(
"Failed to load prior snapshot for '{}': {}",
scenario_key, e
),
});
}
}
match store.load_offline_model(scenario_key) {
Ok(model) => {
tracing::debug!(
ucb1_c = model.parameters.ucb1_c,
strategy = %model.strategy_config.initial_strategy,
action_order = model.action_order.is_some(),
"Offline model loaded"
);
self.offline_model = Some(model);
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
tracing::debug!(scenario = %scenario_key, "No offline model found (first run)");
}
Err(e) => {
return Err(SwarmError::Config {
message: format!("Failed to load offline model for '{}': {}", scenario_key, e),
});
}
}
Ok(())
}
pub fn build(mut self) -> Result<LearnableSwarm, SwarmError> {
if let Some(err) = self.deferred_error.take() {
return Err(err);
}
if self.config.learning_enabled && self.config.scenario.is_empty() {
return Err(SwarmError::Config {
message: "scenario is required when learning is enabled".into(),
});
}
if self.config.learning_enabled {
LearningEventChannel::global().enable();
}
let swarm_config = self.swarm_config.take().unwrap_or_default();
let (action_publisher, _initial_rx) = ActionEventPublisher::new(1024);
let mut subscriber_handles_trace = Vec::new();
if let Some(trace_subscriber) = self.trace_subscriber.take() {
let rx = action_publisher.subscribe();
let handle = self.runtime.spawn(async move {
run_trace_subscriber(rx, trace_subscriber).await;
});
subscriber_handles_trace.push(handle);
}
let LearningSetupResult {
daemon_handle,
mut subscriber_handles,
shutdown_tx,
} = if self.config.learning_enabled {
self.setup_learning_components(&action_publisher)?
} else {
LearningSetupResult::default()
};
subscriber_handles.extend(subscriber_handles_trace);
let extensions = self.build_extensions();
let mut orch_builder = OrchestratorBuilder::new()
.config(swarm_config)
.extensions(extensions);
for worker in self.workers {
orch_builder = orch_builder.add_worker_boxed(worker);
}
for manager in self.managers {
orch_builder = orch_builder.add_manager_boxed(manager);
}
if let Some(invoker) = self.batch_invoker {
orch_builder = orch_builder.batch_invoker_boxed(invoker);
}
if let Some(provider) = self.dependency_provider {
orch_builder = orch_builder.dependency_provider_boxed(provider);
}
if let Some(provider) = self.operator_provider {
orch_builder = orch_builder.operator_provider_boxed(provider);
}
if self.enable_exploration {
orch_builder = orch_builder.with_exploration();
}
if let Some(ref model) = self.offline_model {
orch_builder = orch_builder.with_offline_model(model.clone());
}
if let Some(hook) = self.lifecycle_hook {
orch_builder = orch_builder.lifecycle_hook(hook);
}
orch_builder = orch_builder.action_collector(action_publisher);
let orchestrator = orch_builder.build(self.runtime.clone());
Ok(LearnableSwarm {
orchestrator,
runtime: self.runtime,
config: self.config,
learning_store: self.learning_store,
offline_model: self.offline_model,
daemon_handle,
subscriber_handles,
shutdown_tx,
})
}
fn build_extensions(&mut self) -> Extensions {
let mut ext = self.extensions.take().unwrap_or_default();
if let Some(graph) = self.dependency_graph.take() {
ext.insert(graph);
}
if let Some(snapshot) = self.prior_snapshot.take() {
ext.insert(snapshot);
}
ext
}
fn setup_learning_components(
&self,
action_publisher: &ActionEventPublisher,
) -> Result<LearningSetupResult, SwarmError> {
let daemon_config = DaemonConfig::new(&self.config.scenario)
.data_dir(&self.config.data_dir)
.check_interval(self.config.daemon_check_interval);
let trigger = self
.train_trigger
.clone()
.unwrap_or_else(|| TriggerBuilder::never());
let mut daemon =
LearningDaemon::new(daemon_config, trigger).map_err(|e| SwarmError::Config {
message: format!("Failed to create LearningDaemon: {}", e),
})?;
let record_tx = daemon.record_sender();
let shutdown_tx = daemon.shutdown_sender();
let sub_config = EventSubscriberConfig::new()
.batch_size(self.config.subscriber_batch_size)
.flush_interval_ms(self.config.subscriber_flush_interval_ms);
let mut subscriber_handles = Vec::new();
let action_sub = ActionEventSubscriber::with_config(
action_publisher.subscribe(),
record_tx.clone(),
sub_config.clone(),
);
let action_handle = self.runtime.spawn(async move {
action_sub.run().await;
});
subscriber_handles.push(action_handle);
let learning_channel = LearningEventChannel::global();
let learning_sub = LearningEventSubscriber::with_config(
learning_channel.subscribe(),
record_tx,
sub_config,
);
let learning_handle = self.runtime.spawn(async move {
learning_sub.run().await;
});
subscriber_handles.push(learning_handle);
let daemon_handle = self.runtime.spawn(async move { daemon.run().await });
Ok(LearningSetupResult {
daemon_handle: Some(daemon_handle),
subscriber_handles,
shutdown_tx: Some(shutdown_tx),
})
}
}
pub struct LearnableSwarm {
orchestrator: Orchestrator,
runtime: Handle,
config: LearnableSwarmConfig,
learning_store: Option<LearningStore>,
offline_model: Option<OfflineModel>,
daemon_handle: Option<DaemonHandle>,
subscriber_handles: Vec<JoinHandle<()>>,
shutdown_tx: Option<mpsc::Sender<()>>,
}
impl LearnableSwarm {
pub fn run_task(&mut self, task: SwarmTask) -> Result<SwarmResult, SwarmError> {
self.orchestrator.run_task(task)
}
pub fn run(&mut self) -> SwarmResult {
self.orchestrator.run()
}
pub fn orchestrator(&self) -> &Orchestrator {
&self.orchestrator
}
pub fn orchestrator_mut(&mut self) -> &mut Orchestrator {
&mut self.orchestrator
}
pub fn dependency_graph(&self) -> Option<&DependencyGraph> {
self.orchestrator.dependency_graph()
}
pub fn config(&self) -> &LearnableSwarmConfig {
&self.config
}
pub fn learning_store(&self) -> Option<&LearningStore> {
self.learning_store.as_ref()
}
pub fn offline_model(&self) -> Option<&OfflineModel> {
self.offline_model.as_ref()
}
pub fn is_learning_enabled(&self) -> bool {
self.config.learning_enabled
}
pub fn emit_stats_snapshot(&self) {
use crate::events::{LearnStatsOutcome, LearningEvent};
use crate::util::epoch_millis;
let state = self.orchestrator.state();
let tick = state.shared.tick;
let total_actions = state.shared.stats.total_visits() as u64;
let stats_json = if let Some(provider) = self.orchestrator.learned_provider() {
provider
.stats()
.map(|stats| serde_json::to_string(stats).unwrap_or_default())
.unwrap_or_default()
} else {
String::new()
};
let session_id = format!("{}", epoch_millis());
let outcome = if state.shared.environment_done {
LearnStatsOutcome::Success { score: 1.0 }
} else {
LearnStatsOutcome::Timeout {
partial_score: None,
}
};
let event = LearningEvent::learn_stats_snapshot(&self.config.scenario)
.session_id(session_id)
.stats_json(stats_json)
.total_ticks(tick)
.total_actions(total_actions);
let event = match outcome {
LearnStatsOutcome::Success { score } => event.success(score),
LearnStatsOutcome::Timeout { partial_score } => event.timeout(partial_score),
LearnStatsOutcome::Failure { reason } => event.failure(reason),
};
LearningEventChannel::global().emit(event.build());
tracing::debug!(
scenario = %self.config.scenario,
tick = tick,
total_actions = total_actions,
"LearnStatsSnapshot emitted"
);
}
pub fn take_shutdown_tx(&mut self) -> Option<mpsc::Sender<()>> {
self.shutdown_tx.take()
}
pub async fn shutdown(self) {
if self.config.learning_enabled {
self.emit_stats_snapshot();
}
if let Some(tx) = self.shutdown_tx {
let _ = tx.send(()).await;
}
if let Some(handle) = self.daemon_handle {
match handle.await {
Ok(Ok(())) => {
tracing::debug!("LearningDaemon shutdown completed");
}
Ok(Err(e)) => {
tracing::warn!("LearningDaemon error on shutdown: {}", e);
}
Err(e) => {
tracing::warn!("LearningDaemon join error: {}", e);
}
}
}
for handle in self.subscriber_handles {
let _ = handle.await;
}
tracing::debug!("LearnableSwarm shutdown completed");
}
pub fn shutdown_blocking(self) {
let runtime = self.runtime.clone();
runtime.block_on(self.shutdown());
}
}
async fn run_trace_subscriber(
mut rx: tokio::sync::broadcast::Receiver<crate::events::ActionEvent>,
subscriber: Arc<dyn TraceSubscriber>,
) {
while let Ok(event) = rx.recv().await {
subscriber.on_event(&event);
}
subscriber.finish();
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::GenericWorker;
fn make_test_runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
}
#[test]
fn test_config_default() {
let config = LearnableSwarmConfig::default();
assert!(!config.learning_enabled);
assert!(config.scenario.is_empty());
}
#[test]
fn test_config_builder() {
let config = LearnableSwarmConfig::new("test-scenario")
.with_learning(true)
.data_dir("/tmp/test");
assert_eq!(config.scenario, "test-scenario");
assert!(config.learning_enabled);
assert_eq!(config.data_dir, PathBuf::from("/tmp/test"));
}
#[test]
fn test_builder_basic() {
let rt = make_test_runtime();
let builder = LearnableSwarmBuilder::new(rt.handle().clone())
.scenario("test")
.add_worker(Box::new(GenericWorker::new(0)));
assert_eq!(builder.config.scenario, "test");
assert_eq!(builder.workers.len(), 1);
}
#[test]
fn test_builder_with_learning() {
let rt = make_test_runtime();
let builder = LearnableSwarmBuilder::new(rt.handle().clone())
.scenario("test")
.with_learning(true)
.add_worker(Box::new(GenericWorker::new(0)));
assert!(builder.config.learning_enabled);
}
#[test]
fn test_builder_learning_without_scenario_fails() {
let rt = make_test_runtime();
let result = LearnableSwarmBuilder::new(rt.handle().clone())
.with_learning(true)
.add_worker(Box::new(GenericWorker::new(0)))
.build();
assert!(result.is_err());
if let Err(err) = result {
assert!(err.to_string().contains("scenario is required"));
}
}
#[test]
fn test_builder_learning_disabled_without_scenario_ok() {
let rt = make_test_runtime();
let result = LearnableSwarmBuilder::new(rt.handle().clone())
.add_worker(Box::new(GenericWorker::new(0)))
.build();
assert!(result.is_ok());
}
#[test]
fn test_builder_with_scenario_profile() {
use crate::learn::learned_component::LearnedExploration;
use crate::learn::scenario_profile::{ScenarioProfile, ScenarioSource};
let rt = make_test_runtime();
let mut profile =
ScenarioProfile::new("test-profile", ScenarioSource::from_path("/test.toml"));
profile.exploration = Some(LearnedExploration::new(2.5, 0.4, 1.2));
let builder = LearnableSwarmBuilder::new(rt.handle().clone())
.with_scenario_profile(&profile)
.add_worker(Box::new(GenericWorker::new(0)));
assert_eq!(builder.config.scenario, "test-profile");
assert!(builder.offline_model.is_some());
let model = builder.offline_model.as_ref().unwrap();
assert_eq!(model.parameters.ucb1_c, 2.5);
}
}