Skip to main content

osproxy_sink/
memory.rs

1//! An in-memory [`Sink`] (and [`Reader`]) for tests and dry-run routing.
2//!
3//! Records every batch it receives and acknowledges each operation as a
4//! success, without any network. It also keeps the indexed documents by
5//! `(index, id)` so it can serve get-by-id [`Reader`] requests, which lets the
6//! full write→read round-trip be exercised in memory (the real `OpenSearchSink`
7//! is covered by a testcontainer round-trip). Not for production: it persists
8//! nothing.
9//
10// JUSTIFY(file-length): one cohesive in-memory double implementing the full
11// `Sink` + `Reader` surface (write/get/search/search_stream/count) over a single
12// shared store, plus its focused unit tests; splitting the trait impls from the
13// store they share would scatter one small test fixture across files.
14
15use std::collections::HashMap;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::Mutex;
18
19use bytes::Bytes;
20
21use crate::ack::{OpResult, WriteAck};
22use crate::batch::{DocOp, WriteBatch};
23use crate::error::SinkError;
24use crate::opensearch::buffered;
25use crate::read::{
26    CountOutcome, ReadOp, ReadOutcome, Reader, SearchOp, SearchOutcome, StreamingSearch,
27};
28use crate::sink::Sink;
29
30/// A non-persistent [`Sink`]/[`Reader`] that records batches, stores indexed
31/// documents, and acknowledges success.
32#[derive(Debug, Default)]
33pub struct MemorySink {
34    recorded: Mutex<Vec<WriteBatch>>,
35    /// Indexed documents keyed by `(physical index, physical id)`.
36    docs: Mutex<HashMap<(String, String), Vec<u8>>>,
37    /// Search operations received, in arrival order (for test assertions on the
38    /// wrapped query the engine dispatched).
39    searches: Mutex<Vec<SearchOp>>,
40    auto_id: AtomicU64,
41}
42
43impl MemorySink {
44    /// Creates an empty recording sink.
45    #[must_use]
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// The batches recorded so far, in arrival order.
51    ///
52    /// Recovers a poisoned lock: the recording is inert data with no invariant
53    /// a panicking writer could tear, and a test asserting on it must not itself
54    /// panic on poisoning.
55    #[must_use]
56    pub fn recorded(&self) -> Vec<WriteBatch> {
57        self.recorded
58            .lock()
59            .unwrap_or_else(std::sync::PoisonError::into_inner)
60            .clone()
61    }
62
63    /// The search operations received so far, in arrival order. Recovers a
64    /// poisoned lock for the same reason as [`MemorySink::recorded`].
65    #[must_use]
66    pub fn recorded_searches(&self) -> Vec<SearchOp> {
67        self.searches
68            .lock()
69            .unwrap_or_else(std::sync::PoisonError::into_inner)
70            .clone()
71    }
72
73    /// Builds the success ack for a batch, assigning ids to auto-id operations.
74    fn ack_for(&self, batch: &WriteBatch) -> WriteAck {
75        let results = batch
76            .ops()
77            .iter()
78            .map(|op| match &op.doc {
79                DocOp::Index { id, .. } | DocOp::Create { id, .. } => {
80                    let id = id.clone().unwrap_or_else(|| self.next_auto_id());
81                    OpResult::new(id, 201, true)
82                }
83                DocOp::Update { id, .. } | DocOp::Delete { id, .. } => {
84                    OpResult::new(id.clone(), 200, false)
85                }
86            })
87            .collect();
88        WriteAck::new(results)
89    }
90
91    /// A deterministic id for an auto-id index op (`auto-1`, `auto-2`, …).
92    fn next_auto_id(&self) -> String {
93        let n = self.auto_id.fetch_add(1, Ordering::SeqCst) + 1;
94        format!("auto-{n}")
95    }
96
97    /// Applies a batch to the document store: index/create store, update merges,
98    /// delete removes. The ack supplies any auto-assigned id.
99    fn store(&self, batch: &WriteBatch, ack: &WriteAck) {
100        let mut docs = self
101            .docs
102            .lock()
103            .unwrap_or_else(std::sync::PoisonError::into_inner);
104        for (op, result) in batch.ops().iter().zip(ack.results()) {
105            let index = op.target.index.as_str().to_owned();
106            match &op.doc {
107                DocOp::Index { body, .. } | DocOp::Create { body, .. } => {
108                    docs.insert((index, result.id.clone()), body.to_vec());
109                }
110                DocOp::Update { id, body, .. } => {
111                    let key = (index, id.clone());
112                    let existing = docs
113                        .get(&key)
114                        .and_then(|b| serde_json::from_slice::<serde_json::Value>(b).ok());
115                    if let Some(bytes) =
116                        apply_update(existing, body).and_then(|m| serde_json::to_vec(&m).ok())
117                    {
118                        docs.insert(key, bytes);
119                    }
120                }
121                DocOp::Delete { id, .. } => {
122                    docs.remove(&(index, id.clone()));
123                }
124            }
125        }
126    }
127}
128
129impl Sink for MemorySink {
130    async fn write(&self, batch: WriteBatch) -> Result<WriteAck, SinkError> {
131        let ack = self.ack_for(&batch);
132        self.store(&batch, &ack);
133        self.recorded
134            .lock()
135            .unwrap_or_else(std::sync::PoisonError::into_inner)
136            .push(batch);
137        Ok(ack)
138    }
139}
140
141impl Reader for MemorySink {
142    async fn get(&self, op: ReadOp) -> Result<ReadOutcome, SinkError> {
143        let index = op.target.index.as_str().to_owned();
144        let doc = self
145            .docs
146            .lock()
147            .unwrap_or_else(std::sync::PoisonError::into_inner)
148            .get(&(index.clone(), op.id.clone()))
149            .cloned();
150        // Emulate the OpenSearch get-by-id envelope so the engine's response
151        // shaping is identical against the memory sink and a real cluster.
152        Ok(match doc {
153            Some(body) => ReadOutcome::found(200, envelope(&index, &op.id, &body, true)),
154            None => ReadOutcome::not_found(404, envelope(&index, &op.id, b"null", false)),
155        })
156    }
157
158    async fn search(&self, op: SearchOp) -> Result<SearchOutcome, SinkError> {
159        // A degenerate match-all: return every stored doc in the target index as
160        // a hit. It does NOT evaluate the DSL (real filtering/isolation is proven
161        // against a live cluster); it exists so the engine's query wrapping and
162        // hit-stripping can be exercised, and it records the wrapped query.
163        let index = op.target.index.as_str().to_owned();
164        let hits: Vec<serde_json::Value> = self
165            .docs
166            .lock()
167            .unwrap_or_else(std::sync::PoisonError::into_inner)
168            .iter()
169            .filter(|((idx, _), _)| idx == &index)
170            .map(|((idx, id), body)| {
171                let source: serde_json::Value =
172                    serde_json::from_slice(body).unwrap_or(serde_json::Value::Null);
173                serde_json::json!({ "_index": idx, "_id": id, "_source": source })
174            })
175            .collect();
176        self.searches
177            .lock()
178            .unwrap_or_else(std::sync::PoisonError::into_inner)
179            .push(op);
180        let body = serde_json::json!({
181            "hits": { "total": { "value": hits.len() }, "hits": hits },
182        });
183        Ok(SearchOutcome::new(
184            200,
185            serde_json::to_vec(&body).unwrap_or_else(|_| b"{}".to_vec()),
186        ))
187    }
188
189    async fn search_stream(&self, op: SearchOp) -> Result<StreamingSearch, SinkError> {
190        // Reuse the buffered match-all, then hand the bytes back as a (single-frame)
191        // stream so the engine's streaming hit-transform wiring can be exercised in
192        // memory (multi-frame resumability is covered by the scanner's fuzz test and
193        // the live RSS test).
194        let out = self.search(op).await?;
195        Ok(StreamingSearch {
196            status: out.status,
197            body: buffered(Bytes::from(out.body)),
198            pool_reuse: false,
199        })
200    }
201
202    async fn count(&self, op: SearchOp) -> Result<CountOutcome, SinkError> {
203        // Degenerate match-all count: every stored doc in the target index (the
204        // DSL is not evaluated; real filtering is proven against a live cluster).
205        let index = op.target.index.as_str().to_owned();
206        let count = self
207            .docs
208            .lock()
209            .unwrap_or_else(std::sync::PoisonError::into_inner)
210            .keys()
211            .filter(|(idx, _)| idx == &index)
212            .count();
213        self.searches
214            .lock()
215            .unwrap_or_else(std::sync::PoisonError::into_inner)
216            .push(op);
217        Ok(CountOutcome::new(200, count as u64))
218    }
219}
220
221/// Applies an `_update` body: a partial `doc` is shallow-merged into the
222/// existing source; when absent the `upsert` (or, with `doc_as_upsert`, the
223/// `doc`) becomes the new source. `None` if nothing to write; scripts no-op.
224fn apply_update(existing: Option<serde_json::Value>, body: &[u8]) -> Option<serde_json::Value> {
225    let patch: serde_json::Value = serde_json::from_slice(body).unwrap_or(serde_json::Value::Null);
226    let Some(mut source) = existing else {
227        let doc_as_upsert = patch
228            .get("doc_as_upsert")
229            .and_then(serde_json::Value::as_bool)
230            == Some(true);
231        return patch
232            .get("upsert")
233            .or_else(|| doc_as_upsert.then(|| patch.get("doc")).flatten())
234            .cloned();
235    };
236    if let (Some(target), Some(doc)) = (
237        source.as_object_mut(),
238        patch.get("doc").and_then(serde_json::Value::as_object),
239    ) {
240        for (k, v) in doc {
241            target.insert(k.clone(), v.clone());
242        }
243    }
244    Some(source)
245}
246
247/// Builds the OpenSearch get-by-id response envelope around a stored document.
248fn envelope(index: &str, id: &str, source: &[u8], found: bool) -> Vec<u8> {
249    let source: serde_json::Value =
250        serde_json::from_slice(source).unwrap_or(serde_json::Value::Null);
251    let doc = serde_json::json!({
252        "_index": index,
253        "_id": id,
254        "found": found,
255        "_source": source,
256    });
257    serde_json::to_vec(&doc).unwrap_or_else(|_| b"{\"found\":false}".to_vec())
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::batch::WriteOp;
264    use osproxy_core::{ClusterId, Epoch, IndexName, Target};
265
266    fn index_op(id: Option<&str>) -> WriteOp {
267        WriteOp::new(
268            Target::new(ClusterId::from("c"), IndexName::from("i")),
269            DocOp::Index {
270                id: id.map(str::to_owned),
271                routing: None,
272                body: bytes::Bytes::from_static(b"{}"),
273            },
274            Epoch::new(1),
275        )
276    }
277
278    #[tokio::test]
279    async fn auto_ids_are_deterministic_and_increment() {
280        let sink = MemorySink::new();
281        let ack = sink
282            .write(WriteBatch::new().with(index_op(None)).with(index_op(None)))
283            .await
284            .unwrap();
285        assert_eq!(ack.results()[0].id, "auto-1");
286        assert_eq!(ack.results()[1].id, "auto-2");
287    }
288
289    #[tokio::test]
290    async fn explicit_id_is_preserved() {
291        let sink = MemorySink::new();
292        let ack = sink
293            .write(WriteBatch::single(index_op(Some("p:7"))))
294            .await
295            .unwrap();
296        assert_eq!(ack.results()[0].id, "p:7");
297    }
298
299    fn target() -> Target {
300        Target::new(ClusterId::from("c"), IndexName::from("i"))
301    }
302
303    #[tokio::test]
304    async fn written_document_is_readable_by_id() {
305        let sink = MemorySink::new();
306        let op = WriteOp::new(
307            target(),
308            DocOp::Index {
309                id: Some("acme:7".to_owned()),
310                routing: Some("acme".to_owned()),
311                body: bytes::Bytes::from_static(br#"{"msg":"hi"}"#),
312            },
313            Epoch::new(1),
314        );
315        sink.write(WriteBatch::single(op)).await.unwrap();
316
317        let hit = sink
318            .get(ReadOp::new(target(), "acme:7", Some("acme".to_owned())))
319            .await
320            .unwrap();
321        assert!(hit.found);
322        // The body is the OpenSearch get-by-id envelope around the stored doc.
323        let doc: serde_json::Value = serde_json::from_slice(&hit.body).unwrap();
324        assert_eq!(doc["found"], true);
325        assert_eq!(doc["_id"], "acme:7");
326        assert_eq!(doc["_source"]["msg"], "hi");
327    }
328
329    #[tokio::test]
330    async fn missing_document_is_a_not_found_outcome() {
331        let sink = MemorySink::new();
332        let miss = sink
333            .get(ReadOp::new(target(), "absent", None))
334            .await
335            .unwrap();
336        assert!(!miss.found);
337        assert_eq!(miss.status, 404);
338    }
339
340    #[tokio::test]
341    async fn search_returns_stored_docs_and_records_the_query() {
342        let sink = MemorySink::new();
343        sink.write(WriteBatch::single(WriteOp::new(
344            target(),
345            DocOp::Index {
346                id: Some("acme:7".to_owned()),
347                routing: None,
348                body: bytes::Bytes::from_static(br#"{"_tenant":"acme","msg":"hi"}"#),
349            },
350            Epoch::new(1),
351        )))
352        .await
353        .unwrap();
354
355        let wrapped = br#"{"query":{"bool":{"filter":[{"term":{"_tenant":"acme"}}]}}}"#.to_vec();
356        let out = sink
357            .search(SearchOp::new(target(), wrapped.clone()))
358            .await
359            .unwrap();
360        assert_eq!(out.status, 200);
361        let doc: serde_json::Value = serde_json::from_slice(&out.body).unwrap();
362        assert_eq!(doc["hits"]["total"]["value"], 1);
363        assert_eq!(doc["hits"]["hits"][0]["_source"]["msg"], "hi");
364        // The wrapped query the engine dispatched was recorded for assertions.
365        assert_eq!(sink.recorded_searches().len(), 1);
366        assert_eq!(sink.recorded_searches()[0].body, wrapped);
367    }
368
369    #[tokio::test]
370    async fn count_returns_the_number_of_stored_docs() {
371        let sink = MemorySink::new();
372        for id in ["acme:1", "acme:2"] {
373            sink.write(WriteBatch::single(WriteOp::new(
374                target(),
375                DocOp::Index {
376                    id: Some(id.to_owned()),
377                    routing: None,
378                    body: bytes::Bytes::from_static(b"{}"),
379                },
380                Epoch::new(1),
381            )))
382            .await
383            .unwrap();
384        }
385        let out = sink
386            .count(SearchOp::new(target(), b"{}".to_vec()))
387            .await
388            .unwrap();
389        assert_eq!(out.status, 200);
390        assert_eq!(out.count, 2);
391    }
392
393    #[tokio::test]
394    async fn delete_removes_a_stored_document() {
395        let sink = MemorySink::new();
396        sink.write(WriteBatch::single(WriteOp::new(
397            target(),
398            DocOp::Index {
399                id: Some("acme:7".to_owned()),
400                routing: None,
401                body: bytes::Bytes::from_static(b"{}"),
402            },
403            Epoch::new(1),
404        )))
405        .await
406        .unwrap();
407        sink.write(WriteBatch::single(WriteOp::new(
408            target(),
409            DocOp::Delete {
410                id: "acme:7".to_owned(),
411                routing: None,
412            },
413            Epoch::new(1),
414        )))
415        .await
416        .unwrap();
417        let miss = sink
418            .get(ReadOp::new(target(), "acme:7", None))
419            .await
420            .unwrap();
421        assert!(!miss.found);
422    }
423}