use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use camel_api::security_policy::SecurityPolicy;
use camel_api::{CamelError, Exchange};
use camel_auth::{CredentialSource, TokenAuthenticator};
pub struct ExchangeEnvelope {
pub exchange: Exchange,
pub reply_tx: Option<oneshot::Sender<Result<Exchange, CamelError>>>,
}
#[derive(Clone)]
pub struct ConsumerContext {
sender: mpsc::Sender<ExchangeEnvelope>,
cancel_token: CancellationToken,
}
impl ConsumerContext {
pub fn new(sender: mpsc::Sender<ExchangeEnvelope>, cancel_token: CancellationToken) -> Self {
Self {
sender,
cancel_token,
}
}
pub async fn cancelled(&self) {
self.cancel_token.cancelled().await
}
pub fn is_cancelled(&self) -> bool {
self.cancel_token.is_cancelled()
}
pub fn cancel_token(&self) -> CancellationToken {
self.cancel_token.clone()
}
pub fn sender(&self) -> mpsc::Sender<ExchangeEnvelope> {
self.sender.clone()
}
pub async fn send(&self, exchange: Exchange) -> Result<(), CamelError> {
self.sender
.send(ExchangeEnvelope {
exchange,
reply_tx: None,
})
.await
.map_err(|_| CamelError::ChannelClosed)
}
pub async fn send_and_wait(&self, exchange: Exchange) -> Result<Exchange, CamelError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.sender
.send(ExchangeEnvelope {
exchange,
reply_tx: Some(reply_tx),
})
.await
.map_err(|_| CamelError::ChannelClosed)?;
reply_rx.await.map_err(|_| CamelError::ChannelClosed)?
}
}
pub struct SecurityContext {
pub policy: Arc<dyn SecurityPolicy>,
pub authenticator: Arc<dyn TokenAuthenticator>,
pub credential_sources: Vec<CredentialSource>,
}
impl SecurityContext {
pub fn new(
policy: impl SecurityPolicy + 'static,
authenticator: Arc<dyn TokenAuthenticator>,
) -> Self {
Self {
policy: Arc::new(policy),
authenticator,
credential_sources: vec![CredentialSource::AuthorizationHeader],
}
}
pub fn from_arc(
policy: Arc<dyn SecurityPolicy>,
authenticator: Arc<dyn TokenAuthenticator>,
) -> Self {
Self {
policy,
authenticator,
credential_sources: vec![CredentialSource::AuthorizationHeader],
}
}
pub fn with_credential_sources(mut self, sources: Vec<CredentialSource>) -> Self {
self.credential_sources = sources;
self
}
}
impl Clone for SecurityContext {
fn clone(&self) -> Self {
Self {
policy: Arc::clone(&self.policy),
authenticator: Arc::clone(&self.authenticator),
credential_sources: self.credential_sources.clone(),
}
}
}
impl std::fmt::Debug for SecurityContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SecurityContext")
.field("policy", &"<SecurityPolicy>")
.field("authenticator", &"<TokenAuthenticator>")
.field("credential_sources", &self.credential_sources)
.finish()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConcurrencyModel {
Sequential,
Concurrent { max: Option<usize> },
}
#[async_trait]
pub trait Consumer: Send + Sync {
async fn start(&mut self, context: ConsumerContext) -> Result<(), CamelError>;
async fn stop(&mut self) -> Result<(), CamelError>;
async fn suspend(&self) -> Result<(), CamelError> {
Ok(())
}
async fn resume(&self) -> Result<(), CamelError> {
Ok(())
}
fn concurrency_model(&self) -> ConcurrencyModel {
ConcurrencyModel::Sequential
}
fn background_task_handle(&mut self) -> Option<JoinHandle<Result<(), CamelError>>> {
None
}
fn set_security_context(&mut self, _ctx: SecurityContext) {}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_consumer_context_cancelled() {
let (tx, _rx) = mpsc::channel(16);
let token = CancellationToken::new();
let ctx = ConsumerContext::new(tx, token.clone());
assert!(!ctx.is_cancelled());
token.cancel();
ctx.cancelled().await;
assert!(ctx.is_cancelled());
}
#[test]
fn test_concurrency_model_default_is_sequential() {
use super::ConcurrencyModel;
struct DummyConsumer;
#[async_trait::async_trait]
impl super::Consumer for DummyConsumer {
async fn start(&mut self, _ctx: super::ConsumerContext) -> Result<(), CamelError> {
Ok(())
}
async fn stop(&mut self) -> Result<(), CamelError> {
Ok(())
}
}
let consumer = DummyConsumer;
assert_eq!(consumer.concurrency_model(), ConcurrencyModel::Sequential);
}
#[test]
fn test_concurrency_model_concurrent_override() {
use super::ConcurrencyModel;
struct ConcurrentConsumer;
#[async_trait::async_trait]
impl super::Consumer for ConcurrentConsumer {
async fn start(&mut self, _ctx: super::ConsumerContext) -> Result<(), CamelError> {
Ok(())
}
async fn stop(&mut self) -> Result<(), CamelError> {
Ok(())
}
fn concurrency_model(&self) -> ConcurrencyModel {
ConcurrencyModel::Concurrent { max: Some(16) }
}
}
let consumer = ConcurrentConsumer;
assert_eq!(
consumer.concurrency_model(),
ConcurrencyModel::Concurrent { max: Some(16) }
);
}
#[tokio::test]
async fn test_consumer_default_suspend_resume() {
struct DummyConsumer;
#[async_trait::async_trait]
impl super::Consumer for DummyConsumer {
async fn start(&mut self, _ctx: super::ConsumerContext) -> Result<(), CamelError> {
Ok(())
}
async fn stop(&mut self) -> Result<(), CamelError> {
Ok(())
}
}
let consumer = DummyConsumer;
assert!(consumer.suspend().await.is_ok());
assert!(consumer.resume().await.is_ok());
}
struct StubPolicy;
#[async_trait::async_trait]
impl SecurityPolicy for StubPolicy {
async fn evaluate(
&self,
_exchange: &mut Exchange,
) -> Result<camel_api::security_policy::AuthorizationDecision, CamelError> {
Ok(camel_api::security_policy::AuthorizationDecision::Granted {
principal: camel_api::security_policy::Principal {
subject: "stub".into(),
issuer: "stub".into(),
audience: vec![],
scopes: vec![],
roles: vec![],
claims: serde_json::json!({}),
},
})
}
}
struct StubAuthenticator;
#[async_trait::async_trait]
impl camel_auth::TokenAuthenticator for StubAuthenticator {
async fn authenticate_bearer(
&self,
_token: &str,
) -> Result<camel_api::security_policy::Principal, CamelError> {
Ok(camel_api::security_policy::Principal {
subject: "stub".into(),
issuer: "stub".into(),
audience: vec![],
scopes: vec![],
roles: vec![],
claims: serde_json::json!({}),
})
}
}
#[test]
fn test_security_context_new() {
let ctx = SecurityContext::new(StubPolicy, Arc::new(StubAuthenticator));
assert!(Arc::strong_count(&ctx.policy) == 1);
assert!(Arc::strong_count(&ctx.authenticator) == 1);
assert_eq!(
ctx.credential_sources,
vec![camel_auth::CredentialSource::AuthorizationHeader]
);
}
#[test]
fn test_security_context_from_arc() {
let policy: Arc<dyn SecurityPolicy> = Arc::new(StubPolicy);
let authenticator: Arc<dyn camel_auth::TokenAuthenticator> = Arc::new(StubAuthenticator);
let ctx = SecurityContext::from_arc(Arc::clone(&policy), Arc::clone(&authenticator));
assert!(Arc::ptr_eq(&ctx.policy, &policy));
assert!(Arc::ptr_eq(&ctx.authenticator, &authenticator));
assert_eq!(
ctx.credential_sources,
vec![camel_auth::CredentialSource::AuthorizationHeader]
);
}
#[test]
fn test_security_context_clone_independent() {
let ctx = SecurityContext::new(StubPolicy, Arc::new(StubAuthenticator));
let cloned = ctx.clone();
assert!(Arc::ptr_eq(&ctx.policy, &cloned.policy));
assert!(Arc::ptr_eq(&ctx.authenticator, &cloned.authenticator));
assert_eq!(ctx.credential_sources, cloned.credential_sources);
}
#[test]
fn test_security_context_debug_redacts_traits() {
let ctx = SecurityContext::new(StubPolicy, Arc::new(StubAuthenticator));
let debug_str = format!("{ctx:?}");
assert!(debug_str.contains("<SecurityPolicy>"));
assert!(debug_str.contains("<TokenAuthenticator>"));
assert!(debug_str.contains("credential_sources"));
}
#[test]
fn test_security_context_with_credential_sources() {
let ctx = SecurityContext::new(StubPolicy, Arc::new(StubAuthenticator))
.with_credential_sources(vec![
camel_auth::CredentialSource::Cookie {
name: "session".into(),
},
camel_auth::CredentialSource::AuthorizationHeader,
]);
assert_eq!(ctx.credential_sources.len(), 2);
assert!(matches!(
&ctx.credential_sources[0],
camel_auth::CredentialSource::Cookie { .. }
));
}
}