use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder, NoopAck};
use arkflow_core::{Error, MessageBatch};
use async_trait::async_trait;
use serde::{Deserialize, Deserializer, Serialize};
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateInputConfig {
context: String,
#[serde(deserialize_with = "deserialize_duration")]
interval: Duration,
count: Option<usize>,
batch_size: Option<usize>,
}
pub struct GenerateInput {
config: GenerateInputConfig,
count: AtomicI64,
batch_size: usize,
}
impl GenerateInput {
pub fn new(config: GenerateInputConfig) -> Result<Self, Error> {
let batch_size = config.batch_size.unwrap_or(1);
Ok(Self {
config,
count: AtomicI64::new(0),
batch_size,
})
}
}
#[async_trait]
impl Input for GenerateInput {
async fn connect(&self) -> Result<(), Error> {
Ok(())
}
async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
tokio::time::sleep(self.config.interval).await;
if let Some(count) = self.config.count {
let current_count = self.count.load(Ordering::SeqCst);
if current_count >= count as i64 {
return Err(Error::EOF);
}
if current_count + self.batch_size as i64 > count as i64 {
return Err(Error::EOF);
}
}
let mut msgs = Vec::with_capacity(self.batch_size);
for _ in 0..self.batch_size {
let s = self.config.context.clone();
msgs.push(s.into_bytes())
}
self.count
.fetch_add(self.batch_size as i64, Ordering::SeqCst);
Ok((MessageBatch::new_binary(msgs), Arc::new(NoopAck)))
}
async fn close(&self) -> Result<(), Error> {
Ok(())
}
}
fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
humantime::parse_duration(&s).map_err(serde::de::Error::custom)
}
pub(crate) struct GenerateInputBuilder;
impl InputBuilder for GenerateInputBuilder {
fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
if config.is_none() {
return Err(Error::Config(
"Generate input configuration is missing".to_string(),
));
}
let config: GenerateInputConfig =
serde_json::from_value::<GenerateInputConfig>(config.clone().unwrap())?;
Ok(Arc::new(GenerateInput::new(config)?))
}
}
pub fn init() {
register_input_builder("generate", Arc::new(GenerateInputBuilder));
}
#[cfg(test)]
mod tests {
use crate::input::{generate::GenerateInput, generate::GenerateInputConfig};
use arkflow_core::input::Input;
use arkflow_core::Error;
use std::time::Duration;
#[tokio::test]
async fn test_generate_input_new() {
let config = GenerateInputConfig {
context: "test message".to_string(),
interval: Duration::from_millis(10),
count: Some(5),
batch_size: Some(2),
};
let input = GenerateInput::new(config).unwrap();
assert_eq!(input.batch_size, 2);
}
#[tokio::test]
async fn test_generate_input_default_batch_size() {
let config = GenerateInputConfig {
context: "test message".to_string(),
interval: Duration::from_millis(10),
count: Some(5),
batch_size: None,
};
let input = GenerateInput::new(config).unwrap();
assert_eq!(input.batch_size, 1); }
#[tokio::test]
async fn test_generate_input_connect() {
let config = GenerateInputConfig {
context: "test message".to_string(),
interval: Duration::from_millis(10),
count: Some(5),
batch_size: Some(2),
};
let input = GenerateInput::new(config).unwrap();
assert!(input.connect().await.is_ok()); }
#[tokio::test]
async fn test_generate_input_read() {
let config = GenerateInputConfig {
context: "test message".to_string(),
interval: Duration::from_millis(10),
count: Some(5),
batch_size: Some(2),
};
let input = GenerateInput::new(config).unwrap();
let (batch, ack) = input.read().await.unwrap();
let messages = batch.as_binary();
assert_eq!(messages.len(), 2); for msg in messages {
assert_eq!(String::from_utf8(msg.to_vec()).unwrap(), "test message");
}
ack.ack().await;
let (batch, ack) = input.read().await.unwrap();
let messages = batch.as_binary();
assert_eq!(messages.len(), 2);
ack.ack().await;
let result = input.read().await;
assert!(matches!(result, Err(Error::EOF)));
}
#[tokio::test]
async fn test_generate_input_without_count_limit() {
let config = GenerateInputConfig {
context: "test message".to_string(),
interval: Duration::from_millis(10),
count: None, batch_size: Some(1),
};
let input = GenerateInput::new(config).unwrap();
for _ in 0..10 {
let result = input.read().await;
assert!(result.is_ok());
let (batch, ack) = result.unwrap();
let messages = batch.as_binary();
assert_eq!(messages.len(), 1);
ack.ack().await;
}
}
#[tokio::test]
async fn test_generate_input_close() {
let config = GenerateInputConfig {
context: "test message".to_string(),
interval: Duration::from_millis(10),
count: Some(5),
batch_size: Some(2),
};
let input = GenerateInput::new(config).unwrap();
assert!(input.close().await.is_ok()); }
#[tokio::test]
async fn test_generate_input_exact_count() {
let config = GenerateInputConfig {
context: "test message".to_string(),
interval: Duration::from_millis(10),
count: Some(4),
batch_size: Some(2),
};
let input = GenerateInput::new(config).unwrap();
let result = input.read().await;
assert!(result.is_ok());
let result = input.read().await;
assert!(result.is_ok());
let result = input.read().await;
assert!(matches!(result, Err(Error::EOF)));
}
#[tokio::test]
async fn test_deserialize_duration() {
let json = r#"{
"context": "test message",
"interval": "10ms",
"count": 5,
"batch_size": 2
}"#;
let config: GenerateInputConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.context, "test message");
assert_eq!(config.interval, Duration::from_millis(10));
assert_eq!(config.count, Some(5));
assert_eq!(config.batch_size, Some(2));
}
}