Skip to main content

batty_cli/team/
context_management.rs

1//! Utilities for proactively tracking context pressure and preserving restart state.
2
3use std::path::{Path, PathBuf};
4use std::process::Command;
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
7use anyhow::{Context, Result, bail};
8
9use crate::task::Task;
10use crate::team::checkpoint::{self, Checkpoint, RestartContext};
11
12const DEFAULT_THRESHOLD_PCT: u8 = 80;
13const DEFAULT_CONTEXT_LIMIT_TOKENS: usize = 128_000;
14const STATUS_LINE_LIMIT: usize = 20;
15const TEST_OUTPUT_LINE_LIMIT: usize = 50;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ContextAction {
19    GracefulHandoff,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct ContextPressure {
24    pub message_count: usize,
25    pub estimated_tokens: usize,
26    pub threshold_pct: u8,
27}
28
29impl Default for ContextPressure {
30    fn default() -> Self {
31        Self {
32            message_count: 0,
33            estimated_tokens: 0,
34            threshold_pct: DEFAULT_THRESHOLD_PCT,
35        }
36    }
37}
38
39impl ContextPressure {
40    pub fn new(message_count: usize, estimated_tokens: usize) -> Self {
41        Self {
42            message_count,
43            estimated_tokens,
44            ..Self::default()
45        }
46    }
47
48    fn usage_pct(&self) -> usize {
49        self.estimated_tokens.saturating_mul(100) / DEFAULT_CONTEXT_LIMIT_TOKENS.max(1)
50    }
51}
52
53pub fn estimate_token_usage(output_bytes: usize) -> usize {
54    output_bytes.div_ceil(4)
55}
56
57pub fn check_context_pressure(pressure: &ContextPressure) -> Option<ContextAction> {
58    (pressure.usage_pct() >= pressure.threshold_pct as usize)
59        .then_some(ContextAction::GracefulHandoff)
60}
61
62pub fn create_checkpoint(worktree: &Path, task_id: u32) -> Result<Checkpoint> {
63    let role = worktree_role(worktree)?;
64    let project_root = project_root_from_worktree(worktree)?;
65    let checkpoint = Checkpoint {
66        role,
67        task_id,
68        task_title: format!("Task #{task_id}"),
69        task_description: build_state_summary(worktree, task_id),
70        branch: git_output(worktree, &["rev-parse", "--abbrev-ref", "HEAD"]),
71        last_commit: git_output(worktree, &["log", "-1", "--oneline"]),
72        test_summary: last_test_output(worktree),
73        timestamp: timestamp_now(),
74    };
75    checkpoint::write_checkpoint(&project_root, &checkpoint)?;
76    Ok(checkpoint)
77}
78
79pub fn stage_restart_context(
80    worktree: &Path,
81    role: &str,
82    task: &Task,
83    reason: &str,
84    restart_count: u32,
85    output_bytes: Option<u64>,
86) -> Result<RestartContext> {
87    let context = RestartContext {
88        role: role.to_string(),
89        task_id: task.id,
90        task_title: task.title.clone(),
91        task_description: task.description.clone(),
92        branch: task
93            .branch
94            .clone()
95            .or_else(|| git_output(worktree, &["rev-parse", "--abbrev-ref", "HEAD"])),
96        worktree_path: Some(worktree.display().to_string()),
97        restart_count,
98        reason: reason.to_string(),
99        output_bytes,
100        last_commit: git_output(worktree, &["rev-parse", "HEAD"]),
101        created_at_epoch_secs: Some(epoch_now_secs()),
102        handoff_consumed: false,
103    };
104    checkpoint::write_restart_context(worktree, &context)?;
105    Ok(context)
106}
107
108pub fn consume_restart_context(worktree: &Path) -> Result<Option<RestartContext>> {
109    let Some(mut context) = checkpoint::read_restart_context(worktree) else {
110        return Ok(None);
111    };
112    context.handoff_consumed = true;
113    checkpoint::write_restart_context(worktree, &context)?;
114    Ok(Some(context))
115}
116
117pub fn clear_restart_context(worktree: &Path) {
118    checkpoint::remove_restart_context(worktree);
119}
120
121pub fn clear_proactive_restart_context_if_stale(
122    worktree: &Path,
123    output_bytes: u64,
124    cooldown: Duration,
125) -> bool {
126    let Some(context) = checkpoint::read_restart_context(worktree) else {
127        return false;
128    };
129    if context.reason != "context_pressure" {
130        return false;
131    }
132
133    let commit_changed = git_output(worktree, &["rev-parse", "HEAD"]) != context.last_commit;
134    let cooldown_elapsed = context
135        .created_at_epoch_secs
136        .map(|started| epoch_now_secs().saturating_sub(started) >= cooldown.as_secs())
137        .unwrap_or(true);
138    if output_bytes > 0 || commit_changed || cooldown_elapsed {
139        checkpoint::remove_restart_context(worktree);
140        return true;
141    }
142    false
143}
144
145pub fn proactive_restart_is_suppressed(
146    worktree: &Path,
147    output_bytes: u64,
148    cooldown: Duration,
149) -> bool {
150    let Some(context) = checkpoint::read_restart_context(worktree) else {
151        return false;
152    };
153    if context.reason != "context_pressure" || !context.handoff_consumed {
154        return false;
155    }
156
157    let commit_changed = git_output(worktree, &["rev-parse", "HEAD"]) != context.last_commit;
158    let cooldown_elapsed = context
159        .created_at_epoch_secs
160        .map(|started| epoch_now_secs().saturating_sub(started) >= cooldown.as_secs())
161        .unwrap_or(true);
162    output_bytes == 0 && !commit_changed && !cooldown_elapsed
163}
164
165fn worktree_role(worktree: &Path) -> Result<String> {
166    worktree
167        .file_name()
168        .and_then(|name| name.to_str())
169        .filter(|name| !name.is_empty())
170        .map(ToOwned::to_owned)
171        .context("worktree path must end with the member role")
172}
173
174fn project_root_from_worktree(worktree: &Path) -> Result<PathBuf> {
175    let worktrees_dir = worktree
176        .parent()
177        .context("worktree path must be inside .batty/worktrees/<role>")?;
178    if worktrees_dir.file_name().and_then(|name| name.to_str()) != Some("worktrees") {
179        bail!("worktree path must be inside .batty/worktrees/<role>");
180    }
181
182    let batty_dir = worktrees_dir
183        .parent()
184        .context("worktree path must be inside .batty/worktrees/<role>")?;
185    if batty_dir.file_name().and_then(|name| name.to_str()) != Some(".batty") {
186        bail!("worktree path must be inside .batty/worktrees/<role>");
187    }
188
189    batty_dir
190        .parent()
191        .map(Path::to_path_buf)
192        .context("could not locate project root from worktree path")
193}
194
195fn build_state_summary(worktree: &Path, task_id: u32) -> String {
196    let mut sections = vec![format!(
197        "Resume task #{task_id} from the current worktree state at {}.",
198        worktree.display()
199    )];
200
201    if let Some(status) = git_status_summary(worktree) {
202        sections.push(format!("## Git Status\n\n{status}"));
203    }
204
205    if let Some(test_summary) = last_test_output(worktree) {
206        sections.push(format!("## Recent Test Output\n\n{test_summary}"));
207    }
208
209    sections.join("\n\n")
210}
211
212fn git_status_summary(worktree: &Path) -> Option<String> {
213    let status = git_output(worktree, &["status", "--short"])?;
214    let lines: Vec<&str> = status.lines().take(STATUS_LINE_LIMIT).collect();
215    if lines.is_empty() {
216        Some(String::from("clean working tree"))
217    } else {
218        Some(lines.join("\n"))
219    }
220}
221
222fn git_output(worktree: &Path, args: &[&str]) -> Option<String> {
223    if !worktree.exists() {
224        return None;
225    }
226    let output = Command::new("git")
227        .args(args)
228        .current_dir(worktree)
229        .output()
230        .ok()?;
231    if !output.status.success() {
232        return None;
233    }
234    let value = String::from_utf8_lossy(&output.stdout).trim().to_string();
235    if value.is_empty() || value == "HEAD" {
236        None
237    } else {
238        Some(value)
239    }
240}
241
242fn last_test_output(worktree: &Path) -> Option<String> {
243    let output_path = worktree.join(".batty_test_output");
244    let content = std::fs::read_to_string(output_path).ok()?;
245    let lines: Vec<&str> = content.lines().collect();
246    let start = lines.len().saturating_sub(TEST_OUTPUT_LINE_LIMIT);
247    let summary = lines[start..].join("\n");
248    (!summary.is_empty()).then_some(summary)
249}
250
251fn timestamp_now() -> String {
252    let secs = epoch_now_secs();
253    let hours = (secs / 3600) % 24;
254    let minutes = (secs / 60) % 60;
255    let seconds = secs % 60;
256    let days_since_epoch = secs / 86400;
257    let (year, month, day) = epoch_days_to_date(days_since_epoch);
258    format!("{year:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}Z")
259}
260
261fn epoch_now_secs() -> u64 {
262    SystemTime::now()
263        .duration_since(UNIX_EPOCH)
264        .unwrap_or_default()
265        .as_secs()
266}
267
268fn epoch_days_to_date(days: u64) -> (u64, u64, u64) {
269    let z = days + 719468;
270    let era = z / 146097;
271    let doe = z - era * 146097;
272    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
273    let y = yoe + era * 400;
274    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
275    let mp = (5 * doy + 2) / 153;
276    let d = doy - (153 * mp + 2) / 5 + 1;
277    let m = if mp < 10 { mp + 3 } else { mp - 9 };
278    let y = if m <= 2 { y + 1 } else { y };
279    (y, m, d)
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    fn init_git_repo(path: &Path) {
287        std::fs::create_dir_all(path).unwrap();
288        let status = Command::new("git")
289            .args(["init"])
290            .current_dir(path)
291            .status()
292            .unwrap();
293        assert!(status.success());
294        let status = Command::new("git")
295            .args(["config", "user.name", "Batty Test"])
296            .current_dir(path)
297            .status()
298            .unwrap();
299        assert!(status.success());
300        let status = Command::new("git")
301            .args(["config", "user.email", "batty@example.com"])
302            .current_dir(path)
303            .status()
304            .unwrap();
305        assert!(status.success());
306    }
307
308    fn commit_file(repo: &Path, rel_path: &str, content: &str, message: &str) {
309        let file_path = repo.join(rel_path);
310        if let Some(parent) = file_path.parent() {
311            std::fs::create_dir_all(parent).unwrap();
312        }
313        std::fs::write(&file_path, content).unwrap();
314        let status = Command::new("git")
315            .args(["add", rel_path])
316            .current_dir(repo)
317            .status()
318            .unwrap();
319        assert!(status.success());
320        let status = Command::new("git")
321            .args(["commit", "-m", message])
322            .current_dir(repo)
323            .status()
324            .unwrap();
325        assert!(status.success());
326    }
327
328    #[test]
329    fn estimate_token_usage_uses_four_chars_per_token() {
330        assert_eq!(estimate_token_usage(0), 0);
331        assert_eq!(estimate_token_usage(4), 1);
332        assert_eq!(estimate_token_usage(5), 2);
333    }
334
335    #[test]
336    fn threshold_detection_returns_graceful_handoff_at_default_threshold() {
337        let pressure = ContextPressure::new(24, 102_400);
338        assert_eq!(
339            check_context_pressure(&pressure),
340            Some(ContextAction::GracefulHandoff)
341        );
342    }
343
344    #[test]
345    fn threshold_detection_stays_idle_below_limit() {
346        let pressure = ContextPressure::new(8, 90_000);
347        assert_eq!(check_context_pressure(&pressure), None);
348    }
349
350    #[test]
351    fn create_checkpoint_persists_restart_summary() {
352        let tmp = tempfile::tempdir().unwrap();
353        let project_root = tmp.path();
354        let worktree = project_root
355            .join(".batty")
356            .join("worktrees")
357            .join("eng-1-2");
358        init_git_repo(&worktree);
359        commit_file(&worktree, "src/lib.rs", "pub fn ready() {}\n", "initial");
360        std::fs::write(
361            worktree.join(".batty_test_output"),
362            "test a ... ok\ntest b ... ok\n",
363        )
364        .unwrap();
365        std::fs::write(worktree.join("notes.txt"), "pending change\n").unwrap();
366
367        let checkpoint = create_checkpoint(&worktree, 453).unwrap();
368
369        assert_eq!(checkpoint.role, "eng-1-2");
370        assert_eq!(checkpoint.task_id, 453);
371        assert!(
372            matches!(checkpoint.branch.as_deref(), Some("master") | Some("main")),
373            "unexpected branch: {:?}",
374            checkpoint.branch
375        );
376        assert!(
377            checkpoint
378                .last_commit
379                .as_deref()
380                .unwrap()
381                .contains("initial")
382        );
383        assert!(
384            checkpoint
385                .task_description
386                .contains("Resume task #453 from the current worktree state")
387        );
388        assert!(checkpoint.task_description.contains("notes.txt"));
389
390        let stored = checkpoint::read_checkpoint(project_root, "eng-1-2").unwrap();
391        assert!(stored.contains("Task #453"));
392        assert!(stored.contains("Recent Test Output"));
393    }
394
395    #[test]
396    fn create_checkpoint_requires_project_root_layout() {
397        let tmp = tempfile::tempdir().unwrap();
398        let err = create_checkpoint(tmp.path(), 1).unwrap_err();
399        assert!(err.to_string().contains(".batty/worktrees"));
400    }
401
402    fn make_task(id: u32) -> Task {
403        Task {
404            id,
405            title: format!("Task #{id}"),
406            status: "in-progress".to_string(),
407            priority: "high".to_string(),
408            claimed_by: Some("eng-1-2".to_string()),
409            claimed_at: None,
410            claim_ttl_secs: None,
411            claim_expires_at: None,
412            last_progress_at: None,
413            claim_warning_sent_at: None,
414            claim_extensions: None,
415            last_output_bytes: None,
416            blocked: None,
417            tags: vec![],
418            depends_on: vec![],
419            review_owner: None,
420            blocked_on: None,
421            worktree_path: Some("/tmp/worktree".to_string()),
422            branch: Some(format!("eng-1-2/{id}")),
423            commit: None,
424            artifacts: vec![],
425            next_action: Some("Finish the implementation.".to_string()),
426            scheduled_for: None,
427            cron_schedule: None,
428            cron_last_run: None,
429            completed: None,
430            description: "Continue from the saved state.".to_string(),
431            batty_config: None,
432            source_path: PathBuf::from("/tmp/task.md"),
433        }
434    }
435
436    #[test]
437    fn staged_restart_context_round_trips_and_marks_consumed() {
438        let tmp = tempfile::tempdir().unwrap();
439        let repo = tmp.path().join("eng-1-2");
440        init_git_repo(&repo);
441        commit_file(&repo, "src/lib.rs", "pub fn ready() {}\n", "initial");
442        let task = make_task(77);
443
444        let staged =
445            stage_restart_context(&repo, "eng-1-2", &task, "context_pressure", 2, Some(256))
446                .unwrap();
447        assert_eq!(staged.reason, "context_pressure");
448        assert_eq!(staged.restart_count, 2);
449        assert_eq!(staged.output_bytes, Some(256));
450        assert!(!staged.handoff_consumed);
451
452        let consumed = consume_restart_context(&repo).unwrap().unwrap();
453        assert!(consumed.handoff_consumed);
454        assert_eq!(consumed.last_commit, staged.last_commit);
455    }
456
457    #[test]
458    fn proactive_restart_is_suppressed_until_progress_or_cooldown() {
459        let tmp = tempfile::tempdir().unwrap();
460        let repo = tmp.path().join("eng-1-2");
461        init_git_repo(&repo);
462        commit_file(&repo, "src/lib.rs", "pub fn ready() {}\n", "initial");
463        let task = make_task(88);
464
465        stage_restart_context(&repo, "eng-1-2", &task, "context_pressure", 1, Some(0)).unwrap();
466        let consumed = consume_restart_context(&repo).unwrap().unwrap();
467        assert!(consumed.handoff_consumed);
468
469        assert!(proactive_restart_is_suppressed(
470            &repo,
471            0,
472            Duration::from_secs(30)
473        ));
474        assert!(!clear_proactive_restart_context_if_stale(
475            &repo,
476            0,
477            Duration::from_secs(30)
478        ));
479    }
480
481    #[test]
482    fn proactive_restart_guard_clears_after_new_output() {
483        let tmp = tempfile::tempdir().unwrap();
484        let repo = tmp.path().join("eng-1-2");
485        init_git_repo(&repo);
486        commit_file(&repo, "src/lib.rs", "pub fn ready() {}\n", "initial");
487        let task = make_task(89);
488
489        stage_restart_context(&repo, "eng-1-2", &task, "context_pressure", 1, Some(0)).unwrap();
490        consume_restart_context(&repo).unwrap();
491
492        assert!(clear_proactive_restart_context_if_stale(
493            &repo,
494            12,
495            Duration::from_secs(30)
496        ));
497        assert!(checkpoint::read_restart_context(&repo).is_none());
498    }
499
500    #[test]
501    fn proactive_restart_guard_clears_after_new_commit() {
502        let tmp = tempfile::tempdir().unwrap();
503        let repo = tmp.path().join("eng-1-2");
504        init_git_repo(&repo);
505        commit_file(&repo, "src/lib.rs", "pub fn ready() {}\n", "initial");
506        let task = make_task(90);
507
508        stage_restart_context(&repo, "eng-1-2", &task, "context_pressure", 1, Some(0)).unwrap();
509        consume_restart_context(&repo).unwrap();
510        commit_file(
511            &repo,
512            "src/lib.rs",
513            "pub fn ready() { println!(\"ok\"); }\n",
514            "follow-up",
515        );
516
517        assert!(clear_proactive_restart_context_if_stale(
518            &repo,
519            0,
520            Duration::from_secs(30)
521        ));
522        assert!(!proactive_restart_is_suppressed(
523            &repo,
524            0,
525            Duration::from_secs(30)
526        ));
527    }
528}