Skip to main content

rain_engine_ingress/
lib.rs

1//! Event ingress adapters for RainEngine workers.
2//!
3//! This crate provides a shared event envelope and Valkey Streams utilities
4//! that drive the kernel through explicit advance loops.
5
6use rain_engine_core::{
7    AdvanceRequest, AgentEngine, AgentTrigger, ContinueRequest, CoordinationStore, EnginePolicy,
8    ProcessRequest, ProviderRequestConfig,
9};
10use rain_engine_store_valkey::ValkeyCoordinationStore;
11use redis::cmd;
12use serde::{Deserialize, Serialize};
13use std::collections::BTreeSet;
14use std::time::Duration;
15use thiserror::Error;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct IngressEventEnvelope {
19    pub session_id: String,
20    pub trigger: AgentTrigger,
21    #[serde(default)]
22    pub granted_scopes: BTreeSet<String>,
23    #[serde(default)]
24    pub idempotency_key: Option<String>,
25    #[serde(default)]
26    pub policy: Option<EnginePolicy>,
27    #[serde(default)]
28    pub provider: Option<ProviderRequestConfig>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32pub struct ValkeyStreamConfig {
33    pub url: String,
34    pub stream: String,
35    pub group: String,
36    pub consumer: String,
37    pub block_ms: usize,
38}
39
40#[derive(Debug, Error)]
41pub enum IngressError {
42    #[error("{0}")]
43    Message(String),
44}
45
46#[derive(Clone)]
47pub struct ValkeyStreamIngress {
48    store: ValkeyCoordinationStore,
49    config: ValkeyStreamConfig,
50}
51
52impl ValkeyStreamIngress {
53    pub fn new(config: ValkeyStreamConfig) -> Result<Self, IngressError> {
54        let store = ValkeyCoordinationStore::connect(&config.url)
55            .map_err(|err| IngressError::Message(err.message))?;
56        Ok(Self { store, config })
57    }
58
59    pub async fn publish(&self, event: &IngressEventEnvelope) -> Result<String, IngressError> {
60        let client = redis::Client::open(self.config.url.clone())
61            .map_err(|err| IngressError::Message(err.to_string()))?;
62        let stream = self.config.stream.clone();
63        let payload =
64            serde_json::to_string(event).map_err(|err| IngressError::Message(err.to_string()))?;
65        tokio::task::spawn_blocking(move || {
66            let mut connection = client
67                .get_connection()
68                .map_err(|err| IngressError::Message(err.to_string()))?;
69            let id: String = cmd("XADD")
70                .arg(stream)
71                .arg("*")
72                .arg("payload")
73                .arg(payload)
74                .query(&mut connection)
75                .map_err(|err| IngressError::Message(err.to_string()))?;
76            Ok(id)
77        })
78        .await
79        .map_err(|err| IngressError::Message(err.to_string()))?
80    }
81
82    pub async fn ensure_group(&self) -> Result<(), IngressError> {
83        let client = redis::Client::open(self.config.url.clone())
84            .map_err(|err| IngressError::Message(err.to_string()))?;
85        let stream = self.config.stream.clone();
86        let group = self.config.group.clone();
87        tokio::task::spawn_blocking(move || {
88            let mut connection = client
89                .get_connection()
90                .map_err(|err| IngressError::Message(err.to_string()))?;
91            let result: Result<String, redis::RedisError> = cmd("XGROUP")
92                .arg("CREATE")
93                .arg(&stream)
94                .arg(&group)
95                .arg("0")
96                .arg("MKSTREAM")
97                .query(&mut connection);
98            match result {
99                Ok(_) => Ok(()),
100                Err(err) if err.to_string().contains("BUSYGROUP") => Ok(()),
101                Err(err) => Err(IngressError::Message(err.to_string())),
102            }
103        })
104        .await
105        .map_err(|err| IngressError::Message(err.to_string()))?
106    }
107
108    pub async fn run_once(&self, engine: &AgentEngine) -> Result<bool, IngressError> {
109        self.ensure_group().await?;
110        let client = redis::Client::open(self.config.url.clone())
111            .map_err(|err| IngressError::Message(err.to_string()))?;
112        let stream = self.config.stream.clone();
113        let group = self.config.group.clone();
114        let consumer = self.config.consumer.clone();
115        let block_ms = self.config.block_ms;
116        let read = tokio::task::spawn_blocking(move || {
117            let mut connection = client
118                .get_connection()
119                .map_err(|err| IngressError::Message(err.to_string()))?;
120            let value: redis::Value = cmd("XREADGROUP")
121                .arg("GROUP")
122                .arg(&group)
123                .arg(&consumer)
124                .arg("COUNT")
125                .arg(1)
126                .arg("BLOCK")
127                .arg(block_ms)
128                .arg("STREAMS")
129                .arg(&stream)
130                .arg(">")
131                .query(&mut connection)
132                .map_err(|err| IngressError::Message(err.to_string()))?;
133            Ok::<_, IngressError>(value)
134        })
135        .await
136        .map_err(|err| IngressError::Message(err.to_string()))??;
137
138        let Some((entry_id, event)) = parse_xreadgroup_payload(read)? else {
139            return Ok(false);
140        };
141        let trigger_key = event
142            .idempotency_key
143            .clone()
144            .unwrap_or_else(|| entry_id.clone());
145        let Some(claim) = self
146            .store
147            .claim_trigger(&trigger_key, Duration::from_secs(300))
148            .await
149            .map_err(|err| IngressError::Message(err.message))?
150        else {
151            self.ack_entry(entry_id).await?;
152            return Ok(false);
153        };
154
155        let result = run_until_terminal(
156            engine,
157            ProcessRequest {
158                session_id: event.session_id.clone(),
159                trigger: event.trigger,
160                granted_scopes: event.granted_scopes,
161                idempotency_key: event.idempotency_key,
162                policy: event.policy.unwrap_or_default(),
163                provider: event.provider.unwrap_or_default(),
164                cancellation: tokio_util::sync::CancellationToken::new(),
165            },
166        )
167        .await;
168
169        if let Err(err) = result {
170            let _ = self.store.release_claim(&claim.claim_id).await;
171            return Err(IngressError::Message(err.to_string()));
172        }
173
174        self.ack_entry(entry_id).await?;
175        self.store
176            .release_claim(&claim.claim_id)
177            .await
178            .map_err(|err| IngressError::Message(err.message))?;
179
180        Ok(true)
181    }
182
183    async fn ack_entry(&self, entry_id: String) -> Result<(), IngressError> {
184        let client = redis::Client::open(self.config.url.clone())
185            .map_err(|err| IngressError::Message(err.to_string()))?;
186        let stream = self.config.stream.clone();
187        let group = self.config.group.clone();
188        tokio::task::spawn_blocking(move || {
189            let mut connection = client
190                .get_connection()
191                .map_err(|err| IngressError::Message(err.to_string()))?;
192            let _: usize = cmd("XACK")
193                .arg(stream)
194                .arg(group)
195                .arg(entry_id)
196                .query(&mut connection)
197                .map_err(|err| IngressError::Message(err.to_string()))?;
198            Ok(())
199        })
200        .await
201        .map_err(|err| IngressError::Message(err.to_string()))??;
202        Ok(())
203    }
204}
205
206async fn run_until_terminal(
207    engine: &AgentEngine,
208    request: ProcessRequest,
209) -> Result<rain_engine_core::EngineOutcome, rain_engine_core::EngineError> {
210    let mut next = AdvanceRequest::Trigger(request.clone());
211    loop {
212        let result = engine.advance(next).await?;
213        if let Some(outcome) = result.outcome {
214            return Ok(outcome);
215        }
216        next = AdvanceRequest::Continue(ContinueRequest {
217            session_id: request.session_id.clone(),
218            granted_scopes: request.granted_scopes.clone(),
219            policy: request.policy.clone(),
220            provider: request.provider.clone(),
221            cancellation: request.cancellation.clone(),
222        });
223    }
224}
225
226fn parse_xreadgroup_payload(
227    value: redis::Value,
228) -> Result<Option<(String, IngressEventEnvelope)>, IngressError> {
229    use redis::Value;
230    let Value::Array(streams) = value else {
231        return Ok(None);
232    };
233    let Some(Value::Array(stream_record)) = streams.into_iter().next() else {
234        return Ok(None);
235    };
236    let Some(Value::Array(entries)) = stream_record.get(1).cloned() else {
237        return Ok(None);
238    };
239    let Some(Value::Array(entry)) = entries.into_iter().next() else {
240        return Ok(None);
241    };
242    let Some(Value::BulkString(id_bytes)) = entry.first().cloned() else {
243        return Ok(None);
244    };
245    let Some(Value::Array(fields)) = entry.get(1).cloned() else {
246        return Ok(None);
247    };
248    let mut payload = None::<String>;
249    let mut index = 0usize;
250    while index + 1 < fields.len() {
251        let key = match &fields[index] {
252            Value::BulkString(bytes) => String::from_utf8_lossy(bytes).to_string(),
253            _ => String::new(),
254        };
255        if key == "payload" {
256            if let Value::BulkString(bytes) = &fields[index + 1] {
257                payload = Some(String::from_utf8_lossy(bytes).to_string());
258            }
259            break;
260        }
261        index += 2;
262    }
263    let payload =
264        payload.ok_or_else(|| IngressError::Message("missing payload field".to_string()))?;
265    let event =
266        serde_json::from_str(&payload).map_err(|err| IngressError::Message(err.to_string()))?;
267    Ok(Some((
268        String::from_utf8_lossy(&id_bytes).to_string(),
269        event,
270    )))
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn round_trips_envelope_json() {
279        let envelope = IngressEventEnvelope {
280            session_id: "s1".to_string(),
281            trigger: AgentTrigger::Message {
282                user_id: "u".to_string(),
283                content: "hello".to_string(),
284                attachments: Vec::new(),
285            },
286            granted_scopes: BTreeSet::from(["tool:run".to_string()]),
287            idempotency_key: Some("abc".to_string()),
288            policy: Some(EnginePolicy::default()),
289            provider: Some(ProviderRequestConfig::default()),
290        };
291        let encoded = serde_json::to_string(&envelope).expect("encode");
292        let decoded: IngressEventEnvelope = serde_json::from_str(&encoded).expect("decode");
293        assert_eq!(decoded, envelope);
294    }
295}