1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use async_trait::async_trait;
8use tokio::sync::{Mutex, mpsc, oneshot};
9use tower::Service;
10
11use camel_component_api::UriConfig;
12use camel_component_api::{BoxProcessor, CamelError, Exchange};
13use camel_component_api::{Component, Consumer, ConsumerContext, Endpoint, ProducerContext};
14
15type DirectSender = mpsc::Sender<(Exchange, oneshot::Sender<Result<Exchange, CamelError>>)>;
23type DirectRegistry = Arc<Mutex<HashMap<String, DirectSender>>>;
24
25#[derive(Debug, Clone, UriConfig)]
35#[uri_scheme = "direct"]
36#[uri_config(crate = "camel_component_api")]
37pub struct DirectConfig {
38 pub name: String,
40}
41
42pub struct DirectComponent {
54 registry: DirectRegistry,
55}
56
57impl DirectComponent {
58 pub fn new() -> Self {
59 Self {
60 registry: Arc::new(Mutex::new(HashMap::new())),
61 }
62 }
63}
64
65impl Default for DirectComponent {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl Component for DirectComponent {
72 fn scheme(&self) -> &str {
73 "direct"
74 }
75
76 fn create_endpoint(
77 &self,
78 uri: &str,
79 _ctx: &dyn camel_component_api::ComponentContext,
80 ) -> Result<Box<dyn Endpoint>, CamelError> {
81 let config = DirectConfig::from_uri(uri)?;
82 Ok(Box::new(DirectEndpoint {
83 uri: uri.to_string(),
84 name: config.name,
85 registry: Arc::clone(&self.registry),
86 }))
87 }
88}
89
90struct DirectEndpoint {
95 uri: String,
96 name: String,
97 registry: DirectRegistry,
98}
99
100impl Endpoint for DirectEndpoint {
101 fn uri(&self) -> &str {
102 &self.uri
103 }
104
105 fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
106 Ok(Box::new(DirectConsumer {
107 name: self.name.clone(),
108 registry: Arc::clone(&self.registry),
109 }))
110 }
111
112 fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
113 Ok(BoxProcessor::new(DirectProducer {
114 name: self.name.clone(),
115 registry: Arc::clone(&self.registry),
116 }))
117 }
118}
119
120struct DirectConsumer {
127 name: String,
128 registry: DirectRegistry,
129}
130
131#[async_trait]
132impl Consumer for DirectConsumer {
133 async fn start(&mut self, context: ConsumerContext) -> Result<(), CamelError> {
134 let (tx, mut rx) =
136 mpsc::channel::<(Exchange, oneshot::Sender<Result<Exchange, CamelError>>)>(32);
137
138 {
140 let mut reg = self.registry.lock().await;
141 reg.insert(self.name.clone(), tx);
142 }
143
144 loop {
146 tokio::select! {
147 _ = context.cancelled() => {
148 tracing::debug!("Direct '{}' received cancellation, stopping", self.name);
149 break;
150 }
151 msg = rx.recv() => {
152 match msg {
153 Some((exchange, reply_tx)) => {
154 let result = context.send_and_wait(exchange).await;
155 let _ = reply_tx.send(result);
156 }
157 None => break,
158 }
159 }
160 }
161 }
162
163 {
165 let mut reg = self.registry.lock().await;
166 reg.remove(&self.name);
167 }
168
169 Ok(())
170 }
171
172 async fn stop(&mut self) -> Result<(), CamelError> {
173 let mut reg = self.registry.lock().await;
175 reg.remove(&self.name);
176 Ok(())
177 }
178}
179
180#[derive(Clone)]
187struct DirectProducer {
188 name: String,
189 registry: DirectRegistry,
190}
191
192impl Service<Exchange> for DirectProducer {
193 type Response = Exchange;
194 type Error = CamelError;
195 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
196
197 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
198 Poll::Ready(Ok(()))
199 }
200
201 fn call(&mut self, exchange: Exchange) -> Self::Future {
202 let name = self.name.clone();
203 let registry = Arc::clone(&self.registry);
204
205 Box::pin(async move {
206 let reg = registry.lock().await;
207 let sender = reg.get(&name).ok_or_else(|| {
208 CamelError::EndpointCreationFailed(format!(
209 "no consumer registered for direct:{name}"
210 ))
211 })?;
212
213 let (reply_tx, reply_rx) = oneshot::channel();
214 sender
215 .send((exchange, reply_tx))
216 .await
217 .map_err(|_| CamelError::ChannelClosed)?;
218
219 drop(reg);
221
222 reply_rx.await.map_err(|_| CamelError::ChannelClosed)?
224 })
225 }
226}
227
228#[cfg(test)]
233mod tests {
234 use super::*;
235 use camel_component_api::ExchangeEnvelope;
236 use camel_component_api::Message;
237 use camel_component_api::NoOpComponentContext;
238 use tower::ServiceExt;
239
240 fn test_producer_ctx() -> ProducerContext {
241 ProducerContext::new()
242 }
243
244 #[test]
245 fn test_direct_component_scheme() {
246 let component = DirectComponent::new();
247 assert_eq!(component.scheme(), "direct");
248 }
249
250 #[test]
251 fn test_direct_component_default() {
252 let component = DirectComponent::default();
253 assert_eq!(component.scheme(), "direct");
254 }
255
256 #[test]
257 fn test_direct_config_from_uri() {
258 let config = DirectConfig::from_uri("direct:orders").unwrap();
259 assert_eq!(config.name, "orders");
260 }
261
262 #[test]
263 fn test_direct_endpoint_uri() {
264 let component = DirectComponent::new();
265 let endpoint = component
266 .create_endpoint("direct:uri-check", &NoOpComponentContext)
267 .unwrap();
268 assert_eq!(endpoint.uri(), "direct:uri-check");
269 }
270
271 #[test]
272 fn test_direct_creates_endpoint() {
273 let component = DirectComponent::new();
274 let endpoint = component.create_endpoint("direct:foo", &NoOpComponentContext);
275 assert!(endpoint.is_ok());
276 }
277
278 #[test]
279 fn test_direct_wrong_scheme() {
280 let component = DirectComponent::new();
281 let result = component.create_endpoint("timer:tick", &NoOpComponentContext);
282 assert!(result.is_err());
283 }
284
285 #[test]
286 fn test_direct_endpoint_creates_consumer() {
287 let component = DirectComponent::new();
288 let endpoint = component
289 .create_endpoint("direct:foo", &NoOpComponentContext)
290 .unwrap();
291 assert!(endpoint.create_consumer().is_ok());
292 }
293
294 #[test]
295 fn test_direct_endpoint_creates_producer() {
296 let ctx = test_producer_ctx();
297 let component = DirectComponent::new();
298 let endpoint = component
299 .create_endpoint("direct:foo", &NoOpComponentContext)
300 .unwrap();
301 assert!(endpoint.create_producer(&ctx).is_ok());
302 }
303
304 #[tokio::test]
305 async fn test_direct_producer_no_consumer_registered() {
306 let ctx = test_producer_ctx();
307 let component = DirectComponent::new();
308 let endpoint = component
309 .create_endpoint("direct:missing", &NoOpComponentContext)
310 .unwrap();
311 let producer = endpoint.create_producer(&ctx).unwrap();
312
313 let exchange = Exchange::new(Message::new("test"));
314 let result = producer.oneshot(exchange).await;
315 assert!(result.is_err());
316 }
317
318 #[tokio::test]
319 async fn test_direct_producer_consumer_roundtrip() {
320 let component = DirectComponent::new();
321
322 let consumer_endpoint = component
324 .create_endpoint("direct:test", &NoOpComponentContext)
325 .unwrap();
326 let mut consumer = consumer_endpoint.create_consumer().unwrap();
327
328 let (route_tx, mut route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
330 let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
331
332 tokio::spawn(async move {
334 consumer.start(ctx).await.unwrap();
335 });
336
337 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
339
340 tokio::spawn(async move {
342 while let Some(envelope) = route_rx.recv().await {
343 let ExchangeEnvelope { exchange, reply_tx } = envelope;
344 if let Some(tx) = reply_tx {
345 let _ = tx.send(Ok(exchange));
346 }
347 }
348 });
349
350 let ctx = test_producer_ctx();
352 let producer_endpoint = component
353 .create_endpoint("direct:test", &NoOpComponentContext)
354 .unwrap();
355 let producer = producer_endpoint.create_producer(&ctx).unwrap();
356
357 let exchange = Exchange::new(Message::new("hello direct"));
358 let result = producer.oneshot(exchange).await;
359
360 assert!(result.is_ok());
361 let reply = result.unwrap();
362 assert_eq!(reply.input.body.as_text(), Some("hello direct"));
363 }
364
365 #[tokio::test]
366 async fn test_direct_propagates_error_when_no_handler() {
367 let component = DirectComponent::new();
368
369 let consumer_endpoint = component
370 .create_endpoint("direct:err-test", &NoOpComponentContext)
371 .unwrap();
372 let mut consumer = consumer_endpoint.create_consumer().unwrap();
373
374 let (route_tx, mut route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
375 let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
376
377 tokio::spawn(async move {
378 consumer.start(ctx).await.unwrap();
379 });
380
381 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
382
383 tokio::spawn(async move {
385 while let Some(envelope) = route_rx.recv().await {
386 if let Some(tx) = envelope.reply_tx {
387 let _ = tx.send(Err(CamelError::ProcessorError("subroute failed".into())));
388 }
389 }
390 });
391
392 let ctx = test_producer_ctx();
393 let producer_endpoint = component
394 .create_endpoint("direct:err-test", &NoOpComponentContext)
395 .unwrap();
396 let producer = producer_endpoint.create_producer(&ctx).unwrap();
397
398 let exchange = Exchange::new(Message::new("test"));
399 let result = producer.oneshot(exchange).await;
400 assert!(result.is_err());
401 assert!(matches!(result.unwrap_err(), CamelError::ProcessorError(_)));
402 }
403
404 #[tokio::test]
405 async fn test_direct_consumer_stop_unregisters() {
406 let component = DirectComponent::new();
407 let endpoint = component
408 .create_endpoint("direct:cleanup", &NoOpComponentContext)
409 .unwrap();
410
411 let mut consumer = endpoint.create_consumer().unwrap();
413
414 let (route_tx, _route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
415 let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
416
417 let handle = tokio::spawn(async move {
419 consumer.start(ctx).await.unwrap();
420 });
421
422 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
423
424 {
426 let reg = component.registry.lock().await;
427 assert!(reg.contains_key("cleanup"));
428 }
429
430 let mut stop_consumer = DirectConsumer {
432 name: "cleanup".to_string(),
433 registry: Arc::clone(&component.registry),
434 };
435 stop_consumer.stop().await.unwrap();
436
437 {
439 let reg = component.registry.lock().await;
440 assert!(!reg.contains_key("cleanup"));
441 }
442
443 handle.abort();
444 }
445
446 #[tokio::test]
447 async fn test_direct_consumer_respects_cancellation() {
448 use tokio_util::sync::CancellationToken;
449
450 let registry: DirectRegistry = Arc::new(Mutex::new(HashMap::new()));
451 let token = CancellationToken::new();
452 let (tx, _rx) = mpsc::channel(16);
453 let ctx = ConsumerContext::new(tx, token.clone());
454
455 let mut consumer = DirectConsumer {
456 name: "cancel-test".to_string(),
457 registry: registry.clone(),
458 };
459
460 let handle = tokio::spawn(async move {
461 consumer.start(ctx).await.unwrap();
462 });
463
464 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
465 assert!(registry.lock().await.contains_key("cancel-test"));
466
467 token.cancel();
468 let result = tokio::time::timeout(std::time::Duration::from_secs(1), handle).await;
469 assert!(
470 result.is_ok(),
471 "Consumer should have stopped after cancellation"
472 );
473
474 assert!(!registry.lock().await.contains_key("cancel-test"));
476 }
477
478 #[tokio::test]
479 async fn test_direct_consumer_stop_missing_entry_is_ok() {
480 let registry: DirectRegistry = Arc::new(Mutex::new(HashMap::new()));
481 let mut consumer = DirectConsumer {
482 name: "never-registered".to_string(),
483 registry,
484 };
485 let result = consumer.stop().await;
486 assert!(result.is_ok());
487 }
488}