arkflow_plugin/input/
generate.rs1use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder, NoopAck};
2use arkflow_core::{Error, MessageBatch};
3use async_trait::async_trait;
4use serde::{Deserialize, Deserializer, Serialize};
5use std::sync::atomic::{AtomicI64, Ordering};
6use std::sync::Arc;
7use std::time::Duration;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct GenerateInputConfig {
11 context: String,
12 #[serde(deserialize_with = "deserialize_duration")]
13 interval: Duration,
14 count: Option<usize>,
15 batch_size: Option<usize>,
16}
17
18pub struct GenerateInput {
19 config: GenerateInputConfig,
20 count: AtomicI64,
21 batch_size: usize,
22}
23impl GenerateInput {
24 pub fn new(config: GenerateInputConfig) -> Result<Self, Error> {
25 let batch_size = config.batch_size.unwrap_or(1);
26
27 Ok(Self {
28 config,
29 count: AtomicI64::new(0),
30 batch_size,
31 })
32 }
33}
34
35#[async_trait]
36impl Input for GenerateInput {
37 async fn connect(&self) -> Result<(), Error> {
38 Ok(())
39 }
40
41 async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
42 tokio::time::sleep(self.config.interval).await;
43
44 if let Some(count) = self.config.count {
45 let current_count = self.count.load(Ordering::SeqCst);
46 if current_count >= count as i64 {
47 return Err(Error::EOF);
48 }
49 if current_count + self.batch_size as i64 > count as i64 {
51 return Err(Error::EOF);
52 }
53 }
54 let mut msgs = Vec::with_capacity(self.batch_size);
55 for _ in 0..self.batch_size {
56 let s = self.config.context.clone();
57 msgs.push(s.into_bytes())
58 }
59
60 self.count
61 .fetch_add(self.batch_size as i64, Ordering::SeqCst);
62
63 Ok((MessageBatch::new_binary(msgs), Arc::new(NoopAck)))
64 }
65 async fn close(&self) -> Result<(), Error> {
66 Ok(())
67 }
68}
69
70fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
71where
72 D: Deserializer<'de>,
73{
74 let s: String = Deserialize::deserialize(deserializer)?;
75 humantime::parse_duration(&s).map_err(serde::de::Error::custom)
76}
77
78pub(crate) struct GenerateInputBuilder;
79impl InputBuilder for GenerateInputBuilder {
80 fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
81 if config.is_none() {
82 return Err(Error::Config(
83 "Generate input configuration is missing".to_string(),
84 ));
85 }
86 let config: GenerateInputConfig =
87 serde_json::from_value::<GenerateInputConfig>(config.clone().unwrap())?;
88 Ok(Arc::new(GenerateInput::new(config)?))
89 }
90}
91
92pub fn init() {
93 register_input_builder("generate", Arc::new(GenerateInputBuilder));
94}
95
96#[cfg(test)]
97mod tests {
98 use crate::input::{generate::GenerateInput, generate::GenerateInputConfig};
99 use arkflow_core::input::Input;
100 use arkflow_core::Error;
101 use std::time::Duration;
102
103 #[tokio::test]
104 async fn test_generate_input_new() {
105 let config = GenerateInputConfig {
107 context: "test message".to_string(),
108 interval: Duration::from_millis(10),
109 count: Some(5),
110 batch_size: Some(2),
111 };
112 let input = GenerateInput::new(config).unwrap();
113 assert_eq!(input.batch_size, 2);
114 }
115
116 #[tokio::test]
117 async fn test_generate_input_default_batch_size() {
118 let config = GenerateInputConfig {
120 context: "test message".to_string(),
121 interval: Duration::from_millis(10),
122 count: Some(5),
123 batch_size: None,
124 };
125 let input = GenerateInput::new(config).unwrap();
126 assert_eq!(input.batch_size, 1); }
128
129 #[tokio::test]
130 async fn test_generate_input_connect() {
131 let config = GenerateInputConfig {
133 context: "test message".to_string(),
134 interval: Duration::from_millis(10),
135 count: Some(5),
136 batch_size: Some(2),
137 };
138 let input = GenerateInput::new(config).unwrap();
139 assert!(input.connect().await.is_ok()); }
141
142 #[tokio::test]
143 async fn test_generate_input_read() {
144 let config = GenerateInputConfig {
146 context: "test message".to_string(),
147 interval: Duration::from_millis(10),
148 count: Some(5),
149 batch_size: Some(2),
150 };
151 let input = GenerateInput::new(config).unwrap();
152
153 let (batch, ack) = input.read().await.unwrap();
155 let messages = batch.as_binary();
156 assert_eq!(messages.len(), 2); for msg in messages {
158 assert_eq!(String::from_utf8(msg.to_vec()).unwrap(), "test message");
159 }
160 ack.ack().await;
161
162 let (batch, ack) = input.read().await.unwrap();
164 let messages = batch.as_binary();
165 assert_eq!(messages.len(), 2);
166 ack.ack().await;
167
168 let result = input.read().await;
170 assert!(matches!(result, Err(Error::EOF)));
171 }
172
173 #[tokio::test]
174 async fn test_generate_input_without_count_limit() {
175 let config = GenerateInputConfig {
177 context: "test message".to_string(),
178 interval: Duration::from_millis(10),
179 count: None, batch_size: Some(1),
181 };
182 let input = GenerateInput::new(config).unwrap();
183
184 for _ in 0..10 {
186 let result = input.read().await;
187 assert!(result.is_ok());
188 let (batch, ack) = result.unwrap();
189 let messages = batch.as_binary();
190 assert_eq!(messages.len(), 1);
191 ack.ack().await;
192 }
193 }
194
195 #[tokio::test]
196 async fn test_generate_input_close() {
197 let config = GenerateInputConfig {
199 context: "test message".to_string(),
200 interval: Duration::from_millis(10),
201 count: Some(5),
202 batch_size: Some(2),
203 };
204 let input = GenerateInput::new(config).unwrap();
205 assert!(input.close().await.is_ok()); }
207
208 #[tokio::test]
209 async fn test_generate_input_exact_count() {
210 let config = GenerateInputConfig {
212 context: "test message".to_string(),
213 interval: Duration::from_millis(10),
214 count: Some(4),
215 batch_size: Some(2),
216 };
217 let input = GenerateInput::new(config).unwrap();
218
219 let result = input.read().await;
221 assert!(result.is_ok());
222
223 let result = input.read().await;
225 assert!(result.is_ok());
226
227 let result = input.read().await;
229 assert!(matches!(result, Err(Error::EOF)));
230 }
231
232 #[tokio::test]
233 async fn test_deserialize_duration() {
234 let json = r#"{
236 "context": "test message",
237 "interval": "10ms",
238 "count": 5,
239 "batch_size": 2
240 }"#;
241
242 let config: GenerateInputConfig = serde_json::from_str(json).unwrap();
243 assert_eq!(config.context, "test message");
244 assert_eq!(config.interval, Duration::from_millis(10));
245 assert_eq!(config.count, Some(5));
246 assert_eq!(config.batch_size, Some(2));
247 }
248}