shepherd-rs 0.2.0

Shepherd is a resilient, non-blocking orchestrator that persistently transforms and delivers data—built for remote, compute-heavy workloads.
Documentation
//! # In-Memory Database
//!
//! This module provides an in-memory implementation of the `Database` trait.
//!
//! ## Overview
//! - **InMemoryDatabase**: Stores transformation and consumption data in
//!   memory.
//! - **Error Handling**: Defines custom error types for in-memory operations.
//!
//! ## Example
//! ```rust
//! let db = InMemoryDatabase::new();
//! db.register_transform_request(...);
//! ```

use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;

use async_trait::async_trait;
use thiserror::Error;
use tokio::sync::Mutex;

use crate::config::Config;
use crate::consumer::ConsumeAttempt;
use crate::consumer::consumer::ConsumeAttemptResult;
use crate::database::Database;
use crate::transform::{TransformAttempt, TransformRequest};
use crate::worker::worker_manager::WorkerManagerResult;

#[derive(Debug)]
pub struct InMemoryDatabase<TR, TA, CA, C>
where
    TR: TransformRequest + Send + Sync,
    TA: TransformAttempt<
            TransformRequestIdentifier = TR::Identifier,
            CallArgsType = TR::Input,
            ReturnType = TR::Output,
        > + Send
        + Sync,
    CA: ConsumeAttempt<
            TransformRequestIdentifier = TR::Identifier,
            TransformAttemptIdentifier = TA::Identifier,
            ConsumeVal = TR::Output,
        > + Send
        + Sync,
    C: Config,
{
    transform_requests: HashMap<TR::Identifier, TR>,
    transform_attempts: HashMap<TR::Identifier, HashMap<TA::Identifier, TA>>,
    consume_attempts: HashMap<TR::Identifier, HashMap<TA::Identifier, HashMap<CA::Identifier, CA>>>,
    _marker: std::marker::PhantomData<C>,
}

#[derive(Debug, Error)]
pub enum InMemoryDatabaseError {
    #[error("Database error: {0}")]
    DatabaseError(String),
    #[error("Not found: {0}")]
    NotFound(String),
}

#[async_trait]
impl<TR, TA, CA, C> Database for InMemoryDatabase<TR, TA, CA, C>
where
    TR: TransformRequest + Send + Sync,
    TA: TransformAttempt<
            TransformRequestIdentifier = TR::Identifier,
            CallArgsType = TR::Input,
            ReturnType = TR::Output,
        > + Send
        + Sync,
    CA: ConsumeAttempt<
            TransformRequestIdentifier = TR::Identifier,
            TransformAttemptIdentifier = TA::Identifier,
            ConsumeVal = TR::Output,
        > + Send
        + Sync,
    C: Config<KeyType = String, ValueType = Vec<u8>>,
    TR::Identifier: Hash,
{
    type Config = C;
    type ConsumeAttempt = CA;
    type DatabaseError = InMemoryDatabaseError;
    type Input = TR::Input;
    type Output = TR::Output;
    type TransformAttempt = TA;
    type TransformRequest = TR;

    async fn new(_ctx: Arc<Mutex<Self::Config>>) -> Result<Self, Self::DatabaseError>
    where
        Self: Sized,
    {
        Ok(Self {
            transform_requests: HashMap::new(),
            transform_attempts: HashMap::new(),
            consume_attempts: HashMap::new(),
            _marker: std::marker::PhantomData,
        })
    }

    async fn get_dyn_configs(
        &mut self,
    ) -> Result<
        Vec<(
            <Self::Config as Config>::KeyType,
            <Self::Config as Config>::ValueType,
        )>,
        Self::DatabaseError,
    > {
        // In a real implementation, you would retrieve dynamic configurations.
        // For this in-memory example, we will return an empty vector.
        Ok(Vec::new())
    }

    async fn register_transform_request(
        &mut self,
        request: &Self::TransformRequest,
    ) -> Result<(), Self::DatabaseError> {
        if self.transform_requests.contains_key(&request.request_id()) {
            return Err(InMemoryDatabaseError::DatabaseError(
                "Request already exists".to_string(),
            ));
        }
        self.transform_requests
            .insert(request.request_id(), request.clone());
        Ok(())
    }

    async fn register_transform_attempt(
        &mut self,
        attempt: &Self::TransformAttempt,
    ) -> Result<(), Self::DatabaseError> {
        let request_id = attempt.request_id();
        if !self.transform_requests.contains_key(&request_id) {
            return Err(InMemoryDatabaseError::NotFound(format!(
                "Request with ID {:?} not found",
                request_id
            )));
        }
        let attempts = self.transform_attempts.entry(request_id).or_default();
        if attempts.contains_key(&attempt.attempt_id()) {
            return Err(InMemoryDatabaseError::DatabaseError(
                "Attempt already exists".to_string(),
            ));
        }
        attempts.insert(attempt.attempt_id(), attempt.clone());
        Ok(())
    }

    async fn update_transform_attempt(
        &mut self,
        attempt: &WorkerManagerResult<Self::TransformAttempt>,
    ) -> Result<(), Self::DatabaseError> {
        let (attempt_id, return_pkg) = match attempt {
            WorkerManagerResult::Success(attempt_id, return_pkg) => (attempt_id, return_pkg),
            WorkerManagerResult::Failure(attempt_id, return_pkg) => (attempt_id, return_pkg),
        };

        let request_id = attempt_id.clone().into();

        if !self.transform_requests.contains_key(&request_id) {
            return Err(InMemoryDatabaseError::NotFound(format!(
                "Request with ID {:?} not found",
                request_id
            )));
        }

        let transform_attempts = self
            .transform_attempts
            .get_mut(&request_id)
            .ok_or_else(|| {
                InMemoryDatabaseError::NotFound(format!(
                    "Transform attempts for request {:?} not found",
                    request_id
                ))
            })?;

        let attempt = transform_attempts.get_mut(&attempt_id).ok_or_else(|| {
            InMemoryDatabaseError::NotFound(format!(
                "Transform attempt with ID {:?} for request {:?} not found",
                attempt_id, request_id
            ))
        })?;

        attempt.set_return_package(return_pkg.clone());
        Ok(())
    }

    async fn register_consume_attempt(
        &mut self,
        attempt: &Self::ConsumeAttempt,
    ) -> Result<(), Self::DatabaseError> {
        let request_id = attempt.request_id();
        let attempt_id = attempt.attempt_id();
        if !self.transform_requests.contains_key(&request_id) {
            return Err(InMemoryDatabaseError::NotFound(format!(
                "Request with ID {:?} not found",
                request_id
            )));
        }

        let attempts_entry = self
            .transform_attempts
            .get(&request_id)
            .and_then(|attempts| attempts.get(&attempt_id));

        if attempts_entry.is_none() {
            return Err(InMemoryDatabaseError::NotFound(format!(
                "Transform attempt with ID {:?} for request {:?} not found",
                attempt_id, request_id
            )));
        }

        let consume_attempts = self
            .consume_attempts
            .entry(request_id)
            .or_default()
            .entry(attempt_id)
            .or_default();

        if consume_attempts.contains_key(&attempt.consume_id()) {
            return Err(InMemoryDatabaseError::DatabaseError(
                "Consume attempt already exists".to_string(),
            ));
        }
        consume_attempts.insert(attempt.consume_id(), attempt.clone());
        Ok(())
    }

    async fn update_consume_attempt(
        &mut self,
        attempt: ConsumeAttemptResult<Self::ConsumeAttempt>,
    ) -> Result<(), Self::DatabaseError> {
        let (consume_attempt_id, return_ctx) = match attempt {
            ConsumeAttemptResult::Success(consume_attempt_id, return_ctx) =>
                (consume_attempt_id, return_ctx),
            ConsumeAttemptResult::Failure(consume_attempt_id, return_ctx) =>
                (consume_attempt_id, return_ctx),
        };

        let request_id = consume_attempt_id.clone().into();
        let attempt_id = consume_attempt_id.clone().into();
        if !self.transform_requests.contains_key(&request_id) {
            return Err(InMemoryDatabaseError::NotFound(format!(
                "Request with ID {:?} not found",
                request_id
            )));
        }

        let consume_attempts = self
            .consume_attempts
            .get_mut(&request_id)
            .and_then(|attempts| attempts.get_mut(&attempt_id))
            .ok_or_else(|| {
                InMemoryDatabaseError::NotFound(format!(
                    "Consume attempts for request {:?} and attempt {:?} not found",
                    request_id, attempt_id
                ))
            })?;

        let consume_attempt = consume_attempts
            .get_mut(&consume_attempt_id)
            .ok_or_else(|| {
                InMemoryDatabaseError::NotFound(format!(
                    "Consume attempt with ID {:?} for request {:?} and attempt {:?} not found",
                    consume_attempt_id, request_id, attempt_id
                ))
            })?;

        consume_attempt.set_return_context(return_ctx);

        Ok(())
    }

    async fn archive_request_with_id(
        &mut self,
        request: &<Self::TransformRequest as TransformRequest>::Identifier,
    ) -> Result<(), Self::DatabaseError> {
        if !self.transform_requests.contains_key(request) {
            return Err(InMemoryDatabaseError::NotFound(format!(
                "Request with ID {:?} not found",
                request
            )));
        }
        self.transform_requests.remove(request);
        self.transform_attempts.remove(request);
        self.consume_attempts.remove(request);
        Ok(())
    }
}