rain-engine-ingress 0.1.0

Ingress utilities for RainEngine event delivery
Documentation
//! Event ingress adapters for RainEngine workers.
//!
//! This crate provides a shared event envelope and Valkey Streams utilities
//! that drive the kernel through explicit advance loops.

use rain_engine_core::{
    AdvanceRequest, AgentEngine, AgentTrigger, ContinueRequest, CoordinationStore, EnginePolicy,
    ProcessRequest, ProviderRequestConfig,
};
use rain_engine_store_valkey::ValkeyCoordinationStore;
use redis::cmd;
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
use std::time::Duration;
use thiserror::Error;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct IngressEventEnvelope {
    pub session_id: String,
    pub trigger: AgentTrigger,
    #[serde(default)]
    pub granted_scopes: BTreeSet<String>,
    #[serde(default)]
    pub idempotency_key: Option<String>,
    #[serde(default)]
    pub policy: Option<EnginePolicy>,
    #[serde(default)]
    pub provider: Option<ProviderRequestConfig>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ValkeyStreamConfig {
    pub url: String,
    pub stream: String,
    pub group: String,
    pub consumer: String,
    pub block_ms: usize,
}

#[derive(Debug, Error)]
pub enum IngressError {
    #[error("{0}")]
    Message(String),
}

#[derive(Clone)]
pub struct ValkeyStreamIngress {
    store: ValkeyCoordinationStore,
    config: ValkeyStreamConfig,
}

impl ValkeyStreamIngress {
    pub fn new(config: ValkeyStreamConfig) -> Result<Self, IngressError> {
        let store = ValkeyCoordinationStore::connect(&config.url)
            .map_err(|err| IngressError::Message(err.message))?;
        Ok(Self { store, config })
    }

    pub async fn publish(&self, event: &IngressEventEnvelope) -> Result<String, IngressError> {
        let client = redis::Client::open(self.config.url.clone())
            .map_err(|err| IngressError::Message(err.to_string()))?;
        let stream = self.config.stream.clone();
        let payload =
            serde_json::to_string(event).map_err(|err| IngressError::Message(err.to_string()))?;
        tokio::task::spawn_blocking(move || {
            let mut connection = client
                .get_connection()
                .map_err(|err| IngressError::Message(err.to_string()))?;
            let id: String = cmd("XADD")
                .arg(stream)
                .arg("*")
                .arg("payload")
                .arg(payload)
                .query(&mut connection)
                .map_err(|err| IngressError::Message(err.to_string()))?;
            Ok(id)
        })
        .await
        .map_err(|err| IngressError::Message(err.to_string()))?
    }

    pub async fn ensure_group(&self) -> Result<(), IngressError> {
        let client = redis::Client::open(self.config.url.clone())
            .map_err(|err| IngressError::Message(err.to_string()))?;
        let stream = self.config.stream.clone();
        let group = self.config.group.clone();
        tokio::task::spawn_blocking(move || {
            let mut connection = client
                .get_connection()
                .map_err(|err| IngressError::Message(err.to_string()))?;
            let result: Result<String, redis::RedisError> = cmd("XGROUP")
                .arg("CREATE")
                .arg(&stream)
                .arg(&group)
                .arg("0")
                .arg("MKSTREAM")
                .query(&mut connection);
            match result {
                Ok(_) => Ok(()),
                Err(err) if err.to_string().contains("BUSYGROUP") => Ok(()),
                Err(err) => Err(IngressError::Message(err.to_string())),
            }
        })
        .await
        .map_err(|err| IngressError::Message(err.to_string()))?
    }

    pub async fn run_once(&self, engine: &AgentEngine) -> Result<bool, IngressError> {
        self.ensure_group().await?;
        let client = redis::Client::open(self.config.url.clone())
            .map_err(|err| IngressError::Message(err.to_string()))?;
        let stream = self.config.stream.clone();
        let group = self.config.group.clone();
        let consumer = self.config.consumer.clone();
        let block_ms = self.config.block_ms;
        let read = tokio::task::spawn_blocking(move || {
            let mut connection = client
                .get_connection()
                .map_err(|err| IngressError::Message(err.to_string()))?;
            let value: redis::Value = cmd("XREADGROUP")
                .arg("GROUP")
                .arg(&group)
                .arg(&consumer)
                .arg("COUNT")
                .arg(1)
                .arg("BLOCK")
                .arg(block_ms)
                .arg("STREAMS")
                .arg(&stream)
                .arg(">")
                .query(&mut connection)
                .map_err(|err| IngressError::Message(err.to_string()))?;
            Ok::<_, IngressError>(value)
        })
        .await
        .map_err(|err| IngressError::Message(err.to_string()))??;

        let Some((entry_id, event)) = parse_xreadgroup_payload(read)? else {
            return Ok(false);
        };
        let trigger_key = event
            .idempotency_key
            .clone()
            .unwrap_or_else(|| entry_id.clone());
        let Some(claim) = self
            .store
            .claim_trigger(&trigger_key, Duration::from_secs(300))
            .await
            .map_err(|err| IngressError::Message(err.message))?
        else {
            self.ack_entry(entry_id).await?;
            return Ok(false);
        };

        let result = run_until_terminal(
            engine,
            ProcessRequest {
                session_id: event.session_id.clone(),
                trigger: event.trigger,
                granted_scopes: event.granted_scopes,
                idempotency_key: event.idempotency_key,
                policy: event.policy.unwrap_or_default(),
                provider: event.provider.unwrap_or_default(),
                cancellation: tokio_util::sync::CancellationToken::new(),
            },
        )
        .await;

        if let Err(err) = result {
            let _ = self.store.release_claim(&claim.claim_id).await;
            return Err(IngressError::Message(err.to_string()));
        }

        self.ack_entry(entry_id).await?;
        self.store
            .release_claim(&claim.claim_id)
            .await
            .map_err(|err| IngressError::Message(err.message))?;

        Ok(true)
    }

    async fn ack_entry(&self, entry_id: String) -> Result<(), IngressError> {
        let client = redis::Client::open(self.config.url.clone())
            .map_err(|err| IngressError::Message(err.to_string()))?;
        let stream = self.config.stream.clone();
        let group = self.config.group.clone();
        tokio::task::spawn_blocking(move || {
            let mut connection = client
                .get_connection()
                .map_err(|err| IngressError::Message(err.to_string()))?;
            let _: usize = cmd("XACK")
                .arg(stream)
                .arg(group)
                .arg(entry_id)
                .query(&mut connection)
                .map_err(|err| IngressError::Message(err.to_string()))?;
            Ok(())
        })
        .await
        .map_err(|err| IngressError::Message(err.to_string()))??;
        Ok(())
    }
}

async fn run_until_terminal(
    engine: &AgentEngine,
    request: ProcessRequest,
) -> Result<rain_engine_core::EngineOutcome, rain_engine_core::EngineError> {
    let mut next = AdvanceRequest::Trigger(request.clone());
    loop {
        let result = engine.advance(next).await?;
        if let Some(outcome) = result.outcome {
            return Ok(outcome);
        }
        next = AdvanceRequest::Continue(ContinueRequest {
            session_id: request.session_id.clone(),
            granted_scopes: request.granted_scopes.clone(),
            policy: request.policy.clone(),
            provider: request.provider.clone(),
            cancellation: request.cancellation.clone(),
        });
    }
}

fn parse_xreadgroup_payload(
    value: redis::Value,
) -> Result<Option<(String, IngressEventEnvelope)>, IngressError> {
    use redis::Value;
    let Value::Array(streams) = value else {
        return Ok(None);
    };
    let Some(Value::Array(stream_record)) = streams.into_iter().next() else {
        return Ok(None);
    };
    let Some(Value::Array(entries)) = stream_record.get(1).cloned() else {
        return Ok(None);
    };
    let Some(Value::Array(entry)) = entries.into_iter().next() else {
        return Ok(None);
    };
    let Some(Value::BulkString(id_bytes)) = entry.first().cloned() else {
        return Ok(None);
    };
    let Some(Value::Array(fields)) = entry.get(1).cloned() else {
        return Ok(None);
    };
    let mut payload = None::<String>;
    let mut index = 0usize;
    while index + 1 < fields.len() {
        let key = match &fields[index] {
            Value::BulkString(bytes) => String::from_utf8_lossy(bytes).to_string(),
            _ => String::new(),
        };
        if key == "payload" {
            if let Value::BulkString(bytes) = &fields[index + 1] {
                payload = Some(String::from_utf8_lossy(bytes).to_string());
            }
            break;
        }
        index += 2;
    }
    let payload =
        payload.ok_or_else(|| IngressError::Message("missing payload field".to_string()))?;
    let event =
        serde_json::from_str(&payload).map_err(|err| IngressError::Message(err.to_string()))?;
    Ok(Some((
        String::from_utf8_lossy(&id_bytes).to_string(),
        event,
    )))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn round_trips_envelope_json() {
        let envelope = IngressEventEnvelope {
            session_id: "s1".to_string(),
            trigger: AgentTrigger::Message {
                user_id: "u".to_string(),
                content: "hello".to_string(),
                attachments: Vec::new(),
            },
            granted_scopes: BTreeSet::from(["tool:run".to_string()]),
            idempotency_key: Some("abc".to_string()),
            policy: Some(EnginePolicy::default()),
            provider: Some(ProviderRequestConfig::default()),
        };
        let encoded = serde_json::to_string(&envelope).expect("encode");
        let decoded: IngressEventEnvelope = serde_json::from_str(&encoded).expect("decode");
        assert_eq!(decoded, envelope);
    }
}