use std::collections::VecDeque;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use tokio::sync::Mutex;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder, NoopAck};
use arkflow_core::{Error, MessageBatch};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryInputConfig {
pub messages: Option<Vec<String>>,
}
pub struct MemoryInput {
queue: Arc<Mutex<VecDeque<MessageBatch>>>,
connected: AtomicBool,
}
impl MemoryInput {
pub fn new(config: MemoryInputConfig) -> Result<Self, Error> {
let mut queue = VecDeque::new();
if let Some(messages) = &config.messages {
for msg_str in messages {
queue.push_back(MessageBatch::from_string(msg_str));
}
}
Ok(Self {
queue: Arc::new(Mutex::new(queue)),
connected: AtomicBool::new(false),
})
}
pub async fn push(&self, msg: MessageBatch) -> Result<(), Error> {
let mut queue = self.queue.lock().await;
queue.push_back(msg);
Ok(())
}
}
#[async_trait]
impl Input for MemoryInput {
async fn connect(&self) -> Result<(), Error> {
self.connected
.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
if !self.connected.load(std::sync::atomic::Ordering::SeqCst) {
return Err(Error::Connection("The input is not connected".to_string()));
}
let msg_option;
{
let mut queue = self.queue.lock().await;
msg_option = queue.pop_front();
}
if let Some(msg) = msg_option {
Ok((msg, Arc::new(NoopAck)))
} else {
Err(Error::EOF)
}
}
async fn close(&self) -> Result<(), Error> {
self.connected
.store(false, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
}
pub(crate) struct MemoryInputBuilder;
impl InputBuilder for MemoryInputBuilder {
fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
if config.is_none() {
return Err(Error::Config(
"Memory input configuration is missing".to_string(),
));
}
let config: MemoryInputConfig = serde_json::from_value(config.clone().unwrap())?;
Ok(Arc::new(MemoryInput::new(config)?))
}
}
pub fn init() {
register_input_builder("memory", Arc::new(MemoryInputBuilder));
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_input_new() {
let config = MemoryInputConfig { messages: None };
let input = MemoryInput::new(config);
assert!(input.is_ok());
let messages = vec!["message1".to_string(), "message2".to_string()];
let config = MemoryInputConfig {
messages: Some(messages),
};
let input = MemoryInput::new(config);
assert!(input.is_ok());
}
#[tokio::test]
async fn test_memory_input_connect() {
let config = MemoryInputConfig { messages: None };
let input = MemoryInput::new(config).unwrap();
let result = input.connect().await;
assert!(result.is_ok());
assert!(input.connected.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn test_memory_input_read_without_connect() {
let config = MemoryInputConfig { messages: None };
let input = MemoryInput::new(config).unwrap();
let result = input.read().await;
assert!(result.is_err());
match result {
Err(Error::Connection(_)) => {} _ => panic!("Expected Connection error"),
}
}
#[tokio::test]
async fn test_memory_input_read_empty_queue() {
let config = MemoryInputConfig { messages: None };
let input = MemoryInput::new(config).unwrap();
assert!(input.connect().await.is_ok());
let result = input.read().await;
assert!(result.is_err());
match result {
Err(Error::EOF) => {} _ => panic!("Expected Done error"),
}
}
#[tokio::test]
async fn test_memory_input_read_with_initial_messages() {
let messages = vec!["message1".to_string(), "message2".to_string()];
let config = MemoryInputConfig {
messages: Some(messages),
};
let input = MemoryInput::new(config).unwrap();
assert!(input.connect().await.is_ok());
let (batch, ack) = input.read().await.unwrap();
assert_eq!(batch.as_string().unwrap(), vec!["message1"]);
ack.ack().await;
let (batch, ack) = input.read().await.unwrap();
assert_eq!(batch.as_string().unwrap(), vec!["message2"]);
ack.ack().await;
let result = input.read().await;
assert!(result.is_err());
match result {
Err(Error::EOF) => {} _ => panic!("Expected Done error"),
}
}
#[tokio::test]
async fn test_memory_input_push() {
let config = MemoryInputConfig { messages: None };
let input = MemoryInput::new(config).unwrap();
assert!(input.connect().await.is_ok());
let msg = MessageBatch::from_string("pushed message");
assert!(input.push(msg).await.is_ok());
let (batch, ack) = input.read().await.unwrap();
assert_eq!(batch.as_string().unwrap(), vec!["pushed message"]);
ack.ack().await;
let result = input.read().await;
assert!(result.is_err());
match result {
Err(Error::EOF) => {} _ => panic!("Expected Done error"),
}
}
#[tokio::test]
async fn test_memory_input_close() {
let config = MemoryInputConfig { messages: None };
let input = MemoryInput::new(config).unwrap();
assert!(input.connect().await.is_ok());
assert!(input.connected.load(std::sync::atomic::Ordering::SeqCst));
assert!(input.close().await.is_ok());
assert!(!input.connected.load(std::sync::atomic::Ordering::SeqCst));
let result = input.read().await;
assert!(result.is_err());
match result {
Err(Error::Connection(_)) => {} _ => panic!("Expected Connection error"),
}
}
#[tokio::test]
async fn test_memory_input_multiple_push_read() {
let config = MemoryInputConfig { messages: None };
let input = MemoryInput::new(config).unwrap();
assert!(input.connect().await.is_ok());
let msg1 = MessageBatch::from_string("message1");
let msg2 = MessageBatch::from_string("message2");
let msg3 = MessageBatch::from_string("message3");
assert!(input.push(msg1).await.is_ok());
assert!(input.push(msg2).await.is_ok());
assert!(input.push(msg3).await.is_ok());
let (batch, ack) = input.read().await.unwrap();
assert_eq!(batch.as_string().unwrap(), vec!["message1"]);
ack.ack().await;
let (batch, ack) = input.read().await.unwrap();
assert_eq!(batch.as_string().unwrap(), vec!["message2"]);
ack.ack().await;
let (batch, ack) = input.read().await.unwrap();
assert_eq!(batch.as_string().unwrap(), vec!["message3"]);
ack.ack().await;
let result = input.read().await;
assert!(result.is_err());
match result {
Err(Error::EOF) => {} _ => panic!("Expected Done error"),
}
}
}