Skip to main content

ailake_query/
memory_decay.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Periodic recency-decay job for `EpisodicMemorySchema` tables.
3//!
4//! Reads the `last_accessed_at` column from each data file (Timestamp(ns, UTC) or legacy Utf8),
5//! recomputes
6//! `recency_weight = exp(-lambda * days_since_access)`, rewrites the column,
7//! and commits a new Iceberg snapshot replacing the old files.
8//!
9//! Integrates with the existing `CompactionExecutor` infrastructure: it reads
10//! and rewrites individual data files (not a merge), preserving HNSW indexes.
11
12use std::sync::Arc;
13use std::time::{SystemTime, UNIX_EPOCH};
14
15use tracing::{info, warn};
16
17use ailake_catalog::{
18    new_snapshot_id, CatalogProvider, NewSnapshot, SnapshotOperation, TableIdent,
19};
20use ailake_core::{AilakeError, AilakeResult, VectorStoragePolicy};
21use ailake_file::{AilakeFileReader, AilakeFileWriter};
22use ailake_store::Store;
23use ailake_vec::compute_centroid_and_radius;
24use arrow_array::{
25    Array, Float32Array, RecordBatch, TimestampMicrosecondArray, TimestampNanosecondArray,
26};
27use arrow_schema::{DataType, Field};
28
29const LAST_ACCESSED_COL: &str = "last_accessed_at";
30const RECENCY_WEIGHT_COL: &str = "recency_weight";
31
32/// Periodic job that updates `recency_weight` for all records in a table.
33///
34/// The weight decays exponentially with age:
35/// `recency_weight = exp(-lambda * days_since_last_access)`
36///
37/// Where `days_since_last_access` is computed from the `last_accessed_at`
38/// column (ISO 8601 string or Unix timestamp string in the record).
39///
40/// # Usage
41///
42/// ```ignore
43/// let job = MemoryDecayJob::new(catalog, store, policy, lambda: 0.1);
44/// let updated = job.run(&table).await?;
45/// println!("{updated} files updated");
46/// ```
47pub struct MemoryDecayJob {
48    catalog: Arc<dyn CatalogProvider>,
49    store: Arc<dyn Store>,
50    policy: VectorStoragePolicy,
51    /// Exponential decay rate. Higher lambda → faster decay.
52    /// Typical values: 0.05 (slow) to 0.5 (aggressive).
53    pub decay_lambda: f32,
54}
55
56impl MemoryDecayJob {
57    pub fn new(
58        catalog: Arc<dyn CatalogProvider>,
59        store: Arc<dyn Store>,
60        policy: VectorStoragePolicy,
61        decay_lambda: f32,
62    ) -> Self {
63        Self {
64            catalog,
65            store,
66            policy,
67            decay_lambda,
68        }
69    }
70
71    /// Run decay update across all data files in the table's current snapshot.
72    ///
73    /// Returns the number of files that were rewritten (files missing the
74    /// `last_accessed_at` column are skipped).
75    pub async fn run(&self, table: &TableIdent) -> AilakeResult<usize> {
76        let files = self.catalog.list_files(table, None).await?;
77        if files.is_empty() {
78            return Ok(0);
79        }
80
81        let today_day = current_day_since_epoch();
82        let mut new_entries = Vec::with_capacity(files.len());
83        let mut updated = 0usize;
84
85        for file_entry in &files {
86            let file_bytes = self.store.get(&file_entry.path).await?;
87            let reader =
88                AilakeFileReader::new(file_bytes, &self.policy.column_name, self.policy.dim);
89
90            if !reader.is_ailake_file() {
91                // Not an AI-Lake file — carry forward unchanged.
92                new_entries.push(file_entry.clone());
93                continue;
94            }
95
96            let (batch, embeddings) = match reader.read_parquet() {
97                Ok(pair) => pair,
98                Err(e) => {
99                    warn!(
100                        "ailake: MemoryDecayJob skipping {} — read error: {}",
101                        file_entry.path, e
102                    );
103                    new_entries.push(file_entry.clone());
104                    continue;
105                }
106            };
107
108            if batch.column_by_name(LAST_ACCESSED_COL).is_none() {
109                // Table doesn't have last_accessed_at — nothing to decay.
110                new_entries.push(file_entry.clone());
111                continue;
112            }
113
114            let updated_batch = apply_decay(&batch, today_day, self.decay_lambda)?;
115
116            // Rewrite file with updated recency_weight column.
117            let file_writer = AilakeFileWriter::new(self.policy.clone());
118            let new_bytes = file_writer.write(&updated_batch, &embeddings)?;
119            let new_size = new_bytes.len() as u64;
120            self.store.put(&file_entry.path, new_bytes.clone()).await?;
121
122            let centroid = compute_centroid_and_radius(&embeddings, self.policy.metric);
123            let new_reader =
124                AilakeFileReader::new(new_bytes, &self.policy.column_name, self.policy.dim);
125            let header = new_reader.read_header()?;
126            let ailk_start = new_reader.ailk_offset()?;
127
128            let new_entry = ailake_catalog::make_data_file_entry(
129                &file_entry.path,
130                updated_batch.num_rows() as u64,
131                new_size,
132                &centroid,
133                ailake_catalog::VectorIndexInfo {
134                    column: &self.policy.column_name,
135                    dim: self.policy.dim,
136                    hnsw_offset: ailk_start + header.hnsw_offset,
137                    hnsw_len: header.hnsw_len,
138                },
139            );
140            new_entries.push(new_entry);
141            updated += 1;
142        }
143
144        if updated == 0 {
145            info!(
146                "ailake: MemoryDecayJob — no files with last_accessed_at column; skipping commit"
147            );
148            return Ok(0);
149        }
150
151        let snap = NewSnapshot {
152            snapshot_id: new_snapshot_id(),
153            parent_snapshot_id: None,
154            files: new_entries,
155            operation: SnapshotOperation::Overwrite,
156            iceberg_schema: None,
157            extra_properties: std::collections::HashMap::new(),
158            bloom_filters: vec![],
159            equality_delete_files: vec![],
160        };
161        self.catalog.commit_snapshot(table, snap).await?;
162        info!(
163            "ailake: MemoryDecayJob — updated recency_weight in {} files (lambda={})",
164            updated, self.decay_lambda
165        );
166        Ok(updated)
167    }
168}
169
170/// Extract days-since-access for each row, supporting Timestamp(ns/us) and legacy Utf8.
171fn days_old_vec(col: &Arc<dyn Array>, today_day: i64) -> AilakeResult<Vec<f32>> {
172    if let Some(ts) = col.as_any().downcast_ref::<TimestampNanosecondArray>() {
173        return Ok((0..ts.len())
174            .map(|i| {
175                if !ts.is_valid(i) {
176                    return 0.0f32;
177                }
178                let day = ts.value(i) / (86_400 * 1_000_000_000i64);
179                (today_day - day).max(0) as f32
180            })
181            .collect());
182    }
183    if let Some(ts) = col.as_any().downcast_ref::<TimestampMicrosecondArray>() {
184        return Ok((0..ts.len())
185            .map(|i| {
186                if !ts.is_valid(i) {
187                    return 0.0f32;
188                }
189                let day = ts.value(i) / (86_400 * 1_000_000i64);
190                (today_day - day).max(0) as f32
191            })
192            .collect());
193    }
194    if let Some(sa) = col.as_any().downcast_ref::<arrow_array::StringArray>() {
195        return Ok((0..sa.len())
196            .map(|i| {
197                if !sa.is_valid(i) {
198                    return 0.0f32;
199                }
200                let access_day = parse_iso_date_days(sa.value(i)).unwrap_or(today_day);
201                (today_day - access_day).max(0) as f32
202            })
203            .collect());
204    }
205    Err(AilakeError::Catalog(
206        "last_accessed_at must be Timestamp(Nanosecond/Microsecond) or Utf8".into(),
207    ))
208}
209
210/// Rewrite the `recency_weight` column in `batch` based on `last_accessed_at`.
211fn apply_decay(batch: &RecordBatch, today_day: i64, lambda: f32) -> AilakeResult<RecordBatch> {
212    let col = batch
213        .column_by_name(LAST_ACCESSED_COL)
214        .ok_or_else(|| AilakeError::Catalog("last_accessed_at column not found".into()))?;
215
216    let days_old = days_old_vec(col, today_day)?;
217    let new_weights: Vec<f32> = days_old.into_iter().map(|d| (-lambda * d).exp()).collect();
218
219    let new_weight_array = Arc::new(Float32Array::from(new_weights));
220
221    // Rebuild RecordBatch replacing (or adding) the recency_weight column.
222    let old_schema = batch.schema();
223    let decay_field = Field::new(RECENCY_WEIGHT_COL, DataType::Float32, false);
224
225    let mut new_fields: Vec<arrow_schema::FieldRef> = old_schema.fields().iter().cloned().collect();
226    let mut new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
227        .map(|i| batch.column(i).clone())
228        .collect();
229
230    if let Some(pos) = old_schema
231        .fields()
232        .iter()
233        .position(|f| f.name() == RECENCY_WEIGHT_COL)
234    {
235        new_fields[pos] = Arc::new(decay_field);
236        new_columns[pos] = new_weight_array;
237    } else {
238        new_fields.push(Arc::new(decay_field));
239        new_columns.push(new_weight_array);
240    }
241
242    let new_schema = Arc::new(arrow_schema::Schema::new(new_fields));
243    RecordBatch::try_new(new_schema, new_columns).map_err(|e| AilakeError::Arrow(e.to_string()))
244}
245
246/// Parse first 10 chars of an ISO 8601 string as YYYY-MM-DD and return
247/// days since Unix epoch (1970-01-01). Returns None on parse failure.
248fn parse_iso_date_days(s: &str) -> Option<i64> {
249    if s.len() < 10 {
250        return None;
251    }
252    let y: i64 = s[0..4].parse().ok()?;
253    let m: i64 = s[5..7].parse().ok()?;
254    let d: i64 = s[8..10].parse().ok()?;
255    // Julian Day Number (Gregorian calendar formula)
256    let a = (14 - m) / 12;
257    let y2 = y + 4800 - a;
258    let m2 = m + 12 * a - 3;
259    let jdn = d + (153 * m2 + 2) / 5 + 365 * y2 + y2 / 4 - y2 / 100 + y2 / 400 - 32045;
260    // Unix epoch = JDN 2440588
261    Some(jdn - 2440588)
262}
263
264fn current_day_since_epoch() -> i64 {
265    SystemTime::now()
266        .duration_since(UNIX_EPOCH)
267        .map(|d| d.as_secs() as i64 / 86400)
268        .unwrap_or(0)
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn parse_iso_date_unix_epoch() {
277        assert_eq!(parse_iso_date_days("1970-01-01T00:00:00"), Some(0));
278    }
279
280    #[test]
281    fn parse_iso_date_known_date() {
282        // 2024-01-15 — verify against known day count
283        let days = parse_iso_date_days("2024-01-15").unwrap();
284        // 2024-01-15 is 19737 days after 1970-01-01
285        assert_eq!(days, 19737);
286    }
287
288    #[test]
289    fn parse_iso_date_returns_none_on_short_string() {
290        assert!(parse_iso_date_days("2024").is_none());
291        assert!(parse_iso_date_days("").is_none());
292    }
293
294    #[test]
295    fn apply_decay_updates_recency_weight() {
296        use arrow_array::StringArray;
297        use arrow_schema::{Field, Schema};
298
299        let today = current_day_since_epoch();
300        // 10 days ago
301        let past_day = today - 10;
302        let y = 1970 + past_day / 365; // rough
303                                       // Use a fixed known date instead
304        let past_str = "2024-01-05T00:00:00"; // 10 days before 2024-01-15
305
306        let schema = Arc::new(Schema::new(vec![
307            Field::new(LAST_ACCESSED_COL, DataType::Utf8, true),
308            Field::new(RECENCY_WEIGHT_COL, DataType::Float32, false),
309        ]));
310        let batch = RecordBatch::try_new(
311            schema,
312            vec![
313                Arc::new(StringArray::from(vec![past_str])),
314                Arc::new(Float32Array::from(vec![1.0f32])),
315            ],
316        )
317        .unwrap();
318
319        // Use today fixed to 2024-01-15 = day 19737
320        let today_day = 19737i64;
321        let result = apply_decay(&batch, today_day, 0.1).unwrap();
322        let weights = result
323            .column_by_name(RECENCY_WEIGHT_COL)
324            .unwrap()
325            .as_any()
326            .downcast_ref::<Float32Array>()
327            .unwrap();
328
329        let w = weights.value(0);
330        // 2024-01-05 = day 19727, so 10 days old: exp(-0.1 * 10) = exp(-1) ≈ 0.368
331        let expected = (-0.1f32 * 10.0).exp();
332        assert!((w - expected).abs() < 0.001, "expected {expected}, got {w}");
333        let _ = y; // suppress unused warning
334    }
335
336    #[test]
337    fn apply_decay_handles_timestamp_nanosecond() {
338        use arrow_schema::{Field, Schema, TimeUnit};
339
340        // 2024-01-05 00:00:00 UTC in nanoseconds = day 19727
341        // 2024-01-05 = 19727 days × 86400s × 1e9 ns
342        let day_19727_ns: i64 = 19727i64 * 86_400 * 1_000_000_000;
343
344        let schema = Arc::new(Schema::new(vec![
345            Field::new(
346                LAST_ACCESSED_COL,
347                DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
348                true,
349            ),
350            Field::new(RECENCY_WEIGHT_COL, DataType::Float32, false),
351        ]));
352        let batch = RecordBatch::try_new(
353            schema,
354            vec![
355                Arc::new(TimestampNanosecondArray::from(vec![day_19727_ns]).with_timezone("UTC")),
356                Arc::new(Float32Array::from(vec![1.0f32])),
357            ],
358        )
359        .unwrap();
360
361        // today = 2024-01-15 = day 19737 → 10 days old → exp(-0.1 * 10) ≈ 0.368
362        let today_day = 19737i64;
363        let result = apply_decay(&batch, today_day, 0.1).unwrap();
364        let weights = result
365            .column_by_name(RECENCY_WEIGHT_COL)
366            .unwrap()
367            .as_any()
368            .downcast_ref::<Float32Array>()
369            .unwrap();
370        let w = weights.value(0);
371        let expected = (-0.1f32 * 10.0).exp();
372        assert!((w - expected).abs() < 0.001, "expected {expected}, got {w}");
373    }
374
375    #[test]
376    fn now_ns_is_recent() {
377        // now_ns() must be > 2025-01-01 00:00:00 UTC in nanoseconds
378        let floor_2025_ns: i64 = 55 * 365 * 86_400 * 1_000_000_000i64; // ~2025
379        let t = ailake_core::now_ns();
380        assert!(
381            t > floor_2025_ns,
382            "now_ns() returned suspiciously small value: {t}"
383        );
384    }
385}