use std::sync::Arc;
use futures_util::StreamExt;
use sha2::{Digest, Sha256};
use taquba::object_store::{Error as ObjectStoreError, ObjectStore, path::Path};
use crate::error::{Error, Result};
#[derive(Clone)]
pub struct MemoStore {
store: Arc<dyn ObjectStore>,
prefix: String,
}
impl std::fmt::Debug for MemoStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoStore")
.field("prefix", &self.prefix)
.finish_non_exhaustive()
}
}
impl MemoStore {
pub fn new(store: Arc<dyn ObjectStore>, prefix: impl Into<String>) -> Self {
Self {
store,
prefix: prefix.into(),
}
}
async fn get(&self, run_id: &str, step_number: u32, key: &str) -> Result<Option<Vec<u8>>> {
let path = self.memo_path(run_id, step_number, key);
match self.store.get(&path).await {
Ok(result) => {
let bytes = result.bytes().await?;
Ok(Some(bytes.to_vec()))
}
Err(ObjectStoreError::NotFound { .. }) => Ok(None),
Err(err) => Err(Error::Store(err)),
}
}
async fn put(&self, run_id: &str, step_number: u32, key: &str, value: &[u8]) -> Result<()> {
let path = self.memo_path(run_id, step_number, key);
self.store.put(&path, value.to_vec().into()).await?;
Ok(())
}
pub fn new_memo(&self, run_id: impl Into<String>, step_number: u32) -> Memo {
Memo::new(self.clone(), run_id, step_number)
}
pub async fn clear_memos_for_run(&self, run_id: &str) -> Result<usize> {
let prefix = self.memos_run_prefix(run_id);
let mut stream = self.store.list(Some(&prefix));
let mut deleted = 0usize;
while let Some(item) = stream.next().await {
let meta = item.map_err(Error::Store)?;
match self.store.delete(&meta.location).await {
Ok(()) => deleted += 1,
Err(ObjectStoreError::NotFound { .. }) => {}
Err(err) => {
tracing::warn!(
run_id = %run_id,
path = %meta.location,
error = %err,
"failed to delete memo entry",
);
}
}
}
Ok(deleted)
}
pub async fn write_terminal_marker(&self, run_id: &str, terminal_at_ms: u64) -> Result<()> {
let path = self.terminal_marker_path(run_id, terminal_at_ms);
self.store.put(&path, Vec::new().into()).await?;
Ok(())
}
pub async fn list_terminal_markers(&self) -> Result<Vec<TerminalMarker>> {
let prefix = self.terminals_prefix();
let mut stream = self.store.list(Some(&prefix));
let mut out = Vec::new();
while let Some(item) = stream.next().await {
let meta = item.map_err(Error::Store)?;
let Some(name) = meta.location.filename() else {
continue;
};
match parse_terminal_marker_name(name) {
Some((terminal_at_ms, run_id)) => out.push(TerminalMarker {
run_id,
terminal_at_ms,
}),
None => {
tracing::warn!(
path = %meta.location,
"unparseable terminal marker; skipping",
);
}
}
}
Ok(out)
}
pub async fn delete_terminal_marker(&self, marker: &TerminalMarker) -> Result<()> {
let path = self.terminal_marker_path(&marker.run_id, marker.terminal_at_ms);
match self.store.delete(&path).await {
Ok(()) | Err(ObjectStoreError::NotFound { .. }) => Ok(()),
Err(err) => Err(Error::Store(err)),
}
}
fn memo_path(&self, run_id: &str, step_number: u32, key: &str) -> Path {
let key_hash = hex_sha256(key.as_bytes());
Path::from(format!(
"{}/memos/{}/{}/{}",
self.prefix, run_id, step_number, key_hash
))
}
fn memos_run_prefix(&self, run_id: &str) -> Path {
Path::from(format!("{}/memos/{}", self.prefix, run_id))
}
fn terminal_marker_path(&self, run_id: &str, terminal_at_ms: u64) -> Path {
Path::from(format!(
"{}/terminals/{:020}_{}",
self.prefix, terminal_at_ms, run_id
))
}
fn terminals_prefix(&self) -> Path {
Path::from(format!("{}/terminals", self.prefix))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TerminalMarker {
pub run_id: String,
pub terminal_at_ms: u64,
}
#[derive(Clone)]
pub struct Memo {
store: MemoStore,
run_id: String,
step_number: u32,
}
impl Memo {
fn new(store: MemoStore, run_id: impl Into<String>, step_number: u32) -> Self {
Self {
store,
run_id: run_id.into(),
step_number,
}
}
pub fn run_id(&self) -> &str {
&self.run_id
}
pub fn step_number(&self) -> u32 {
self.step_number
}
pub async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
self.store.get(&self.run_id, self.step_number, key).await
}
pub async fn put(&self, key: &str, value: &[u8]) -> Result<()> {
self.store
.put(&self.run_id, self.step_number, key, value)
.await
}
}
impl std::fmt::Debug for Memo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Memo")
.field("run_id", &self.run_id)
.field("step_number", &self.step_number)
.finish_non_exhaustive()
}
}
fn hex_sha256(bytes: &[u8]) -> String {
use std::fmt::Write;
let mut hasher = Sha256::new();
hasher.update(bytes);
let digest = hasher.finalize();
let mut hex = String::with_capacity(64);
for byte in digest {
let _ = write!(&mut hex, "{byte:02x}");
}
hex
}
fn parse_terminal_marker_name(name: &str) -> Option<(u64, String)> {
let (ts_str, rest) = name.split_at_checked(20)?;
let ts: u64 = ts_str.parse().ok()?;
let run_id = rest.strip_prefix('_')?;
Some((ts, run_id.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use taquba::object_store::memory::InMemory;
fn make_memo() -> Memo {
MemoStore::new(Arc::new(InMemory::new()), "memo").new_memo("run-1", 0)
}
#[tokio::test]
async fn get_returns_none_for_missing_key() {
let memo = make_memo();
assert_eq!(memo.get("missing").await.unwrap(), None);
}
#[tokio::test]
async fn put_then_get_round_trips() {
let memo = make_memo();
memo.put("k", b"hello").await.unwrap();
assert_eq!(memo.get("k").await.unwrap(), Some(b"hello".to_vec()));
}
#[tokio::test]
async fn put_overwrites_prior_value() {
let memo = make_memo();
memo.put("k", b"first").await.unwrap();
memo.put("k", b"second").await.unwrap();
assert_eq!(memo.get("k").await.unwrap(), Some(b"second".to_vec()));
}
#[tokio::test]
async fn run_id_namespaces_are_isolated() {
let store = MemoStore::new(Arc::new(InMemory::new()), "memo");
let in_run_a = store.new_memo("run-a", 0);
let in_run_b = store.new_memo("run-b", 0);
in_run_a.put("k", b"value-a").await.unwrap();
in_run_b.put("k", b"value-b").await.unwrap();
assert_eq!(in_run_a.get("k").await.unwrap(), Some(b"value-a".to_vec()));
assert_eq!(in_run_b.get("k").await.unwrap(), Some(b"value-b".to_vec()));
}
#[tokio::test]
async fn step_number_namespaces_are_isolated() {
let store = MemoStore::new(Arc::new(InMemory::new()), "memo");
let at_step_0 = store.new_memo("run-1", 0);
let at_step_1 = store.new_memo("run-1", 1);
at_step_0.put("k", b"step-0").await.unwrap();
at_step_1.put("k", b"step-1").await.unwrap();
assert_eq!(at_step_0.get("k").await.unwrap(), Some(b"step-0".to_vec()));
assert_eq!(at_step_1.get("k").await.unwrap(), Some(b"step-1".to_vec()));
}
#[tokio::test]
async fn distinct_user_keys_map_to_distinct_entries() {
let memo = make_memo();
memo.put("k1", b"one").await.unwrap();
memo.put("k2", b"two").await.unwrap();
assert_eq!(memo.get("k1").await.unwrap(), Some(b"one".to_vec()));
assert_eq!(memo.get("k2").await.unwrap(), Some(b"two".to_vec()));
}
#[tokio::test]
async fn awkward_user_keys_round_trip() {
let memo = make_memo();
let keys = [
"",
"with/slash",
"with spaces",
"üñÃçødé",
&"a".repeat(10_000),
];
for (i, key) in keys.iter().enumerate() {
let expected = format!("v{i}").into_bytes();
memo.put(key, &expected).await.unwrap();
assert_eq!(memo.get(key).await.unwrap(), Some(expected));
}
}
#[tokio::test]
async fn empty_value_round_trips() {
let memo = make_memo();
memo.put("k", b"").await.unwrap();
assert_eq!(memo.get("k").await.unwrap(), Some(Vec::new()));
}
#[tokio::test]
async fn instances_sharing_a_backing_store_see_the_same_entries() {
let backing: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
let writer = MemoStore::new(backing.clone(), "memo").new_memo("run-1", 0);
let reader = MemoStore::new(backing, "memo").new_memo("run-1", 0);
writer.put("k", b"shared").await.unwrap();
assert_eq!(reader.get("k").await.unwrap(), Some(b"shared".to_vec()));
}
#[tokio::test]
async fn clear_memos_for_run_removes_only_that_runs_entries() {
let backing: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
let store = MemoStore::new(backing, "memo");
let in_run_a = store.new_memo("run-a", 0);
let in_run_a_step1 = store.new_memo("run-a", 1);
let in_run_b = store.new_memo("run-b", 0);
in_run_a.put("k", b"a-0").await.unwrap();
in_run_a_step1.put("k", b"a-1").await.unwrap();
in_run_b.put("k", b"b-0").await.unwrap();
let deleted = store.clear_memos_for_run("run-a").await.unwrap();
assert_eq!(deleted, 2);
assert_eq!(in_run_a.get("k").await.unwrap(), None);
assert_eq!(in_run_a_step1.get("k").await.unwrap(), None);
assert_eq!(in_run_b.get("k").await.unwrap(), Some(b"b-0".to_vec()));
}
#[tokio::test]
async fn clear_memos_for_run_returns_zero_when_nothing_to_delete() {
let store = MemoStore::new(Arc::new(InMemory::new()), "memo");
let deleted = store
.clear_memos_for_run("run-with-no-memos")
.await
.unwrap();
assert_eq!(deleted, 0);
}
#[tokio::test]
async fn clear_memos_for_run_does_not_match_run_id_as_prefix() {
let store = MemoStore::new(Arc::new(InMemory::new()), "memo");
store.new_memo("run", 0).put("k", b"short").await.unwrap();
store
.new_memo("run-suffix", 0)
.put("k", b"long")
.await
.unwrap();
let deleted = store.clear_memos_for_run("run").await.unwrap();
assert_eq!(deleted, 1);
assert_eq!(
store.new_memo("run-suffix", 0).get("k").await.unwrap(),
Some(b"long".to_vec()),
);
}
#[tokio::test]
async fn write_terminal_marker_then_list_returns_it() {
let store = MemoStore::new(Arc::new(InMemory::new()), "memo");
store
.write_terminal_marker("run-1", 1_700_000_000_000)
.await
.unwrap();
let terminals = store.list_terminal_markers().await.unwrap();
assert_eq!(terminals.len(), 1);
assert_eq!(terminals[0].run_id, "run-1");
assert_eq!(terminals[0].terminal_at_ms, 1_700_000_000_000);
}
#[tokio::test]
async fn list_terminal_markers_is_empty_when_none_written() {
let store = MemoStore::new(Arc::new(InMemory::new()), "memo");
let terminals = store.list_terminal_markers().await.unwrap();
assert!(terminals.is_empty());
}
#[tokio::test]
async fn list_terminal_markers_returns_all() {
let store = MemoStore::new(Arc::new(InMemory::new()), "memo");
store.write_terminal_marker("run-a", 1_000).await.unwrap();
store.write_terminal_marker("run-b", 2_000).await.unwrap();
store.write_terminal_marker("run-c", 3_000).await.unwrap();
let mut terminals = store.list_terminal_markers().await.unwrap();
terminals.sort_by_key(|t| t.terminal_at_ms);
assert_eq!(
terminals,
vec![
TerminalMarker {
run_id: "run-a".into(),
terminal_at_ms: 1_000
},
TerminalMarker {
run_id: "run-b".into(),
terminal_at_ms: 2_000
},
TerminalMarker {
run_id: "run-c".into(),
terminal_at_ms: 3_000
},
],
);
}
#[tokio::test]
async fn delete_terminal_marker_removes_only_the_named_one() {
let store = MemoStore::new(Arc::new(InMemory::new()), "memo");
store.write_terminal_marker("run-a", 1_000).await.unwrap();
store.write_terminal_marker("run-b", 2_000).await.unwrap();
store
.delete_terminal_marker(&TerminalMarker {
run_id: "run-a".into(),
terminal_at_ms: 1_000,
})
.await
.unwrap();
let terminals = store.list_terminal_markers().await.unwrap();
assert_eq!(terminals.len(), 1);
assert_eq!(terminals[0].run_id, "run-b");
}
#[tokio::test]
async fn delete_terminal_marker_succeeds_on_missing() {
let store = MemoStore::new(Arc::new(InMemory::new()), "memo");
store
.delete_terminal_marker(&TerminalMarker {
run_id: "nope".into(),
terminal_at_ms: 1_000,
})
.await
.unwrap();
}
#[tokio::test]
async fn delete_terminal_marker_is_idempotent() {
let store = MemoStore::new(Arc::new(InMemory::new()), "memo");
let marker = TerminalMarker {
run_id: "run-1".into(),
terminal_at_ms: 1_000,
};
store
.write_terminal_marker(&marker.run_id, marker.terminal_at_ms)
.await
.unwrap();
store.delete_terminal_marker(&marker).await.unwrap();
store.delete_terminal_marker(&marker).await.unwrap();
}
#[tokio::test]
async fn terminal_markers_and_memos_do_not_collide() {
let store = MemoStore::new(Arc::new(InMemory::new()), "memo");
store.new_memo("run-1", 0).put("k", b"v").await.unwrap();
store.write_terminal_marker("run-1", 1_000).await.unwrap();
assert_eq!(
store.new_memo("run-1", 0).get("k").await.unwrap(),
Some(b"v".to_vec()),
);
let terminals = store.list_terminal_markers().await.unwrap();
assert_eq!(terminals.len(), 1);
store.clear_memos_for_run("run-1").await.unwrap();
let terminals = store.list_terminal_markers().await.unwrap();
assert_eq!(terminals.len(), 1);
}
}