Skip to main content

hirn_engine/operators/
temporal.rs

1//! Temporal expansion operator.
2//!
3//! Given input batches containing a `created_at_ms` column, retrieves
4//! additional memories within a ± time window around each timestamp.
5
6use arrow_array::cast::AsArray;
7use arrow_array::{Array, RecordBatch};
8use async_trait::async_trait;
9use futures::TryStreamExt;
10
11use hirn_core::error::{HirnError, HirnResult};
12use hirn_storage::store::ScanOptions;
13
14use super::{OpContext, Operator};
15
16/// Operator that expands results with temporally adjacent memories.
17///
18/// For each unique `created_at_ms` value in the input, scans the given
19/// dataset for memories within `window_ms` milliseconds in either direction.
20pub struct TemporalExpand {
21    /// Dataset to scan for temporal neighbours.
22    pub dataset: String,
23    /// Time window in milliseconds (applied in both directions).
24    pub window_ms: i64,
25}
26
27#[async_trait]
28impl Operator for TemporalExpand {
29    async fn execute(
30        &self,
31        input: Vec<RecordBatch>,
32        ctx: &OpContext,
33    ) -> HirnResult<Vec<RecordBatch>> {
34        let timestamps = extract_timestamps(&input)?;
35        if timestamps.is_empty() {
36            return Ok(input);
37        }
38
39        // Compute the overall [min - window, max + window] range.
40        let min_ts = timestamps.iter().copied().min().unwrap_or(0);
41        let max_ts = timestamps.iter().copied().max().unwrap_or(0);
42        let lo = min_ts.saturating_sub(self.window_ms);
43        let hi = max_ts.saturating_add(self.window_ms);
44
45        let filter = format!("created_at_ms >= {lo} AND created_at_ms <= {hi}");
46        let mut expanded = ctx
47            .store
48            .scan_stream(
49                &self.dataset,
50                ScanOptions {
51                    filter: Some(filter),
52                    exact_filter: None,
53                    columns: None,
54                    order_by: None,
55                    limit: None,
56                    offset: None,
57                },
58            )
59            .await
60            .map_err(|e| HirnError::storage(e))?;
61
62        // Merge input + expanded (caller can deduplicate later).
63        let mut out = input;
64        while let Some(batch) = expanded
65            .try_next()
66            .await
67            .map_err(|e| HirnError::storage(e))?
68        {
69            out.push(batch);
70        }
71        Ok(out)
72    }
73}
74
75/// Extract `created_at_ms` values from all input batches.
76fn extract_timestamps(batches: &[RecordBatch]) -> HirnResult<Vec<i64>> {
77    let mut ts = Vec::new();
78    for batch in batches {
79        if let Some(col) = batch.column_by_name("created_at_ms") {
80            let arr = col.as_primitive::<arrow_array::types::Int64Type>();
81            for i in 0..arr.len() {
82                if !arr.is_null(i) {
83                    ts.push(arr.value(i));
84                }
85            }
86        }
87    }
88    Ok(ts)
89}