camel_component_direct/
lib.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use async_trait::async_trait;
8use tokio::sync::{Mutex, mpsc, oneshot};
9use tower::Service;
10
11use camel_api::{BoxProcessor, CamelError, Exchange};
12use camel_component::{Component, Consumer, ConsumerContext, Endpoint, ProducerContext};
13use camel_endpoint::UriConfig;
14
15type DirectSender = mpsc::Sender<(Exchange, oneshot::Sender<Result<Exchange, CamelError>>)>;
23type DirectRegistry = Arc<Mutex<HashMap<String, DirectSender>>>;
24
25#[derive(Debug, Clone, UriConfig)]
35#[uri_scheme = "direct"]
36pub struct DirectConfig {
37 pub name: String,
39}
40
41pub struct DirectComponent {
53 registry: DirectRegistry,
54}
55
56impl DirectComponent {
57 pub fn new() -> Self {
58 Self {
59 registry: Arc::new(Mutex::new(HashMap::new())),
60 }
61 }
62}
63
64impl Default for DirectComponent {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl Component for DirectComponent {
71 fn scheme(&self) -> &str {
72 "direct"
73 }
74
75 fn create_endpoint(&self, uri: &str) -> Result<Box<dyn Endpoint>, CamelError> {
76 let config = DirectConfig::from_uri(uri)?;
77 Ok(Box::new(DirectEndpoint {
78 uri: uri.to_string(),
79 name: config.name,
80 registry: Arc::clone(&self.registry),
81 }))
82 }
83}
84
85struct DirectEndpoint {
90 uri: String,
91 name: String,
92 registry: DirectRegistry,
93}
94
95impl Endpoint for DirectEndpoint {
96 fn uri(&self) -> &str {
97 &self.uri
98 }
99
100 fn create_consumer(&self) -> Result<Box<dyn Consumer>, CamelError> {
101 Ok(Box::new(DirectConsumer {
102 name: self.name.clone(),
103 registry: Arc::clone(&self.registry),
104 }))
105 }
106
107 fn create_producer(&self, _ctx: &ProducerContext) -> Result<BoxProcessor, CamelError> {
108 Ok(BoxProcessor::new(DirectProducer {
109 name: self.name.clone(),
110 registry: Arc::clone(&self.registry),
111 }))
112 }
113}
114
115struct DirectConsumer {
122 name: String,
123 registry: DirectRegistry,
124}
125
126#[async_trait]
127impl Consumer for DirectConsumer {
128 async fn start(&mut self, context: ConsumerContext) -> Result<(), CamelError> {
129 let (tx, mut rx) =
131 mpsc::channel::<(Exchange, oneshot::Sender<Result<Exchange, CamelError>>)>(32);
132
133 {
135 let mut reg = self.registry.lock().await;
136 reg.insert(self.name.clone(), tx);
137 }
138
139 loop {
141 tokio::select! {
142 _ = context.cancelled() => {
143 tracing::debug!("Direct '{}' received cancellation, stopping", self.name);
144 break;
145 }
146 msg = rx.recv() => {
147 match msg {
148 Some((exchange, reply_tx)) => {
149 let result = context.send_and_wait(exchange).await;
150 let _ = reply_tx.send(result);
151 }
152 None => break,
153 }
154 }
155 }
156 }
157
158 {
160 let mut reg = self.registry.lock().await;
161 reg.remove(&self.name);
162 }
163
164 Ok(())
165 }
166
167 async fn stop(&mut self) -> Result<(), CamelError> {
168 let mut reg = self.registry.lock().await;
170 reg.remove(&self.name);
171 Ok(())
172 }
173}
174
175#[derive(Clone)]
182struct DirectProducer {
183 name: String,
184 registry: DirectRegistry,
185}
186
187impl Service<Exchange> for DirectProducer {
188 type Response = Exchange;
189 type Error = CamelError;
190 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
191
192 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
193 Poll::Ready(Ok(()))
194 }
195
196 fn call(&mut self, exchange: Exchange) -> Self::Future {
197 let name = self.name.clone();
198 let registry = Arc::clone(&self.registry);
199
200 Box::pin(async move {
201 let reg = registry.lock().await;
202 let sender = reg.get(&name).ok_or_else(|| {
203 CamelError::EndpointCreationFailed(format!(
204 "no consumer registered for direct:{name}"
205 ))
206 })?;
207
208 let (reply_tx, reply_rx) = oneshot::channel();
209 sender
210 .send((exchange, reply_tx))
211 .await
212 .map_err(|_| CamelError::ChannelClosed)?;
213
214 drop(reg);
216
217 reply_rx.await.map_err(|_| CamelError::ChannelClosed)?
219 })
220 }
221}
222
223#[cfg(test)]
228mod tests {
229 use super::*;
230 use camel_api::Message;
231 use camel_component::ExchangeEnvelope;
232 use tower::ServiceExt;
233
234 fn test_producer_ctx() -> ProducerContext {
235 ProducerContext::new()
236 }
237
238 #[test]
239 fn test_direct_component_scheme() {
240 let component = DirectComponent::new();
241 assert_eq!(component.scheme(), "direct");
242 }
243
244 #[test]
245 fn test_direct_creates_endpoint() {
246 let component = DirectComponent::new();
247 let endpoint = component.create_endpoint("direct:foo");
248 assert!(endpoint.is_ok());
249 }
250
251 #[test]
252 fn test_direct_wrong_scheme() {
253 let component = DirectComponent::new();
254 let result = component.create_endpoint("timer:tick");
255 assert!(result.is_err());
256 }
257
258 #[test]
259 fn test_direct_endpoint_creates_consumer() {
260 let component = DirectComponent::new();
261 let endpoint = component.create_endpoint("direct:foo").unwrap();
262 assert!(endpoint.create_consumer().is_ok());
263 }
264
265 #[test]
266 fn test_direct_endpoint_creates_producer() {
267 let ctx = test_producer_ctx();
268 let component = DirectComponent::new();
269 let endpoint = component.create_endpoint("direct:foo").unwrap();
270 assert!(endpoint.create_producer(&ctx).is_ok());
271 }
272
273 #[tokio::test]
274 async fn test_direct_producer_no_consumer_registered() {
275 let ctx = test_producer_ctx();
276 let component = DirectComponent::new();
277 let endpoint = component.create_endpoint("direct:missing").unwrap();
278 let producer = endpoint.create_producer(&ctx).unwrap();
279
280 let exchange = Exchange::new(Message::new("test"));
281 let result = producer.oneshot(exchange).await;
282 assert!(result.is_err());
283 }
284
285 #[tokio::test]
286 async fn test_direct_producer_consumer_roundtrip() {
287 let component = DirectComponent::new();
288
289 let consumer_endpoint = component.create_endpoint("direct:test").unwrap();
291 let mut consumer = consumer_endpoint.create_consumer().unwrap();
292
293 let (route_tx, mut route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
295 let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
296
297 tokio::spawn(async move {
299 consumer.start(ctx).await.unwrap();
300 });
301
302 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
304
305 tokio::spawn(async move {
307 while let Some(envelope) = route_rx.recv().await {
308 let ExchangeEnvelope { exchange, reply_tx } = envelope;
309 if let Some(tx) = reply_tx {
310 let _ = tx.send(Ok(exchange));
311 }
312 }
313 });
314
315 let ctx = test_producer_ctx();
317 let producer_endpoint = component.create_endpoint("direct:test").unwrap();
318 let producer = producer_endpoint.create_producer(&ctx).unwrap();
319
320 let exchange = Exchange::new(Message::new("hello direct"));
321 let result = producer.oneshot(exchange).await;
322
323 assert!(result.is_ok());
324 let reply = result.unwrap();
325 assert_eq!(reply.input.body.as_text(), Some("hello direct"));
326 }
327
328 #[tokio::test]
329 async fn test_direct_propagates_error_when_no_handler() {
330 let component = DirectComponent::new();
331
332 let consumer_endpoint = component.create_endpoint("direct:err-test").unwrap();
333 let mut consumer = consumer_endpoint.create_consumer().unwrap();
334
335 let (route_tx, mut route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
336 let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
337
338 tokio::spawn(async move {
339 consumer.start(ctx).await.unwrap();
340 });
341
342 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
343
344 tokio::spawn(async move {
346 while let Some(envelope) = route_rx.recv().await {
347 if let Some(tx) = envelope.reply_tx {
348 let _ = tx.send(Err(CamelError::ProcessorError("subroute failed".into())));
349 }
350 }
351 });
352
353 let ctx = test_producer_ctx();
354 let producer_endpoint = component.create_endpoint("direct:err-test").unwrap();
355 let producer = producer_endpoint.create_producer(&ctx).unwrap();
356
357 let exchange = Exchange::new(Message::new("test"));
358 let result = producer.oneshot(exchange).await;
359 assert!(result.is_err());
360 assert!(matches!(result.unwrap_err(), CamelError::ProcessorError(_)));
361 }
362
363 #[tokio::test]
364 async fn test_direct_consumer_stop_unregisters() {
365 let component = DirectComponent::new();
366 let endpoint = component.create_endpoint("direct:cleanup").unwrap();
367
368 let mut consumer = endpoint.create_consumer().unwrap();
370
371 let (route_tx, _route_rx) = mpsc::channel::<ExchangeEnvelope>(16);
372 let ctx = ConsumerContext::new(route_tx, tokio_util::sync::CancellationToken::new());
373
374 let handle = tokio::spawn(async move {
376 consumer.start(ctx).await.unwrap();
377 });
378
379 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
380
381 {
383 let reg = component.registry.lock().await;
384 assert!(reg.contains_key("cleanup"));
385 }
386
387 let mut stop_consumer = DirectConsumer {
389 name: "cleanup".to_string(),
390 registry: Arc::clone(&component.registry),
391 };
392 stop_consumer.stop().await.unwrap();
393
394 {
396 let reg = component.registry.lock().await;
397 assert!(!reg.contains_key("cleanup"));
398 }
399
400 handle.abort();
401 }
402
403 #[tokio::test]
404 async fn test_direct_consumer_respects_cancellation() {
405 use tokio_util::sync::CancellationToken;
406
407 let registry: DirectRegistry = Arc::new(Mutex::new(HashMap::new()));
408 let token = CancellationToken::new();
409 let (tx, _rx) = mpsc::channel(16);
410 let ctx = ConsumerContext::new(tx, token.clone());
411
412 let mut consumer = DirectConsumer {
413 name: "cancel-test".to_string(),
414 registry: registry.clone(),
415 };
416
417 let handle = tokio::spawn(async move {
418 consumer.start(ctx).await.unwrap();
419 });
420
421 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
422 assert!(registry.lock().await.contains_key("cancel-test"));
423
424 token.cancel();
425 let result = tokio::time::timeout(std::time::Duration::from_secs(1), handle).await;
426 assert!(
427 result.is_ok(),
428 "Consumer should have stopped after cancellation"
429 );
430
431 assert!(!registry.lock().await.contains_key("cancel-test"));
433 }
434}