1use std::sync::Arc;
9
10use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use tokio::sync::{Semaphore, watch};
13use tokio::task::JoinHandle;
14use tracing::{error, info, warn};
15
16use punch_memory::MemorySubstrate;
17use punch_runtime::{
18 FighterLoopParams, FighterLoopResult, LlmDriver, run_fighter_loop, tools_for_capabilities,
19};
20use punch_types::{
21 FighterId, FighterManifest, GorillaId, GorillaManifest, ModelConfig, PunchResult, WeightClass,
22};
23
24const DEFAULT_LLM_CONCURRENCY: usize = 3;
26
27struct GorillaTask {
29 handle: JoinHandle<()>,
30 #[allow(dead_code)]
31 started_at: DateTime<Utc>,
32}
33
34pub struct BackgroundExecutor {
36 tasks: DashMap<GorillaId, GorillaTask>,
38 llm_semaphore: Arc<Semaphore>,
40 _shutdown_tx: watch::Sender<bool>,
42 shutdown_rx: watch::Receiver<bool>,
44}
45
46pub fn fighter_manifest_from_gorilla(
49 manifest: &GorillaManifest,
50 default_model: &ModelConfig,
51) -> FighterManifest {
52 let model = manifest
53 .model
54 .clone()
55 .unwrap_or_else(|| default_model.clone());
56 let capabilities = manifest.effective_capabilities();
57 let weight_class = manifest.weight_class.unwrap_or(WeightClass::Middleweight);
58 let system_prompt = manifest.effective_system_prompt();
59
60 FighterManifest {
61 name: manifest.name.clone(),
62 description: format!("Autonomous gorilla: {}", manifest.name),
63 model,
64 system_prompt,
65 capabilities,
66 weight_class,
67 tenant_id: None,
68 }
69}
70
71pub async fn run_gorilla_tick(
74 gorilla_id: GorillaId,
75 manifest: &GorillaManifest,
76 default_model: &ModelConfig,
77 memory: &Arc<MemorySubstrate>,
78 driver: &Arc<dyn LlmDriver>,
79) -> PunchResult<FighterLoopResult> {
80 let fighter_manifest = fighter_manifest_from_gorilla(manifest, default_model);
81 let gorilla_name = &manifest.name;
82 let system_prompt = fighter_manifest.system_prompt.clone();
83
84 let autonomous_prompt = format!(
86 "[AUTONOMOUS TICK] You are {}. Review your memory, check your goals, and take the next action. {}",
87 gorilla_name, system_prompt
88 );
89
90 let fighter_id = FighterId::new();
92
93 if let Err(e) = memory
95 .save_fighter(
96 &fighter_id,
97 &fighter_manifest,
98 punch_types::FighterStatus::Idle,
99 )
100 .await
101 {
102 warn!(gorilla_id = %gorilla_id, error = %e, "failed to persist gorilla fighter");
103 }
104
105 let bout_id = memory.create_bout(&fighter_id).await?;
107
108 let available_tools = tools_for_capabilities(&fighter_manifest.capabilities);
109
110 let params = FighterLoopParams {
111 manifest: fighter_manifest,
112 user_message: autonomous_prompt,
113 bout_id,
114 fighter_id,
115 memory: Arc::clone(memory),
116 driver: Arc::clone(driver),
117 available_tools,
118 max_iterations: Some(10),
119 context_window: None,
120 tool_timeout_secs: None,
121 coordinator: None,
122 approval_engine: None,
123 sandbox: None,
124 };
125
126 run_fighter_loop(params).await
127}
128
129impl BackgroundExecutor {
130 pub fn new() -> Self {
132 let (shutdown_tx, shutdown_rx) = watch::channel(false);
133 Self {
134 tasks: DashMap::new(),
135 llm_semaphore: Arc::new(Semaphore::new(DEFAULT_LLM_CONCURRENCY)),
136 _shutdown_tx: shutdown_tx,
137 shutdown_rx,
138 }
139 }
140
141 pub fn with_shutdown(
143 shutdown_tx: watch::Sender<bool>,
144 shutdown_rx: watch::Receiver<bool>,
145 ) -> Self {
146 Self {
147 tasks: DashMap::new(),
148 llm_semaphore: Arc::new(Semaphore::new(DEFAULT_LLM_CONCURRENCY)),
149 _shutdown_tx: shutdown_tx,
150 shutdown_rx,
151 }
152 }
153
154 pub fn parse_schedule(schedule: &str) -> Option<std::time::Duration> {
161 let s = schedule.trim().to_lowercase();
162
163 if let Some(duration) = Self::parse_human_schedule(&s) {
165 return Some(duration);
166 }
167
168 if let Some(duration) = Self::parse_cron_schedule(&s) {
170 return Some(duration);
171 }
172
173 s.parse::<u64>().ok().map(std::time::Duration::from_secs)
175 }
176
177 fn parse_human_schedule(s: &str) -> Option<std::time::Duration> {
179 let s = s.strip_prefix("every ").unwrap_or(s);
180 let s = s.trim();
181
182 if let Some(num_str) = s.strip_suffix('s') {
183 num_str
184 .trim()
185 .parse::<u64>()
186 .ok()
187 .map(std::time::Duration::from_secs)
188 } else if let Some(num_str) = s.strip_suffix('m') {
189 num_str
190 .trim()
191 .parse::<u64>()
192 .ok()
193 .map(|m| std::time::Duration::from_secs(m * 60))
194 } else if let Some(num_str) = s.strip_suffix('h') {
195 num_str
196 .trim()
197 .parse::<u64>()
198 .ok()
199 .map(|h| std::time::Duration::from_secs(h * 3600))
200 } else if let Some(num_str) = s.strip_suffix('d') {
201 num_str
202 .trim()
203 .parse::<u64>()
204 .ok()
205 .map(|d| std::time::Duration::from_secs(d * 86400))
206 } else {
207 None
208 }
209 }
210
211 fn parse_cron_schedule(s: &str) -> Option<std::time::Duration> {
219 let fields: Vec<&str> = s.split_whitespace().collect();
220 if fields.len() != 5 {
221 return None;
222 }
223
224 let (minute, hour, day, _month, _dow) =
225 (fields[0], fields[1], fields[2], fields[3], fields[4]);
226
227 if let Some(step) = minute.strip_prefix("*/")
229 && hour == "*"
230 && day == "*"
231 && let Ok(n) = step.parse::<u64>()
232 {
233 return Some(std::time::Duration::from_secs(n * 60));
234 }
235
236 if minute == "0"
238 && let Some(step) = hour.strip_prefix("*/")
239 && day == "*"
240 && let Ok(n) = step.parse::<u64>()
241 {
242 return Some(std::time::Duration::from_secs(n * 3600));
243 }
244
245 if minute == "0"
247 && hour == "0"
248 && let Some(step) = day.strip_prefix("*/")
249 && let Ok(n) = step.parse::<u64>()
250 {
251 return Some(std::time::Duration::from_secs(n * 86400));
252 }
253
254 if minute == "0" && hour == "0" && day == "*" {
256 return Some(std::time::Duration::from_secs(86400));
257 }
258
259 if minute == "0" && day == "*" && hour.parse::<u64>().is_ok() {
261 return Some(std::time::Duration::from_secs(86400));
262 }
263
264 None
265 }
266
267 pub fn start_gorilla(
276 &self,
277 id: GorillaId,
278 manifest: GorillaManifest,
279 default_model: ModelConfig,
280 memory: Arc<MemorySubstrate>,
281 driver: Arc<dyn LlmDriver>,
282 ) -> PunchResult<()> {
283 if self.tasks.contains_key(&id) {
284 return Err(punch_types::PunchError::Gorilla(format!(
285 "gorilla {} is already running",
286 id
287 )));
288 }
289
290 let interval = Self::parse_schedule(&manifest.schedule).unwrap_or_else(|| {
291 warn!(
292 gorilla_id = %id,
293 schedule = %manifest.schedule,
294 "could not parse schedule, defaulting to 5m"
295 );
296 std::time::Duration::from_secs(300)
297 });
298
299 let semaphore = Arc::clone(&self.llm_semaphore);
300 let mut shutdown_rx = self.shutdown_rx.clone();
301 let gorilla_name = manifest.name.clone();
302
303 let handle = tokio::spawn(async move {
304 info!(
305 gorilla_id = %id,
306 name = %gorilla_name,
307 interval_secs = interval.as_secs(),
308 "gorilla background task started"
309 );
310
311 let mut tasks_completed: u64 = 0;
312 let mut error_count: u64 = 0;
313
314 loop {
315 tokio::select! {
317 _ = tokio::time::sleep(interval) => {},
318 _ = shutdown_rx.changed() => {
319 if *shutdown_rx.borrow() {
320 info!(gorilla_id = %id, "gorilla received shutdown signal");
321 break;
322 }
323 }
324 }
325
326 if *shutdown_rx.borrow() {
328 break;
329 }
330
331 let _permit = match semaphore.acquire().await {
333 Ok(permit) => permit,
334 Err(_) => {
335 warn!(gorilla_id = %id, "semaphore closed, stopping gorilla");
336 break;
337 }
338 };
339
340 match run_gorilla_tick(id, &manifest, &default_model, &memory, &driver).await {
341 Ok(result) => {
342 tasks_completed += 1;
343 info!(
344 gorilla_id = %id,
345 tasks_completed,
346 tokens = result.usage.total(),
347 "gorilla tick completed successfully"
348 );
349 }
350 Err(e) => {
351 error_count += 1;
352 error!(
353 gorilla_id = %id,
354 error = %e,
355 error_count,
356 "gorilla tick failed"
357 );
358 }
359 }
360 }
361
362 info!(
363 gorilla_id = %id,
364 tasks_completed,
365 "gorilla background task stopped"
366 );
367 });
368
369 self.tasks.insert(
370 id,
371 GorillaTask {
372 handle,
373 started_at: Utc::now(),
374 },
375 );
376
377 Ok(())
378 }
379
380 pub fn stop_gorilla(&self, id: &GorillaId) -> bool {
382 if let Some((_, task)) = self.tasks.remove(id) {
383 task.handle.abort();
384 info!(gorilla_id = %id, "gorilla task stopped");
385 true
386 } else {
387 false
388 }
389 }
390
391 pub fn is_running(&self, id: &GorillaId) -> bool {
393 self.tasks.contains_key(id)
394 }
395
396 pub fn list_running(&self) -> Vec<GorillaId> {
398 self.tasks.iter().map(|entry| *entry.key()).collect()
399 }
400
401 pub fn shutdown_all(&self) {
403 let ids: Vec<GorillaId> = self.tasks.iter().map(|e| *e.key()).collect();
404 for id in &ids {
405 if let Some((_, task)) = self.tasks.remove(id) {
406 task.handle.abort();
407 }
408 }
409 info!(count = ids.len(), "all gorilla tasks shut down");
410 }
411
412 pub fn running_count(&self) -> usize {
414 self.tasks.len()
415 }
416}
417
418impl Default for BackgroundExecutor {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
433 fn parse_schedule_seconds() {
434 assert_eq!(
435 BackgroundExecutor::parse_schedule("every 30s"),
436 Some(std::time::Duration::from_secs(30))
437 );
438 }
439
440 #[test]
441 fn parse_schedule_minutes() {
442 assert_eq!(
443 BackgroundExecutor::parse_schedule("every 5m"),
444 Some(std::time::Duration::from_secs(300))
445 );
446 }
447
448 #[test]
449 fn parse_schedule_hours() {
450 assert_eq!(
451 BackgroundExecutor::parse_schedule("every 1h"),
452 Some(std::time::Duration::from_secs(3600))
453 );
454 }
455
456 #[test]
457 fn parse_schedule_days() {
458 assert_eq!(
459 BackgroundExecutor::parse_schedule("every 1d"),
460 Some(std::time::Duration::from_secs(86400))
461 );
462 }
463
464 #[test]
465 fn parse_schedule_invalid() {
466 assert_eq!(BackgroundExecutor::parse_schedule("invalid"), None);
467 }
468
469 #[test]
470 fn parse_schedule_cron_every_30_minutes() {
471 assert_eq!(
472 BackgroundExecutor::parse_schedule("*/30 * * * *"),
473 Some(std::time::Duration::from_secs(1800))
474 );
475 }
476
477 #[test]
478 fn parse_schedule_cron_every_6_hours() {
479 assert_eq!(
480 BackgroundExecutor::parse_schedule("0 */6 * * *"),
481 Some(std::time::Duration::from_secs(21600))
482 );
483 }
484
485 #[test]
486 fn parse_schedule_cron_every_2_hours() {
487 assert_eq!(
488 BackgroundExecutor::parse_schedule("0 */2 * * *"),
489 Some(std::time::Duration::from_secs(7200))
490 );
491 }
492
493 #[test]
494 fn parse_schedule_cron_every_2_days() {
495 assert_eq!(
496 BackgroundExecutor::parse_schedule("0 0 */2 * *"),
497 Some(std::time::Duration::from_secs(172800))
498 );
499 }
500
501 #[test]
502 fn parse_schedule_cron_daily() {
503 assert_eq!(
504 BackgroundExecutor::parse_schedule("0 0 * * *"),
505 Some(std::time::Duration::from_secs(86400))
506 );
507 }
508
509 #[test]
510 fn parse_schedule_cron_every_3_hours() {
511 assert_eq!(
512 BackgroundExecutor::parse_schedule("0 */3 * * *"),
513 Some(std::time::Duration::from_secs(10800))
514 );
515 }
516
517 #[test]
518 fn parse_schedule_cron_every_4_hours() {
519 assert_eq!(
520 BackgroundExecutor::parse_schedule("0 */4 * * *"),
521 Some(std::time::Duration::from_secs(14400))
522 );
523 }
524
525 #[tokio::test]
526 async fn start_and_stop_gorilla() {
527 let executor = BackgroundExecutor::new();
528 let id = GorillaId::new();
529 let _manifest = GorillaManifest {
530 name: "test-gorilla".to_string(),
531 description: "test".to_string(),
532 schedule: "every 30s".to_string(),
533 moves_required: Vec::new(),
534 settings_schema: None,
535 dashboard_metrics: Vec::new(),
536 system_prompt: None,
537 model: None,
538 capabilities: Vec::new(),
539 weight_class: None,
540 };
541
542 let handle = tokio::spawn(async {
545 futures::future::pending::<()>().await;
546 });
547
548 executor.tasks.insert(
549 id,
550 GorillaTask {
551 handle,
552 started_at: Utc::now(),
553 },
554 );
555
556 assert_eq!(executor.running_count(), 1);
557 assert!(executor.list_running().contains(&id));
558
559 assert!(executor.stop_gorilla(&id));
560 assert_eq!(executor.running_count(), 0);
561 }
562
563 #[tokio::test]
564 async fn shutdown_all_stops_everything() {
565 let executor = BackgroundExecutor::new();
566
567 for _ in 0..3 {
568 let id = GorillaId::new();
569 let handle = tokio::spawn(async {
570 futures::future::pending::<()>().await;
571 });
572 executor.tasks.insert(
573 id,
574 GorillaTask {
575 handle,
576 started_at: Utc::now(),
577 },
578 );
579 }
580
581 assert_eq!(executor.running_count(), 3);
582 executor.shutdown_all();
583 assert_eq!(executor.running_count(), 0);
584 }
585
586 #[tokio::test]
587 async fn stop_nonexistent_gorilla_returns_false() {
588 let executor = BackgroundExecutor::new();
589 let id = GorillaId::new();
590 assert!(!executor.stop_gorilla(&id));
591 }
592
593 #[test]
594 fn parse_schedule_raw_seconds() {
595 assert_eq!(
596 BackgroundExecutor::parse_schedule("60"),
597 Some(std::time::Duration::from_secs(60))
598 );
599 }
600
601 #[test]
602 fn parse_schedule_with_whitespace() {
603 assert_eq!(
604 BackgroundExecutor::parse_schedule(" every 10s "),
605 Some(std::time::Duration::from_secs(10))
606 );
607 }
608
609 #[test]
610 fn parse_schedule_case_insensitive() {
611 assert_eq!(
612 BackgroundExecutor::parse_schedule("Every 2H"),
613 Some(std::time::Duration::from_secs(7200))
614 );
615 }
616
617 #[test]
618 fn parse_schedule_empty_string() {
619 assert_eq!(BackgroundExecutor::parse_schedule(""), None);
620 }
621
622 #[test]
623 fn parse_schedule_just_prefix() {
624 assert_eq!(BackgroundExecutor::parse_schedule("every "), None);
625 }
626
627 #[test]
628 fn default_creates_executor() {
629 let executor = BackgroundExecutor::default();
630 assert_eq!(executor.running_count(), 0);
631 assert!(executor.list_running().is_empty());
632 }
633
634 #[tokio::test]
635 async fn is_running_returns_correct_state() {
636 let executor = BackgroundExecutor::new();
637 let id = GorillaId::new();
638
639 assert!(!executor.is_running(&id));
640
641 let handle = tokio::spawn(async {
642 futures::future::pending::<()>().await;
643 });
644 executor.tasks.insert(
645 id,
646 GorillaTask {
647 handle,
648 started_at: Utc::now(),
649 },
650 );
651
652 assert!(executor.is_running(&id));
653 executor.stop_gorilla(&id);
654 assert!(!executor.is_running(&id));
655 }
656
657 #[tokio::test]
658 async fn multiple_gorillas_tracked_independently() {
659 let executor = BackgroundExecutor::new();
660 let ids: Vec<GorillaId> = (0..5).map(|_| GorillaId::new()).collect();
661
662 for &id in &ids {
663 let handle = tokio::spawn(async {
664 futures::future::pending::<()>().await;
665 });
666 executor.tasks.insert(
667 id,
668 GorillaTask {
669 handle,
670 started_at: Utc::now(),
671 },
672 );
673 }
674
675 assert_eq!(executor.running_count(), 5);
676
677 executor.stop_gorilla(&ids[0]);
679 executor.stop_gorilla(&ids[1]);
680 assert_eq!(executor.running_count(), 3);
681
682 for &id in &ids[2..] {
684 assert!(executor.is_running(&id));
685 }
686
687 executor.shutdown_all();
688 assert_eq!(executor.running_count(), 0);
689 }
690
691 #[tokio::test]
692 async fn with_shutdown_receives_shutdown_signal() {
693 let (tx, rx) = watch::channel(false);
694 let executor = BackgroundExecutor::with_shutdown(tx.clone(), rx);
695
696 let id = GorillaId::new();
697 let handle = tokio::spawn(async {
698 futures::future::pending::<()>().await;
699 });
700 executor.tasks.insert(
701 id,
702 GorillaTask {
703 handle,
704 started_at: Utc::now(),
705 },
706 );
707
708 assert_eq!(executor.running_count(), 1);
709 executor.shutdown_all();
710 assert_eq!(executor.running_count(), 0);
711 }
712
713 #[test]
714 fn fighter_manifest_from_gorilla_uses_default_model() {
715 use punch_types::{ModelConfig, Provider};
716
717 let manifest = GorillaManifest {
718 name: "test-gorilla".to_string(),
719 description: "A test gorilla".to_string(),
720 schedule: "every 30s".to_string(),
721 moves_required: Vec::new(),
722 settings_schema: None,
723 dashboard_metrics: Vec::new(),
724 system_prompt: Some("Custom prompt".to_string()),
725 model: None,
726 capabilities: Vec::new(),
727 weight_class: None,
728 };
729
730 let default_model = ModelConfig {
731 provider: Provider::Anthropic,
732 model: "claude-sonnet-4-20250514".to_string(),
733 api_key_env: None,
734 base_url: None,
735 max_tokens: Some(4096),
736 temperature: Some(0.7),
737 };
738
739 let fighter = fighter_manifest_from_gorilla(&manifest, &default_model);
740 assert_eq!(fighter.name, "test-gorilla");
741 assert_eq!(fighter.model.model, "claude-sonnet-4-20250514");
742 assert_eq!(fighter.system_prompt, "Custom prompt");
743 assert_eq!(fighter.weight_class, punch_types::WeightClass::Middleweight);
744 }
745
746 #[test]
747 fn fighter_manifest_from_gorilla_uses_gorilla_model_if_set() {
748 use punch_types::{ModelConfig, Provider};
749
750 let gorilla_model = ModelConfig {
751 provider: Provider::OpenAI,
752 model: "gpt-4o".to_string(),
753 api_key_env: None,
754 base_url: None,
755 max_tokens: Some(8192),
756 temperature: Some(0.5),
757 };
758
759 let manifest = GorillaManifest {
760 name: "smart-gorilla".to_string(),
761 description: "Uses its own model".to_string(),
762 schedule: "every 1h".to_string(),
763 moves_required: Vec::new(),
764 settings_schema: None,
765 dashboard_metrics: Vec::new(),
766 system_prompt: None,
767 model: Some(gorilla_model),
768 capabilities: Vec::new(),
769 weight_class: Some(punch_types::WeightClass::Heavyweight),
770 };
771
772 let default_model = ModelConfig {
773 provider: Provider::Anthropic,
774 model: "claude-sonnet-4-20250514".to_string(),
775 api_key_env: None,
776 base_url: None,
777 max_tokens: Some(4096),
778 temperature: Some(0.7),
779 };
780
781 let fighter = fighter_manifest_from_gorilla(&manifest, &default_model);
782 assert_eq!(fighter.model.model, "gpt-4o");
783 assert_eq!(fighter.weight_class, punch_types::WeightClass::Heavyweight);
784 assert_eq!(fighter.system_prompt, "Uses its own model");
786 }
787
788 #[tokio::test]
789 async fn list_running_returns_all_ids() {
790 let executor = BackgroundExecutor::new();
791 let mut expected_ids = Vec::new();
792
793 for _ in 0..3 {
794 let id = GorillaId::new();
795 expected_ids.push(id);
796 let handle = tokio::spawn(async {
797 futures::future::pending::<()>().await;
798 });
799 executor.tasks.insert(
800 id,
801 GorillaTask {
802 handle,
803 started_at: Utc::now(),
804 },
805 );
806 }
807
808 let running = executor.list_running();
809 assert_eq!(running.len(), 3);
810 for id in &expected_ids {
811 assert!(running.contains(id));
812 }
813
814 executor.shutdown_all();
815 }
816}