Skip to main content

swink_agent/
context_cache.rs

1//! Context caching abstractions for provider-side cache control.
2//!
3//! Provides [`CacheConfig`] for opt-in caching, [`CacheHint`] for annotating
4//! messages with write/read intent, and [`CacheState`] for tracking the
5//! cache lifecycle across turns.
6
7#![forbid(unsafe_code)]
8
9use std::time::Duration;
10
11use serde::{Deserialize, Serialize};
12
13// ─── CacheConfig ───────────────────────────────────────────────────────────
14
15/// Configuration for provider-side context caching.
16///
17/// When attached to [`AgentOptions`](crate::AgentOptions), the framework
18/// annotates cacheable prefix messages with [`CacheHint`] markers that
19/// adapters translate to provider-specific cache control headers.
20#[derive(Debug, Clone)]
21pub struct CacheConfig {
22    /// Time-to-live for the cached prefix on the provider side.
23    pub ttl: Duration,
24    /// Minimum token count for the cached prefix; caching is suppressed
25    /// when the prefix is smaller than this threshold.
26    pub min_tokens: usize,
27    /// Number of turns between cache refreshes (Write → Read × N → Write).
28    pub cache_intervals: usize,
29}
30
31impl CacheConfig {
32    /// Create a new cache configuration.
33    pub const fn new(ttl: Duration, min_tokens: usize, cache_intervals: usize) -> Self {
34        Self {
35            ttl,
36            min_tokens,
37            cache_intervals,
38        }
39    }
40}
41
42// ─── CacheHint ─────────────────────────────────────────────────────────────
43
44/// Hint attached to messages indicating the desired cache action.
45///
46/// Adapters inspect this during message conversion to translate into
47/// provider-specific cache control (e.g., Anthropic's `cache_control` field).
48#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
49#[serde(tag = "action", rename_all = "snake_case")]
50pub enum CacheHint {
51    /// Write (or refresh) the cached prefix with the given TTL.
52    Write {
53        #[serde(with = "duration_secs")]
54        ttl: Duration,
55    },
56    /// Read from an existing cached prefix.
57    Read,
58}
59
60/// Serde helper: serialize/deserialize `Duration` as integer seconds.
61mod duration_secs {
62    use std::time::Duration;
63
64    use serde::{Deserialize, Deserializer, Serializer};
65
66    pub fn serialize<S: Serializer>(dur: &Duration, s: S) -> Result<S::Ok, S::Error> {
67        s.serialize_u64(dur.as_secs())
68    }
69
70    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
71        let secs = u64::deserialize(d)?;
72        Ok(Duration::from_secs(secs))
73    }
74}
75
76// ─── CacheState ────────────────────────────────────────────────────────────
77
78/// Tracks the cache lifecycle across turns.
79///
80/// Call [`advance_turn`](Self::advance_turn) once per turn to get the
81/// appropriate [`CacheHint`]. Call [`reset`](Self::reset) when the adapter
82/// reports a cache miss so the next turn re-sends a `Write`.
83#[derive(Debug, Clone)]
84pub struct CacheState {
85    turns_since_write: usize,
86    /// Number of tokens in the cached prefix (set after annotation).
87    pub cached_prefix_len: usize,
88}
89
90impl CacheState {
91    /// Create a new cache state (first turn will emit `Write`).
92    pub const fn new() -> Self {
93        Self {
94            turns_since_write: 0,
95            cached_prefix_len: 0,
96        }
97    }
98
99    /// Advance the turn counter and return the cache hint for this turn.
100    ///
101    /// - First turn (or after reset/refresh): returns `Write { ttl }`.
102    /// - Subsequent turns within `cache_intervals`: returns `Read`.
103    /// - After `cache_intervals` turns: returns `Write` (refresh).
104    pub const fn advance_turn(&mut self, config: &CacheConfig) -> CacheHint {
105        if self.turns_since_write == 0 {
106            // First turn or just after reset — write.
107            self.turns_since_write = 1;
108            CacheHint::Write { ttl: config.ttl }
109        } else if self.turns_since_write >= config.cache_intervals {
110            // Refresh cycle reached — write again.
111            self.turns_since_write = 1;
112            CacheHint::Write { ttl: config.ttl }
113        } else {
114            self.turns_since_write += 1;
115            CacheHint::Read
116        }
117    }
118
119    /// Reset the cache state, forcing the next turn to emit `Write`.
120    ///
121    /// Called when the adapter reports a provider cache miss.
122    pub const fn reset(&mut self) {
123        self.turns_since_write = 0;
124        self.cached_prefix_len = 0;
125    }
126}
127
128impl Default for CacheState {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    fn test_config(intervals: usize) -> CacheConfig {
139        CacheConfig::new(Duration::from_secs(600), 4096, intervals)
140    }
141
142    #[test]
143    fn first_turn_emits_write() {
144        let mut state = CacheState::new();
145        let config = test_config(3);
146        let hint = state.advance_turn(&config);
147        assert_eq!(
148            hint,
149            CacheHint::Write {
150                ttl: Duration::from_secs(600)
151            }
152        );
153    }
154
155    #[test]
156    fn subsequent_turns_emit_read() {
157        let mut state = CacheState::new();
158        let config = test_config(3);
159        state.advance_turn(&config); // turn 1: Write
160        assert_eq!(state.advance_turn(&config), CacheHint::Read); // turn 2
161        assert_eq!(state.advance_turn(&config), CacheHint::Read); // turn 3
162    }
163
164    #[test]
165    fn refresh_after_cache_intervals() {
166        let mut state = CacheState::new();
167        let config = test_config(3);
168        state.advance_turn(&config); // turn 1: Write
169        state.advance_turn(&config); // turn 2: Read
170        state.advance_turn(&config); // turn 3: Read
171        // turn 4: should refresh (turns_since_write == 3 == cache_intervals)
172        let hint = state.advance_turn(&config);
173        assert_eq!(
174            hint,
175            CacheHint::Write {
176                ttl: Duration::from_secs(600)
177            }
178        );
179    }
180
181    #[test]
182    fn reset_forces_write_on_next_turn() {
183        let mut state = CacheState::new();
184        let config = test_config(5);
185        state.advance_turn(&config); // Write
186        state.advance_turn(&config); // Read
187        state.reset(); // adapter-reported cache miss
188        let hint = state.advance_turn(&config);
189        assert_eq!(
190            hint,
191            CacheHint::Write {
192                ttl: Duration::from_secs(600)
193            }
194        );
195    }
196
197    #[test]
198    fn cached_prefix_len_tracks_correctly() {
199        let mut state = CacheState::new();
200        assert_eq!(state.cached_prefix_len, 0);
201        state.cached_prefix_len = 5;
202        assert_eq!(state.cached_prefix_len, 5);
203        state.reset();
204        assert_eq!(state.cached_prefix_len, 0);
205    }
206
207    #[test]
208    fn min_tokens_below_threshold_suppresses_hints() {
209        // The CacheState itself doesn't enforce min_tokens — that check
210        // happens in the turn pipeline. Verify the config carries it.
211        let config = CacheConfig::new(Duration::from_secs(300), 8192, 2);
212        assert_eq!(config.min_tokens, 8192);
213    }
214
215    #[test]
216    fn serde_round_trip_write_hint() {
217        let hint = CacheHint::Write {
218            ttl: Duration::from_secs(600),
219        };
220        let json = serde_json::to_string(&hint).unwrap();
221        let back: CacheHint = serde_json::from_str(&json).unwrap();
222        assert_eq!(hint, back);
223    }
224
225    #[test]
226    fn serde_round_trip_read_hint() {
227        let hint = CacheHint::Read;
228        let json = serde_json::to_string(&hint).unwrap();
229        let back: CacheHint = serde_json::from_str(&json).unwrap();
230        assert_eq!(hint, back);
231    }
232}