Skip to main content

cognee_cognify/memify/
config.rs

1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4
5use super::error::MemifyError;
6
7/// Opaque wrapper around an async task callback for the memify pipeline.
8///
9/// Mirrors Python's `extraction_tasks` / `enrichment_tasks` parameters.
10/// The callback receives a JSON array of input data and returns a JSON array
11/// of output data.
12#[derive(Clone)]
13#[allow(clippy::type_complexity)]
14pub struct MemifyTask(
15    pub  Arc<
16        dyn Fn(
17                Vec<serde_json::Value>,
18            ) -> std::pin::Pin<
19                Box<
20                    dyn std::future::Future<Output = Result<Vec<serde_json::Value>, MemifyError>>
21                        + Send,
22                >,
23            > + Send
24            + Sync,
25    >,
26);
27
28impl std::fmt::Debug for MemifyTask {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.write_str("MemifyTask(…)")
31    }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct MemifyConfig {
36    /// Batch size for reading triplets from the graph.
37    /// Python: triplets_batch_size=100 in get_triplet_datapoints()
38    ///
39    /// In the Rust implementation this controls the chunk size used when
40    /// batching embedding calls, NOT graph pagination (since Rust loads
41    /// the full graph via get_graph_data()). Kept for future use if
42    /// get_triplets_batch() is added to GraphDBTrait.
43    pub triplet_batch_size: usize,
44
45    /// Optional filter: only process nodes of this type.
46    /// Python: memify(node_type=NodeSet) default
47    /// When set along with node_name_filter, uses
48    /// GraphDBTrait::get_nodeset_subgraph().
49    pub node_type_filter: Option<String>,
50
51    /// Optional filter: only process nodes with these names.
52    /// Python: memify(node_name=None) default
53    pub node_name_filter: Option<Vec<String>>,
54
55    /// Operator for node_name filtering ("OR" or "AND").
56    /// Python: get_memory_fragment(node_name_filter_operator="OR")
57    pub node_name_filter_operator: String,
58
59    /// Custom extraction tasks to run instead of (or in addition to)
60    /// the default triplet extraction.
61    /// Python: memify(extraction_tasks=...)
62    #[serde(skip)]
63    pub extraction_tasks: Option<Vec<MemifyTask>>,
64
65    /// Custom enrichment tasks to run after extraction.
66    /// Python: memify(enrichment_tasks=...)
67    #[serde(skip)]
68    pub enrichment_tasks: Option<Vec<MemifyTask>>,
69
70    /// Custom input data. When provided, skip reading from the graph
71    /// and use this data directly as input to the pipeline.
72    /// Python: memify(data=...)
73    #[serde(skip)]
74    pub custom_data: Option<Vec<serde_json::Value>>,
75}
76
77impl Default for MemifyConfig {
78    fn default() -> Self {
79        Self {
80            triplet_batch_size: 100,
81            node_type_filter: None,
82            node_name_filter: None,
83            node_name_filter_operator: "OR".to_string(),
84            extraction_tasks: None,
85            enrichment_tasks: None,
86            custom_data: None,
87        }
88    }
89}
90
91impl MemifyConfig {
92    pub fn with_triplet_batch_size(mut self, size: usize) -> Self {
93        self.triplet_batch_size = size;
94        self
95    }
96
97    pub fn with_node_type_filter(mut self, node_type: String) -> Self {
98        self.node_type_filter = Some(node_type);
99        self
100    }
101
102    pub fn with_node_name_filter(mut self, names: Vec<String>) -> Self {
103        self.node_name_filter = Some(names);
104        self
105    }
106
107    pub fn with_node_name_filter_operator(mut self, op: String) -> Self {
108        self.node_name_filter_operator = op;
109        self
110    }
111
112    pub fn with_extraction_tasks(mut self, tasks: Vec<MemifyTask>) -> Self {
113        self.extraction_tasks = Some(tasks);
114        self
115    }
116
117    pub fn with_enrichment_tasks(mut self, tasks: Vec<MemifyTask>) -> Self {
118        self.enrichment_tasks = Some(tasks);
119        self
120    }
121
122    pub fn with_custom_data(mut self, data: Vec<serde_json::Value>) -> Self {
123        self.custom_data = Some(data);
124        self
125    }
126
127    pub fn validate(&self) -> Result<(), MemifyError> {
128        if self.triplet_batch_size == 0 {
129            return Err(MemifyError::ConfigError(
130                "triplet_batch_size must be > 0".into(),
131            ));
132        }
133        let op = self.node_name_filter_operator.as_str();
134        if op != "OR" && op != "AND" {
135            return Err(MemifyError::ConfigError(format!(
136                "node_name_filter_operator must be \"OR\" or \"AND\", got \"{op}\""
137            )));
138        }
139        Ok(())
140    }
141}
142
143#[cfg(test)]
144#[allow(
145    clippy::unwrap_used,
146    clippy::expect_used,
147    reason = "test code — panics are acceptable failures"
148)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn test_default_config() {
154        let config = MemifyConfig::default();
155        assert_eq!(config.triplet_batch_size, 100);
156        assert!(config.node_type_filter.is_none());
157        assert!(config.node_name_filter.is_none());
158        assert_eq!(config.node_name_filter_operator, "OR");
159    }
160
161    #[test]
162    fn test_validate_valid_config() {
163        let config = MemifyConfig::default();
164        assert!(config.validate().is_ok());
165    }
166
167    #[test]
168    fn test_validate_zero_batch_size() {
169        let config = MemifyConfig::default().with_triplet_batch_size(0);
170        let err = config.validate().unwrap_err();
171        assert!(
172            err.to_string().contains("triplet_batch_size must be > 0"),
173            "expected batch size error, got: {err}"
174        );
175    }
176
177    #[test]
178    fn test_validate_invalid_operator() {
179        let config = MemifyConfig::default().with_node_name_filter_operator("XOR".to_string());
180        let err = config.validate().unwrap_err();
181        assert!(
182            err.to_string().contains("node_name_filter_operator"),
183            "expected operator error, got: {err}"
184        );
185    }
186
187    #[test]
188    fn test_validate_and_operator_accepted() {
189        let config = MemifyConfig::default().with_node_name_filter_operator("AND".to_string());
190        assert!(config.validate().is_ok());
191    }
192
193    #[test]
194    fn test_config_empty_node_names_vec_passes_validation() {
195        // ANCHOR: pins the current behavior that an empty `node_name_filter`
196        // vec passes `validate()` as-is (it is NOT coerced to `None`). If
197        // future work adds such coercion, this test must be updated.
198        let config = MemifyConfig::default().with_node_name_filter(vec![]);
199        assert_eq!(config.node_name_filter, Some(vec![]));
200        assert!(config.validate().is_ok());
201    }
202
203    #[test]
204    fn test_config_operator_case_sensitive() {
205        for op in ["or", "Or", "aNd", "and", "xor", ""] {
206            let config = MemifyConfig::default().with_node_name_filter_operator(op.to_string());
207            let err = config.validate().unwrap_err();
208            assert!(
209                matches!(err, MemifyError::ConfigError(_)),
210                "expected ConfigError for operator {op:?}, got: {err}"
211            );
212        }
213        for op in ["OR", "AND"] {
214            let config = MemifyConfig::default().with_node_name_filter_operator(op.to_string());
215            assert!(
216                config.validate().is_ok(),
217                "expected operator {op:?} to pass validation"
218            );
219        }
220    }
221
222    #[test]
223    fn test_config_large_batch_size_accepted() {
224        let config = MemifyConfig::default().with_triplet_batch_size(10_000);
225        assert!(config.validate().is_ok());
226    }
227
228    #[test]
229    fn test_builder_methods() {
230        let config = MemifyConfig::default()
231            .with_triplet_batch_size(50)
232            .with_node_type_filter("Entity".to_string())
233            .with_node_name_filter(vec!["Alice".to_string(), "Bob".to_string()])
234            .with_node_name_filter_operator("AND".to_string());
235
236        assert_eq!(config.triplet_batch_size, 50);
237        assert_eq!(config.node_type_filter, Some("Entity".to_string()));
238        assert_eq!(
239            config.node_name_filter,
240            Some(vec!["Alice".to_string(), "Bob".to_string()])
241        );
242        assert_eq!(config.node_name_filter_operator, "AND");
243        assert!(config.validate().is_ok());
244    }
245}