1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use super::{DurabilityError, DurableStore};
6
7mod codec;
8#[cfg(test)]
9mod tests;
10
11const READ_BATCH_SIZE: usize = 1_024;
12
13#[derive(Clone, Debug, PartialEq, Eq)]
15pub enum ConversationEvent {
16 MessageReceived {
18 message_id: String,
20 received_at: u64,
22 },
23 ProcessingStarted {
25 message_id: String,
27 },
28 StepCompleted {
30 message_id: String,
32 step_index: u32,
34 output: Vec<u8>,
36 },
37 ProcessingFinished {
39 message_id: String,
41 },
42 ErrorOccurred {
44 message_id: String,
46 error: String,
48 },
49}
50
51impl ConversationEvent {
52 #[must_use]
54 pub fn message_id(&self) -> &str {
55 match self {
56 Self::MessageReceived { message_id, .. }
57 | Self::ProcessingStarted { message_id }
58 | Self::StepCompleted { message_id, .. }
59 | Self::ProcessingFinished { message_id }
60 | Self::ErrorOccurred { message_id, .. } => message_id,
61 }
62 }
63}
64
65#[derive(Clone, Debug, Default, PartialEq, Eq)]
67pub struct ConversationState {
68 pub received_messages: HashSet<String>,
70 pub in_progress: HashSet<String>,
72 pub completed_steps: HashMap<(String, u32), Vec<u8>>,
74 pub finished_messages: HashSet<String>,
76 pub errored_messages: HashMap<String, String>,
78}
79
80impl ConversationState {
81 #[must_use]
83 pub fn replay(events: &[ConversationEvent]) -> Self {
84 let mut state = Self::default();
85 for event in events {
86 state.apply(event);
87 }
88 state
89 }
90
91 pub fn apply(&mut self, event: &ConversationEvent) {
93 match event {
94 ConversationEvent::MessageReceived { message_id, .. } => {
95 self.received_messages.insert(message_id.clone());
96 }
97 ConversationEvent::ProcessingStarted { message_id } => {
98 self.in_progress.insert(message_id.clone());
99 }
100 ConversationEvent::StepCompleted {
101 message_id,
102 step_index,
103 output,
104 } => {
105 self.completed_steps
106 .insert((message_id.clone(), *step_index), output.clone());
107 }
108 ConversationEvent::ProcessingFinished { message_id } => {
109 self.finished_messages.insert(message_id.clone());
110 self.in_progress.remove(message_id);
111 }
112 ConversationEvent::ErrorOccurred { message_id, error } => {
113 self.errored_messages
114 .insert(message_id.clone(), error.clone());
115 self.in_progress.remove(message_id);
116 }
117 }
118 }
119
120 #[must_use]
122 pub fn is_fully_processed(&self, message_id: &str) -> bool {
123 self.finished_messages.contains(message_id)
124 }
125
126 #[must_use]
128 pub fn last_completed_step(&self, message_id: &str) -> Option<u32> {
129 self.completed_steps
130 .keys()
131 .filter(|(stored_id, _)| stored_id.as_str() == message_id)
132 .map(|(_, step_index)| *step_index)
133 .max()
134 }
135
136 fn next_step_index(&self, message_id: &str) -> Result<u32, DurabilityError> {
137 self.last_completed_step(message_id).map_or(Ok(0), |step| {
138 step.checked_add(1).ok_or_else(|| {
139 DurabilityError::ConfigError("conversation step index overflow".to_owned())
140 })
141 })
142 }
143}
144
145#[derive(Clone, Debug, PartialEq, Eq)]
147pub enum RedeliveryDecision {
148 Skip,
150 ResumeFrom(u32),
152 Start,
154}
155
156#[derive(Clone)]
158pub struct DurableConversation {
159 conversation_id: String,
160 store: Arc<dyn DurableStore>,
161 state: ConversationState,
162 expected_seq: u64,
163}
164
165impl DurableConversation {
166 #[must_use]
168 pub fn new(conversation_id: impl Into<String>, store: Arc<dyn DurableStore>) -> Self {
169 Self {
170 conversation_id: conversation_id.into(),
171 store,
172 state: ConversationState::default(),
173 expected_seq: 0,
174 }
175 }
176
177 pub async fn recover(
183 conversation_id: impl Into<String>,
184 store: Arc<dyn DurableStore>,
185 ) -> Result<Self, DurabilityError> {
186 let conversation_id = conversation_id.into();
187 let (state, expected_seq) = replay_stream(store.as_ref(), &conversation_id).await?;
188 Ok(Self {
189 conversation_id,
190 store,
191 state,
192 expected_seq,
193 })
194 }
195
196 #[must_use]
198 pub fn conversation_id(&self) -> &str {
199 &self.conversation_id
200 }
201
202 #[must_use]
204 pub const fn state(&self) -> &ConversationState {
205 &self.state
206 }
207
208 #[must_use]
210 pub const fn expected_seq(&self) -> u64 {
211 self.expected_seq
212 }
213
214 pub async fn receive_message(
221 &mut self,
222 message_id: impl Into<String>,
223 received_at: u64,
224 ) -> Result<RedeliveryDecision, DurabilityError> {
225 let message_id = message_id.into();
226 if self.state.is_fully_processed(&message_id) {
227 return Ok(RedeliveryDecision::Skip);
228 }
229 if self.state.received_messages.contains(&message_id) {
230 return Ok(RedeliveryDecision::ResumeFrom(
231 self.state.next_step_index(&message_id)?,
232 ));
233 }
234 self.record_message_received(message_id, received_at)
235 .await?;
236 Ok(RedeliveryDecision::Start)
237 }
238
239 pub async fn record_message_received(
246 &mut self,
247 message_id: impl Into<String>,
248 received_at: u64,
249 ) -> Result<u64, DurabilityError> {
250 self.append_event(ConversationEvent::MessageReceived {
251 message_id: message_id.into(),
252 received_at,
253 })
254 .await
255 }
256
257 pub async fn record_processing_started(
264 &mut self,
265 message_id: impl Into<String>,
266 ) -> Result<u64, DurabilityError> {
267 self.append_event(ConversationEvent::ProcessingStarted {
268 message_id: message_id.into(),
269 })
270 .await
271 }
272
273 pub async fn record_step_completed(
280 &mut self,
281 message_id: impl Into<String>,
282 step_index: u32,
283 output: Vec<u8>,
284 ) -> Result<u64, DurabilityError> {
285 self.append_event(ConversationEvent::StepCompleted {
286 message_id: message_id.into(),
287 step_index,
288 output,
289 })
290 .await
291 }
292
293 pub async fn record_processing_finished(
300 &mut self,
301 message_id: impl Into<String>,
302 ) -> Result<u64, DurabilityError> {
303 self.append_event(ConversationEvent::ProcessingFinished {
304 message_id: message_id.into(),
305 })
306 .await
307 }
308
309 pub async fn record_error(
316 &mut self,
317 message_id: impl Into<String>,
318 error: impl Into<String>,
319 ) -> Result<u64, DurabilityError> {
320 self.append_event(ConversationEvent::ErrorOccurred {
321 message_id: message_id.into(),
322 error: error.into(),
323 })
324 .await
325 }
326
327 async fn append_event(&mut self, event: ConversationEvent) -> Result<u64, DurabilityError> {
328 let payload = event.serialize()?;
329 let assigned_seq = self
330 .store
331 .append(&self.conversation_id, payload, self.expected_seq)
332 .await?;
333 self.expected_seq = assigned_seq.checked_add(1).ok_or_else(|| {
334 DurabilityError::ConfigError(
335 "sequence number overflow after conversation append".to_owned(),
336 )
337 })?;
338 self.state.apply(&event);
339 Ok(assigned_seq)
340 }
341}
342
343impl fmt::Debug for DurableConversation {
344 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
345 formatter
346 .debug_struct("DurableConversation")
347 .field("conversation_id", &self.conversation_id)
348 .field("state", &self.state)
349 .field("expected_seq", &self.expected_seq)
350 .field("store", &self.store)
351 .finish()
352 }
353}
354
355async fn replay_stream(
356 store: &dyn DurableStore,
357 conversation_id: &str,
358) -> Result<(ConversationState, u64), DurabilityError> {
359 let mut state = ConversationState::default();
360 let mut offset = 0;
361 let mut last_sequence = None;
362 loop {
363 let batch = store
364 .read_from(conversation_id, offset, READ_BATCH_SIZE)
365 .await?;
366 let batch_len = batch.len();
367 if batch_len == 0 {
368 break;
369 }
370 for stored in &batch {
371 let event = ConversationEvent::deserialize(&stored.payload)?;
372 state.apply(&event);
373 last_sequence = Some(stored.sequence);
374 }
375 offset = offset.checked_add(len_to_u64(batch_len)?).ok_or_else(|| {
376 DurabilityError::ConfigError("conversation read offset overflow".to_owned())
377 })?;
378 if batch_len < READ_BATCH_SIZE {
379 break;
380 }
381 }
382 let expected_seq = last_sequence.map_or(Ok(0), |sequence| {
383 sequence.checked_add(1).ok_or_else(|| {
384 DurabilityError::ConfigError(
385 "sequence number overflow after conversation replay".to_owned(),
386 )
387 })
388 })?;
389 Ok((state, expected_seq))
390}
391
392fn len_to_u64(len: usize) -> Result<u64, DurabilityError> {
393 u64::try_from(len).map_err(|error| {
394 DurabilityError::ConfigError(format!("conversation entry count cannot fit u64: {error}"))
395 })
396}