arkflow_plugin/input/
memory.rs

1//! Memory input component
2//!
3//! Read data from an in-memory message queue
4
5use std::collections::VecDeque;
6use std::sync::atomic::AtomicBool;
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12
13use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder, NoopAck};
14use arkflow_core::{Error, MessageBatch};
15
16/// Memory input configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MemoryInputConfig {
19    /// Initial message for the memory queue
20    pub messages: Option<Vec<String>>,
21}
22
23/// Memory input component
24pub struct MemoryInput {
25    queue: Arc<Mutex<VecDeque<MessageBatch>>>,
26    connected: AtomicBool,
27}
28
29impl MemoryInput {
30    /// Create a new memory input component
31    pub fn new(config: MemoryInputConfig) -> Result<Self, Error> {
32        let mut queue = VecDeque::new();
33
34        // If there is an initial message in the configuration, it is added to the queue
35        if let Some(messages) = &config.messages {
36            for msg_str in messages {
37                queue.push_back(MessageBatch::from_string(msg_str));
38            }
39        }
40
41        Ok(Self {
42            queue: Arc::new(Mutex::new(queue)),
43            connected: AtomicBool::new(false),
44        })
45    }
46
47    /// Add a message to the memory input
48    pub async fn push(&self, msg: MessageBatch) -> Result<(), Error> {
49        let mut queue = self.queue.lock().await;
50        queue.push_back(msg);
51        Ok(())
52    }
53}
54
55#[async_trait]
56impl Input for MemoryInput {
57    async fn connect(&self) -> Result<(), Error> {
58        self.connected
59            .store(true, std::sync::atomic::Ordering::SeqCst);
60        Ok(())
61    }
62
63    async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
64        if !self.connected.load(std::sync::atomic::Ordering::SeqCst) {
65            return Err(Error::Connection("The input is not connected".to_string()));
66        }
67
68        // Try to get a message from the queue
69        let msg_option;
70        {
71            let mut queue = self.queue.lock().await;
72            msg_option = queue.pop_front();
73        }
74
75        if let Some(msg) = msg_option {
76            Ok((msg, Arc::new(NoopAck)))
77        } else {
78            Err(Error::EOF)
79        }
80    }
81
82    async fn close(&self) -> Result<(), Error> {
83        self.connected
84            .store(false, std::sync::atomic::Ordering::SeqCst);
85        Ok(())
86    }
87}
88
89pub(crate) struct MemoryInputBuilder;
90impl InputBuilder for MemoryInputBuilder {
91    fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
92        if config.is_none() {
93            return Err(Error::Config(
94                "Memory input configuration is missing".to_string(),
95            ));
96        }
97        let config: MemoryInputConfig = serde_json::from_value(config.clone().unwrap())?;
98        Ok(Arc::new(MemoryInput::new(config)?))
99    }
100}
101
102pub fn init() {
103    register_input_builder("memory", Arc::new(MemoryInputBuilder));
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[tokio::test]
111    async fn test_memory_input_new() {
112        // Test creating MemoryInput instance without initial messages
113        let config = MemoryInputConfig { messages: None };
114        let input = MemoryInput::new(config);
115        assert!(input.is_ok());
116
117        // Test creating MemoryInput instance with initial messages
118        let messages = vec!["message1".to_string(), "message2".to_string()];
119        let config = MemoryInputConfig {
120            messages: Some(messages),
121        };
122        let input = MemoryInput::new(config);
123        assert!(input.is_ok());
124    }
125
126    #[tokio::test]
127    async fn test_memory_input_connect() {
128        let config = MemoryInputConfig { messages: None };
129        let input = MemoryInput::new(config).unwrap();
130
131        // Test connection
132        let result = input.connect().await;
133        assert!(result.is_ok());
134
135        // Verify connection status
136        assert!(input.connected.load(std::sync::atomic::Ordering::SeqCst));
137    }
138
139    #[tokio::test]
140    async fn test_memory_input_read_without_connect() {
141        let config = MemoryInputConfig { messages: None };
142        let input = MemoryInput::new(config).unwrap();
143
144        // Reading without connection should return an error
145        let result = input.read().await;
146        assert!(result.is_err());
147        match result {
148            Err(Error::Connection(_)) => {} // Expected error type
149            _ => panic!("Expected Connection error"),
150        }
151    }
152
153    #[tokio::test]
154    async fn test_memory_input_read_empty_queue() {
155        let config = MemoryInputConfig { messages: None };
156        let input = MemoryInput::new(config).unwrap();
157
158        // Connect
159        assert!(input.connect().await.is_ok());
160
161        // Queue is empty, should return Done error
162        let result = input.read().await;
163        assert!(result.is_err());
164        match result {
165            Err(Error::EOF) => {} // Expected error type
166            _ => panic!("Expected Done error"),
167        }
168    }
169
170    #[tokio::test]
171    async fn test_memory_input_read_with_initial_messages() {
172        let messages = vec!["message1".to_string(), "message2".to_string()];
173        let config = MemoryInputConfig {
174            messages: Some(messages),
175        };
176        let input = MemoryInput::new(config).unwrap();
177
178        // Connect
179        assert!(input.connect().await.is_ok());
180
181        // Read the first message
182        let (batch, ack) = input.read().await.unwrap();
183        assert_eq!(batch.as_string().unwrap(), vec!["message1"]);
184        ack.ack().await;
185
186        // Read the second message
187        let (batch, ack) = input.read().await.unwrap();
188        assert_eq!(batch.as_string().unwrap(), vec!["message2"]);
189        ack.ack().await;
190
191        // Queue is empty, should return Done error
192        let result = input.read().await;
193        assert!(result.is_err());
194        match result {
195            Err(Error::EOF) => {} // Expected error type
196            _ => panic!("Expected Done error"),
197        }
198    }
199
200    #[tokio::test]
201    async fn test_memory_input_push() {
202        let config = MemoryInputConfig { messages: None };
203        let input = MemoryInput::new(config).unwrap();
204
205        // Connect
206        assert!(input.connect().await.is_ok());
207
208        // Push message
209        let msg = MessageBatch::from_string("pushed message");
210        assert!(input.push(msg).await.is_ok());
211
212        // Read the pushed message
213        let (batch, ack) = input.read().await.unwrap();
214        assert_eq!(batch.as_string().unwrap(), vec!["pushed message"]);
215        ack.ack().await;
216
217        // Queue is empty, should return Done error
218        let result = input.read().await;
219        assert!(result.is_err());
220        match result {
221            Err(Error::EOF) => {} // Expected error type
222            _ => panic!("Expected Done error"),
223        }
224    }
225
226    #[tokio::test]
227    async fn test_memory_input_close() {
228        let config = MemoryInputConfig { messages: None };
229        let input = MemoryInput::new(config).unwrap();
230
231        // Connect
232        assert!(input.connect().await.is_ok());
233        assert!(input.connected.load(std::sync::atomic::Ordering::SeqCst));
234
235        // Close
236        assert!(input.close().await.is_ok());
237        assert!(!input.connected.load(std::sync::atomic::Ordering::SeqCst));
238
239        // Reading after close should return error
240        let result = input.read().await;
241        assert!(result.is_err());
242        match result {
243            Err(Error::Connection(_)) => {} // Expected error type
244            _ => panic!("Expected Connection error"),
245        }
246    }
247
248    #[tokio::test]
249    async fn test_memory_input_multiple_push_read() {
250        let config = MemoryInputConfig { messages: None };
251        let input = MemoryInput::new(config).unwrap();
252
253        // Connect
254        assert!(input.connect().await.is_ok());
255
256        // Push multiple messages
257        let msg1 = MessageBatch::from_string("message1");
258        let msg2 = MessageBatch::from_string("message2");
259        let msg3 = MessageBatch::from_string("message3");
260
261        assert!(input.push(msg1).await.is_ok());
262        assert!(input.push(msg2).await.is_ok());
263        assert!(input.push(msg3).await.is_ok());
264
265        // Read messages in order
266        let (batch, ack) = input.read().await.unwrap();
267        assert_eq!(batch.as_string().unwrap(), vec!["message1"]);
268        ack.ack().await;
269
270        let (batch, ack) = input.read().await.unwrap();
271        assert_eq!(batch.as_string().unwrap(), vec!["message2"]);
272        ack.ack().await;
273
274        let (batch, ack) = input.read().await.unwrap();
275        assert_eq!(batch.as_string().unwrap(), vec!["message3"]);
276        ack.ack().await;
277
278        // Queue is empty, should return Done error
279        let result = input.read().await;
280        assert!(result.is_err());
281        match result {
282            Err(Error::EOF) => {} // Expected error type
283            _ => panic!("Expected Done error"),
284        }
285    }
286}