1use juncture_core::checkpoint::{
6 Checkpoint, CheckpointError as CoreCheckpointError, CheckpointFilter, CheckpointMetadata,
7 CheckpointSaver, CheckpointTuple, PendingWrite,
8};
9use juncture_core::config::RunnableConfig;
10use juncture_core::info_span;
11#[cfg(target_family = "wasm")]
12use juncture_core::tracing_wasm::WasmInstrument;
13use juncture_tracing::spans::names;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17#[cfg(not(target_family = "wasm"))]
18use tracing::Instrument;
19
20use crate::error::CheckpointError;
21
22#[allow(dead_code, reason = "conversion trait used internally")]
24trait ToCoreCheckpointError<T> {
25 fn map_checkpoint(self) -> Result<T, CoreCheckpointError>;
26}
27
28impl<T> ToCoreCheckpointError<T> for Result<T, CheckpointError> {
29 fn map_checkpoint(self) -> Result<T, CoreCheckpointError> {
30 self.map_err(|e| match e {
31 CheckpointError::Serialize(msg) | CheckpointError::Serialization(msg) => {
32 CoreCheckpointError::Serialize(msg)
33 }
34 CheckpointError::Deserialize(msg) => CoreCheckpointError::Deserialize(msg),
35 CheckpointError::NotFound {
36 thread_id,
37 checkpoint_id,
38 } => CoreCheckpointError::NotFound {
39 thread_id,
40 checkpoint_id,
41 },
42 CheckpointError::Storage(msg) | CheckpointError::Database(msg) => {
43 CoreCheckpointError::Storage(msg)
44 }
45 CheckpointError::SchemaMigration { from, to, reason } => {
46 CoreCheckpointError::Other(format!("Schema migration: {from} -> {to}: {reason}"))
47 }
48 CheckpointError::PoolExhausted => {
49 CoreCheckpointError::Storage("Connection pool exhausted".into())
50 }
51 })
52 }
53}
54
55type StorageMap = HashMap<String, HashMap<String, Vec<CheckpointTuple>>>;
57
58type WritesMap = HashMap<(String, String, String), Vec<PendingWrite>>;
60
61#[derive(Clone, Debug)]
66pub struct MemorySaver {
67 storage: Arc<RwLock<StorageMap>>,
69
70 writes: Arc<RwLock<WritesMap>>,
72
73 ttl_config: Arc<std::sync::RwLock<crate::types::TtlConfig>>,
75}
76
77impl MemorySaver {
78 #[must_use]
80 pub fn new() -> Self {
81 Self {
82 storage: Arc::new(RwLock::new(HashMap::new())),
83 writes: Arc::new(RwLock::new(HashMap::new())),
84 ttl_config: Arc::new(std::sync::RwLock::new(crate::types::TtlConfig::default())),
85 }
86 }
87
88 #[must_use]
94 pub fn with_ttl_config(mut self, ttl_config: crate::types::TtlConfig) -> Self {
95 self.ttl_config = Arc::new(std::sync::RwLock::new(ttl_config));
96 self
97 }
98
99 #[must_use]
108 pub fn ttl_config(&self) -> crate::types::TtlConfig {
109 self.ttl_config.read().unwrap().clone()
110 }
111
112 pub fn set_ttl_config(&self, ttl_config: crate::types::TtlConfig) {
123 *self.ttl_config.write().unwrap() = ttl_config;
124 }
125
126 #[allow(
132 clippy::significant_drop_tightening,
133 reason = "lock scope is already optimized for minimal contention"
134 )]
135 async fn lazy_cleanup(
136 &self,
137 thread_id: &str,
138 checkpoint_ns: &str,
139 ) -> Result<(), CheckpointError> {
140 let ttl_config = self.ttl_config.read().unwrap().clone();
141
142 let (checkpoint_ids, expired_count) = {
144 let mut storage = self.storage.write().await;
145
146 let thread_map = storage
147 .entry(thread_id.to_string())
148 .or_insert_with(HashMap::new);
149 let checkpoints = thread_map
150 .entry(checkpoint_ns.to_string())
151 .or_insert_with(Vec::new);
152
153 let original_len = checkpoints.len();
155 checkpoints.retain(|tuple| !ttl_config.is_expired(&tuple.checkpoint.created_at));
156 let expired_count = original_len - checkpoints.len();
157
158 let Some(max) = ttl_config.max_checkpoints else {
160 return Ok(());
161 };
162
163 if checkpoints.len() > max {
164 let excess = checkpoints.len() - max;
165 checkpoints.truncate(max);
166 tracing::debug!("Deleted {excess} oldest checkpoints (max_checkpoints={max})");
167 }
168
169 let checkpoint_ids: std::collections::HashSet<String> = checkpoints
171 .iter()
172 .map(|t| t.checkpoint.id.clone())
173 .collect();
174
175 (checkpoint_ids, expired_count)
176 };
177
178 if expired_count > 0 {
180 let mut writes = self.writes.write().await;
181
182 writes.retain(|(thread, ns, id), _| {
184 thread == thread_id && ns == checkpoint_ns && checkpoint_ids.contains(id)
185 });
186 }
187
188 Ok(())
189 }
190
191 #[must_use]
193 fn get_checkpoint_ns(config: &RunnableConfig) -> String {
194 config
195 .checkpoint_ns
196 .as_ref()
197 .map_or_else(String::new, juncture_core::CheckpointNamespace::as_str)
198 }
199
200 fn get_thread_id(config: &RunnableConfig) -> Result<String, CheckpointError> {
202 config
203 .thread_id
204 .clone()
205 .ok_or_else(|| CheckpointError::Storage("thread_id is required".into()))
206 }
207
208 fn sort_checkpoints(checkpoints: &mut [CheckpointTuple]) {
210 checkpoints.sort_by(|a, b| {
211 b.checkpoint
212 .created_at
213 .cmp(&a.checkpoint.created_at)
214 .then_with(|| b.checkpoint.id.cmp(&a.checkpoint.id))
215 });
216 }
217}
218
219impl Default for MemorySaver {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225#[async_trait::async_trait]
226impl CheckpointSaver for MemorySaver {
227 async fn get_tuple(
228 &self,
229 config: &RunnableConfig,
230 ) -> Result<Option<CheckpointTuple>, CoreCheckpointError> {
231 let thread_id = Self::get_thread_id(config).map_checkpoint()?;
232 let checkpoint_ns = Self::get_checkpoint_ns(config);
233
234 Self::lazy_cleanup(self, &thread_id, &checkpoint_ns)
236 .await
237 .map_checkpoint()?;
238
239 let storage = self.storage.read().await;
241 let checkpoint_data = storage
242 .get(&thread_id)
243 .and_then(|ns| ns.get(&checkpoint_ns))
244 .cloned();
245 drop(storage);
246
247 let tuple_opt = checkpoint_data.and_then(|checkpoints| {
248 config.checkpoint_id.as_ref().map_or_else(
249 || checkpoints.first().cloned(),
250 |checkpoint_id| {
251 checkpoints
252 .iter()
253 .find(|t| t.checkpoint.id == *checkpoint_id)
254 .cloned()
255 },
256 )
257 });
258
259 if let Some(mut tuple) = tuple_opt {
261 let checkpoint_id = tuple.checkpoint.id.clone();
262 let writes = self.writes.read().await;
263 let pending_writes = writes
264 .get(&(thread_id, checkpoint_id, checkpoint_ns))
265 .cloned()
266 .unwrap_or_default();
267 drop(writes);
268
269 tuple.pending_writes = pending_writes;
270 Ok(Some(tuple))
271 } else {
272 Ok(None)
273 }
274 }
275
276 async fn list(
277 &self,
278 config: &RunnableConfig,
279 filter: Option<CheckpointFilter>,
280 ) -> Result<Vec<CheckpointTuple>, CoreCheckpointError> {
281 let thread_id = Self::get_thread_id(config).map_checkpoint()?;
282 let checkpoint_ns = Self::get_checkpoint_ns(config);
283
284 Self::lazy_cleanup(self, &thread_id, &checkpoint_ns)
286 .await
287 .map_checkpoint()?;
288
289 let namespace = {
290 let storage = self.storage.read().await;
291 storage
292 .get(&thread_id)
293 .and_then(|ns| ns.get(&checkpoint_ns))
294 .cloned()
295 };
296
297 let mut checkpoints = namespace.unwrap_or_default();
298
299 if let Some(f) = filter {
301 if let Some(source) = f.source {
303 checkpoints.retain(|t| t.metadata.source == source);
304 }
305
306 if let Some(min_step) = f.step_gte {
308 checkpoints.retain(|t| t.metadata.step >= min_step);
309 }
310 if let Some(max_step) = f.step_lte {
311 checkpoints.retain(|t| t.metadata.step <= max_step);
312 }
313
314 if let Some(before_id) = f.before {
316 let before_pos = checkpoints
317 .iter()
318 .position(|t| t.checkpoint.id == before_id);
319 if let Some(pos) = before_pos {
320 checkpoints = checkpoints.into_iter().take(pos).collect();
321 }
322 }
323 if let Some(after_id) = f.after {
324 let after_pos = checkpoints.iter().position(|t| t.checkpoint.id == after_id);
325 if let Some(pos) = after_pos {
326 checkpoints = checkpoints.into_iter().skip(pos + 1).collect();
327 }
328 }
329
330 if let Some(limit) = f.limit {
332 checkpoints.truncate(limit);
333 }
334 }
335
336 Ok(checkpoints)
337 }
338
339 async fn put(
340 &self,
341 config: &RunnableConfig,
342 checkpoint: Checkpoint,
343 metadata: CheckpointMetadata,
344 ) -> Result<RunnableConfig, CoreCheckpointError> {
345 let span = info_span!(
347 target: "juncture",
348 names::CHECKPOINT_PUT,
349 "juncture.checkpoint.id" = %checkpoint.id,
350 "juncture.checkpoint.source" = ?metadata.source,
351 "juncture.checkpoint.step" = metadata.step,
352 );
353
354 async move {
355 let thread_id = Self::get_thread_id(config).map_checkpoint()?;
356 let checkpoint_ns = Self::get_checkpoint_ns(config);
357 let checkpoint_id = checkpoint.id.clone();
358 let source = metadata.source.clone();
359
360 let tuple = CheckpointTuple {
362 config: config.clone(),
363 checkpoint,
364 metadata,
365 pending_writes: Vec::new(),
366 parent_config: None,
367 };
368
369 let mut storage = self.storage.write().await;
372 let thread_map = storage
373 .entry(thread_id.clone())
374 .or_insert_with(HashMap::new);
375 let namespace = thread_map
376 .entry(checkpoint_ns.clone())
377 .or_insert_with(Vec::new);
378
379 namespace.push(tuple);
380
381 Self::sort_checkpoints(namespace);
383 drop(storage);
384
385 tracing::debug!(
387 name: "juncture.checkpoint.writes",
388 source = ?source,
389 );
390
391 let mut result_config = config.clone();
393 result_config.checkpoint_id = Some(checkpoint_id);
394
395 Ok(result_config)
396 }
397 .instrument(span)
398 .await
399 }
400
401 async fn put_writes(
402 &self,
403 config: &RunnableConfig,
404 writes: Vec<PendingWrite>,
405 task_id: &str,
406 ) -> Result<(), CoreCheckpointError> {
407 let checkpoint_id_for_span = config.checkpoint_id.clone().unwrap_or_default();
408
409 let span = info_span!(
411 target: "juncture",
412 "juncture.checkpoint.put_writes",
413 "juncture.checkpoint.id" = %checkpoint_id_for_span,
414 "juncture.checkpoint.task_id" = %task_id,
415 "juncture.checkpoint.writes_count" = writes.len(),
416 );
417
418 async move {
419 let thread_id = Self::get_thread_id(config).map_checkpoint()?;
420 let checkpoint_ns = Self::get_checkpoint_ns(config);
421 let checkpoint_id = config
422 .checkpoint_id
423 .clone()
424 .ok_or_else(|| CoreCheckpointError::Storage("checkpoint_id is required".into()))?;
425
426 let key = (thread_id, checkpoint_id, checkpoint_ns);
427
428 let prepared_writes: Vec<PendingWrite> = writes
430 .into_iter()
431 .map(|mut w| {
432 w.task_id = task_id.to_string();
433 w
434 })
435 .collect();
436
437 self.writes
440 .write()
441 .await
442 .entry(key)
443 .or_insert_with(Vec::new)
444 .extend(prepared_writes);
445
446 Ok(())
447 }
448 .instrument(span)
449 .await
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use chrono::Utc;
457 use juncture_core::checkpoint::CheckpointSource;
458 use serde_json::json;
459
460 fn create_test_checkpoint(id: &str, _step: i64) -> Checkpoint {
461 Checkpoint {
462 id: id.to_string(),
463 channel_values: json!({}),
464 channel_versions: HashMap::new(),
465 versions_seen: HashMap::new(),
466 pending_tasks: vec![],
467 pending_sends: vec![],
468 pending_interrupts: vec![],
469 schema_version: 1,
470 created_at: Utc::now().to_rfc3339(),
471 v: 1,
472 new_versions: HashMap::new(),
473 counters_since_delta_snapshot: HashMap::new(),
474 }
475 }
476
477 fn create_test_metadata(source: CheckpointSource, step: i64) -> CheckpointMetadata {
478 CheckpointMetadata {
479 source,
480 step,
481 writes: HashMap::new(),
482 parents: HashMap::new(),
483 run_id: "test-run".to_string(),
484 }
485 }
486
487 fn create_test_config(thread_id: &str) -> RunnableConfig {
488 RunnableConfig::default().with_thread_id(thread_id)
489 }
490
491 #[tokio::test]
492 async fn test_memory_saver_put_get() {
493 let saver = MemorySaver::new();
494 let config = create_test_config("thread1");
495 let checkpoint = create_test_checkpoint("cp1", 0);
496 let metadata = create_test_metadata(CheckpointSource::Input, 0);
497
498 let result_config = saver
499 .put(&config, checkpoint.clone(), metadata)
500 .await
501 .unwrap();
502
503 assert_eq!(result_config.checkpoint_id, Some("cp1".to_string()));
504
505 let retrieved = saver.get_tuple(&result_config).await.unwrap().unwrap();
506 assert_eq!(retrieved.checkpoint.id, "cp1");
507 }
508
509 #[tokio::test]
510 async fn test_memory_saver_get_latest() {
511 let saver = MemorySaver::new();
512 let config = create_test_config("thread1");
513
514 for i in 0..3 {
516 let checkpoint = create_test_checkpoint(&format!("cp{i}"), i);
517 let metadata = create_test_metadata(CheckpointSource::Loop, i);
518 saver.put(&config, checkpoint, metadata).await.unwrap();
519 }
520
521 let latest = saver.get_tuple(&config).await.unwrap().unwrap();
523 assert_eq!(latest.checkpoint.id, "cp2"); }
525
526 #[tokio::test]
527 async fn test_memory_saver_list() {
528 let saver = MemorySaver::new();
529 let config = create_test_config("thread1");
530
531 for i in 0..5 {
533 let checkpoint = create_test_checkpoint(&format!("cp{i}"), i);
534 let metadata = create_test_metadata(CheckpointSource::Loop, i);
535 saver.put(&config, checkpoint, metadata).await.unwrap();
536 }
537
538 let all = saver.list(&config, None).await.unwrap();
540 assert_eq!(all.len(), 5);
541
542 let limited = saver
544 .list(
545 &config,
546 Some(CheckpointFilter {
547 limit: Some(3),
548 ..Default::default()
549 }),
550 )
551 .await
552 .unwrap();
553 assert_eq!(limited.len(), 3);
554
555 let filtered = saver
557 .list(
558 &config,
559 Some(CheckpointFilter {
560 step_gte: Some(2),
561 ..Default::default()
562 }),
563 )
564 .await
565 .unwrap();
566 assert_eq!(filtered.len(), 3); }
568
569 #[tokio::test]
570 async fn test_memory_saver_put_writes() {
571 let saver = MemorySaver::new();
572 let config = create_test_config("thread1");
573 let checkpoint = create_test_checkpoint("cp1", 0);
574 let metadata = create_test_metadata(CheckpointSource::Input, 0);
575
576 let result_config = saver.put(&config, checkpoint, metadata).await.unwrap();
577
578 let writes = vec![PendingWrite {
580 task_id: String::new(),
581 channel: "messages".to_string(),
582 value: json!("hello"),
583 }];
584
585 saver
586 .put_writes(&result_config, writes, "task1")
587 .await
588 .unwrap();
589
590 let tuple = saver.get_tuple(&result_config).await.unwrap().unwrap();
592 assert_eq!(tuple.pending_writes.len(), 1);
593 assert_eq!(tuple.pending_writes[0].channel, "messages");
594 assert_eq!(tuple.pending_writes[0].task_id, "task1");
595 }
596
597 #[tokio::test]
598 async fn test_memory_saver_namespace_isolation() {
599 let saver = MemorySaver::new();
600
601 let config_ns1 = RunnableConfig::default()
602 .with_thread_id("thread1")
603 .with_checkpoint_ns(juncture_core::checkpoint::CheckpointNamespace::parse("ns1"));
604 let config_ns2 = RunnableConfig::default()
605 .with_thread_id("thread1")
606 .with_checkpoint_ns(juncture_core::checkpoint::CheckpointNamespace::parse("ns2"));
607
608 let checkpoint1 = create_test_checkpoint("cp1", 0);
609 let checkpoint2 = create_test_checkpoint("cp2", 0);
610 let metadata = create_test_metadata(CheckpointSource::Input, 0);
611
612 saver
613 .put(&config_ns1, checkpoint1, metadata.clone())
614 .await
615 .unwrap();
616 saver.put(&config_ns2, checkpoint2, metadata).await.unwrap();
617
618 let result = saver.get_tuple(&config_ns2).await.unwrap().unwrap();
620 assert_eq!(result.checkpoint.id, "cp2");
621 }
622
623 #[tokio::test]
624 async fn test_memory_saver_thread_isolation() {
625 let saver = MemorySaver::new();
626
627 let config_t1 = RunnableConfig::default().with_thread_id("thread1");
628 let config_t2 = RunnableConfig::default().with_thread_id("thread2");
629
630 let checkpoint1 = create_test_checkpoint("cp1", 0);
631 let checkpoint2 = create_test_checkpoint("cp2", 0);
632 let metadata = create_test_metadata(CheckpointSource::Input, 0);
633
634 saver
635 .put(&config_t1, checkpoint1, metadata.clone())
636 .await
637 .unwrap();
638 saver.put(&config_t2, checkpoint2, metadata).await.unwrap();
639
640 let result1 = saver.get_tuple(&config_t1).await.unwrap().unwrap();
642 assert_eq!(result1.checkpoint.id, "cp1");
643
644 let result2 = saver.get_tuple(&config_t2).await.unwrap().unwrap();
645 assert_eq!(result2.checkpoint.id, "cp2");
646 }
647
648 #[tokio::test]
649 async fn test_memory_saver_not_found() {
650 let saver = MemorySaver::new();
651 let config = RunnableConfig::default()
652 .with_thread_id("nonexistent")
653 .with_checkpoint_id("missing");
654
655 let result = saver.get_tuple(&config).await.unwrap();
656 assert!(result.is_none());
657 }
658
659 #[tokio::test]
660 async fn test_memory_saver_filter_by_source() {
661 let saver = MemorySaver::new();
662 let config = create_test_config("thread1");
663
664 let cp_input = create_test_checkpoint("cp1", 0);
666 let meta_input = create_test_metadata(CheckpointSource::Input, 0);
667 saver.put(&config, cp_input, meta_input).await.unwrap();
668
669 let cp_loop = create_test_checkpoint("cp2", 1);
670 let meta_loop = create_test_metadata(CheckpointSource::Loop, 1);
671 saver.put(&config, cp_loop, meta_loop).await.unwrap();
672
673 let filtered = saver
675 .list(
676 &config,
677 Some(CheckpointFilter {
678 source: Some(CheckpointSource::Loop),
679 ..Default::default()
680 }),
681 )
682 .await
683 .unwrap();
684
685 assert_eq!(filtered.len(), 1);
686 assert!(matches!(
687 filtered[0].metadata.source,
688 CheckpointSource::Loop
689 ));
690 }
691
692 #[tokio::test]
693 async fn test_memory_saver_clone() {
694 let saver = MemorySaver::new();
695 let cloned = saver.clone();
696
697 let config = create_test_config("thread1");
698 let checkpoint = create_test_checkpoint("cp1", 0);
699 let metadata = create_test_metadata(CheckpointSource::Input, 0);
700
701 saver
703 .put(&config, checkpoint.clone(), metadata.clone())
704 .await
705 .unwrap();
706
707 let result = cloned.get_tuple(&config).await.unwrap();
709 assert!(result.is_some());
710 assert_eq!(result.unwrap().checkpoint.id, "cp1");
711 }
712
713 #[tokio::test]
714 async fn test_memory_saver_ttl_expiration() {
715 use crate::types::TtlConfig;
716 use std::time::Duration;
717
718 let saver = MemorySaver::new().with_ttl_config(TtlConfig {
719 default_ttl: Some(Duration::from_millis(100)), sweep_interval: Duration::from_secs(3600),
721 max_checkpoints: None,
722 });
723
724 let config = create_test_config("thread1");
725
726 for i in 0..3 {
728 let checkpoint = create_test_checkpoint(&format!("cp{i}"), i);
729 let metadata = create_test_metadata(CheckpointSource::Loop, i);
730 saver.put(&config, checkpoint, metadata).await.unwrap();
731 }
732
733 let list = saver.list(&config, None).await.unwrap();
735 assert_eq!(list.len(), 3);
736
737 tokio::time::sleep(Duration::from_millis(150)).await;
739
740 let result = saver.get_tuple(&config).await.unwrap();
742
743 assert!(result.is_none());
745
746 let list = saver.list(&config, None).await.unwrap();
748 assert_eq!(list.len(), 0);
749 }
750
751 #[tokio::test]
752 async fn test_memory_saver_max_checkpoints() {
753 use crate::types::TtlConfig;
754
755 let saver = MemorySaver::new().with_ttl_config(TtlConfig {
756 default_ttl: None,
757 sweep_interval: std::time::Duration::from_secs(3600),
758 max_checkpoints: Some(2), });
760
761 let config = create_test_config("thread1");
762
763 for i in 0..5 {
765 let checkpoint = create_test_checkpoint(&format!("cp{i}"), i);
766 let metadata = create_test_metadata(CheckpointSource::Loop, i);
767 saver.put(&config, checkpoint, metadata).await.unwrap();
768 }
769
770 let list = saver.list(&config, None).await.unwrap();
772
773 assert_eq!(list.len(), 2);
775 assert_eq!(list[0].checkpoint.id, "cp4"); assert_eq!(list[1].checkpoint.id, "cp3"); }
778}
779
780