Skip to main content

punch_kernel/
background.rs

1//! Background executor for autonomous gorilla tasks.
2//!
3//! The [`BackgroundExecutor`] manages tokio tasks that run gorillas on their
4//! configured schedules. Each gorilla gets its own spawned task that sleeps
5//! for the configured interval, acquires a global LLM concurrency semaphore,
6//! and then runs the fighter loop with an autonomous prompt.
7
8use std::sync::Arc;
9
10use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use tokio::sync::{Semaphore, watch};
13use tokio::task::JoinHandle;
14use tracing::{error, info, warn};
15
16use punch_memory::MemorySubstrate;
17use punch_runtime::{
18    FighterLoopParams, FighterLoopResult, LlmDriver, run_fighter_loop, tools_for_capabilities,
19};
20use punch_types::{
21    FighterId, FighterManifest, GorillaId, GorillaManifest, ModelConfig, PunchResult, WeightClass,
22};
23
24/// Maximum concurrent LLM calls across all gorillas.
25const DEFAULT_LLM_CONCURRENCY: usize = 3;
26
27/// A running gorilla background task.
28struct GorillaTask {
29    handle: JoinHandle<()>,
30    #[allow(dead_code)]
31    started_at: DateTime<Utc>,
32}
33
34/// Manages background gorilla tasks that run autonomously on schedules.
35pub struct BackgroundExecutor {
36    /// Running gorilla tasks.
37    tasks: DashMap<GorillaId, GorillaTask>,
38    /// Global LLM concurrency limiter.
39    llm_semaphore: Arc<Semaphore>,
40    /// Shutdown signal sender (kept alive to prevent channel closure).
41    _shutdown_tx: watch::Sender<bool>,
42    /// Shutdown signal receiver (cloned for each gorilla task).
43    shutdown_rx: watch::Receiver<bool>,
44}
45
46/// Build a [`FighterManifest`] from a [`GorillaManifest`], using the provided
47/// `default_model` as a fallback when the gorilla does not specify its own model.
48pub fn fighter_manifest_from_gorilla(
49    manifest: &GorillaManifest,
50    default_model: &ModelConfig,
51) -> FighterManifest {
52    let model = manifest
53        .model
54        .clone()
55        .unwrap_or_else(|| default_model.clone());
56    let capabilities = manifest.effective_capabilities();
57    let weight_class = manifest.weight_class.unwrap_or(WeightClass::Middleweight);
58    let system_prompt = manifest.effective_system_prompt();
59
60    FighterManifest {
61        name: manifest.name.clone(),
62        description: format!("Autonomous gorilla: {}", manifest.name),
63        model,
64        system_prompt,
65        capabilities,
66        weight_class,
67        tenant_id: None,
68    }
69}
70
71/// Run a single autonomous tick for a gorilla. This is the reusable core that
72/// both the background scheduler and the CLI `gorilla test` command invoke.
73pub async fn run_gorilla_tick(
74    gorilla_id: GorillaId,
75    manifest: &GorillaManifest,
76    default_model: &ModelConfig,
77    memory: &Arc<MemorySubstrate>,
78    driver: &Arc<dyn LlmDriver>,
79) -> PunchResult<FighterLoopResult> {
80    let fighter_manifest = fighter_manifest_from_gorilla(manifest, default_model);
81    let gorilla_name = &manifest.name;
82    let system_prompt = fighter_manifest.system_prompt.clone();
83
84    // Build the autonomous prompt.
85    let autonomous_prompt = format!(
86        "[AUTONOMOUS TICK] You are {}. Review your memory, check your goals, and take the next action. {}",
87        gorilla_name, system_prompt
88    );
89
90    // Create a temporary fighter identity for this gorilla tick.
91    let fighter_id = FighterId::new();
92
93    // Save the fighter first (required for FK constraint on bout creation).
94    if let Err(e) = memory
95        .save_fighter(
96            &fighter_id,
97            &fighter_manifest,
98            punch_types::FighterStatus::Idle,
99        )
100        .await
101    {
102        warn!(gorilla_id = %gorilla_id, error = %e, "failed to persist gorilla fighter");
103    }
104
105    // Create a bout for this tick.
106    let bout_id = memory.create_bout(&fighter_id).await?;
107
108    let available_tools = tools_for_capabilities(&fighter_manifest.capabilities);
109
110    let params = FighterLoopParams {
111        manifest: fighter_manifest,
112        user_message: autonomous_prompt,
113        bout_id,
114        fighter_id,
115        memory: Arc::clone(memory),
116        driver: Arc::clone(driver),
117        available_tools,
118        max_iterations: Some(10),
119        context_window: None,
120        tool_timeout_secs: None,
121        coordinator: None,
122        approval_engine: None,
123        sandbox: None,
124    };
125
126    run_fighter_loop(params).await
127}
128
129impl BackgroundExecutor {
130    /// Create a new background executor.
131    pub fn new() -> Self {
132        let (shutdown_tx, shutdown_rx) = watch::channel(false);
133        Self {
134            tasks: DashMap::new(),
135            llm_semaphore: Arc::new(Semaphore::new(DEFAULT_LLM_CONCURRENCY)),
136            _shutdown_tx: shutdown_tx,
137            shutdown_rx,
138        }
139    }
140
141    /// Create a new background executor with a custom shutdown channel.
142    pub fn with_shutdown(
143        shutdown_tx: watch::Sender<bool>,
144        shutdown_rx: watch::Receiver<bool>,
145    ) -> Self {
146        Self {
147            tasks: DashMap::new(),
148            llm_semaphore: Arc::new(Semaphore::new(DEFAULT_LLM_CONCURRENCY)),
149            _shutdown_tx: shutdown_tx,
150            shutdown_rx,
151        }
152    }
153
154    /// Parse a schedule string like "every 30s", "every 5m", "every 1h", "every 1d"
155    /// into a [`std::time::Duration`].
156    pub fn parse_schedule(schedule: &str) -> Option<std::time::Duration> {
157        let s = schedule.trim().to_lowercase();
158        let s = s.strip_prefix("every ").unwrap_or(&s);
159        let s = s.trim();
160
161        if let Some(num_str) = s.strip_suffix('s') {
162            num_str
163                .trim()
164                .parse::<u64>()
165                .ok()
166                .map(std::time::Duration::from_secs)
167        } else if let Some(num_str) = s.strip_suffix('m') {
168            num_str
169                .trim()
170                .parse::<u64>()
171                .ok()
172                .map(|m| std::time::Duration::from_secs(m * 60))
173        } else if let Some(num_str) = s.strip_suffix('h') {
174            num_str
175                .trim()
176                .parse::<u64>()
177                .ok()
178                .map(|h| std::time::Duration::from_secs(h * 3600))
179        } else if let Some(num_str) = s.strip_suffix('d') {
180            num_str
181                .trim()
182                .parse::<u64>()
183                .ok()
184                .map(|d| std::time::Duration::from_secs(d * 86400))
185        } else {
186            // Try to parse as raw seconds.
187            s.parse::<u64>().ok().map(std::time::Duration::from_secs)
188        }
189    }
190
191    /// Start a gorilla's autonomous background task.
192    ///
193    /// The task will loop on the gorilla's schedule, acquiring the LLM
194    /// semaphore before each run, and executing the fighter loop with an
195    /// autonomous prompt derived from the gorilla's manifest.
196    ///
197    /// `default_model` is used as a fallback when the gorilla manifest does
198    /// not specify its own `model` configuration.
199    pub fn start_gorilla(
200        &self,
201        id: GorillaId,
202        manifest: GorillaManifest,
203        default_model: ModelConfig,
204        memory: Arc<MemorySubstrate>,
205        driver: Arc<dyn LlmDriver>,
206    ) -> PunchResult<()> {
207        if self.tasks.contains_key(&id) {
208            return Err(punch_types::PunchError::Gorilla(format!(
209                "gorilla {} is already running",
210                id
211            )));
212        }
213
214        let interval = Self::parse_schedule(&manifest.schedule).unwrap_or_else(|| {
215            warn!(
216                gorilla_id = %id,
217                schedule = %manifest.schedule,
218                "could not parse schedule, defaulting to 5m"
219            );
220            std::time::Duration::from_secs(300)
221        });
222
223        let semaphore = Arc::clone(&self.llm_semaphore);
224        let mut shutdown_rx = self.shutdown_rx.clone();
225        let gorilla_name = manifest.name.clone();
226
227        let handle = tokio::spawn(async move {
228            info!(
229                gorilla_id = %id,
230                name = %gorilla_name,
231                interval_secs = interval.as_secs(),
232                "gorilla background task started"
233            );
234
235            let mut tasks_completed: u64 = 0;
236            let mut error_count: u64 = 0;
237
238            loop {
239                // Sleep for the interval, checking shutdown signal.
240                tokio::select! {
241                    _ = tokio::time::sleep(interval) => {},
242                    _ = shutdown_rx.changed() => {
243                        if *shutdown_rx.borrow() {
244                            info!(gorilla_id = %id, "gorilla received shutdown signal");
245                            break;
246                        }
247                    }
248                }
249
250                // Check shutdown before proceeding.
251                if *shutdown_rx.borrow() {
252                    break;
253                }
254
255                // Acquire semaphore permit.
256                let _permit = match semaphore.acquire().await {
257                    Ok(permit) => permit,
258                    Err(_) => {
259                        warn!(gorilla_id = %id, "semaphore closed, stopping gorilla");
260                        break;
261                    }
262                };
263
264                match run_gorilla_tick(id, &manifest, &default_model, &memory, &driver).await {
265                    Ok(result) => {
266                        tasks_completed += 1;
267                        info!(
268                            gorilla_id = %id,
269                            tasks_completed,
270                            tokens = result.usage.total(),
271                            "gorilla tick completed successfully"
272                        );
273                    }
274                    Err(e) => {
275                        error_count += 1;
276                        error!(
277                            gorilla_id = %id,
278                            error = %e,
279                            error_count,
280                            "gorilla tick failed"
281                        );
282                    }
283                }
284            }
285
286            info!(
287                gorilla_id = %id,
288                tasks_completed,
289                "gorilla background task stopped"
290            );
291        });
292
293        self.tasks.insert(
294            id,
295            GorillaTask {
296                handle,
297                started_at: Utc::now(),
298            },
299        );
300
301        Ok(())
302    }
303
304    /// Stop a gorilla's background task by aborting it.
305    pub fn stop_gorilla(&self, id: &GorillaId) -> bool {
306        if let Some((_, task)) = self.tasks.remove(id) {
307            task.handle.abort();
308            info!(gorilla_id = %id, "gorilla task stopped");
309            true
310        } else {
311            false
312        }
313    }
314
315    /// Check whether a gorilla is currently running.
316    pub fn is_running(&self, id: &GorillaId) -> bool {
317        self.tasks.contains_key(id)
318    }
319
320    /// List all currently running gorilla IDs.
321    pub fn list_running(&self) -> Vec<GorillaId> {
322        self.tasks.iter().map(|entry| *entry.key()).collect()
323    }
324
325    /// Shutdown all running gorilla tasks.
326    pub fn shutdown_all(&self) {
327        let ids: Vec<GorillaId> = self.tasks.iter().map(|e| *e.key()).collect();
328        for id in &ids {
329            if let Some((_, task)) = self.tasks.remove(id) {
330                task.handle.abort();
331            }
332        }
333        info!(count = ids.len(), "all gorilla tasks shut down");
334    }
335
336    /// Returns the number of currently running gorilla tasks.
337    pub fn running_count(&self) -> usize {
338        self.tasks.len()
339    }
340}
341
342impl Default for BackgroundExecutor {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348// ---------------------------------------------------------------------------
349// Tests
350// ---------------------------------------------------------------------------
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn parse_schedule_seconds() {
358        assert_eq!(
359            BackgroundExecutor::parse_schedule("every 30s"),
360            Some(std::time::Duration::from_secs(30))
361        );
362    }
363
364    #[test]
365    fn parse_schedule_minutes() {
366        assert_eq!(
367            BackgroundExecutor::parse_schedule("every 5m"),
368            Some(std::time::Duration::from_secs(300))
369        );
370    }
371
372    #[test]
373    fn parse_schedule_hours() {
374        assert_eq!(
375            BackgroundExecutor::parse_schedule("every 1h"),
376            Some(std::time::Duration::from_secs(3600))
377        );
378    }
379
380    #[test]
381    fn parse_schedule_days() {
382        assert_eq!(
383            BackgroundExecutor::parse_schedule("every 1d"),
384            Some(std::time::Duration::from_secs(86400))
385        );
386    }
387
388    #[test]
389    fn parse_schedule_invalid() {
390        assert_eq!(BackgroundExecutor::parse_schedule("invalid"), None);
391    }
392
393    #[tokio::test]
394    async fn start_and_stop_gorilla() {
395        let executor = BackgroundExecutor::new();
396        let id = GorillaId::new();
397        let _manifest = GorillaManifest {
398            name: "test-gorilla".to_string(),
399            description: "test".to_string(),
400            schedule: "every 30s".to_string(),
401            moves_required: Vec::new(),
402            settings_schema: None,
403            dashboard_metrics: Vec::new(),
404            system_prompt: None,
405            model: None,
406            capabilities: Vec::new(),
407            weight_class: None,
408        };
409
410        // We can't actually run the gorilla loop without a real driver/memory,
411        // but we can test the task management.
412        let handle = tokio::spawn(async {
413            futures::future::pending::<()>().await;
414        });
415
416        executor.tasks.insert(
417            id,
418            GorillaTask {
419                handle,
420                started_at: Utc::now(),
421            },
422        );
423
424        assert_eq!(executor.running_count(), 1);
425        assert!(executor.list_running().contains(&id));
426
427        assert!(executor.stop_gorilla(&id));
428        assert_eq!(executor.running_count(), 0);
429    }
430
431    #[tokio::test]
432    async fn shutdown_all_stops_everything() {
433        let executor = BackgroundExecutor::new();
434
435        for _ in 0..3 {
436            let id = GorillaId::new();
437            let handle = tokio::spawn(async {
438                futures::future::pending::<()>().await;
439            });
440            executor.tasks.insert(
441                id,
442                GorillaTask {
443                    handle,
444                    started_at: Utc::now(),
445                },
446            );
447        }
448
449        assert_eq!(executor.running_count(), 3);
450        executor.shutdown_all();
451        assert_eq!(executor.running_count(), 0);
452    }
453
454    #[tokio::test]
455    async fn stop_nonexistent_gorilla_returns_false() {
456        let executor = BackgroundExecutor::new();
457        let id = GorillaId::new();
458        assert!(!executor.stop_gorilla(&id));
459    }
460
461    #[test]
462    fn parse_schedule_raw_seconds() {
463        assert_eq!(
464            BackgroundExecutor::parse_schedule("60"),
465            Some(std::time::Duration::from_secs(60))
466        );
467    }
468
469    #[test]
470    fn parse_schedule_with_whitespace() {
471        assert_eq!(
472            BackgroundExecutor::parse_schedule("  every  10s  "),
473            Some(std::time::Duration::from_secs(10))
474        );
475    }
476
477    #[test]
478    fn parse_schedule_case_insensitive() {
479        assert_eq!(
480            BackgroundExecutor::parse_schedule("Every 2H"),
481            Some(std::time::Duration::from_secs(7200))
482        );
483    }
484
485    #[test]
486    fn parse_schedule_empty_string() {
487        assert_eq!(BackgroundExecutor::parse_schedule(""), None);
488    }
489
490    #[test]
491    fn parse_schedule_just_prefix() {
492        assert_eq!(BackgroundExecutor::parse_schedule("every "), None);
493    }
494
495    #[test]
496    fn default_creates_executor() {
497        let executor = BackgroundExecutor::default();
498        assert_eq!(executor.running_count(), 0);
499        assert!(executor.list_running().is_empty());
500    }
501
502    #[tokio::test]
503    async fn is_running_returns_correct_state() {
504        let executor = BackgroundExecutor::new();
505        let id = GorillaId::new();
506
507        assert!(!executor.is_running(&id));
508
509        let handle = tokio::spawn(async {
510            futures::future::pending::<()>().await;
511        });
512        executor.tasks.insert(
513            id,
514            GorillaTask {
515                handle,
516                started_at: Utc::now(),
517            },
518        );
519
520        assert!(executor.is_running(&id));
521        executor.stop_gorilla(&id);
522        assert!(!executor.is_running(&id));
523    }
524
525    #[tokio::test]
526    async fn multiple_gorillas_tracked_independently() {
527        let executor = BackgroundExecutor::new();
528        let ids: Vec<GorillaId> = (0..5).map(|_| GorillaId::new()).collect();
529
530        for &id in &ids {
531            let handle = tokio::spawn(async {
532                futures::future::pending::<()>().await;
533            });
534            executor.tasks.insert(
535                id,
536                GorillaTask {
537                    handle,
538                    started_at: Utc::now(),
539                },
540            );
541        }
542
543        assert_eq!(executor.running_count(), 5);
544
545        // Stop the first two.
546        executor.stop_gorilla(&ids[0]);
547        executor.stop_gorilla(&ids[1]);
548        assert_eq!(executor.running_count(), 3);
549
550        // The remaining three should still be running.
551        for &id in &ids[2..] {
552            assert!(executor.is_running(&id));
553        }
554
555        executor.shutdown_all();
556        assert_eq!(executor.running_count(), 0);
557    }
558
559    #[tokio::test]
560    async fn with_shutdown_receives_shutdown_signal() {
561        let (tx, rx) = watch::channel(false);
562        let executor = BackgroundExecutor::with_shutdown(tx.clone(), rx);
563
564        let id = GorillaId::new();
565        let handle = tokio::spawn(async {
566            futures::future::pending::<()>().await;
567        });
568        executor.tasks.insert(
569            id,
570            GorillaTask {
571                handle,
572                started_at: Utc::now(),
573            },
574        );
575
576        assert_eq!(executor.running_count(), 1);
577        executor.shutdown_all();
578        assert_eq!(executor.running_count(), 0);
579    }
580
581    #[test]
582    fn fighter_manifest_from_gorilla_uses_default_model() {
583        use punch_types::{ModelConfig, Provider};
584
585        let manifest = GorillaManifest {
586            name: "test-gorilla".to_string(),
587            description: "A test gorilla".to_string(),
588            schedule: "every 30s".to_string(),
589            moves_required: Vec::new(),
590            settings_schema: None,
591            dashboard_metrics: Vec::new(),
592            system_prompt: Some("Custom prompt".to_string()),
593            model: None,
594            capabilities: Vec::new(),
595            weight_class: None,
596        };
597
598        let default_model = ModelConfig {
599            provider: Provider::Anthropic,
600            model: "claude-sonnet-4-20250514".to_string(),
601            api_key_env: None,
602            base_url: None,
603            max_tokens: Some(4096),
604            temperature: Some(0.7),
605        };
606
607        let fighter = fighter_manifest_from_gorilla(&manifest, &default_model);
608        assert_eq!(fighter.name, "test-gorilla");
609        assert_eq!(fighter.model.model, "claude-sonnet-4-20250514");
610        assert_eq!(fighter.system_prompt, "Custom prompt");
611        assert_eq!(fighter.weight_class, punch_types::WeightClass::Middleweight);
612    }
613
614    #[test]
615    fn fighter_manifest_from_gorilla_uses_gorilla_model_if_set() {
616        use punch_types::{ModelConfig, Provider};
617
618        let gorilla_model = ModelConfig {
619            provider: Provider::OpenAI,
620            model: "gpt-4o".to_string(),
621            api_key_env: None,
622            base_url: None,
623            max_tokens: Some(8192),
624            temperature: Some(0.5),
625        };
626
627        let manifest = GorillaManifest {
628            name: "smart-gorilla".to_string(),
629            description: "Uses its own model".to_string(),
630            schedule: "every 1h".to_string(),
631            moves_required: Vec::new(),
632            settings_schema: None,
633            dashboard_metrics: Vec::new(),
634            system_prompt: None,
635            model: Some(gorilla_model),
636            capabilities: Vec::new(),
637            weight_class: Some(punch_types::WeightClass::Heavyweight),
638        };
639
640        let default_model = ModelConfig {
641            provider: Provider::Anthropic,
642            model: "claude-sonnet-4-20250514".to_string(),
643            api_key_env: None,
644            base_url: None,
645            max_tokens: Some(4096),
646            temperature: Some(0.7),
647        };
648
649        let fighter = fighter_manifest_from_gorilla(&manifest, &default_model);
650        assert_eq!(fighter.model.model, "gpt-4o");
651        assert_eq!(fighter.weight_class, punch_types::WeightClass::Heavyweight);
652        // system_prompt falls back to description when None.
653        assert_eq!(fighter.system_prompt, "Uses its own model");
654    }
655
656    #[tokio::test]
657    async fn list_running_returns_all_ids() {
658        let executor = BackgroundExecutor::new();
659        let mut expected_ids = Vec::new();
660
661        for _ in 0..3 {
662            let id = GorillaId::new();
663            expected_ids.push(id);
664            let handle = tokio::spawn(async {
665                futures::future::pending::<()>().await;
666            });
667            executor.tasks.insert(
668                id,
669                GorillaTask {
670                    handle,
671                    started_at: Utc::now(),
672                },
673            );
674        }
675
676        let running = executor.list_running();
677        assert_eq!(running.len(), 3);
678        for id in &expected_ids {
679            assert!(running.contains(id));
680        }
681
682        executor.shutdown_all();
683    }
684}