arkflow_plugin/input/
memory.rs1use std::collections::VecDeque;
6use std::sync::atomic::AtomicBool;
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12
13use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder, NoopAck};
14use arkflow_core::{Error, MessageBatch};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MemoryInputConfig {
19 pub messages: Option<Vec<String>>,
21}
22
23pub struct MemoryInput {
25 queue: Arc<Mutex<VecDeque<MessageBatch>>>,
26 connected: AtomicBool,
27}
28
29impl MemoryInput {
30 pub fn new(config: MemoryInputConfig) -> Result<Self, Error> {
32 let mut queue = VecDeque::new();
33
34 if let Some(messages) = &config.messages {
36 for msg_str in messages {
37 queue.push_back(MessageBatch::from_string(msg_str));
38 }
39 }
40
41 Ok(Self {
42 queue: Arc::new(Mutex::new(queue)),
43 connected: AtomicBool::new(false),
44 })
45 }
46
47 pub async fn push(&self, msg: MessageBatch) -> Result<(), Error> {
49 let mut queue = self.queue.lock().await;
50 queue.push_back(msg);
51 Ok(())
52 }
53}
54
55#[async_trait]
56impl Input for MemoryInput {
57 async fn connect(&self) -> Result<(), Error> {
58 self.connected
59 .store(true, std::sync::atomic::Ordering::SeqCst);
60 Ok(())
61 }
62
63 async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
64 if !self.connected.load(std::sync::atomic::Ordering::SeqCst) {
65 return Err(Error::Connection("The input is not connected".to_string()));
66 }
67
68 let msg_option;
70 {
71 let mut queue = self.queue.lock().await;
72 msg_option = queue.pop_front();
73 }
74
75 if let Some(msg) = msg_option {
76 Ok((msg, Arc::new(NoopAck)))
77 } else {
78 Err(Error::EOF)
79 }
80 }
81
82 async fn close(&self) -> Result<(), Error> {
83 self.connected
84 .store(false, std::sync::atomic::Ordering::SeqCst);
85 Ok(())
86 }
87}
88
89pub(crate) struct MemoryInputBuilder;
90impl InputBuilder for MemoryInputBuilder {
91 fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
92 if config.is_none() {
93 return Err(Error::Config(
94 "Memory input configuration is missing".to_string(),
95 ));
96 }
97 let config: MemoryInputConfig = serde_json::from_value(config.clone().unwrap())?;
98 Ok(Arc::new(MemoryInput::new(config)?))
99 }
100}
101
102pub fn init() {
103 register_input_builder("memory", Arc::new(MemoryInputBuilder));
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[tokio::test]
111 async fn test_memory_input_new() {
112 let config = MemoryInputConfig { messages: None };
114 let input = MemoryInput::new(config);
115 assert!(input.is_ok());
116
117 let messages = vec!["message1".to_string(), "message2".to_string()];
119 let config = MemoryInputConfig {
120 messages: Some(messages),
121 };
122 let input = MemoryInput::new(config);
123 assert!(input.is_ok());
124 }
125
126 #[tokio::test]
127 async fn test_memory_input_connect() {
128 let config = MemoryInputConfig { messages: None };
129 let input = MemoryInput::new(config).unwrap();
130
131 let result = input.connect().await;
133 assert!(result.is_ok());
134
135 assert!(input.connected.load(std::sync::atomic::Ordering::SeqCst));
137 }
138
139 #[tokio::test]
140 async fn test_memory_input_read_without_connect() {
141 let config = MemoryInputConfig { messages: None };
142 let input = MemoryInput::new(config).unwrap();
143
144 let result = input.read().await;
146 assert!(result.is_err());
147 match result {
148 Err(Error::Connection(_)) => {} _ => panic!("Expected Connection error"),
150 }
151 }
152
153 #[tokio::test]
154 async fn test_memory_input_read_empty_queue() {
155 let config = MemoryInputConfig { messages: None };
156 let input = MemoryInput::new(config).unwrap();
157
158 assert!(input.connect().await.is_ok());
160
161 let result = input.read().await;
163 assert!(result.is_err());
164 match result {
165 Err(Error::EOF) => {} _ => panic!("Expected Done error"),
167 }
168 }
169
170 #[tokio::test]
171 async fn test_memory_input_read_with_initial_messages() {
172 let messages = vec!["message1".to_string(), "message2".to_string()];
173 let config = MemoryInputConfig {
174 messages: Some(messages),
175 };
176 let input = MemoryInput::new(config).unwrap();
177
178 assert!(input.connect().await.is_ok());
180
181 let (batch, ack) = input.read().await.unwrap();
183 assert_eq!(batch.as_string().unwrap(), vec!["message1"]);
184 ack.ack().await;
185
186 let (batch, ack) = input.read().await.unwrap();
188 assert_eq!(batch.as_string().unwrap(), vec!["message2"]);
189 ack.ack().await;
190
191 let result = input.read().await;
193 assert!(result.is_err());
194 match result {
195 Err(Error::EOF) => {} _ => panic!("Expected Done error"),
197 }
198 }
199
200 #[tokio::test]
201 async fn test_memory_input_push() {
202 let config = MemoryInputConfig { messages: None };
203 let input = MemoryInput::new(config).unwrap();
204
205 assert!(input.connect().await.is_ok());
207
208 let msg = MessageBatch::from_string("pushed message");
210 assert!(input.push(msg).await.is_ok());
211
212 let (batch, ack) = input.read().await.unwrap();
214 assert_eq!(batch.as_string().unwrap(), vec!["pushed message"]);
215 ack.ack().await;
216
217 let result = input.read().await;
219 assert!(result.is_err());
220 match result {
221 Err(Error::EOF) => {} _ => panic!("Expected Done error"),
223 }
224 }
225
226 #[tokio::test]
227 async fn test_memory_input_close() {
228 let config = MemoryInputConfig { messages: None };
229 let input = MemoryInput::new(config).unwrap();
230
231 assert!(input.connect().await.is_ok());
233 assert!(input.connected.load(std::sync::atomic::Ordering::SeqCst));
234
235 assert!(input.close().await.is_ok());
237 assert!(!input.connected.load(std::sync::atomic::Ordering::SeqCst));
238
239 let result = input.read().await;
241 assert!(result.is_err());
242 match result {
243 Err(Error::Connection(_)) => {} _ => panic!("Expected Connection error"),
245 }
246 }
247
248 #[tokio::test]
249 async fn test_memory_input_multiple_push_read() {
250 let config = MemoryInputConfig { messages: None };
251 let input = MemoryInput::new(config).unwrap();
252
253 assert!(input.connect().await.is_ok());
255
256 let msg1 = MessageBatch::from_string("message1");
258 let msg2 = MessageBatch::from_string("message2");
259 let msg3 = MessageBatch::from_string("message3");
260
261 assert!(input.push(msg1).await.is_ok());
262 assert!(input.push(msg2).await.is_ok());
263 assert!(input.push(msg3).await.is_ok());
264
265 let (batch, ack) = input.read().await.unwrap();
267 assert_eq!(batch.as_string().unwrap(), vec!["message1"]);
268 ack.ack().await;
269
270 let (batch, ack) = input.read().await.unwrap();
271 assert_eq!(batch.as_string().unwrap(), vec!["message2"]);
272 ack.ack().await;
273
274 let (batch, ack) = input.read().await.unwrap();
275 assert_eq!(batch.as_string().unwrap(), vec!["message3"]);
276 ack.ack().await;
277
278 let result = input.read().await;
280 assert!(result.is_err());
281 match result {
282 Err(Error::EOF) => {} _ => panic!("Expected Done error"),
284 }
285 }
286}