1use std::sync::Arc;
9
10use chrono::{DateTime, Utc};
11use dashmap::DashMap;
12use tokio::sync::{Semaphore, watch};
13use tokio::task::JoinHandle;
14use tracing::{error, info, warn};
15
16use punch_memory::MemorySubstrate;
17use punch_runtime::{
18 FighterLoopParams, FighterLoopResult, LlmDriver, run_fighter_loop, tools_for_capabilities,
19};
20use punch_types::{
21 FighterId, FighterManifest, GorillaId, GorillaManifest, ModelConfig, PunchResult, WeightClass,
22};
23
24const DEFAULT_LLM_CONCURRENCY: usize = 3;
26
27struct GorillaTask {
29 handle: JoinHandle<()>,
30 #[allow(dead_code)]
31 started_at: DateTime<Utc>,
32}
33
34pub struct BackgroundExecutor {
36 tasks: DashMap<GorillaId, GorillaTask>,
38 llm_semaphore: Arc<Semaphore>,
40 _shutdown_tx: watch::Sender<bool>,
42 shutdown_rx: watch::Receiver<bool>,
44}
45
46pub fn fighter_manifest_from_gorilla(
49 manifest: &GorillaManifest,
50 default_model: &ModelConfig,
51) -> FighterManifest {
52 let model = manifest
53 .model
54 .clone()
55 .unwrap_or_else(|| default_model.clone());
56 let capabilities = manifest.effective_capabilities();
57 let weight_class = manifest.weight_class.unwrap_or(WeightClass::Middleweight);
58 let system_prompt = manifest.effective_system_prompt();
59
60 FighterManifest {
61 name: manifest.name.clone(),
62 description: format!("Autonomous gorilla: {}", manifest.name),
63 model,
64 system_prompt,
65 capabilities,
66 weight_class,
67 tenant_id: None,
68 }
69}
70
71pub async fn run_gorilla_tick(
74 gorilla_id: GorillaId,
75 manifest: &GorillaManifest,
76 default_model: &ModelConfig,
77 memory: &Arc<MemorySubstrate>,
78 driver: &Arc<dyn LlmDriver>,
79) -> PunchResult<FighterLoopResult> {
80 let fighter_manifest = fighter_manifest_from_gorilla(manifest, default_model);
81 let gorilla_name = &manifest.name;
82 let system_prompt = fighter_manifest.system_prompt.clone();
83
84 let autonomous_prompt = format!(
86 "[AUTONOMOUS TICK] You are {}. Review your memory, check your goals, and take the next action. {}",
87 gorilla_name, system_prompt
88 );
89
90 let fighter_id = FighterId::new();
92
93 if let Err(e) = memory
95 .save_fighter(
96 &fighter_id,
97 &fighter_manifest,
98 punch_types::FighterStatus::Idle,
99 )
100 .await
101 {
102 warn!(gorilla_id = %gorilla_id, error = %e, "failed to persist gorilla fighter");
103 }
104
105 let bout_id = memory.create_bout(&fighter_id).await?;
107
108 let available_tools = tools_for_capabilities(&fighter_manifest.capabilities);
109
110 let params = FighterLoopParams {
111 manifest: fighter_manifest,
112 user_message: autonomous_prompt,
113 bout_id,
114 fighter_id,
115 memory: Arc::clone(memory),
116 driver: Arc::clone(driver),
117 available_tools,
118 max_iterations: Some(10),
119 context_window: None,
120 tool_timeout_secs: None,
121 coordinator: None,
122 approval_engine: None,
123 sandbox: None,
124 };
125
126 run_fighter_loop(params).await
127}
128
129impl BackgroundExecutor {
130 pub fn new() -> Self {
132 let (shutdown_tx, shutdown_rx) = watch::channel(false);
133 Self {
134 tasks: DashMap::new(),
135 llm_semaphore: Arc::new(Semaphore::new(DEFAULT_LLM_CONCURRENCY)),
136 _shutdown_tx: shutdown_tx,
137 shutdown_rx,
138 }
139 }
140
141 pub fn with_shutdown(
143 shutdown_tx: watch::Sender<bool>,
144 shutdown_rx: watch::Receiver<bool>,
145 ) -> Self {
146 Self {
147 tasks: DashMap::new(),
148 llm_semaphore: Arc::new(Semaphore::new(DEFAULT_LLM_CONCURRENCY)),
149 _shutdown_tx: shutdown_tx,
150 shutdown_rx,
151 }
152 }
153
154 pub fn parse_schedule(schedule: &str) -> Option<std::time::Duration> {
157 let s = schedule.trim().to_lowercase();
158 let s = s.strip_prefix("every ").unwrap_or(&s);
159 let s = s.trim();
160
161 if let Some(num_str) = s.strip_suffix('s') {
162 num_str
163 .trim()
164 .parse::<u64>()
165 .ok()
166 .map(std::time::Duration::from_secs)
167 } else if let Some(num_str) = s.strip_suffix('m') {
168 num_str
169 .trim()
170 .parse::<u64>()
171 .ok()
172 .map(|m| std::time::Duration::from_secs(m * 60))
173 } else if let Some(num_str) = s.strip_suffix('h') {
174 num_str
175 .trim()
176 .parse::<u64>()
177 .ok()
178 .map(|h| std::time::Duration::from_secs(h * 3600))
179 } else if let Some(num_str) = s.strip_suffix('d') {
180 num_str
181 .trim()
182 .parse::<u64>()
183 .ok()
184 .map(|d| std::time::Duration::from_secs(d * 86400))
185 } else {
186 s.parse::<u64>().ok().map(std::time::Duration::from_secs)
188 }
189 }
190
191 pub fn start_gorilla(
200 &self,
201 id: GorillaId,
202 manifest: GorillaManifest,
203 default_model: ModelConfig,
204 memory: Arc<MemorySubstrate>,
205 driver: Arc<dyn LlmDriver>,
206 ) -> PunchResult<()> {
207 if self.tasks.contains_key(&id) {
208 return Err(punch_types::PunchError::Gorilla(format!(
209 "gorilla {} is already running",
210 id
211 )));
212 }
213
214 let interval = Self::parse_schedule(&manifest.schedule).unwrap_or_else(|| {
215 warn!(
216 gorilla_id = %id,
217 schedule = %manifest.schedule,
218 "could not parse schedule, defaulting to 5m"
219 );
220 std::time::Duration::from_secs(300)
221 });
222
223 let semaphore = Arc::clone(&self.llm_semaphore);
224 let mut shutdown_rx = self.shutdown_rx.clone();
225 let gorilla_name = manifest.name.clone();
226
227 let handle = tokio::spawn(async move {
228 info!(
229 gorilla_id = %id,
230 name = %gorilla_name,
231 interval_secs = interval.as_secs(),
232 "gorilla background task started"
233 );
234
235 let mut tasks_completed: u64 = 0;
236 let mut error_count: u64 = 0;
237
238 loop {
239 tokio::select! {
241 _ = tokio::time::sleep(interval) => {},
242 _ = shutdown_rx.changed() => {
243 if *shutdown_rx.borrow() {
244 info!(gorilla_id = %id, "gorilla received shutdown signal");
245 break;
246 }
247 }
248 }
249
250 if *shutdown_rx.borrow() {
252 break;
253 }
254
255 let _permit = match semaphore.acquire().await {
257 Ok(permit) => permit,
258 Err(_) => {
259 warn!(gorilla_id = %id, "semaphore closed, stopping gorilla");
260 break;
261 }
262 };
263
264 match run_gorilla_tick(id, &manifest, &default_model, &memory, &driver).await {
265 Ok(result) => {
266 tasks_completed += 1;
267 info!(
268 gorilla_id = %id,
269 tasks_completed,
270 tokens = result.usage.total(),
271 "gorilla tick completed successfully"
272 );
273 }
274 Err(e) => {
275 error_count += 1;
276 error!(
277 gorilla_id = %id,
278 error = %e,
279 error_count,
280 "gorilla tick failed"
281 );
282 }
283 }
284 }
285
286 info!(
287 gorilla_id = %id,
288 tasks_completed,
289 "gorilla background task stopped"
290 );
291 });
292
293 self.tasks.insert(
294 id,
295 GorillaTask {
296 handle,
297 started_at: Utc::now(),
298 },
299 );
300
301 Ok(())
302 }
303
304 pub fn stop_gorilla(&self, id: &GorillaId) -> bool {
306 if let Some((_, task)) = self.tasks.remove(id) {
307 task.handle.abort();
308 info!(gorilla_id = %id, "gorilla task stopped");
309 true
310 } else {
311 false
312 }
313 }
314
315 pub fn is_running(&self, id: &GorillaId) -> bool {
317 self.tasks.contains_key(id)
318 }
319
320 pub fn list_running(&self) -> Vec<GorillaId> {
322 self.tasks.iter().map(|entry| *entry.key()).collect()
323 }
324
325 pub fn shutdown_all(&self) {
327 let ids: Vec<GorillaId> = self.tasks.iter().map(|e| *e.key()).collect();
328 for id in &ids {
329 if let Some((_, task)) = self.tasks.remove(id) {
330 task.handle.abort();
331 }
332 }
333 info!(count = ids.len(), "all gorilla tasks shut down");
334 }
335
336 pub fn running_count(&self) -> usize {
338 self.tasks.len()
339 }
340}
341
342impl Default for BackgroundExecutor {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn parse_schedule_seconds() {
358 assert_eq!(
359 BackgroundExecutor::parse_schedule("every 30s"),
360 Some(std::time::Duration::from_secs(30))
361 );
362 }
363
364 #[test]
365 fn parse_schedule_minutes() {
366 assert_eq!(
367 BackgroundExecutor::parse_schedule("every 5m"),
368 Some(std::time::Duration::from_secs(300))
369 );
370 }
371
372 #[test]
373 fn parse_schedule_hours() {
374 assert_eq!(
375 BackgroundExecutor::parse_schedule("every 1h"),
376 Some(std::time::Duration::from_secs(3600))
377 );
378 }
379
380 #[test]
381 fn parse_schedule_days() {
382 assert_eq!(
383 BackgroundExecutor::parse_schedule("every 1d"),
384 Some(std::time::Duration::from_secs(86400))
385 );
386 }
387
388 #[test]
389 fn parse_schedule_invalid() {
390 assert_eq!(BackgroundExecutor::parse_schedule("invalid"), None);
391 }
392
393 #[tokio::test]
394 async fn start_and_stop_gorilla() {
395 let executor = BackgroundExecutor::new();
396 let id = GorillaId::new();
397 let _manifest = GorillaManifest {
398 name: "test-gorilla".to_string(),
399 description: "test".to_string(),
400 schedule: "every 30s".to_string(),
401 moves_required: Vec::new(),
402 settings_schema: None,
403 dashboard_metrics: Vec::new(),
404 system_prompt: None,
405 model: None,
406 capabilities: Vec::new(),
407 weight_class: None,
408 };
409
410 let handle = tokio::spawn(async {
413 futures::future::pending::<()>().await;
414 });
415
416 executor.tasks.insert(
417 id,
418 GorillaTask {
419 handle,
420 started_at: Utc::now(),
421 },
422 );
423
424 assert_eq!(executor.running_count(), 1);
425 assert!(executor.list_running().contains(&id));
426
427 assert!(executor.stop_gorilla(&id));
428 assert_eq!(executor.running_count(), 0);
429 }
430
431 #[tokio::test]
432 async fn shutdown_all_stops_everything() {
433 let executor = BackgroundExecutor::new();
434
435 for _ in 0..3 {
436 let id = GorillaId::new();
437 let handle = tokio::spawn(async {
438 futures::future::pending::<()>().await;
439 });
440 executor.tasks.insert(
441 id,
442 GorillaTask {
443 handle,
444 started_at: Utc::now(),
445 },
446 );
447 }
448
449 assert_eq!(executor.running_count(), 3);
450 executor.shutdown_all();
451 assert_eq!(executor.running_count(), 0);
452 }
453
454 #[tokio::test]
455 async fn stop_nonexistent_gorilla_returns_false() {
456 let executor = BackgroundExecutor::new();
457 let id = GorillaId::new();
458 assert!(!executor.stop_gorilla(&id));
459 }
460
461 #[test]
462 fn parse_schedule_raw_seconds() {
463 assert_eq!(
464 BackgroundExecutor::parse_schedule("60"),
465 Some(std::time::Duration::from_secs(60))
466 );
467 }
468
469 #[test]
470 fn parse_schedule_with_whitespace() {
471 assert_eq!(
472 BackgroundExecutor::parse_schedule(" every 10s "),
473 Some(std::time::Duration::from_secs(10))
474 );
475 }
476
477 #[test]
478 fn parse_schedule_case_insensitive() {
479 assert_eq!(
480 BackgroundExecutor::parse_schedule("Every 2H"),
481 Some(std::time::Duration::from_secs(7200))
482 );
483 }
484
485 #[test]
486 fn parse_schedule_empty_string() {
487 assert_eq!(BackgroundExecutor::parse_schedule(""), None);
488 }
489
490 #[test]
491 fn parse_schedule_just_prefix() {
492 assert_eq!(BackgroundExecutor::parse_schedule("every "), None);
493 }
494
495 #[test]
496 fn default_creates_executor() {
497 let executor = BackgroundExecutor::default();
498 assert_eq!(executor.running_count(), 0);
499 assert!(executor.list_running().is_empty());
500 }
501
502 #[tokio::test]
503 async fn is_running_returns_correct_state() {
504 let executor = BackgroundExecutor::new();
505 let id = GorillaId::new();
506
507 assert!(!executor.is_running(&id));
508
509 let handle = tokio::spawn(async {
510 futures::future::pending::<()>().await;
511 });
512 executor.tasks.insert(
513 id,
514 GorillaTask {
515 handle,
516 started_at: Utc::now(),
517 },
518 );
519
520 assert!(executor.is_running(&id));
521 executor.stop_gorilla(&id);
522 assert!(!executor.is_running(&id));
523 }
524
525 #[tokio::test]
526 async fn multiple_gorillas_tracked_independently() {
527 let executor = BackgroundExecutor::new();
528 let ids: Vec<GorillaId> = (0..5).map(|_| GorillaId::new()).collect();
529
530 for &id in &ids {
531 let handle = tokio::spawn(async {
532 futures::future::pending::<()>().await;
533 });
534 executor.tasks.insert(
535 id,
536 GorillaTask {
537 handle,
538 started_at: Utc::now(),
539 },
540 );
541 }
542
543 assert_eq!(executor.running_count(), 5);
544
545 executor.stop_gorilla(&ids[0]);
547 executor.stop_gorilla(&ids[1]);
548 assert_eq!(executor.running_count(), 3);
549
550 for &id in &ids[2..] {
552 assert!(executor.is_running(&id));
553 }
554
555 executor.shutdown_all();
556 assert_eq!(executor.running_count(), 0);
557 }
558
559 #[tokio::test]
560 async fn with_shutdown_receives_shutdown_signal() {
561 let (tx, rx) = watch::channel(false);
562 let executor = BackgroundExecutor::with_shutdown(tx.clone(), rx);
563
564 let id = GorillaId::new();
565 let handle = tokio::spawn(async {
566 futures::future::pending::<()>().await;
567 });
568 executor.tasks.insert(
569 id,
570 GorillaTask {
571 handle,
572 started_at: Utc::now(),
573 },
574 );
575
576 assert_eq!(executor.running_count(), 1);
577 executor.shutdown_all();
578 assert_eq!(executor.running_count(), 0);
579 }
580
581 #[test]
582 fn fighter_manifest_from_gorilla_uses_default_model() {
583 use punch_types::{ModelConfig, Provider};
584
585 let manifest = GorillaManifest {
586 name: "test-gorilla".to_string(),
587 description: "A test gorilla".to_string(),
588 schedule: "every 30s".to_string(),
589 moves_required: Vec::new(),
590 settings_schema: None,
591 dashboard_metrics: Vec::new(),
592 system_prompt: Some("Custom prompt".to_string()),
593 model: None,
594 capabilities: Vec::new(),
595 weight_class: None,
596 };
597
598 let default_model = ModelConfig {
599 provider: Provider::Anthropic,
600 model: "claude-sonnet-4-20250514".to_string(),
601 api_key_env: None,
602 base_url: None,
603 max_tokens: Some(4096),
604 temperature: Some(0.7),
605 };
606
607 let fighter = fighter_manifest_from_gorilla(&manifest, &default_model);
608 assert_eq!(fighter.name, "test-gorilla");
609 assert_eq!(fighter.model.model, "claude-sonnet-4-20250514");
610 assert_eq!(fighter.system_prompt, "Custom prompt");
611 assert_eq!(fighter.weight_class, punch_types::WeightClass::Middleweight);
612 }
613
614 #[test]
615 fn fighter_manifest_from_gorilla_uses_gorilla_model_if_set() {
616 use punch_types::{ModelConfig, Provider};
617
618 let gorilla_model = ModelConfig {
619 provider: Provider::OpenAI,
620 model: "gpt-4o".to_string(),
621 api_key_env: None,
622 base_url: None,
623 max_tokens: Some(8192),
624 temperature: Some(0.5),
625 };
626
627 let manifest = GorillaManifest {
628 name: "smart-gorilla".to_string(),
629 description: "Uses its own model".to_string(),
630 schedule: "every 1h".to_string(),
631 moves_required: Vec::new(),
632 settings_schema: None,
633 dashboard_metrics: Vec::new(),
634 system_prompt: None,
635 model: Some(gorilla_model),
636 capabilities: Vec::new(),
637 weight_class: Some(punch_types::WeightClass::Heavyweight),
638 };
639
640 let default_model = ModelConfig {
641 provider: Provider::Anthropic,
642 model: "claude-sonnet-4-20250514".to_string(),
643 api_key_env: None,
644 base_url: None,
645 max_tokens: Some(4096),
646 temperature: Some(0.7),
647 };
648
649 let fighter = fighter_manifest_from_gorilla(&manifest, &default_model);
650 assert_eq!(fighter.model.model, "gpt-4o");
651 assert_eq!(fighter.weight_class, punch_types::WeightClass::Heavyweight);
652 assert_eq!(fighter.system_prompt, "Uses its own model");
654 }
655
656 #[tokio::test]
657 async fn list_running_returns_all_ids() {
658 let executor = BackgroundExecutor::new();
659 let mut expected_ids = Vec::new();
660
661 for _ in 0..3 {
662 let id = GorillaId::new();
663 expected_ids.push(id);
664 let handle = tokio::spawn(async {
665 futures::future::pending::<()>().await;
666 });
667 executor.tasks.insert(
668 id,
669 GorillaTask {
670 handle,
671 started_at: Utc::now(),
672 },
673 );
674 }
675
676 let running = executor.list_running();
677 assert_eq!(running.len(), 3);
678 for id in &expected_ids {
679 assert!(running.contains(id));
680 }
681
682 executor.shutdown_all();
683 }
684}