1use std::sync::Arc;
9
10use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use tokio::sync::{Semaphore, watch};
13use tokio::task::JoinHandle;
14use tracing::{error, info, warn};
15
16use punch_memory::MemorySubstrate;
17use punch_runtime::{
18 FighterLoopParams, FighterLoopResult, LlmDriver, run_fighter_loop, tools_for_capabilities,
19};
20use punch_types::{
21 FighterId, FighterManifest, GorillaId, GorillaManifest, ModelConfig, PunchResult, WeightClass,
22};
23
24const DEFAULT_LLM_CONCURRENCY: usize = 3;
26
27struct GorillaTask {
29 handle: JoinHandle<()>,
30 #[allow(dead_code)]
31 started_at: DateTime<Utc>,
32}
33
34pub struct BackgroundExecutor {
36 tasks: DashMap<GorillaId, GorillaTask>,
38 llm_semaphore: Arc<Semaphore>,
40 _shutdown_tx: watch::Sender<bool>,
42 shutdown_rx: watch::Receiver<bool>,
44}
45
46pub fn fighter_manifest_from_gorilla(
49 manifest: &GorillaManifest,
50 default_model: &ModelConfig,
51) -> FighterManifest {
52 let model = manifest
53 .model
54 .clone()
55 .unwrap_or_else(|| default_model.clone());
56 let capabilities = manifest.effective_capabilities();
57 let weight_class = manifest.weight_class.unwrap_or(WeightClass::Middleweight);
58 let system_prompt = manifest.effective_system_prompt();
59
60 FighterManifest {
61 name: manifest.name.clone(),
62 description: format!("Autonomous gorilla: {}", manifest.name),
63 model,
64 system_prompt,
65 capabilities,
66 weight_class,
67 tenant_id: None,
68 }
69}
70
71pub async fn run_gorilla_tick(
74 gorilla_id: GorillaId,
75 manifest: &GorillaManifest,
76 default_model: &ModelConfig,
77 memory: &Arc<MemorySubstrate>,
78 driver: &Arc<dyn LlmDriver>,
79) -> PunchResult<FighterLoopResult> {
80 let fighter_manifest = fighter_manifest_from_gorilla(manifest, default_model);
81 let gorilla_name = &manifest.name;
82 let system_prompt = fighter_manifest.system_prompt.clone();
83
84 let autonomous_prompt = format!(
86 "[AUTONOMOUS TICK] You are {}. Review your memory, check your goals, and take the next action. {}",
87 gorilla_name, system_prompt
88 );
89
90 let fighter_id = FighterId::new();
92
93 if let Err(e) = memory
95 .save_fighter(
96 &fighter_id,
97 &fighter_manifest,
98 punch_types::FighterStatus::Idle,
99 )
100 .await
101 {
102 warn!(gorilla_id = %gorilla_id, error = %e, "failed to persist gorilla fighter");
103 }
104
105 let bout_id = memory.create_bout(&fighter_id).await?;
107
108 let available_tools = tools_for_capabilities(&fighter_manifest.capabilities);
109
110 let params = FighterLoopParams {
111 manifest: fighter_manifest,
112 user_message: autonomous_prompt,
113 bout_id,
114 fighter_id,
115 memory: Arc::clone(memory),
116 driver: Arc::clone(driver),
117 available_tools,
118 max_iterations: Some(10),
119 context_window: None,
120 tool_timeout_secs: None,
121 coordinator: None,
122 approval_engine: None,
123 sandbox: None,
124 mcp_clients: None,
125 };
126
127 run_fighter_loop(params).await
128}
129
130impl BackgroundExecutor {
131 pub fn new() -> Self {
133 let (shutdown_tx, shutdown_rx) = watch::channel(false);
134 Self {
135 tasks: DashMap::new(),
136 llm_semaphore: Arc::new(Semaphore::new(DEFAULT_LLM_CONCURRENCY)),
137 _shutdown_tx: shutdown_tx,
138 shutdown_rx,
139 }
140 }
141
142 pub fn with_shutdown(
144 shutdown_tx: watch::Sender<bool>,
145 shutdown_rx: watch::Receiver<bool>,
146 ) -> Self {
147 Self {
148 tasks: DashMap::new(),
149 llm_semaphore: Arc::new(Semaphore::new(DEFAULT_LLM_CONCURRENCY)),
150 _shutdown_tx: shutdown_tx,
151 shutdown_rx,
152 }
153 }
154
155 pub fn parse_schedule(schedule: &str) -> Option<std::time::Duration> {
162 let s = schedule.trim().to_lowercase();
163
164 if let Some(duration) = Self::parse_human_schedule(&s) {
166 return Some(duration);
167 }
168
169 if let Some(duration) = Self::parse_cron_schedule(&s) {
171 return Some(duration);
172 }
173
174 s.parse::<u64>().ok().map(std::time::Duration::from_secs)
176 }
177
178 fn parse_human_schedule(s: &str) -> Option<std::time::Duration> {
180 let s = s.strip_prefix("every ").unwrap_or(s);
181 let s = s.trim();
182
183 if let Some(num_str) = s.strip_suffix('s') {
184 num_str
185 .trim()
186 .parse::<u64>()
187 .ok()
188 .map(std::time::Duration::from_secs)
189 } else if let Some(num_str) = s.strip_suffix('m') {
190 num_str
191 .trim()
192 .parse::<u64>()
193 .ok()
194 .map(|m| std::time::Duration::from_secs(m * 60))
195 } else if let Some(num_str) = s.strip_suffix('h') {
196 num_str
197 .trim()
198 .parse::<u64>()
199 .ok()
200 .map(|h| std::time::Duration::from_secs(h * 3600))
201 } else if let Some(num_str) = s.strip_suffix('d') {
202 num_str
203 .trim()
204 .parse::<u64>()
205 .ok()
206 .map(|d| std::time::Duration::from_secs(d * 86400))
207 } else {
208 None
209 }
210 }
211
212 fn parse_cron_schedule(s: &str) -> Option<std::time::Duration> {
220 let fields: Vec<&str> = s.split_whitespace().collect();
221 if fields.len() != 5 {
222 return None;
223 }
224
225 let (minute, hour, day, _month, _dow) =
226 (fields[0], fields[1], fields[2], fields[3], fields[4]);
227
228 if let Some(step) = minute.strip_prefix("*/")
230 && hour == "*"
231 && day == "*"
232 && let Ok(n) = step.parse::<u64>()
233 {
234 return Some(std::time::Duration::from_secs(n * 60));
235 }
236
237 if minute == "0"
239 && let Some(step) = hour.strip_prefix("*/")
240 && day == "*"
241 && let Ok(n) = step.parse::<u64>()
242 {
243 return Some(std::time::Duration::from_secs(n * 3600));
244 }
245
246 if minute == "0"
248 && hour == "0"
249 && let Some(step) = day.strip_prefix("*/")
250 && let Ok(n) = step.parse::<u64>()
251 {
252 return Some(std::time::Duration::from_secs(n * 86400));
253 }
254
255 if minute == "0" && hour == "0" && day == "*" {
257 return Some(std::time::Duration::from_secs(86400));
258 }
259
260 if minute == "0" && day == "*" && hour.parse::<u64>().is_ok() {
262 return Some(std::time::Duration::from_secs(86400));
263 }
264
265 None
266 }
267
268 pub fn start_gorilla(
277 &self,
278 id: GorillaId,
279 manifest: GorillaManifest,
280 default_model: ModelConfig,
281 memory: Arc<MemorySubstrate>,
282 driver: Arc<dyn LlmDriver>,
283 ) -> PunchResult<()> {
284 if self.tasks.contains_key(&id) {
285 return Err(punch_types::PunchError::Gorilla(format!(
286 "gorilla {} is already running",
287 id
288 )));
289 }
290
291 let interval = Self::parse_schedule(&manifest.schedule).unwrap_or_else(|| {
292 warn!(
293 gorilla_id = %id,
294 schedule = %manifest.schedule,
295 "could not parse schedule, defaulting to 5m"
296 );
297 std::time::Duration::from_secs(300)
298 });
299
300 let semaphore = Arc::clone(&self.llm_semaphore);
301 let mut shutdown_rx = self.shutdown_rx.clone();
302 let gorilla_name = manifest.name.clone();
303
304 let handle = tokio::spawn(async move {
305 info!(
306 gorilla_id = %id,
307 name = %gorilla_name,
308 interval_secs = interval.as_secs(),
309 "gorilla background task started"
310 );
311
312 let mut tasks_completed: u64 = 0;
313 let mut error_count: u64 = 0;
314
315 loop {
316 tokio::select! {
318 _ = tokio::time::sleep(interval) => {},
319 _ = shutdown_rx.changed() => {
320 if *shutdown_rx.borrow() {
321 info!(gorilla_id = %id, "gorilla received shutdown signal");
322 break;
323 }
324 }
325 }
326
327 if *shutdown_rx.borrow() {
329 break;
330 }
331
332 let _permit = match semaphore.acquire().await {
334 Ok(permit) => permit,
335 Err(_) => {
336 warn!(gorilla_id = %id, "semaphore closed, stopping gorilla");
337 break;
338 }
339 };
340
341 match run_gorilla_tick(id, &manifest, &default_model, &memory, &driver).await {
342 Ok(result) => {
343 tasks_completed += 1;
344 info!(
345 gorilla_id = %id,
346 tasks_completed,
347 tokens = result.usage.total(),
348 "gorilla tick completed successfully"
349 );
350 }
351 Err(e) => {
352 error_count += 1;
353 error!(
354 gorilla_id = %id,
355 error = %e,
356 error_count,
357 "gorilla tick failed"
358 );
359 }
360 }
361 }
362
363 info!(
364 gorilla_id = %id,
365 tasks_completed,
366 "gorilla background task stopped"
367 );
368 });
369
370 self.tasks.insert(
371 id,
372 GorillaTask {
373 handle,
374 started_at: Utc::now(),
375 },
376 );
377
378 Ok(())
379 }
380
381 pub fn stop_gorilla(&self, id: &GorillaId) -> bool {
383 if let Some((_, task)) = self.tasks.remove(id) {
384 task.handle.abort();
385 info!(gorilla_id = %id, "gorilla task stopped");
386 true
387 } else {
388 false
389 }
390 }
391
392 pub fn is_running(&self, id: &GorillaId) -> bool {
394 self.tasks.contains_key(id)
395 }
396
397 pub fn list_running(&self) -> Vec<GorillaId> {
399 self.tasks.iter().map(|entry| *entry.key()).collect()
400 }
401
402 pub fn shutdown_all(&self) {
404 let ids: Vec<GorillaId> = self.tasks.iter().map(|e| *e.key()).collect();
405 for id in &ids {
406 if let Some((_, task)) = self.tasks.remove(id) {
407 task.handle.abort();
408 }
409 }
410 info!(count = ids.len(), "all gorilla tasks shut down");
411 }
412
413 pub fn running_count(&self) -> usize {
415 self.tasks.len()
416 }
417}
418
419impl Default for BackgroundExecutor {
420 fn default() -> Self {
421 Self::new()
422 }
423}
424
425#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn parse_schedule_seconds() {
435 assert_eq!(
436 BackgroundExecutor::parse_schedule("every 30s"),
437 Some(std::time::Duration::from_secs(30))
438 );
439 }
440
441 #[test]
442 fn parse_schedule_minutes() {
443 assert_eq!(
444 BackgroundExecutor::parse_schedule("every 5m"),
445 Some(std::time::Duration::from_secs(300))
446 );
447 }
448
449 #[test]
450 fn parse_schedule_hours() {
451 assert_eq!(
452 BackgroundExecutor::parse_schedule("every 1h"),
453 Some(std::time::Duration::from_secs(3600))
454 );
455 }
456
457 #[test]
458 fn parse_schedule_days() {
459 assert_eq!(
460 BackgroundExecutor::parse_schedule("every 1d"),
461 Some(std::time::Duration::from_secs(86400))
462 );
463 }
464
465 #[test]
466 fn parse_schedule_invalid() {
467 assert_eq!(BackgroundExecutor::parse_schedule("invalid"), None);
468 }
469
470 #[test]
471 fn parse_schedule_cron_every_30_minutes() {
472 assert_eq!(
473 BackgroundExecutor::parse_schedule("*/30 * * * *"),
474 Some(std::time::Duration::from_secs(1800))
475 );
476 }
477
478 #[test]
479 fn parse_schedule_cron_every_6_hours() {
480 assert_eq!(
481 BackgroundExecutor::parse_schedule("0 */6 * * *"),
482 Some(std::time::Duration::from_secs(21600))
483 );
484 }
485
486 #[test]
487 fn parse_schedule_cron_every_2_hours() {
488 assert_eq!(
489 BackgroundExecutor::parse_schedule("0 */2 * * *"),
490 Some(std::time::Duration::from_secs(7200))
491 );
492 }
493
494 #[test]
495 fn parse_schedule_cron_every_2_days() {
496 assert_eq!(
497 BackgroundExecutor::parse_schedule("0 0 */2 * *"),
498 Some(std::time::Duration::from_secs(172800))
499 );
500 }
501
502 #[test]
503 fn parse_schedule_cron_daily() {
504 assert_eq!(
505 BackgroundExecutor::parse_schedule("0 0 * * *"),
506 Some(std::time::Duration::from_secs(86400))
507 );
508 }
509
510 #[test]
511 fn parse_schedule_cron_every_3_hours() {
512 assert_eq!(
513 BackgroundExecutor::parse_schedule("0 */3 * * *"),
514 Some(std::time::Duration::from_secs(10800))
515 );
516 }
517
518 #[test]
519 fn parse_schedule_cron_every_4_hours() {
520 assert_eq!(
521 BackgroundExecutor::parse_schedule("0 */4 * * *"),
522 Some(std::time::Duration::from_secs(14400))
523 );
524 }
525
526 #[tokio::test]
527 async fn start_and_stop_gorilla() {
528 let executor = BackgroundExecutor::new();
529 let id = GorillaId::new();
530 let _manifest = GorillaManifest {
531 name: "test-gorilla".to_string(),
532 description: "test".to_string(),
533 schedule: "every 30s".to_string(),
534 moves_required: Vec::new(),
535 settings_schema: None,
536 dashboard_metrics: Vec::new(),
537 system_prompt: None,
538 model: None,
539 capabilities: Vec::new(),
540 weight_class: None,
541 };
542
543 let handle = tokio::spawn(async {
546 futures::future::pending::<()>().await;
547 });
548
549 executor.tasks.insert(
550 id,
551 GorillaTask {
552 handle,
553 started_at: Utc::now(),
554 },
555 );
556
557 assert_eq!(executor.running_count(), 1);
558 assert!(executor.list_running().contains(&id));
559
560 assert!(executor.stop_gorilla(&id));
561 assert_eq!(executor.running_count(), 0);
562 }
563
564 #[tokio::test]
565 async fn shutdown_all_stops_everything() {
566 let executor = BackgroundExecutor::new();
567
568 for _ in 0..3 {
569 let id = GorillaId::new();
570 let handle = tokio::spawn(async {
571 futures::future::pending::<()>().await;
572 });
573 executor.tasks.insert(
574 id,
575 GorillaTask {
576 handle,
577 started_at: Utc::now(),
578 },
579 );
580 }
581
582 assert_eq!(executor.running_count(), 3);
583 executor.shutdown_all();
584 assert_eq!(executor.running_count(), 0);
585 }
586
587 #[tokio::test]
588 async fn stop_nonexistent_gorilla_returns_false() {
589 let executor = BackgroundExecutor::new();
590 let id = GorillaId::new();
591 assert!(!executor.stop_gorilla(&id));
592 }
593
594 #[test]
595 fn parse_schedule_raw_seconds() {
596 assert_eq!(
597 BackgroundExecutor::parse_schedule("60"),
598 Some(std::time::Duration::from_secs(60))
599 );
600 }
601
602 #[test]
603 fn parse_schedule_with_whitespace() {
604 assert_eq!(
605 BackgroundExecutor::parse_schedule(" every 10s "),
606 Some(std::time::Duration::from_secs(10))
607 );
608 }
609
610 #[test]
611 fn parse_schedule_case_insensitive() {
612 assert_eq!(
613 BackgroundExecutor::parse_schedule("Every 2H"),
614 Some(std::time::Duration::from_secs(7200))
615 );
616 }
617
618 #[test]
619 fn parse_schedule_empty_string() {
620 assert_eq!(BackgroundExecutor::parse_schedule(""), None);
621 }
622
623 #[test]
624 fn parse_schedule_just_prefix() {
625 assert_eq!(BackgroundExecutor::parse_schedule("every "), None);
626 }
627
628 #[test]
629 fn default_creates_executor() {
630 let executor = BackgroundExecutor::default();
631 assert_eq!(executor.running_count(), 0);
632 assert!(executor.list_running().is_empty());
633 }
634
635 #[tokio::test]
636 async fn is_running_returns_correct_state() {
637 let executor = BackgroundExecutor::new();
638 let id = GorillaId::new();
639
640 assert!(!executor.is_running(&id));
641
642 let handle = tokio::spawn(async {
643 futures::future::pending::<()>().await;
644 });
645 executor.tasks.insert(
646 id,
647 GorillaTask {
648 handle,
649 started_at: Utc::now(),
650 },
651 );
652
653 assert!(executor.is_running(&id));
654 executor.stop_gorilla(&id);
655 assert!(!executor.is_running(&id));
656 }
657
658 #[tokio::test]
659 async fn multiple_gorillas_tracked_independently() {
660 let executor = BackgroundExecutor::new();
661 let ids: Vec<GorillaId> = (0..5).map(|_| GorillaId::new()).collect();
662
663 for &id in &ids {
664 let handle = tokio::spawn(async {
665 futures::future::pending::<()>().await;
666 });
667 executor.tasks.insert(
668 id,
669 GorillaTask {
670 handle,
671 started_at: Utc::now(),
672 },
673 );
674 }
675
676 assert_eq!(executor.running_count(), 5);
677
678 executor.stop_gorilla(&ids[0]);
680 executor.stop_gorilla(&ids[1]);
681 assert_eq!(executor.running_count(), 3);
682
683 for &id in &ids[2..] {
685 assert!(executor.is_running(&id));
686 }
687
688 executor.shutdown_all();
689 assert_eq!(executor.running_count(), 0);
690 }
691
692 #[tokio::test]
693 async fn with_shutdown_receives_shutdown_signal() {
694 let (tx, rx) = watch::channel(false);
695 let executor = BackgroundExecutor::with_shutdown(tx.clone(), rx);
696
697 let id = GorillaId::new();
698 let handle = tokio::spawn(async {
699 futures::future::pending::<()>().await;
700 });
701 executor.tasks.insert(
702 id,
703 GorillaTask {
704 handle,
705 started_at: Utc::now(),
706 },
707 );
708
709 assert_eq!(executor.running_count(), 1);
710 executor.shutdown_all();
711 assert_eq!(executor.running_count(), 0);
712 }
713
714 #[test]
715 fn fighter_manifest_from_gorilla_uses_default_model() {
716 use punch_types::{ModelConfig, Provider};
717
718 let manifest = GorillaManifest {
719 name: "test-gorilla".to_string(),
720 description: "A test gorilla".to_string(),
721 schedule: "every 30s".to_string(),
722 moves_required: Vec::new(),
723 settings_schema: None,
724 dashboard_metrics: Vec::new(),
725 system_prompt: Some("Custom prompt".to_string()),
726 model: None,
727 capabilities: Vec::new(),
728 weight_class: None,
729 };
730
731 let default_model = ModelConfig {
732 provider: Provider::Anthropic,
733 model: "claude-sonnet-4-20250514".to_string(),
734 api_key_env: None,
735 base_url: None,
736 max_tokens: Some(4096),
737 temperature: Some(0.7),
738 };
739
740 let fighter = fighter_manifest_from_gorilla(&manifest, &default_model);
741 assert_eq!(fighter.name, "test-gorilla");
742 assert_eq!(fighter.model.model, "claude-sonnet-4-20250514");
743 assert_eq!(fighter.system_prompt, "Custom prompt");
744 assert_eq!(fighter.weight_class, punch_types::WeightClass::Middleweight);
745 }
746
747 #[test]
748 fn fighter_manifest_from_gorilla_uses_gorilla_model_if_set() {
749 use punch_types::{ModelConfig, Provider};
750
751 let gorilla_model = ModelConfig {
752 provider: Provider::OpenAI,
753 model: "gpt-4o".to_string(),
754 api_key_env: None,
755 base_url: None,
756 max_tokens: Some(8192),
757 temperature: Some(0.5),
758 };
759
760 let manifest = GorillaManifest {
761 name: "smart-gorilla".to_string(),
762 description: "Uses its own model".to_string(),
763 schedule: "every 1h".to_string(),
764 moves_required: Vec::new(),
765 settings_schema: None,
766 dashboard_metrics: Vec::new(),
767 system_prompt: None,
768 model: Some(gorilla_model),
769 capabilities: Vec::new(),
770 weight_class: Some(punch_types::WeightClass::Heavyweight),
771 };
772
773 let default_model = ModelConfig {
774 provider: Provider::Anthropic,
775 model: "claude-sonnet-4-20250514".to_string(),
776 api_key_env: None,
777 base_url: None,
778 max_tokens: Some(4096),
779 temperature: Some(0.7),
780 };
781
782 let fighter = fighter_manifest_from_gorilla(&manifest, &default_model);
783 assert_eq!(fighter.model.model, "gpt-4o");
784 assert_eq!(fighter.weight_class, punch_types::WeightClass::Heavyweight);
785 assert_eq!(fighter.system_prompt, "Uses its own model");
787 }
788
789 #[tokio::test]
790 async fn list_running_returns_all_ids() {
791 let executor = BackgroundExecutor::new();
792 let mut expected_ids = Vec::new();
793
794 for _ in 0..3 {
795 let id = GorillaId::new();
796 expected_ids.push(id);
797 let handle = tokio::spawn(async {
798 futures::future::pending::<()>().await;
799 });
800 executor.tasks.insert(
801 id,
802 GorillaTask {
803 handle,
804 started_at: Utc::now(),
805 },
806 );
807 }
808
809 let running = executor.list_running();
810 assert_eq!(running.len(), 3);
811 for id in &expected_ids {
812 assert!(running.contains(id));
813 }
814
815 executor.shutdown_all();
816 }
817}