Skip to main content

dstack_memory/
lib.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4
5pub mod file;
6#[cfg(feature = "eruka")]
7pub mod eruka;
8
9/// A single memory field — the atomic unit of persistent context.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Field {
12    pub path: String,
13    pub value: String,
14    pub confidence: f64,
15    pub source: String,
16    pub updated_at: DateTime<Utc>,
17}
18
19impl Field {
20    pub fn new(path: impl Into<String>, value: impl Into<String>, source: impl Into<String>) -> Self {
21        Self {
22            path: path.into(),
23            value: value.into(),
24            confidence: 0.5,
25            source: source.into(),
26            updated_at: Utc::now(),
27        }
28    }
29
30    pub fn with_confidence(mut self, confidence: f64) -> Self {
31        self.confidence = confidence.clamp(0.0, 1.0);
32        self
33    }
34}
35
36#[derive(Debug, thiserror::Error)]
37pub enum MemoryError {
38    #[error("IO error: {0}")]
39    Io(#[from] std::io::Error),
40    #[error("Serialization error: {0}")]
41    Serde(#[from] serde_json::Error),
42    #[error("HTTP error: {0}")]
43    Http(String),
44    #[error("Not found: {0}")]
45    NotFound(String),
46}
47
48pub type Result<T> = std::result::Result<T, MemoryError>;
49
50/// Pluggable memory backend. Implement this for custom storage.
51#[async_trait]
52pub trait MemoryProvider: Send + Sync {
53    /// Load all fields matching a path prefix
54    async fn load(&self, path: &str) -> Result<Vec<Field>>;
55    /// Write a single field (upsert by path)
56    async fn write(&self, field: &Field) -> Result<()>;
57    /// Search fields by keyword in path or value
58    async fn search(&self, query: &str) -> Result<Vec<Field>>;
59    /// Delete a field by exact path
60    async fn delete(&self, path: &str) -> Result<()>;
61    /// Export all fields
62    async fn export_all(&self) -> Result<Vec<Field>>;
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn field_serialization_roundtrip() {
71        let field = Field {
72            path: "projects/myapp/learnings/auth-fix".into(),
73            value: "JWT validation needs rust_crypto feature flag".into(),
74            confidence: 0.95,
75            source: "user_correction".into(),
76            updated_at: chrono::Utc::now(),
77        };
78        let json = serde_json::to_string(&field).unwrap();
79        let parsed: Field = serde_json::from_str(&json).unwrap();
80        assert_eq!(parsed.path, field.path);
81        assert_eq!(parsed.confidence, 0.95);
82    }
83
84    #[test]
85    fn field_default_confidence() {
86        let field = Field::new("test/path", "value", "test");
87        assert_eq!(field.confidence, 0.5);
88        assert_eq!(field.source, "test");
89    }
90
91    #[test]
92    fn field_confidence_clamped() {
93        let field = Field::new("a", "b", "c").with_confidence(1.5);
94        assert_eq!(field.confidence, 1.0);
95        let field2 = Field::new("a", "b", "c").with_confidence(-0.5);
96        assert_eq!(field2.confidence, 0.0);
97    }
98
99    #[test]
100    fn field_new_sets_timestamp() {
101        let before = Utc::now();
102        let field = Field::new("x", "y", "z");
103        let after = Utc::now();
104        assert!(field.updated_at >= before && field.updated_at <= after);
105    }
106}