Skip to main content

camel_component_direct/
lib.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use async_trait::async_trait;
8use tokio::sync::{Mutex, mpsc, oneshot};
9use tower::Service;
10
11use camel_api::{BoxProcessor, CamelError, Exchange};
12use camel_component::{Component, Consumer, ConsumerContext, Endpoint, ProducerContext};
13use camel_endpoint::parse_uri;
14
15// ---------------------------------------------------------------------------
16// Shared state: maps endpoint names to senders that deliver exchanges to the
17// consumer side.  Each entry holds a sender of `(Exchange, oneshot::Sender)`
18// so the producer can wait for the consumer's pipeline to finish processing
19// and receive the (possibly transformed) exchange back.
20// ---------------------------------------------------------------------------
21
22type DirectSender = mpsc::Sender<(Exchange, oneshot::Sender<Result<Exchange, CamelError>>)>;
23type DirectRegistry = Arc<Mutex<HashMap<String, DirectSender>>>;
24
25// ---------------------------------------------------------------------------
26// DirectComponent
27// ---------------------------------------------------------------------------
28
29/// The Direct component provides in-memory synchronous communication between
30/// routes.
31///
32/// URI format: `direct:name`
33///
34/// A producer sending to `direct:foo` will block until the consumer on
35/// `direct:foo` has finished processing the exchange.
36pub struct DirectComponent {
37    registry: DirectRegistry,
38}
39
40impl DirectComponent {
41    pub fn new() -> Self {
42        Self {
43            registry: Arc::new(Mutex::new(HashMap::new())),
44        }
45    }
46}
47
48impl Default for DirectComponent {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl Component for DirectComponent {
55    fn scheme(&self) -> &str {
56        "direct"
57    }
58
59    fn create_endpoint(&self, uri: &str) -> Result<Box<dyn Endpoint>, CamelError> {
60        let parts = parse_uri(uri)?;
61        if parts.scheme != "direct" {
62            return Err(CamelError::InvalidUri(format!(
63                "expected scheme 'direct', got '{}'",
64                parts.scheme
65            )));
66        }
67
68        Ok(Box::new(DirectEndpoint {
69            uri: uri.to_string(),
70            name: parts.path,
71            registry: Arc::clone(&self.registry),
72        }))
73    }
74}
75
76// ---------------------------------------------------------------------------
77// DirectEndpoint
78// ---------------------------------------------------------------------------
79
80struct DirectEndpoint {
81    uri: String,
82    name: String,
83    registry: DirectRegistry,
84}
85
86impl Endpoint for DirectEndpoint {
87    fn uri(&self) -> &str {
88        &self.uri
89    }
90
91    fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
92        Ok(Box::new(DirectConsumer {
93            name: self.name.clone(),
94            registry: Arc::clone(&self.registry),
95        }))
96    }
97
98    fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
99        Ok(BoxProcessor::new(DirectProducer {
100            name: self.name.clone(),
101            registry: Arc::clone(&self.registry),
102        }))
103    }
104}
105
106// ---------------------------------------------------------------------------
107// DirectConsumer
108// ---------------------------------------------------------------------------
109
110/// The Direct consumer registers itself in the shared registry and forwards
111/// incoming exchanges to the route pipeline via `ConsumerContext`.
112struct DirectConsumer {
113    name: String,
114    registry: DirectRegistry,
115}
116
117#[async_trait]
118impl Consumer for DirectConsumer {
119    async fn start(&mut self, context: ConsumerContext) -> Result<(), CamelError> {
120        // Create a channel for producers to send exchanges to this consumer.
121        let (tx, mut rx) =
122            mpsc::channel::<(Exchange, oneshot::Sender<Result<Exchange, CamelError>>)>(32);
123
124        // Register ourselves so producers can find us.
125        {
126            let mut reg = self.registry.lock().await;
127            reg.insert(self.name.clone(), tx);
128        }
129
130        // Process incoming exchanges with cooperative cancellation.
131        loop {
132            tokio::select! {
133                _ = context.cancelled() => {
134                    tracing::debug!("Direct '{}' received cancellation, stopping", self.name);
135                    break;
136                }
137                msg = rx.recv() => {
138                    match msg {
139                        Some((exchange, reply_tx)) => {
140                            let result = context.send_and_wait(exchange).await;
141                            let _ = reply_tx.send(result);
142                        }
143                        None => break,
144                    }
145                }
146            }
147        }
148
149        // Cleanup: remove from registry on exit
150        {
151            let mut reg = self.registry.lock().await;
152            reg.remove(&self.name);
153        }
154
155        Ok(())
156    }
157
158    async fn stop(&mut self) -> Result<(), CamelError> {
159        // Remove from registry so no new producers can send to us.
160        let mut reg = self.registry.lock().await;
161        reg.remove(&self.name);
162        Ok(())
163    }
164}
165
166// ---------------------------------------------------------------------------
167// DirectProducer
168// ---------------------------------------------------------------------------
169
170/// The Direct producer sends an exchange to the named direct endpoint and
171/// waits for a reply (synchronous in-memory call).
172#[derive(Clone)]
173struct DirectProducer {
174    name: String,
175    registry: DirectRegistry,
176}
177
178impl Service<Exchange> for DirectProducer {
179    type Response = Exchange;
180    type Error = CamelError;
181    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
182
183    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184        Poll::Ready(Ok(()))
185    }
186
187    fn call(&mut self, exchange: Exchange) -> Self::Future {
188        let name = self.name.clone();
189        let registry = Arc::clone(&self.registry);
190
191        Box::pin(async move {
192            let reg = registry.lock().await;
193            let sender = reg.get(&name).ok_or_else(|| {
194                CamelError::EndpointCreationFailed(format!(
195                    "no consumer registered for direct:{name}"
196                ))
197            })?;
198
199            let (reply_tx, reply_rx) = oneshot::channel();
200            sender
201                .send((exchange, reply_tx))
202                .await
203                .map_err(|_| CamelError::ChannelClosed)?;
204
205            // Drop the lock before awaiting the reply to avoid deadlocks.
206            drop(reg);
207
208            // Propagate Ok or Err from the subroute pipeline.
209            reply_rx.await.map_err(|_| CamelError::ChannelClosed)?
210        })
211    }
212}
213
214// ---------------------------------------------------------------------------
215// Tests
216// ---------------------------------------------------------------------------
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use camel_api::Message;
222    use camel_component::ExchangeEnvelope;
223    use tower::ServiceExt;
224
225    // NullRouteController for testing
226    struct NullRouteController;
227    #[async_trait::async_trait]
228    impl camel_api::RouteController for NullRouteController {
229        async fn start_route(&mut self, _: &str) -> Result<(), camel_api::CamelError> {
230            Ok(())
231        }
232        async fn stop_route(&mut self, _: &str) -> Result<(), camel_api::CamelError> {
233            Ok(())
234        }
235        async fn restart_route(&mut self, _: &str) -> Result<(), camel_api::CamelError> {
236            Ok(())
237        }
238        async fn suspend_route(&mut self, _: &str) -> Result<(), camel_api::CamelError> {
239            Ok(())
240        }
241        async fn resume_route(&mut self, _: &str) -> Result<(), camel_api::CamelError> {
242            Ok(())
243        }
244        fn route_status(&self, _: &str) -> Option<camel_api::RouteStatus> {
245            None
246        }
247        async fn start_all_routes(&mut self) -> Result<(), camel_api::CamelError> {
248            Ok(())
249        }
250        async fn stop_all_routes(&mut self) -> Result<(), camel_api::CamelError> {
251            Ok(())
252        }
253    }
254
255    fn test_producer_ctx() -> ProducerContext {
256        ProducerContext::new(Arc::new(Mutex::new(NullRouteController)))
257    }
258
259    #[test]
260    fn test_direct_component_scheme() {
261        let component = DirectComponent::new();
262        assert_eq!(component.scheme(), "direct");
263    }
264
265    #[test]
266    fn test_direct_creates_endpoint() {
267        let component = DirectComponent::new();
268        let endpoint = component.create_endpoint("direct:foo");
269        assert!(endpoint.is_ok());
270    }
271
272    #[test]
273    fn test_direct_wrong_scheme() {
274        let component = DirectComponent::new();
275        let result = component.create_endpoint("timer:tick");
276        assert!(result.is_err());
277    }
278
279    #[test]
280    fn test_direct_endpoint_creates_consumer() {
281        let component = DirectComponent::new();
282        let endpoint = component.create_endpoint("direct:foo").unwrap();
283        assert!(endpoint.create_consumer().is_ok());
284    }
285
286    #[test]
287    fn test_direct_endpoint_creates_producer() {
288        let ctx = test_producer_ctx();
289        let component = DirectComponent::new();
290        let endpoint = component.create_endpoint("direct:foo").unwrap();
291        assert!(endpoint.create_producer(&ctx).is_ok());
292    }
293
294    #[tokio::test]
295    async fn test_direct_producer_no_consumer_registered() {
296        let ctx = test_producer_ctx();
297        let component = DirectComponent::new();
298        let endpoint = component.create_endpoint("direct:missing").unwrap();
299        let producer = endpoint.create_producer(&ctx).unwrap();
300
301        let exchange = Exchange::new(Message::new("test"));
302        let result = producer.oneshot(exchange).await;
303        assert!(result.is_err());
304    }
305
306    #[tokio::test]
307    async fn test_direct_producer_consumer_roundtrip() {
308        let component = DirectComponent::new();
309
310        // Create consumer endpoint and start it
311        let consumer_endpoint = component.create_endpoint("direct:test").unwrap();
312        let mut consumer = consumer_endpoint.create_consumer().unwrap();
313
314        // The route channel now carries ExchangeEnvelope (request-reply support).
315        let (route_tx, mut route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
316        let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
317
318        // Start the consumer in a background task
319        tokio::spawn(async move {
320            consumer.start(ctx).await.unwrap();
321        });
322
323        // Give the consumer a moment to register
324        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
325
326        // Spawn a pipeline simulator that reads envelopes and replies Ok.
327        tokio::spawn(async move {
328            while let Some(envelope) = route_rx.recv().await {
329                let ExchangeEnvelope { exchange, reply_tx } = envelope;
330                if let Some(tx) = reply_tx {
331                    let _ = tx.send(Ok(exchange));
332                }
333            }
334        });
335
336        // Now send an exchange via the producer
337        let ctx = test_producer_ctx();
338        let producer_endpoint = component.create_endpoint("direct:test").unwrap();
339        let producer = producer_endpoint.create_producer(&ctx).unwrap();
340
341        let exchange = Exchange::new(Message::new("hello direct"));
342        let result = producer.oneshot(exchange).await;
343
344        assert!(result.is_ok());
345        let reply = result.unwrap();
346        assert_eq!(reply.input.body.as_text(), Some("hello direct"));
347    }
348
349    #[tokio::test]
350    async fn test_direct_propagates_error_when_no_handler() {
351        let component = DirectComponent::new();
352
353        let consumer_endpoint = component.create_endpoint("direct:err-test").unwrap();
354        let mut consumer = consumer_endpoint.create_consumer().unwrap();
355
356        let (route_tx, mut route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
357        let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
358
359        tokio::spawn(async move {
360            consumer.start(ctx).await.unwrap();
361        });
362
363        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
364
365        // Pipeline simulator that replies with Err (simulates no error handler).
366        tokio::spawn(async move {
367            while let Some(envelope) = route_rx.recv().await {
368                if let Some(tx) = envelope.reply_tx {
369                    let _ = tx.send(Err(CamelError::ProcessorError("subroute failed".into())));
370                }
371            }
372        });
373
374        let ctx = test_producer_ctx();
375        let producer_endpoint = component.create_endpoint("direct:err-test").unwrap();
376        let producer = producer_endpoint.create_producer(&ctx).unwrap();
377
378        let exchange = Exchange::new(Message::new("test"));
379        let result = producer.oneshot(exchange).await;
380        assert!(result.is_err());
381        assert!(matches!(result.unwrap_err(), CamelError::ProcessorError(_)));
382    }
383
384    #[tokio::test]
385    async fn test_direct_consumer_stop_unregisters() {
386        let component = DirectComponent::new();
387        let endpoint = component.create_endpoint("direct:cleanup").unwrap();
388
389        // We need a consumer to register
390        let mut consumer = endpoint.create_consumer().unwrap();
391
392        let (route_tx, _route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
393        let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
394
395        // Start consumer in background
396        let handle = tokio::spawn(async move {
397            consumer.start(ctx).await.unwrap();
398        });
399
400        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
401
402        // Verify the name is registered
403        {
404            let reg = component.registry.lock().await;
405            assert!(reg.contains_key("cleanup"));
406        }
407
408        // Create a new consumer just to call stop (stop removes from registry)
409        let mut stop_consumer = DirectConsumer {
410            name: "cleanup".to_string(),
411            registry: Arc::clone(&component.registry),
412        };
413        stop_consumer.stop().await.unwrap();
414
415        // Verify removed from registry
416        {
417            let reg = component.registry.lock().await;
418            assert!(!reg.contains_key("cleanup"));
419        }
420
421        handle.abort();
422    }
423
424    #[tokio::test]
425    async fn test_direct_consumer_respects_cancellation() {
426        use tokio_util::sync::CancellationToken;
427
428        let registry: DirectRegistry = Arc::new(Mutex::new(HashMap::new()));
429        let token = CancellationToken::new();
430        let (tx, _rx) = mpsc::channel(16);
431        let ctx = ConsumerContext::new(tx, token.clone());
432
433        let mut consumer = DirectConsumer {
434            name: "cancel-test".to_string(),
435            registry: registry.clone(),
436        };
437
438        let handle = tokio::spawn(async move {
439            consumer.start(ctx).await.unwrap();
440        });
441
442        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
443        assert!(registry.lock().await.contains_key("cancel-test"));
444
445        token.cancel();
446        let result = tokio::time::timeout(std::time::Duration::from_secs(1), handle).await;
447        assert!(
448            result.is_ok(),
449            "Consumer should have stopped after cancellation"
450        );
451
452        // After cancellation, the consumer should have cleaned up the registry
453        assert!(!registry.lock().await.contains_key("cancel-test"));
454    }
455}