liminal-rs 0.2.0

A conversation-based messaging bus built on beamr
Documentation
use std::collections::VecDeque;
use std::sync::{Arc, Mutex, MutexGuard};

use super::*;
use crate::aion::codec::DispatchResponse;
use crate::aion::{ActivityRequest, Payload, dispatch_channel};
use crate::conversation::ParticipantPid;
use crate::routing::{ConsumerId, ConsumerStateView};

#[derive(Clone)]
pub(super) struct TestSetup {
    pub(super) pool: Arc<TestWorkerPool>,
    pub(super) router: Arc<RecordingRouter>,
    pub(super) factory: Arc<TestConversationFactory>,
    pub(super) conversation: Arc<TestConversationState>,
    pub(super) recorder: Arc<TestRecorder>,
}

impl TestSetup {
    pub(super) fn new(
        workers: Vec<DispatchWorker>,
        events: Vec<DispatchConversationEvent>,
    ) -> Self {
        Self::with_recorder(workers, events, TestRecorder::recording())
    }

    pub(super) fn replaying(outcome: RecordedDispatchOutcome) -> Self {
        Self::with_recorder(Vec::new(), Vec::new(), TestRecorder::replaying(outcome))
    }

    fn with_recorder(
        workers: Vec<DispatchWorker>,
        events: Vec<DispatchConversationEvent>,
        recorder: TestRecorder,
    ) -> Self {
        let conversation = Arc::new(TestConversationState::new(events));
        Self {
            pool: Arc::new(TestWorkerPool::new(workers)),
            router: Arc::new(RecordingRouter::new()),
            factory: Arc::new(TestConversationFactory::new(Arc::clone(&conversation))),
            conversation,
            recorder: Arc::new(recorder),
        }
    }

    pub(super) fn context(&self) -> DispatchContext {
        DispatchContext::new(
            "wf-1",
            self.pool.clone(),
            self.router.clone(),
            self.factory.clone(),
            self.recorder.clone(),
            Arc::new(FixedIds),
        )
    }
}

#[derive(Debug)]
pub(super) struct TestWorkerPool {
    workers: Mutex<Vec<DispatchWorker>>,
    calls: Mutex<usize>,
}

impl TestWorkerPool {
    fn new(workers: Vec<DispatchWorker>) -> Self {
        Self {
            workers: Mutex::new(workers),
            calls: Mutex::new(0),
        }
    }

    pub(super) fn set_workers(&self, workers: Vec<DispatchWorker>) -> Result<(), AionSurfaceError> {
        *lock(&self.workers)? = workers;
        Ok(())
    }

    pub(super) fn calls(&self) -> Result<usize, AionSurfaceError> {
        Ok(*lock(&self.calls)?)
    }
}

impl DispatchWorkerPool for TestWorkerPool {
    fn workers_for(
        &self,
        channel_name: &ChannelName,
        request: &ActivityRequest,
    ) -> Result<Vec<DispatchWorker>, AionSurfaceError> {
        let _ = (channel_name, request);
        *lock(&self.calls)? += 1;
        Ok(lock(&self.workers)?.clone())
    }
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub(super) struct RouterCall {
    pub(super) candidates: Vec<String>,
    pub(super) excluded: Vec<String>,
}

#[derive(Debug)]
pub(super) struct RecordingRouter {
    calls: Mutex<Vec<RouterCall>>,
    decisions: Mutex<VecDeque<Option<String>>>,
}

impl RecordingRouter {
    fn new() -> Self {
        Self {
            calls: Mutex::new(Vec::new()),
            decisions: Mutex::new(VecDeque::new()),
        }
    }

    pub(super) fn push_decision(&self, decision: Option<&str>) -> Result<(), AionSurfaceError> {
        lock(&self.decisions)?.push_back(decision.map(str::to_owned));
        Ok(())
    }

    pub(super) fn calls(&self) -> Result<Vec<RouterCall>, AionSurfaceError> {
        Ok(lock(&self.calls)?.clone())
    }
}

impl DispatchRouter for RecordingRouter {
    fn select_worker(
        &self,
        workflow_id: &str,
        channel_name: &ChannelName,
        request: &ActivityRequest,
        candidates: &[DispatchWorker],
        excluded_worker_ids: &[String],
    ) -> Result<Option<DispatchWorker>, AionSurfaceError> {
        let _ = (workflow_id, channel_name, request);
        lock(&self.calls)?.push(RouterCall {
            candidates: candidates
                .iter()
                .map(|worker| worker.worker_id.clone())
                .collect(),
            excluded: excluded_worker_ids.to_vec(),
        });
        let selected = lock(&self.decisions)?.pop_front().flatten();
        Ok(candidates
            .iter()
            .find(|worker| {
                let requested = selected.as_ref().is_none_or(|id| id == &worker.worker_id);
                requested && !excluded_worker_ids.iter().any(|id| id == &worker.worker_id)
            })
            .cloned())
    }
}

#[derive(Debug)]
pub(super) struct TestConversationFactory {
    state: Arc<TestConversationState>,
    opens: Mutex<Vec<String>>,
}

impl TestConversationFactory {
    fn new(state: Arc<TestConversationState>) -> Self {
        Self {
            state,
            opens: Mutex::new(Vec::new()),
        }
    }

    pub(super) fn opens(&self) -> Result<Vec<String>, AionSurfaceError> {
        Ok(lock(&self.opens)?.clone())
    }
}

impl DispatchConversationFactory for TestConversationFactory {
    fn open(
        &self,
        workflow_id: &str,
        channel_name: &ChannelName,
        conversation_id: &str,
    ) -> Result<Box<dyn DispatchConversation>, AionSurfaceError> {
        let _ = (workflow_id, conversation_id);
        lock(&self.opens)?.push(String::from(channel_name.clone()));
        Ok(Box::new(TestConversation {
            state: Arc::clone(&self.state),
        }))
    }
}

#[derive(Debug)]
pub(super) struct TestConversationState {
    events: Mutex<VecDeque<DispatchConversationEvent>>,
    links: Mutex<Vec<String>>,
    sends: Mutex<Vec<DispatchRequest>>,
    closes: Mutex<usize>,
}

impl TestConversationState {
    fn new(events: Vec<DispatchConversationEvent>) -> Self {
        Self {
            events: Mutex::new(events.into()),
            links: Mutex::new(Vec::new()),
            sends: Mutex::new(Vec::new()),
            closes: Mutex::new(0),
        }
    }

    pub(super) fn links(&self) -> Result<Vec<String>, AionSurfaceError> {
        Ok(lock(&self.links)?.clone())
    }

    pub(super) fn sent_requests(&self) -> Result<Vec<DispatchRequest>, AionSurfaceError> {
        Ok(lock(&self.sends)?.clone())
    }

    pub(super) fn closes(&self) -> Result<usize, AionSurfaceError> {
        Ok(*lock(&self.closes)?)
    }
}

#[derive(Debug)]
struct TestConversation {
    state: Arc<TestConversationState>,
}

impl DispatchConversation for TestConversation {
    fn link_worker(&mut self, worker: &DispatchWorker) -> Result<(), AionSurfaceError> {
        lock(&self.state.links)?.push(worker.worker_id.clone());
        Ok(())
    }

    fn send(&mut self, request: DispatchRequest) -> Result<(), AionSurfaceError> {
        lock(&self.state.sends)?.push(request);
        Ok(())
    }

    fn receive(&mut self) -> Result<DispatchConversationEvent, AionSurfaceError> {
        lock(&self.state.events)?
            .pop_front()
            .ok_or_else(|| test_error("no scripted conversation event"))
    }

    fn close(&mut self) -> Result<(), AionSurfaceError> {
        *lock(&self.state.closes)? += 1;
        Ok(())
    }
}

#[derive(Debug)]
pub(super) struct TestRecorder {
    replay: Option<RecordedDispatchOutcome>,
    operations: Mutex<Vec<DispatchOperation>>,
}

impl TestRecorder {
    fn recording() -> Self {
        Self {
            replay: None,
            operations: Mutex::new(Vec::new()),
        }
    }

    fn replaying(outcome: RecordedDispatchOutcome) -> Self {
        Self {
            replay: Some(outcome),
            operations: Mutex::new(Vec::new()),
        }
    }

    pub(super) fn operations(&self) -> Result<Vec<DispatchOperation>, AionSurfaceError> {
        Ok(lock(&self.operations)?.clone())
    }

    pub(super) fn kinds(&self) -> Result<Vec<DispatchOperationKind>, AionSurfaceError> {
        Ok(self
            .operations()?
            .into_iter()
            .map(|operation| operation.kind)
            .collect())
    }

    pub(super) fn states(&self) -> Result<Vec<ActivityDispatchState>, AionSurfaceError> {
        Ok(self
            .operations()?
            .into_iter()
            .filter_map(|operation| operation.activity_state)
            .collect())
    }
}

impl DispatchRecorder for TestRecorder {
    fn replay_outcome(
        &self,
        channel_name: &str,
        request: &ActivityRequest,
    ) -> Result<Option<RecordedDispatchOutcome>, AionSurfaceError> {
        let _ = (channel_name, request);
        Ok(self.replay.clone())
    }

    fn record(&self, operation: DispatchOperation) -> Result<(), AionSurfaceError> {
        lock(&self.operations)?.push(operation);
        Ok(())
    }
}

#[derive(Debug)]
struct FixedIds;

impl ConversationIdProvider for FixedIds {
    fn next_conversation_id(&self) -> String {
        "conversation-1".to_owned()
    }
}

pub(super) fn request(task_queue: &str) -> ActivityRequest {
    ActivityRequest {
        activity_type: "send-email".to_owned(),
        input: Payload {
            data: b"input".to_vec(),
            content_type: "application/json".to_owned(),
        },
        task_queue: task_queue.to_owned(),
        schedule_to_close_timeout: None,
        start_to_close_timeout: None,
    }
}

pub(super) fn worker(worker_id: &str, pid: u64) -> DispatchWorker {
    let consumer = ConsumerStateView::new(ConsumerId::new(worker_id), 0, 1, 0, Vec::new());
    DispatchWorker::with_consumer_state(worker_id, ParticipantPid::new(pid), consumer)
}

pub(super) fn response(worker_id: &str, result: ActivityResult) -> DispatchConversationEvent {
    DispatchConversationEvent::Response(DispatchResponse::new(worker_id.to_owned(), result))
}

pub(super) fn completed(data: &[u8]) -> ActivityResult {
    ActivityResult::Completed {
        output: Payload {
            data: data.to_vec(),
            content_type: "application/octet-stream".to_owned(),
        },
    }
}

pub(super) fn test_error(message: &str) -> AionSurfaceError {
    let channel_name =
        dispatch_channel("prod", "email").map_or_else(|error| error.to_string(), String::from);
    AionSurfaceError::DispatchFailed {
        channel_name,
        workflow_id: "wf-1".to_owned(),
        message: message.to_owned(),
    }
}

fn lock<T>(mutex: &Mutex<T>) -> Result<MutexGuard<'_, T>, AionSurfaceError> {
    mutex.lock().map_err(|error| test_error(&error.to_string()))
}