arkflow_plugin/input/
generate.rs

1use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder, NoopAck};
2use arkflow_core::{Error, MessageBatch};
3use async_trait::async_trait;
4use serde::{Deserialize, Deserializer, Serialize};
5use std::sync::atomic::{AtomicI64, Ordering};
6use std::sync::Arc;
7use std::time::Duration;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct GenerateInputConfig {
11    context: String,
12    #[serde(deserialize_with = "deserialize_duration")]
13    interval: Duration,
14    count: Option<usize>,
15    batch_size: Option<usize>,
16}
17
18pub struct GenerateInput {
19    config: GenerateInputConfig,
20    count: AtomicI64,
21    batch_size: usize,
22}
23impl GenerateInput {
24    pub fn new(config: GenerateInputConfig) -> Result<Self, Error> {
25        let batch_size = config.batch_size.unwrap_or(1);
26
27        Ok(Self {
28            config,
29            count: AtomicI64::new(0),
30            batch_size,
31        })
32    }
33}
34
35#[async_trait]
36impl Input for GenerateInput {
37    async fn connect(&self) -> Result<(), Error> {
38        Ok(())
39    }
40
41    async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
42        tokio::time::sleep(self.config.interval).await;
43
44        if let Some(count) = self.config.count {
45            let current_count = self.count.load(Ordering::SeqCst);
46            if current_count >= count as i64 {
47                return Err(Error::EOF);
48            }
49            // Check if adding the current batch would exceed the total count limit
50            if current_count + self.batch_size as i64 > count as i64 {
51                return Err(Error::EOF);
52            }
53        }
54        let mut msgs = Vec::with_capacity(self.batch_size);
55        for _ in 0..self.batch_size {
56            let s = self.config.context.clone();
57            msgs.push(s.into_bytes())
58        }
59
60        self.count
61            .fetch_add(self.batch_size as i64, Ordering::SeqCst);
62
63        Ok((MessageBatch::new_binary(msgs), Arc::new(NoopAck)))
64    }
65    async fn close(&self) -> Result<(), Error> {
66        Ok(())
67    }
68}
69
70fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
71where
72    D: Deserializer<'de>,
73{
74    let s: String = Deserialize::deserialize(deserializer)?;
75    humantime::parse_duration(&s).map_err(serde::de::Error::custom)
76}
77
78pub(crate) struct GenerateInputBuilder;
79impl InputBuilder for GenerateInputBuilder {
80    fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
81        if config.is_none() {
82            return Err(Error::Config(
83                "Generate input configuration is missing".to_string(),
84            ));
85        }
86        let config: GenerateInputConfig =
87            serde_json::from_value::<GenerateInputConfig>(config.clone().unwrap())?;
88        Ok(Arc::new(GenerateInput::new(config)?))
89    }
90}
91
92pub fn init() {
93    register_input_builder("generate", Arc::new(GenerateInputBuilder));
94}
95
96#[cfg(test)]
97mod tests {
98    use crate::input::{generate::GenerateInput, generate::GenerateInputConfig};
99    use arkflow_core::input::Input;
100    use arkflow_core::Error;
101    use std::time::Duration;
102
103    #[tokio::test]
104    async fn test_generate_input_new() {
105        // Test creating GenerateInput instance
106        let config = GenerateInputConfig {
107            context: "test message".to_string(),
108            interval: Duration::from_millis(10),
109            count: Some(5),
110            batch_size: Some(2),
111        };
112        let input = GenerateInput::new(config).unwrap();
113        assert_eq!(input.batch_size, 2);
114    }
115
116    #[tokio::test]
117    async fn test_generate_input_default_batch_size() {
118        // Test default batch size
119        let config = GenerateInputConfig {
120            context: "test message".to_string(),
121            interval: Duration::from_millis(10),
122            count: Some(5),
123            batch_size: None,
124        };
125        let input = GenerateInput::new(config).unwrap();
126        assert_eq!(input.batch_size, 1); // Default batch size should be 1
127    }
128
129    #[tokio::test]
130    async fn test_generate_input_connect() {
131        // Test connection method
132        let config = GenerateInputConfig {
133            context: "test message".to_string(),
134            interval: Duration::from_millis(10),
135            count: Some(5),
136            batch_size: Some(2),
137        };
138        let input = GenerateInput::new(config).unwrap();
139        assert!(input.connect().await.is_ok()); // Connection should succeed
140    }
141
142    #[tokio::test]
143    async fn test_generate_input_read() {
144        // Test reading messages
145        let config = GenerateInputConfig {
146            context: "test message".to_string(),
147            interval: Duration::from_millis(10),
148            count: Some(5),
149            batch_size: Some(2),
150        };
151        let input = GenerateInput::new(config).unwrap();
152
153        // Read the first batch of messages
154        let (batch, ack) = input.read().await.unwrap();
155        let messages = batch.as_binary();
156        assert_eq!(messages.len(), 2); // Batch size is 2
157        for msg in messages {
158            assert_eq!(String::from_utf8(msg.to_vec()).unwrap(), "test message");
159        }
160        ack.ack().await;
161
162        // Read the second batch of messages
163        let (batch, ack) = input.read().await.unwrap();
164        let messages = batch.as_binary();
165        assert_eq!(messages.len(), 2);
166        ack.ack().await;
167
168        // Read the third batch of messages (reached the limit of count=5, because 2+2+2>5)
169        let result = input.read().await;
170        assert!(matches!(result, Err(Error::EOF)));
171    }
172
173    #[tokio::test]
174    async fn test_generate_input_without_count_limit() {
175        // Test the case without message count limit
176        let config = GenerateInputConfig {
177            context: "test message".to_string(),
178            interval: Duration::from_millis(10),
179            count: None, // No limit
180            batch_size: Some(1),
181        };
182        let input = GenerateInput::new(config).unwrap();
183
184        // Can read multiple times consecutively
185        for _ in 0..10 {
186            let result = input.read().await;
187            assert!(result.is_ok());
188            let (batch, ack) = result.unwrap();
189            let messages = batch.as_binary();
190            assert_eq!(messages.len(), 1);
191            ack.ack().await;
192        }
193    }
194
195    #[tokio::test]
196    async fn test_generate_input_close() {
197        // Test closing connection
198        let config = GenerateInputConfig {
199            context: "test message".to_string(),
200            interval: Duration::from_millis(10),
201            count: Some(5),
202            batch_size: Some(2),
203        };
204        let input = GenerateInput::new(config).unwrap();
205        assert!(input.close().await.is_ok()); // Closing should succeed
206    }
207
208    #[tokio::test]
209    async fn test_generate_input_exact_count() {
210        // Test exact count limit
211        let config = GenerateInputConfig {
212            context: "test message".to_string(),
213            interval: Duration::from_millis(10),
214            count: Some(4),
215            batch_size: Some(2),
216        };
217        let input = GenerateInput::new(config).unwrap();
218
219        // Read the first batch of messages (2 messages)
220        let result = input.read().await;
221        assert!(result.is_ok());
222
223        // Read the second batch of messages (2 messages, reaching the limit)
224        let result = input.read().await;
225        assert!(result.is_ok());
226
227        // Try to read the third batch of messages (should return Done error)
228        let result = input.read().await;
229        assert!(matches!(result, Err(Error::EOF)));
230    }
231
232    #[tokio::test]
233    async fn test_deserialize_duration() {
234        // Test deserialization from JSON
235        let json = r#"{
236            "context": "test message",
237            "interval": "10ms",
238            "count": 5,
239            "batch_size": 2
240        }"#;
241
242        let config: GenerateInputConfig = serde_json::from_str(json).unwrap();
243        assert_eq!(config.context, "test message");
244        assert_eq!(config.interval, Duration::from_millis(10));
245        assert_eq!(config.count, Some(5));
246        assert_eq!(config.batch_size, Some(2));
247    }
248}