1use std::collections::HashSet;
33use std::sync::RwLock;
34
35use crate::error::PolicyError;
36
37pub type DedupKey = [u8; 32];
39
40pub fn compute_key(
45 kind: &str,
46 conversation_id: &str,
47 role: &str,
48 scope: Option<&str>,
49 text: &str,
50) -> DedupKey {
51 let mut hasher = blake3::Hasher::new();
52 hasher.update(kind.as_bytes());
53 hasher.update(&[0]);
54 hasher.update(conversation_id.as_bytes());
55 hasher.update(&[0]);
56 hasher.update(role.as_bytes());
57 hasher.update(&[0]);
58 hasher.update(scope.unwrap_or("").as_bytes());
59 hasher.update(&[0]);
60 hasher.update(text.as_bytes());
61 *hasher.finalize().as_bytes()
62}
63
64#[derive(Default)]
67pub struct DedupSet {
68 seen: RwLock<HashSet<DedupKey>>,
69}
70
71impl DedupSet {
72 pub fn new() -> Self {
74 Self::default()
75 }
76
77 pub fn contains(&self, key: &DedupKey) -> Result<bool, PolicyError> {
82 let guard = self.seen.read().map_err(|_| PolicyError::Poisoned)?;
83 Ok(guard.contains(key))
84 }
85
86 pub fn insert(&self, key: DedupKey) -> Result<(), PolicyError> {
88 let mut guard = self.seen.write().map_err(|_| PolicyError::Poisoned)?;
89 guard.insert(key);
90 Ok(())
91 }
92
93 pub fn snapshot(&self) -> Result<Vec<String>, PolicyError> {
96 let guard = self.seen.read().map_err(|_| PolicyError::Poisoned)?;
97 let mut out: Vec<String> = guard.iter().map(hex_encode).collect();
98 out.sort();
99 Ok(out)
100 }
101
102 pub fn extend_from_snapshot(&self, hexes: &[String]) -> Result<(), PolicyError> {
105 let mut guard = self.seen.write().map_err(|_| PolicyError::Poisoned)?;
106 for hex in hexes {
107 match hex_decode(hex) {
108 Some(key) => {
109 guard.insert(key);
110 }
111 None => {
112 tracing::warn!(
113 target: "rig_memory_policy::dedup",
114 invalid = %hex,
115 "skipping malformed dedup snapshot entry",
116 );
117 }
118 }
119 }
120 Ok(())
121 }
122
123 #[cfg(test)]
125 pub(crate) fn len(&self) -> Result<usize, PolicyError> {
126 let guard = self.seen.read().map_err(|_| PolicyError::Poisoned)?;
127 Ok(guard.len())
128 }
129}
130
131impl std::fmt::Debug for DedupSet {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 let count = self.seen.read().map(|g| g.len()).unwrap_or_default();
134 f.debug_struct("DedupSet").field("entries", &count).finish()
135 }
136}
137
138fn hex_encode(key: &DedupKey) -> String {
139 let mut out = String::with_capacity(64);
140 for b in key {
141 out.push(nibble_to_hex(b >> 4));
142 out.push(nibble_to_hex(b & 0x0f));
143 }
144 out
145}
146
147pub fn hex_encode_key(key: &DedupKey) -> String {
150 hex_encode(key)
151}
152
153fn nibble_to_hex(n: u8) -> char {
154 let n = n & 0x0f;
158 if n < 10 {
159 (b'0' + n) as char
160 } else {
161 (b'a' + n - 10) as char
162 }
163}
164
165fn hex_decode(hex: &str) -> Option<DedupKey> {
166 if hex.len() != 64 {
167 return None;
168 }
169 let mut out = [0u8; 32];
170 let bytes = hex.as_bytes();
171 for i in 0..32 {
172 let hi = nibble(bytes.get(i * 2).copied()?)?;
173 let lo = nibble(bytes.get(i * 2 + 1).copied()?)?;
174 if let Some(slot) = out.get_mut(i) {
175 *slot = (hi << 4) | lo;
176 }
177 }
178 Some(out)
179}
180
181fn nibble(b: u8) -> Option<u8> {
182 match b {
183 b'0'..=b'9' => Some(b - b'0'),
184 b'a'..=b'f' => Some(b - b'a' + 10),
185 b'A'..=b'F' => Some(b - b'A' + 10),
186 _ => None,
187 }
188}
189
190#[cfg(test)]
191#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn distinct_inputs_produce_distinct_keys() {
197 let a = compute_key("demoted_message", "c1", "user", None, "hello");
198 let b = compute_key("demoted_message", "c1", "user", None, "hello world");
199 let c = compute_key("compaction_summary", "c1", "user", None, "hello");
200 let d = compute_key("demoted_message", "c2", "user", None, "hello");
201 let e = compute_key("demoted_message", "c1", "assistant", None, "hello");
202 let f = compute_key("demoted_message", "c1", "user", Some("s"), "hello");
203 assert_ne!(a, b);
204 assert_ne!(a, c);
205 assert_ne!(a, d);
206 assert_ne!(a, e);
207 assert_ne!(a, f);
208 }
209
210 #[test]
211 fn identical_inputs_produce_identical_keys() {
212 let a = compute_key("demoted_message", "c1", "user", None, "hello");
213 let b = compute_key("demoted_message", "c1", "user", None, "hello");
214 assert_eq!(a, b);
215 }
216
217 #[test]
218 fn boundary_collision_resistance() {
219 let a = compute_key("ab", "c", "user", None, "");
221 let b = compute_key("a", "bc", "user", None, "");
222 assert_ne!(a, b);
223 }
224
225 #[test]
226 fn set_round_trips_via_snapshot() {
227 let set = DedupSet::new();
228 let k1 = compute_key("kind", "conv", "user", None, "one");
229 let k2 = compute_key("kind", "conv", "user", None, "two");
230 set.insert(k1).unwrap();
231 set.insert(k2).unwrap();
232 let snap = set.snapshot().unwrap();
233 assert_eq!(snap.len(), 2);
234
235 let restored = DedupSet::new();
236 restored.extend_from_snapshot(&snap).unwrap();
237 assert!(restored.contains(&k1).unwrap());
238 assert!(restored.contains(&k2).unwrap());
239 }
240
241 #[test]
242 fn malformed_snapshot_entries_are_skipped() {
243 let set = DedupSet::new();
244 let good = compute_key("k", "c", "user", None, "x");
245 let bad = "not-hex".to_string();
246 let snap = vec![hex_encode(&good), bad];
247 set.extend_from_snapshot(&snap).unwrap();
248 assert_eq!(set.len().unwrap(), 1);
249 assert!(set.contains(&good).unwrap());
250 }
251}