1use crate::convert::{
4 ProtoWorkflowId, ProtoWorkflowStatus, WireEnvelope, decode_core_value, encode_core_value,
5};
6use crate::error::WireError;
7
8#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
10pub struct SubscriptionRequest {
11 #[prost(oneof = "subscription_request::Subscription", tags = "1, 2, 3")]
13 pub subscription: Option<subscription_request::Subscription>,
14}
15
16pub mod subscription_request {
18 #[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Oneof)]
20 pub enum Subscription {
21 #[prost(message, tag = "1")]
23 PerWorkflow(super::PerWorkflowSubscription),
24 #[prost(message, tag = "2")]
26 Filtered(super::FilteredSubscription),
27 #[prost(message, tag = "3")]
29 Firehose(super::FirehoseSubscription),
30 }
31}
32
33#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
35pub struct PerWorkflowSubscription {
36 #[prost(string, tag = "1")]
38 pub namespace: String,
39 #[prost(message, optional, tag = "2")]
41 pub workflow_id: Option<ProtoWorkflowId>,
42 #[prost(uint64, optional, tag = "3")]
58 pub resume_from_seq: Option<u64>,
59}
60
61#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
68pub struct FilteredSubscription {
69 #[prost(string, tag = "1")]
71 pub namespace: String,
72 #[prost(string, optional, tag = "2")]
74 pub workflow_type: Option<String>,
75 #[prost(enumeration = "ProtoWorkflowStatus", optional, tag = "3")]
77 pub status: Option<i32>,
78 #[prost(string, optional, tag = "4")]
80 pub namespace_selector: Option<String>,
81}
82
83#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
90pub struct FirehoseSubscription {
91 #[prost(string, tag = "1")]
93 pub namespace: String,
94}
95
96#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
98pub struct StreamedEvent {
99 #[prost(string, tag = "1")]
101 pub namespace: String,
102 #[prost(message, optional, tag = "2")]
104 pub event: Option<WireEnvelope>,
105}
106
107impl StreamedEvent {
108 pub fn encode(
115 namespace: impl Into<String>,
116 request_id: Option<String>,
117 event: &aion_core::Event,
118 ) -> Result<Self, WireError> {
119 let namespace = namespace.into();
120 let event = encode_core_value(namespace.clone(), request_id, event)?;
121 Ok(Self {
122 namespace,
123 event: Some(event),
124 })
125 }
126
127 pub fn decode_event(&self) -> Result<aion_core::Event, WireError> {
135 let event = self
136 .event
137 .as_ref()
138 .ok_or_else(|| WireError::backend("streamed event envelope is missing"))?;
139 if event.namespace != self.namespace {
140 return Err(WireError::backend("streamed event namespace mismatch"));
141 }
142 decode_core_value(event)
143 }
144}
145
146pub fn encode_streamed_event(
152 namespace: impl Into<String>,
153 request_id: Option<String>,
154 event: &aion_core::Event,
155) -> Result<StreamedEvent, WireError> {
156 StreamedEvent::encode(namespace, request_id, event)
157}
158
159#[cfg(test)]
160mod tests {
161 use chrono::{DateTime, Utc};
162 use prost::Message;
163 use serde_json::json;
164
165 use super::{
166 FilteredSubscription, FirehoseSubscription, PerWorkflowSubscription, StreamedEvent,
167 SubscriptionRequest, encode_streamed_event, subscription_request,
168 };
169 use crate::convert::{ProtoWorkflowId, ProtoWorkflowStatus, WireEnvelope};
170 use crate::error::WireError;
171
172 fn workflow_id() -> aion_core::WorkflowId {
173 aion_core::WorkflowId::new(uuid::Uuid::nil())
174 }
175
176 fn recorded_at() -> Result<DateTime<Utc>, chrono::ParseError> {
177 Ok(DateTime::parse_from_rfc3339("2026-01-01T00:00:00Z")?.with_timezone(&Utc))
178 }
179
180 fn event_envelope() -> Result<aion_core::EventEnvelope, chrono::ParseError> {
181 Ok(aion_core::EventEnvelope {
182 seq: 1,
183 recorded_at: recorded_at()?,
184 workflow_id: workflow_id(),
185 })
186 }
187
188 #[test]
189 fn subscription_request_round_trips_all_variants() -> Result<(), Box<dyn std::error::Error>> {
190 let requests = [
191 SubscriptionRequest {
192 subscription: Some(subscription_request::Subscription::PerWorkflow(
193 PerWorkflowSubscription {
194 namespace: String::from("tenant-a"),
195 workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
196 resume_from_seq: None,
197 },
198 )),
199 },
200 SubscriptionRequest {
201 subscription: Some(subscription_request::Subscription::PerWorkflow(
202 PerWorkflowSubscription {
203 namespace: String::from("tenant-a"),
204 workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
205 resume_from_seq: Some(42),
206 },
207 )),
208 },
209 SubscriptionRequest {
210 subscription: Some(subscription_request::Subscription::Filtered(
211 FilteredSubscription {
212 namespace: String::from("tenant-a"),
213 workflow_type: Some(String::from("checkout")),
214 status: Some(ProtoWorkflowStatus::Running as i32),
215 namespace_selector: Some(String::from("tenant-a")),
216 },
217 )),
218 },
219 SubscriptionRequest {
220 subscription: Some(subscription_request::Subscription::Filtered(
221 FilteredSubscription {
222 namespace: String::from("tenant-a"),
223 workflow_type: None,
224 status: None,
225 namespace_selector: None,
226 },
227 )),
228 },
229 SubscriptionRequest {
230 subscription: Some(subscription_request::Subscription::Firehose(
231 FirehoseSubscription {
232 namespace: String::from("tenant-a"),
233 },
234 )),
235 },
236 ];
237
238 for request in requests {
239 let json = serde_json::to_vec(&request)?;
240 let from_json: SubscriptionRequest = serde_json::from_slice(&json)?;
241 assert_eq!(from_json, request);
242
243 let bytes = request.encode_to_vec();
244 let from_proto = SubscriptionRequest::decode(bytes.as_slice())?;
245 assert_eq!(from_proto, request);
246 }
247
248 Ok(())
249 }
250
251 #[test]
252 fn per_workflow_resume_cursor_round_trips_prost() -> Result<(), Box<dyn std::error::Error>> {
253 let with_cursor = PerWorkflowSubscription {
254 namespace: String::from("tenant-a"),
255 workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
256 resume_from_seq: Some(7),
257 };
258 let decoded = PerWorkflowSubscription::decode(with_cursor.encode_to_vec().as_slice())?;
259 assert_eq!(decoded, with_cursor);
260 assert_eq!(decoded.resume_from_seq, Some(7));
261
262 let without_cursor = PerWorkflowSubscription {
263 namespace: String::from("tenant-a"),
264 workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
265 resume_from_seq: None,
266 };
267 let decoded = PerWorkflowSubscription::decode(without_cursor.encode_to_vec().as_slice())?;
268 assert_eq!(decoded, without_cursor);
269 assert_eq!(decoded.resume_from_seq, None);
270
271 Ok(())
272 }
273
274 #[test]
275 fn per_workflow_resume_cursor_json_shape_is_pinned() -> Result<(), Box<dyn std::error::Error>> {
276 let with_cursor = PerWorkflowSubscription {
277 namespace: String::from("tenant-a"),
278 workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
279 resume_from_seq: Some(7),
280 };
281 let value = serde_json::to_value(&with_cursor)?;
282 assert_eq!(
283 value,
284 json!({
285 "namespace": "tenant-a",
286 "workflow_id": { "uuid": "00000000-0000-0000-0000-000000000000" },
287 "resume_from_seq": 7,
288 })
289 );
290 let from_json: PerWorkflowSubscription = serde_json::from_value(value)?;
291 assert_eq!(from_json, with_cursor);
292
293 let without_cursor = PerWorkflowSubscription {
294 namespace: String::from("tenant-a"),
295 workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
296 resume_from_seq: None,
297 };
298 let value = serde_json::to_value(&without_cursor)?;
299 assert_eq!(
300 value,
301 json!({
302 "namespace": "tenant-a",
303 "workflow_id": { "uuid": "00000000-0000-0000-0000-000000000000" },
304 "resume_from_seq": null,
305 })
306 );
307 let from_json: PerWorkflowSubscription = serde_json::from_value(value)?;
308 assert_eq!(from_json, without_cursor);
309
310 Ok(())
311 }
312
313 #[test]
314 fn subscription_request_without_resume_field_decodes_to_none()
315 -> Result<(), Box<dyn std::error::Error>> {
316 let request: SubscriptionRequest = serde_json::from_value(json!({
317 "subscription": {
318 "PerWorkflow": {
319 "namespace": "tenant-a",
320 "workflow_id": { "uuid": "00000000-0000-0000-0000-000000000000" },
321 }
322 }
323 }))?;
324
325 let Some(subscription_request::Subscription::PerWorkflow(per_workflow)) =
326 request.subscription
327 else {
328 return Err(Box::from("expected a per-workflow subscription"));
329 };
330 assert_eq!(per_workflow.namespace, "tenant-a");
331 assert_eq!(
332 per_workflow.workflow_id,
333 Some(ProtoWorkflowId::from(workflow_id()))
334 );
335 assert_eq!(per_workflow.resume_from_seq, None);
336
337 Ok(())
338 }
339
340 #[test]
341 fn streamed_event_round_trips_core_event() -> Result<(), Box<dyn std::error::Error>> {
342 let event = aion_core::Event::WorkflowStarted {
343 envelope: event_envelope()?,
344 workflow_type: String::from("checkout"),
345 input: aion_core::Payload::from_json(&json!({ "cart": ["sku-1"] }))?,
346 run_id: aion_core::RunId::new(uuid::Uuid::from_u128(1)),
347 parent_run_id: None,
348 package_version: aion_core::PackageVersion::new("a".repeat(64)),
349 };
350
351 let frame = encode_streamed_event("tenant-a", Some(String::from("request-1")), &event)?;
352 assert_eq!(frame.namespace, "tenant-a");
353 let envelope = frame
354 .event
355 .as_ref()
356 .ok_or_else(|| WireError::backend("test streamed event envelope is missing"))?;
357 assert_eq!(envelope.namespace, "tenant-a");
358 assert_eq!(envelope.request_id.as_deref(), Some("request-1"));
359
360 let decoded = frame.decode_event()?;
361 assert_eq!(decoded, event);
362 Ok(())
363 }
364
365 #[test]
366 fn streamed_event_rejects_namespace_mismatch() {
367 let frame = StreamedEvent {
368 namespace: String::from("tenant-a"),
369 event: Some(WireEnvelope {
370 namespace: String::from("tenant-b"),
371 request_id: None,
372 payload: None,
373 }),
374 };
375
376 assert_eq!(
377 frame.decode_event(),
378 Err(WireError::backend("streamed event namespace mismatch"))
379 );
380 }
381
382 #[test]
383 fn streamed_event_rejects_missing_envelope() {
384 let frame = StreamedEvent {
385 namespace: String::from("tenant-a"),
386 event: None,
387 };
388
389 assert_eq!(
390 frame.decode_event(),
391 Err(WireError::backend("streamed event envelope is missing"))
392 );
393 }
394}