Skip to main content

camel_component_direct/
lib.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use async_trait::async_trait;
8use tokio::sync::{Mutex, mpsc, oneshot};
9use tower::Service;
10
11use camel_api::{BoxProcessor, CamelError, Exchange};
12use camel_component::{Component, Consumer, ConsumerContext, Endpoint, ProducerContext};
13use camel_endpoint::UriConfig;
14
15// ---------------------------------------------------------------------------
16// Shared state: maps endpoint names to senders that deliver exchanges to the
17// consumer side.  Each entry holds a sender of `(Exchange, oneshot::Sender)`
18// so the producer can wait for the consumer's pipeline to finish processing
19// and receive the (possibly transformed) exchange back.
20// ---------------------------------------------------------------------------
21
22type DirectSender = mpsc::Sender<(Exchange, oneshot::Sender<Result<Exchange, CamelError>>)>;
23type DirectRegistry = Arc<Mutex<HashMap<String, DirectSender>>>;
24
25// ---------------------------------------------------------------------------
26// DirectConfig
27// ---------------------------------------------------------------------------
28
29/// Configuration for Direct endpoints parsed from URIs.
30///
31/// URI format: `direct:name`
32///
33/// Example: `direct:foo` creates an endpoint named "foo"
34#[derive(Debug, Clone, UriConfig)]
35#[uri_scheme = "direct"]
36pub struct DirectConfig {
37    /// Endpoint name (path portion).
38    pub name: String,
39}
40
41// ---------------------------------------------------------------------------
42// DirectComponent
43// ---------------------------------------------------------------------------
44
45/// The Direct component provides in-memory synchronous communication between
46/// routes.
47///
48/// URI format: `direct:name`
49///
50/// A producer sending to `direct:foo` will block until the consumer on
51/// `direct:foo` has finished processing the exchange.
52pub struct DirectComponent {
53    registry: DirectRegistry,
54}
55
56impl DirectComponent {
57    pub fn new() -> Self {
58        Self {
59            registry: Arc::new(Mutex::new(HashMap::new())),
60        }
61    }
62}
63
64impl Default for DirectComponent {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl Component for DirectComponent {
71    fn scheme(&self) -> &str {
72        "direct"
73    }
74
75    fn create_endpoint(&self, uri: &str) -> Result<Box<dyn Endpoint>, CamelError> {
76        let config = DirectConfig::from_uri(uri)?;
77        Ok(Box::new(DirectEndpoint {
78            uri: uri.to_string(),
79            name: config.name,
80            registry: Arc::clone(&self.registry),
81        }))
82    }
83}
84
85// ---------------------------------------------------------------------------
86// DirectEndpoint
87// ---------------------------------------------------------------------------
88
89struct DirectEndpoint {
90    uri: String,
91    name: String,
92    registry: DirectRegistry,
93}
94
95impl Endpoint for DirectEndpoint {
96    fn uri(&self) -> &str {
97        &self.uri
98    }
99
100    fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
101        Ok(Box::new(DirectConsumer {
102            name: self.name.clone(),
103            registry: Arc::clone(&self.registry),
104        }))
105    }
106
107    fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
108        Ok(BoxProcessor::new(DirectProducer {
109            name: self.name.clone(),
110            registry: Arc::clone(&self.registry),
111        }))
112    }
113}
114
115// ---------------------------------------------------------------------------
116// DirectConsumer
117// ---------------------------------------------------------------------------
118
119/// The Direct consumer registers itself in the shared registry and forwards
120/// incoming exchanges to the route pipeline via `ConsumerContext`.
121struct DirectConsumer {
122    name: String,
123    registry: DirectRegistry,
124}
125
126#[async_trait]
127impl Consumer for DirectConsumer {
128    async fn start(&mut self, context: ConsumerContext) -> Result<(), CamelError> {
129        // Create a channel for producers to send exchanges to this consumer.
130        let (tx, mut rx) =
131            mpsc::channel::<(Exchange, oneshot::Sender<Result<Exchange, CamelError>>)>(32);
132
133        // Register ourselves so producers can find us.
134        {
135            let mut reg = self.registry.lock().await;
136            reg.insert(self.name.clone(), tx);
137        }
138
139        // Process incoming exchanges with cooperative cancellation.
140        loop {
141            tokio::select! {
142                _ = context.cancelled() => {
143                    tracing::debug!("Direct '{}' received cancellation, stopping", self.name);
144                    break;
145                }
146                msg = rx.recv() => {
147                    match msg {
148                        Some((exchange, reply_tx)) => {
149                            let result = context.send_and_wait(exchange).await;
150                            let _ = reply_tx.send(result);
151                        }
152                        None => break,
153                    }
154                }
155            }
156        }
157
158        // Cleanup: remove from registry on exit
159        {
160            let mut reg = self.registry.lock().await;
161            reg.remove(&self.name);
162        }
163
164        Ok(())
165    }
166
167    async fn stop(&mut self) -> Result<(), CamelError> {
168        // Remove from registry so no new producers can send to us.
169        let mut reg = self.registry.lock().await;
170        reg.remove(&self.name);
171        Ok(())
172    }
173}
174
175// ---------------------------------------------------------------------------
176// DirectProducer
177// ---------------------------------------------------------------------------
178
179/// The Direct producer sends an exchange to the named direct endpoint and
180/// waits for a reply (synchronous in-memory call).
181#[derive(Clone)]
182struct DirectProducer {
183    name: String,
184    registry: DirectRegistry,
185}
186
187impl Service<Exchange> for DirectProducer {
188    type Response = Exchange;
189    type Error = CamelError;
190    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
191
192    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
193        Poll::Ready(Ok(()))
194    }
195
196    fn call(&mut self, exchange: Exchange) -> Self::Future {
197        let name = self.name.clone();
198        let registry = Arc::clone(&self.registry);
199
200        Box::pin(async move {
201            let reg = registry.lock().await;
202            let sender = reg.get(&name).ok_or_else(|| {
203                CamelError::EndpointCreationFailed(format!(
204                    "no consumer registered for direct:{name}"
205                ))
206            })?;
207
208            let (reply_tx, reply_rx) = oneshot::channel();
209            sender
210                .send((exchange, reply_tx))
211                .await
212                .map_err(|_| CamelError::ChannelClosed)?;
213
214            // Drop the lock before awaiting the reply to avoid deadlocks.
215            drop(reg);
216
217            // Propagate Ok or Err from the subroute pipeline.
218            reply_rx.await.map_err(|_| CamelError::ChannelClosed)?
219        })
220    }
221}
222
223// ---------------------------------------------------------------------------
224// Tests
225// ---------------------------------------------------------------------------
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use camel_api::Message;
231    use camel_component::ExchangeEnvelope;
232    use tower::ServiceExt;
233
234    fn test_producer_ctx() -> ProducerContext {
235        ProducerContext::new()
236    }
237
238    #[test]
239    fn test_direct_component_scheme() {
240        let component = DirectComponent::new();
241        assert_eq!(component.scheme(), "direct");
242    }
243
244    #[test]
245    fn test_direct_creates_endpoint() {
246        let component = DirectComponent::new();
247        let endpoint = component.create_endpoint("direct:foo");
248        assert!(endpoint.is_ok());
249    }
250
251    #[test]
252    fn test_direct_wrong_scheme() {
253        let component = DirectComponent::new();
254        let result = component.create_endpoint("timer:tick");
255        assert!(result.is_err());
256    }
257
258    #[test]
259    fn test_direct_endpoint_creates_consumer() {
260        let component = DirectComponent::new();
261        let endpoint = component.create_endpoint("direct:foo").unwrap();
262        assert!(endpoint.create_consumer().is_ok());
263    }
264
265    #[test]
266    fn test_direct_endpoint_creates_producer() {
267        let ctx = test_producer_ctx();
268        let component = DirectComponent::new();
269        let endpoint = component.create_endpoint("direct:foo").unwrap();
270        assert!(endpoint.create_producer(&ctx).is_ok());
271    }
272
273    #[tokio::test]
274    async fn test_direct_producer_no_consumer_registered() {
275        let ctx = test_producer_ctx();
276        let component = DirectComponent::new();
277        let endpoint = component.create_endpoint("direct:missing").unwrap();
278        let producer = endpoint.create_producer(&ctx).unwrap();
279
280        let exchange = Exchange::new(Message::new("test"));
281        let result = producer.oneshot(exchange).await;
282        assert!(result.is_err());
283    }
284
285    #[tokio::test]
286    async fn test_direct_producer_consumer_roundtrip() {
287        let component = DirectComponent::new();
288
289        // Create consumer endpoint and start it
290        let consumer_endpoint = component.create_endpoint("direct:test").unwrap();
291        let mut consumer = consumer_endpoint.create_consumer().unwrap();
292
293        // The route channel now carries ExchangeEnvelope (request-reply support).
294        let (route_tx, mut route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
295        let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
296
297        // Start the consumer in a background task
298        tokio::spawn(async move {
299            consumer.start(ctx).await.unwrap();
300        });
301
302        // Give the consumer a moment to register
303        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
304
305        // Spawn a pipeline simulator that reads envelopes and replies Ok.
306        tokio::spawn(async move {
307            while let Some(envelope) = route_rx.recv().await {
308                let ExchangeEnvelope { exchange, reply_tx } = envelope;
309                if let Some(tx) = reply_tx {
310                    let _ = tx.send(Ok(exchange));
311                }
312            }
313        });
314
315        // Now send an exchange via the producer
316        let ctx = test_producer_ctx();
317        let producer_endpoint = component.create_endpoint("direct:test").unwrap();
318        let producer = producer_endpoint.create_producer(&ctx).unwrap();
319
320        let exchange = Exchange::new(Message::new("hello direct"));
321        let result = producer.oneshot(exchange).await;
322
323        assert!(result.is_ok());
324        let reply = result.unwrap();
325        assert_eq!(reply.input.body.as_text(), Some("hello direct"));
326    }
327
328    #[tokio::test]
329    async fn test_direct_propagates_error_when_no_handler() {
330        let component = DirectComponent::new();
331
332        let consumer_endpoint = component.create_endpoint("direct:err-test").unwrap();
333        let mut consumer = consumer_endpoint.create_consumer().unwrap();
334
335        let (route_tx, mut route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
336        let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
337
338        tokio::spawn(async move {
339            consumer.start(ctx).await.unwrap();
340        });
341
342        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
343
344        // Pipeline simulator that replies with Err (simulates no error handler).
345        tokio::spawn(async move {
346            while let Some(envelope) = route_rx.recv().await {
347                if let Some(tx) = envelope.reply_tx {
348                    let _ = tx.send(Err(CamelError::ProcessorError("subroute failed".into())));
349                }
350            }
351        });
352
353        let ctx = test_producer_ctx();
354        let producer_endpoint = component.create_endpoint("direct:err-test").unwrap();
355        let producer = producer_endpoint.create_producer(&ctx).unwrap();
356
357        let exchange = Exchange::new(Message::new("test"));
358        let result = producer.oneshot(exchange).await;
359        assert!(result.is_err());
360        assert!(matches!(result.unwrap_err(), CamelError::ProcessorError(_)));
361    }
362
363    #[tokio::test]
364    async fn test_direct_consumer_stop_unregisters() {
365        let component = DirectComponent::new();
366        let endpoint = component.create_endpoint("direct:cleanup").unwrap();
367
368        // We need a consumer to register
369        let mut consumer = endpoint.create_consumer().unwrap();
370
371        let (route_tx, _route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
372        let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
373
374        // Start consumer in background
375        let handle = tokio::spawn(async move {
376            consumer.start(ctx).await.unwrap();
377        });
378
379        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
380
381        // Verify the name is registered
382        {
383            let reg = component.registry.lock().await;
384            assert!(reg.contains_key("cleanup"));
385        }
386
387        // Create a new consumer just to call stop (stop removes from registry)
388        let mut stop_consumer = DirectConsumer {
389            name: "cleanup".to_string(),
390            registry: Arc::clone(&component.registry),
391        };
392        stop_consumer.stop().await.unwrap();
393
394        // Verify removed from registry
395        {
396            let reg = component.registry.lock().await;
397            assert!(!reg.contains_key("cleanup"));
398        }
399
400        handle.abort();
401    }
402
403    #[tokio::test]
404    async fn test_direct_consumer_respects_cancellation() {
405        use tokio_util::sync::CancellationToken;
406
407        let registry: DirectRegistry = Arc::new(Mutex::new(HashMap::new()));
408        let token = CancellationToken::new();
409        let (tx, _rx) = mpsc::channel(16);
410        let ctx = ConsumerContext::new(tx, token.clone());
411
412        let mut consumer = DirectConsumer {
413            name: "cancel-test".to_string(),
414            registry: registry.clone(),
415        };
416
417        let handle = tokio::spawn(async move {
418            consumer.start(ctx).await.unwrap();
419        });
420
421        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
422        assert!(registry.lock().await.contains_key("cancel-test"));
423
424        token.cancel();
425        let result = tokio::time::timeout(std::time::Duration::from_secs(1), handle).await;
426        assert!(
427            result.is_ok(),
428            "Consumer should have stopped after cancellation"
429        );
430
431        // After cancellation, the consumer should have cleaned up the registry
432        assert!(!registry.lock().await.contains_key("cancel-test"));
433    }
434}