rain_engine_ingress/
lib.rs1use rain_engine_core::{
7 AdvanceRequest, AgentEngine, AgentTrigger, ContinueRequest, CoordinationStore, EnginePolicy,
8 ProcessRequest, ProviderRequestConfig,
9};
10use rain_engine_store_valkey::ValkeyCoordinationStore;
11use redis::cmd;
12use serde::{Deserialize, Serialize};
13use std::collections::BTreeSet;
14use std::time::Duration;
15use thiserror::Error;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct IngressEventEnvelope {
19 pub session_id: String,
20 pub trigger: AgentTrigger,
21 #[serde(default)]
22 pub granted_scopes: BTreeSet<String>,
23 #[serde(default)]
24 pub idempotency_key: Option<String>,
25 #[serde(default)]
26 pub policy: Option<EnginePolicy>,
27 #[serde(default)]
28 pub provider: Option<ProviderRequestConfig>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32pub struct ValkeyStreamConfig {
33 pub url: String,
34 pub stream: String,
35 pub group: String,
36 pub consumer: String,
37 pub block_ms: usize,
38}
39
40#[derive(Debug, Error)]
41pub enum IngressError {
42 #[error("{0}")]
43 Message(String),
44}
45
46#[derive(Clone)]
47pub struct ValkeyStreamIngress {
48 store: ValkeyCoordinationStore,
49 config: ValkeyStreamConfig,
50}
51
52impl ValkeyStreamIngress {
53 pub fn new(config: ValkeyStreamConfig) -> Result<Self, IngressError> {
54 let store = ValkeyCoordinationStore::connect(&config.url)
55 .map_err(|err| IngressError::Message(err.message))?;
56 Ok(Self { store, config })
57 }
58
59 pub async fn publish(&self, event: &IngressEventEnvelope) -> Result<String, IngressError> {
60 let client = redis::Client::open(self.config.url.clone())
61 .map_err(|err| IngressError::Message(err.to_string()))?;
62 let stream = self.config.stream.clone();
63 let payload =
64 serde_json::to_string(event).map_err(|err| IngressError::Message(err.to_string()))?;
65 tokio::task::spawn_blocking(move || {
66 let mut connection = client
67 .get_connection()
68 .map_err(|err| IngressError::Message(err.to_string()))?;
69 let id: String = cmd("XADD")
70 .arg(stream)
71 .arg("*")
72 .arg("payload")
73 .arg(payload)
74 .query(&mut connection)
75 .map_err(|err| IngressError::Message(err.to_string()))?;
76 Ok(id)
77 })
78 .await
79 .map_err(|err| IngressError::Message(err.to_string()))?
80 }
81
82 pub async fn ensure_group(&self) -> Result<(), IngressError> {
83 let client = redis::Client::open(self.config.url.clone())
84 .map_err(|err| IngressError::Message(err.to_string()))?;
85 let stream = self.config.stream.clone();
86 let group = self.config.group.clone();
87 tokio::task::spawn_blocking(move || {
88 let mut connection = client
89 .get_connection()
90 .map_err(|err| IngressError::Message(err.to_string()))?;
91 let result: Result<String, redis::RedisError> = cmd("XGROUP")
92 .arg("CREATE")
93 .arg(&stream)
94 .arg(&group)
95 .arg("0")
96 .arg("MKSTREAM")
97 .query(&mut connection);
98 match result {
99 Ok(_) => Ok(()),
100 Err(err) if err.to_string().contains("BUSYGROUP") => Ok(()),
101 Err(err) => Err(IngressError::Message(err.to_string())),
102 }
103 })
104 .await
105 .map_err(|err| IngressError::Message(err.to_string()))?
106 }
107
108 pub async fn run_once(&self, engine: &AgentEngine) -> Result<bool, IngressError> {
109 self.ensure_group().await?;
110 let client = redis::Client::open(self.config.url.clone())
111 .map_err(|err| IngressError::Message(err.to_string()))?;
112 let stream = self.config.stream.clone();
113 let group = self.config.group.clone();
114 let consumer = self.config.consumer.clone();
115 let block_ms = self.config.block_ms;
116 let read = tokio::task::spawn_blocking(move || {
117 let mut connection = client
118 .get_connection()
119 .map_err(|err| IngressError::Message(err.to_string()))?;
120 let value: redis::Value = cmd("XREADGROUP")
121 .arg("GROUP")
122 .arg(&group)
123 .arg(&consumer)
124 .arg("COUNT")
125 .arg(1)
126 .arg("BLOCK")
127 .arg(block_ms)
128 .arg("STREAMS")
129 .arg(&stream)
130 .arg(">")
131 .query(&mut connection)
132 .map_err(|err| IngressError::Message(err.to_string()))?;
133 Ok::<_, IngressError>(value)
134 })
135 .await
136 .map_err(|err| IngressError::Message(err.to_string()))??;
137
138 let Some((entry_id, event)) = parse_xreadgroup_payload(read)? else {
139 return Ok(false);
140 };
141 let trigger_key = event
142 .idempotency_key
143 .clone()
144 .unwrap_or_else(|| entry_id.clone());
145 let Some(claim) = self
146 .store
147 .claim_trigger(&trigger_key, Duration::from_secs(300))
148 .await
149 .map_err(|err| IngressError::Message(err.message))?
150 else {
151 self.ack_entry(entry_id).await?;
152 return Ok(false);
153 };
154
155 let result = run_until_terminal(
156 engine,
157 ProcessRequest {
158 session_id: event.session_id.clone(),
159 trigger: event.trigger,
160 granted_scopes: event.granted_scopes,
161 idempotency_key: event.idempotency_key,
162 policy: event.policy.unwrap_or_default(),
163 provider: event.provider.unwrap_or_default(),
164 cancellation: tokio_util::sync::CancellationToken::new(),
165 },
166 )
167 .await;
168
169 if let Err(err) = result {
170 let _ = self.store.release_claim(&claim.claim_id).await;
171 return Err(IngressError::Message(err.to_string()));
172 }
173
174 self.ack_entry(entry_id).await?;
175 self.store
176 .release_claim(&claim.claim_id)
177 .await
178 .map_err(|err| IngressError::Message(err.message))?;
179
180 Ok(true)
181 }
182
183 async fn ack_entry(&self, entry_id: String) -> Result<(), IngressError> {
184 let client = redis::Client::open(self.config.url.clone())
185 .map_err(|err| IngressError::Message(err.to_string()))?;
186 let stream = self.config.stream.clone();
187 let group = self.config.group.clone();
188 tokio::task::spawn_blocking(move || {
189 let mut connection = client
190 .get_connection()
191 .map_err(|err| IngressError::Message(err.to_string()))?;
192 let _: usize = cmd("XACK")
193 .arg(stream)
194 .arg(group)
195 .arg(entry_id)
196 .query(&mut connection)
197 .map_err(|err| IngressError::Message(err.to_string()))?;
198 Ok(())
199 })
200 .await
201 .map_err(|err| IngressError::Message(err.to_string()))??;
202 Ok(())
203 }
204}
205
206async fn run_until_terminal(
207 engine: &AgentEngine,
208 request: ProcessRequest,
209) -> Result<rain_engine_core::EngineOutcome, rain_engine_core::EngineError> {
210 let mut next = AdvanceRequest::Trigger(request.clone());
211 loop {
212 let result = engine.advance(next).await?;
213 if let Some(outcome) = result.outcome {
214 return Ok(outcome);
215 }
216 next = AdvanceRequest::Continue(ContinueRequest {
217 session_id: request.session_id.clone(),
218 granted_scopes: request.granted_scopes.clone(),
219 policy: request.policy.clone(),
220 provider: request.provider.clone(),
221 cancellation: request.cancellation.clone(),
222 });
223 }
224}
225
226fn parse_xreadgroup_payload(
227 value: redis::Value,
228) -> Result<Option<(String, IngressEventEnvelope)>, IngressError> {
229 use redis::Value;
230 let Value::Array(streams) = value else {
231 return Ok(None);
232 };
233 let Some(Value::Array(stream_record)) = streams.into_iter().next() else {
234 return Ok(None);
235 };
236 let Some(Value::Array(entries)) = stream_record.get(1).cloned() else {
237 return Ok(None);
238 };
239 let Some(Value::Array(entry)) = entries.into_iter().next() else {
240 return Ok(None);
241 };
242 let Some(Value::BulkString(id_bytes)) = entry.first().cloned() else {
243 return Ok(None);
244 };
245 let Some(Value::Array(fields)) = entry.get(1).cloned() else {
246 return Ok(None);
247 };
248 let mut payload = None::<String>;
249 let mut index = 0usize;
250 while index + 1 < fields.len() {
251 let key = match &fields[index] {
252 Value::BulkString(bytes) => String::from_utf8_lossy(bytes).to_string(),
253 _ => String::new(),
254 };
255 if key == "payload" {
256 if let Value::BulkString(bytes) = &fields[index + 1] {
257 payload = Some(String::from_utf8_lossy(bytes).to_string());
258 }
259 break;
260 }
261 index += 2;
262 }
263 let payload =
264 payload.ok_or_else(|| IngressError::Message("missing payload field".to_string()))?;
265 let event =
266 serde_json::from_str(&payload).map_err(|err| IngressError::Message(err.to_string()))?;
267 Ok(Some((
268 String::from_utf8_lossy(&id_bytes).to_string(),
269 event,
270 )))
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn round_trips_envelope_json() {
279 let envelope = IngressEventEnvelope {
280 session_id: "s1".to_string(),
281 trigger: AgentTrigger::Message {
282 user_id: "u".to_string(),
283 content: "hello".to_string(),
284 attachments: Vec::new(),
285 },
286 granted_scopes: BTreeSet::from(["tool:run".to_string()]),
287 idempotency_key: Some("abc".to_string()),
288 policy: Some(EnginePolicy::default()),
289 provider: Some(ProviderRequestConfig::default()),
290 };
291 let encoded = serde_json::to_string(&envelope).expect("encode");
292 let decoded: IngressEventEnvelope = serde_json::from_str(&encoded).expect("decode");
293 assert_eq!(decoded, envelope);
294 }
295}