1use std::collections::HashMap;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::Mutex;
18
19use bytes::Bytes;
20
21use crate::ack::{OpResult, WriteAck};
22use crate::batch::{DocOp, WriteBatch};
23use crate::error::SinkError;
24use crate::opensearch::buffered;
25use crate::read::{
26 CountOutcome, ReadOp, ReadOutcome, Reader, SearchOp, SearchOutcome, StreamingSearch,
27};
28use crate::sink::Sink;
29
30#[derive(Debug, Default)]
33pub struct MemorySink {
34 recorded: Mutex<Vec<WriteBatch>>,
35 docs: Mutex<HashMap<(String, String), Vec<u8>>>,
37 searches: Mutex<Vec<SearchOp>>,
40 auto_id: AtomicU64,
41}
42
43impl MemorySink {
44 #[must_use]
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 #[must_use]
56 pub fn recorded(&self) -> Vec<WriteBatch> {
57 self.recorded
58 .lock()
59 .unwrap_or_else(std::sync::PoisonError::into_inner)
60 .clone()
61 }
62
63 #[must_use]
66 pub fn recorded_searches(&self) -> Vec<SearchOp> {
67 self.searches
68 .lock()
69 .unwrap_or_else(std::sync::PoisonError::into_inner)
70 .clone()
71 }
72
73 fn ack_for(&self, batch: &WriteBatch) -> WriteAck {
75 let results = batch
76 .ops()
77 .iter()
78 .map(|op| match &op.doc {
79 DocOp::Index { id, .. } | DocOp::Create { id, .. } => {
80 let id = id.clone().unwrap_or_else(|| self.next_auto_id());
81 OpResult::new(id, 201, true)
82 }
83 DocOp::Update { id, .. } | DocOp::Delete { id, .. } => {
84 OpResult::new(id.clone(), 200, false)
85 }
86 })
87 .collect();
88 WriteAck::new(results)
89 }
90
91 fn next_auto_id(&self) -> String {
93 let n = self.auto_id.fetch_add(1, Ordering::SeqCst) + 1;
94 format!("auto-{n}")
95 }
96
97 fn store(&self, batch: &WriteBatch, ack: &WriteAck) {
100 let mut docs = self
101 .docs
102 .lock()
103 .unwrap_or_else(std::sync::PoisonError::into_inner);
104 for (op, result) in batch.ops().iter().zip(ack.results()) {
105 let index = op.target.index.as_str().to_owned();
106 match &op.doc {
107 DocOp::Index { body, .. } | DocOp::Create { body, .. } => {
108 docs.insert((index, result.id.clone()), body.to_vec());
109 }
110 DocOp::Update { id, body, .. } => {
111 let key = (index, id.clone());
112 let existing = docs
113 .get(&key)
114 .and_then(|b| serde_json::from_slice::<serde_json::Value>(b).ok());
115 if let Some(bytes) =
116 apply_update(existing, body).and_then(|m| serde_json::to_vec(&m).ok())
117 {
118 docs.insert(key, bytes);
119 }
120 }
121 DocOp::Delete { id, .. } => {
122 docs.remove(&(index, id.clone()));
123 }
124 }
125 }
126 }
127}
128
129impl Sink for MemorySink {
130 async fn write(&self, batch: WriteBatch) -> Result<WriteAck, SinkError> {
131 let ack = self.ack_for(&batch);
132 self.store(&batch, &ack);
133 self.recorded
134 .lock()
135 .unwrap_or_else(std::sync::PoisonError::into_inner)
136 .push(batch);
137 Ok(ack)
138 }
139}
140
141impl Reader for MemorySink {
142 async fn get(&self, op: ReadOp) -> Result<ReadOutcome, SinkError> {
143 let index = op.target.index.as_str().to_owned();
144 let doc = self
145 .docs
146 .lock()
147 .unwrap_or_else(std::sync::PoisonError::into_inner)
148 .get(&(index.clone(), op.id.clone()))
149 .cloned();
150 Ok(match doc {
153 Some(body) => ReadOutcome::found(200, envelope(&index, &op.id, &body, true)),
154 None => ReadOutcome::not_found(404, envelope(&index, &op.id, b"null", false)),
155 })
156 }
157
158 async fn search(&self, op: SearchOp) -> Result<SearchOutcome, SinkError> {
159 let index = op.target.index.as_str().to_owned();
164 let hits: Vec<serde_json::Value> = self
165 .docs
166 .lock()
167 .unwrap_or_else(std::sync::PoisonError::into_inner)
168 .iter()
169 .filter(|((idx, _), _)| idx == &index)
170 .map(|((idx, id), body)| {
171 let source: serde_json::Value =
172 serde_json::from_slice(body).unwrap_or(serde_json::Value::Null);
173 serde_json::json!({ "_index": idx, "_id": id, "_source": source })
174 })
175 .collect();
176 self.searches
177 .lock()
178 .unwrap_or_else(std::sync::PoisonError::into_inner)
179 .push(op);
180 let body = serde_json::json!({
181 "hits": { "total": { "value": hits.len() }, "hits": hits },
182 });
183 Ok(SearchOutcome::new(
184 200,
185 serde_json::to_vec(&body).unwrap_or_else(|_| b"{}".to_vec()),
186 ))
187 }
188
189 async fn search_stream(&self, op: SearchOp) -> Result<StreamingSearch, SinkError> {
190 let out = self.search(op).await?;
195 Ok(StreamingSearch {
196 status: out.status,
197 body: buffered(Bytes::from(out.body)),
198 pool_reuse: false,
199 })
200 }
201
202 async fn count(&self, op: SearchOp) -> Result<CountOutcome, SinkError> {
203 let index = op.target.index.as_str().to_owned();
206 let count = self
207 .docs
208 .lock()
209 .unwrap_or_else(std::sync::PoisonError::into_inner)
210 .keys()
211 .filter(|(idx, _)| idx == &index)
212 .count();
213 self.searches
214 .lock()
215 .unwrap_or_else(std::sync::PoisonError::into_inner)
216 .push(op);
217 Ok(CountOutcome::new(200, count as u64))
218 }
219}
220
221fn apply_update(existing: Option<serde_json::Value>, body: &[u8]) -> Option<serde_json::Value> {
225 let patch: serde_json::Value = serde_json::from_slice(body).unwrap_or(serde_json::Value::Null);
226 let Some(mut source) = existing else {
227 let doc_as_upsert = patch
228 .get("doc_as_upsert")
229 .and_then(serde_json::Value::as_bool)
230 == Some(true);
231 return patch
232 .get("upsert")
233 .or_else(|| doc_as_upsert.then(|| patch.get("doc")).flatten())
234 .cloned();
235 };
236 if let (Some(target), Some(doc)) = (
237 source.as_object_mut(),
238 patch.get("doc").and_then(serde_json::Value::as_object),
239 ) {
240 for (k, v) in doc {
241 target.insert(k.clone(), v.clone());
242 }
243 }
244 Some(source)
245}
246
247fn envelope(index: &str, id: &str, source: &[u8], found: bool) -> Vec<u8> {
249 let source: serde_json::Value =
250 serde_json::from_slice(source).unwrap_or(serde_json::Value::Null);
251 let doc = serde_json::json!({
252 "_index": index,
253 "_id": id,
254 "found": found,
255 "_source": source,
256 });
257 serde_json::to_vec(&doc).unwrap_or_else(|_| b"{\"found\":false}".to_vec())
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::batch::WriteOp;
264 use osproxy_core::{ClusterId, Epoch, IndexName, Target};
265
266 fn index_op(id: Option<&str>) -> WriteOp {
267 WriteOp::new(
268 Target::new(ClusterId::from("c"), IndexName::from("i")),
269 DocOp::Index {
270 id: id.map(str::to_owned),
271 routing: None,
272 body: bytes::Bytes::from_static(b"{}"),
273 },
274 Epoch::new(1),
275 )
276 }
277
278 #[tokio::test]
279 async fn auto_ids_are_deterministic_and_increment() {
280 let sink = MemorySink::new();
281 let ack = sink
282 .write(WriteBatch::new().with(index_op(None)).with(index_op(None)))
283 .await
284 .unwrap();
285 assert_eq!(ack.results()[0].id, "auto-1");
286 assert_eq!(ack.results()[1].id, "auto-2");
287 }
288
289 #[tokio::test]
290 async fn explicit_id_is_preserved() {
291 let sink = MemorySink::new();
292 let ack = sink
293 .write(WriteBatch::single(index_op(Some("p:7"))))
294 .await
295 .unwrap();
296 assert_eq!(ack.results()[0].id, "p:7");
297 }
298
299 fn target() -> Target {
300 Target::new(ClusterId::from("c"), IndexName::from("i"))
301 }
302
303 #[tokio::test]
304 async fn written_document_is_readable_by_id() {
305 let sink = MemorySink::new();
306 let op = WriteOp::new(
307 target(),
308 DocOp::Index {
309 id: Some("acme:7".to_owned()),
310 routing: Some("acme".to_owned()),
311 body: bytes::Bytes::from_static(br#"{"msg":"hi"}"#),
312 },
313 Epoch::new(1),
314 );
315 sink.write(WriteBatch::single(op)).await.unwrap();
316
317 let hit = sink
318 .get(ReadOp::new(target(), "acme:7", Some("acme".to_owned())))
319 .await
320 .unwrap();
321 assert!(hit.found);
322 let doc: serde_json::Value = serde_json::from_slice(&hit.body).unwrap();
324 assert_eq!(doc["found"], true);
325 assert_eq!(doc["_id"], "acme:7");
326 assert_eq!(doc["_source"]["msg"], "hi");
327 }
328
329 #[tokio::test]
330 async fn missing_document_is_a_not_found_outcome() {
331 let sink = MemorySink::new();
332 let miss = sink
333 .get(ReadOp::new(target(), "absent", None))
334 .await
335 .unwrap();
336 assert!(!miss.found);
337 assert_eq!(miss.status, 404);
338 }
339
340 #[tokio::test]
341 async fn search_returns_stored_docs_and_records_the_query() {
342 let sink = MemorySink::new();
343 sink.write(WriteBatch::single(WriteOp::new(
344 target(),
345 DocOp::Index {
346 id: Some("acme:7".to_owned()),
347 routing: None,
348 body: bytes::Bytes::from_static(br#"{"_tenant":"acme","msg":"hi"}"#),
349 },
350 Epoch::new(1),
351 )))
352 .await
353 .unwrap();
354
355 let wrapped = br#"{"query":{"bool":{"filter":[{"term":{"_tenant":"acme"}}]}}}"#.to_vec();
356 let out = sink
357 .search(SearchOp::new(target(), wrapped.clone()))
358 .await
359 .unwrap();
360 assert_eq!(out.status, 200);
361 let doc: serde_json::Value = serde_json::from_slice(&out.body).unwrap();
362 assert_eq!(doc["hits"]["total"]["value"], 1);
363 assert_eq!(doc["hits"]["hits"][0]["_source"]["msg"], "hi");
364 assert_eq!(sink.recorded_searches().len(), 1);
366 assert_eq!(sink.recorded_searches()[0].body, wrapped);
367 }
368
369 #[tokio::test]
370 async fn count_returns_the_number_of_stored_docs() {
371 let sink = MemorySink::new();
372 for id in ["acme:1", "acme:2"] {
373 sink.write(WriteBatch::single(WriteOp::new(
374 target(),
375 DocOp::Index {
376 id: Some(id.to_owned()),
377 routing: None,
378 body: bytes::Bytes::from_static(b"{}"),
379 },
380 Epoch::new(1),
381 )))
382 .await
383 .unwrap();
384 }
385 let out = sink
386 .count(SearchOp::new(target(), b"{}".to_vec()))
387 .await
388 .unwrap();
389 assert_eq!(out.status, 200);
390 assert_eq!(out.count, 2);
391 }
392
393 #[tokio::test]
394 async fn delete_removes_a_stored_document() {
395 let sink = MemorySink::new();
396 sink.write(WriteBatch::single(WriteOp::new(
397 target(),
398 DocOp::Index {
399 id: Some("acme:7".to_owned()),
400 routing: None,
401 body: bytes::Bytes::from_static(b"{}"),
402 },
403 Epoch::new(1),
404 )))
405 .await
406 .unwrap();
407 sink.write(WriteBatch::single(WriteOp::new(
408 target(),
409 DocOp::Delete {
410 id: "acme:7".to_owned(),
411 routing: None,
412 },
413 Epoch::new(1),
414 )))
415 .await
416 .unwrap();
417 let miss = sink
418 .get(ReadOp::new(target(), "acme:7", None))
419 .await
420 .unwrap();
421 assert!(!miss.found);
422 }
423}