1use crate::Channel;
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::Mutex;
13
14#[derive(Debug, Clone, Hash, PartialEq, Eq)]
16pub struct DraftKey {
17 pub recipient: String,
18 pub channel_name: String,
19}
20
21#[derive(Debug)]
23struct DraftState {
24 message_id: String,
25 last_update: Instant,
26 latest_text: String,
27}
28
29pub struct DraftTracker {
31 drafts: Arc<Mutex<HashMap<DraftKey, DraftState>>>,
32 update_interval: Duration,
33}
34
35impl DraftTracker {
36 pub fn new(update_interval_ms: u64) -> Self {
37 Self {
38 drafts: Arc::new(Mutex::new(HashMap::new())),
39 update_interval: Duration::from_millis(update_interval_ms),
40 }
41 }
42
43 pub async fn start(
48 &self,
49 key: DraftKey,
50 initial_text: &str,
51 channel: &Arc<dyn Channel>,
52 ) -> anyhow::Result<Option<String>> {
53 if !channel.supports_draft_updates() {
54 return Ok(None);
55 }
56
57 let msg = crate::SendMessage::new(initial_text, &key.recipient);
58 let message_id = channel.send_draft(&msg).await?;
59
60 if let Some(id) = &message_id {
61 let state = DraftState {
62 message_id: id.clone(),
63 last_update: Instant::now(),
64 latest_text: initial_text.to_string(),
65 };
66 self.drafts.lock().await.insert(key, state);
67 }
68
69 Ok(message_id)
70 }
71
72 pub async fn update(
78 &self,
79 key: &DraftKey,
80 text: &str,
81 channel: &Arc<dyn Channel>,
82 ) -> anyhow::Result<bool> {
83 let mut drafts = self.drafts.lock().await;
84 let Some(state) = drafts.get_mut(key) else {
85 return Ok(false);
86 };
87
88 state.latest_text = text.to_string();
89
90 if state.last_update.elapsed() < self.update_interval {
91 return Ok(false);
92 }
93
94 channel
95 .update_draft(&key.recipient, &state.message_id, text)
96 .await?;
97 state.last_update = Instant::now();
98
99 Ok(true)
100 }
101
102 pub async fn finalize(
104 &self,
105 key: &DraftKey,
106 final_text: &str,
107 channel: &Arc<dyn Channel>,
108 ) -> anyhow::Result<()> {
109 let state = self.drafts.lock().await.remove(key);
110 if let Some(state) = state {
111 channel
112 .finalize_draft(&key.recipient, &state.message_id, final_text)
113 .await?;
114 }
115 Ok(())
116 }
117
118 pub async fn cancel(&self, key: &DraftKey, channel: &Arc<dyn Channel>) -> anyhow::Result<()> {
120 let state = self.drafts.lock().await.remove(key);
121 if let Some(state) = state {
122 channel
123 .cancel_draft(&key.recipient, &state.message_id)
124 .await?;
125 }
126 Ok(())
127 }
128
129 pub async fn has_draft(&self, key: &DraftKey) -> bool {
131 self.drafts.lock().await.contains_key(key)
132 }
133
134 pub async fn active_count(&self) -> usize {
136 self.drafts.lock().await.len()
137 }
138
139 pub async fn flush(&self, key: &DraftKey, channel: &Arc<dyn Channel>) -> anyhow::Result<bool> {
141 let mut drafts = self.drafts.lock().await;
142 let Some(state) = drafts.get_mut(key) else {
143 return Ok(false);
144 };
145
146 let text = state.latest_text.clone();
147 channel
148 .update_draft(&key.recipient, &state.message_id, &text)
149 .await?;
150 state.last_update = Instant::now();
151
152 Ok(true)
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use async_trait::async_trait;
160 use std::sync::atomic::{AtomicU32, Ordering};
161
162 struct MockDraftChannel {
164 draft_sends: AtomicU32,
165 draft_updates: AtomicU32,
166 draft_finalizes: AtomicU32,
167 draft_cancels: AtomicU32,
168 }
169
170 impl MockDraftChannel {
171 fn new() -> Self {
172 Self {
173 draft_sends: AtomicU32::new(0),
174 draft_updates: AtomicU32::new(0),
175 draft_finalizes: AtomicU32::new(0),
176 draft_cancels: AtomicU32::new(0),
177 }
178 }
179 }
180
181 #[async_trait]
182 impl Channel for MockDraftChannel {
183 fn name(&self) -> &str {
184 "mock-draft"
185 }
186
187 async fn send(&self, _message: &crate::SendMessage) -> anyhow::Result<()> {
188 Ok(())
189 }
190
191 async fn listen(
192 &self,
193 _tx: tokio::sync::mpsc::Sender<crate::ChannelMessage>,
194 ) -> anyhow::Result<()> {
195 Ok(())
196 }
197
198 fn supports_draft_updates(&self) -> bool {
199 true
200 }
201
202 async fn send_draft(
203 &self,
204 _message: &crate::SendMessage,
205 ) -> anyhow::Result<Option<String>> {
206 self.draft_sends.fetch_add(1, Ordering::SeqCst);
207 Ok(Some(format!(
208 "draft-{}",
209 self.draft_sends.load(Ordering::SeqCst)
210 )))
211 }
212
213 async fn update_draft(
214 &self,
215 _recipient: &str,
216 _message_id: &str,
217 _text: &str,
218 ) -> anyhow::Result<Option<String>> {
219 self.draft_updates.fetch_add(1, Ordering::SeqCst);
220 Ok(None)
221 }
222
223 async fn finalize_draft(
224 &self,
225 _recipient: &str,
226 _message_id: &str,
227 _text: &str,
228 ) -> anyhow::Result<()> {
229 self.draft_finalizes.fetch_add(1, Ordering::SeqCst);
230 Ok(())
231 }
232
233 async fn cancel_draft(&self, _recipient: &str, _message_id: &str) -> anyhow::Result<()> {
234 self.draft_cancels.fetch_add(1, Ordering::SeqCst);
235 Ok(())
236 }
237 }
238
239 struct NoDraftChannel;
241
242 #[async_trait]
243 impl Channel for NoDraftChannel {
244 fn name(&self) -> &str {
245 "no-draft"
246 }
247 async fn send(&self, _message: &crate::SendMessage) -> anyhow::Result<()> {
248 Ok(())
249 }
250 async fn listen(
251 &self,
252 _tx: tokio::sync::mpsc::Sender<crate::ChannelMessage>,
253 ) -> anyhow::Result<()> {
254 Ok(())
255 }
256 }
257
258 fn test_key() -> DraftKey {
259 DraftKey {
260 recipient: "user-1".into(),
261 channel_name: "mock-draft".into(),
262 }
263 }
264
265 #[tokio::test]
266 async fn start_creates_draft_and_tracks_it() {
267 let ch: Arc<dyn Channel> = Arc::new(MockDraftChannel::new());
268 let tracker = DraftTracker::new(500);
269 let key = test_key();
270
271 let id = tracker.start(key.clone(), "hello", &ch).await.unwrap();
272 assert!(id.is_some());
273 assert!(tracker.has_draft(&key).await);
274 assert_eq!(tracker.active_count().await, 1);
275 }
276
277 #[tokio::test]
278 async fn start_returns_none_for_non_draft_channel() {
279 let ch: Arc<dyn Channel> = Arc::new(NoDraftChannel);
280 let tracker = DraftTracker::new(500);
281 let key = test_key();
282
283 let id = tracker.start(key.clone(), "hello", &ch).await.unwrap();
284 assert!(id.is_none());
285 assert!(!tracker.has_draft(&key).await);
286 }
287
288 #[tokio::test]
289 async fn update_throttles_rapid_calls() {
290 let mock = Arc::new(MockDraftChannel::new());
291 let ch: Arc<dyn Channel> = mock.clone();
292 let tracker = DraftTracker::new(1000); let key = test_key();
294
295 tracker.start(key.clone(), "initial", &ch).await.unwrap();
296
297 let sent = tracker.update(&key, "update-1", &ch).await.unwrap();
299 assert!(!sent, "first rapid update should be throttled");
300
301 assert_eq!(mock.draft_updates.load(Ordering::SeqCst), 0);
303 }
304
305 #[tokio::test]
306 async fn update_sends_after_interval() {
307 let mock = Arc::new(MockDraftChannel::new());
308 let ch: Arc<dyn Channel> = mock.clone();
309 let tracker = DraftTracker::new(50); let key = test_key();
311
312 tracker.start(key.clone(), "initial", &ch).await.unwrap();
313
314 tokio::time::sleep(Duration::from_millis(60)).await;
316
317 let sent = tracker.update(&key, "update-1", &ch).await.unwrap();
318 assert!(sent, "update after interval should succeed");
319 assert_eq!(mock.draft_updates.load(Ordering::SeqCst), 1);
320 }
321
322 #[tokio::test]
323 async fn finalize_sends_final_and_removes_tracking() {
324 let mock = Arc::new(MockDraftChannel::new());
325 let ch: Arc<dyn Channel> = mock.clone();
326 let tracker = DraftTracker::new(500);
327 let key = test_key();
328
329 tracker.start(key.clone(), "initial", &ch).await.unwrap();
330 assert!(tracker.has_draft(&key).await);
331
332 tracker.finalize(&key, "final text", &ch).await.unwrap();
333 assert!(!tracker.has_draft(&key).await);
334 assert_eq!(mock.draft_finalizes.load(Ordering::SeqCst), 1);
335 }
336
337 #[tokio::test]
338 async fn cancel_removes_tracking_and_notifies() {
339 let mock = Arc::new(MockDraftChannel::new());
340 let ch: Arc<dyn Channel> = mock.clone();
341 let tracker = DraftTracker::new(500);
342 let key = test_key();
343
344 tracker.start(key.clone(), "initial", &ch).await.unwrap();
345 tracker.cancel(&key, &ch).await.unwrap();
346
347 assert!(!tracker.has_draft(&key).await);
348 assert_eq!(mock.draft_cancels.load(Ordering::SeqCst), 1);
349 }
350
351 #[tokio::test]
352 async fn finalize_without_start_is_noop() {
353 let ch: Arc<dyn Channel> = Arc::new(MockDraftChannel::new());
354 let tracker = DraftTracker::new(500);
355 let key = test_key();
356
357 tracker.finalize(&key, "text", &ch).await.unwrap();
359 }
360
361 #[tokio::test]
362 async fn flush_sends_regardless_of_throttle() {
363 let mock = Arc::new(MockDraftChannel::new());
364 let ch: Arc<dyn Channel> = mock.clone();
365 let tracker = DraftTracker::new(10_000); let key = test_key();
367
368 tracker.start(key.clone(), "initial", &ch).await.unwrap();
369
370 tracker.update(&key, "latest", &ch).await.unwrap();
372
373 let flushed = tracker.flush(&key, &ch).await.unwrap();
375 assert!(flushed);
376 assert_eq!(mock.draft_updates.load(Ordering::SeqCst), 1);
377 }
378
379 #[tokio::test]
380 async fn full_lifecycle_send_update_finalize() {
381 let mock = Arc::new(MockDraftChannel::new());
382 let ch: Arc<dyn Channel> = mock.clone();
383 let tracker = DraftTracker::new(10); let key = test_key();
385
386 let id = tracker
388 .start(key.clone(), "thinking...", &ch)
389 .await
390 .unwrap();
391 assert!(id.is_some());
392 assert_eq!(mock.draft_sends.load(Ordering::SeqCst), 1);
393
394 tokio::time::sleep(Duration::from_millis(20)).await;
396 tracker.update(&key, "thinking... more", &ch).await.unwrap();
397 assert_eq!(mock.draft_updates.load(Ordering::SeqCst), 1);
398
399 tracker
401 .finalize(&key, "Here is the answer.", &ch)
402 .await
403 .unwrap();
404 assert_eq!(mock.draft_finalizes.load(Ordering::SeqCst), 1);
405 assert_eq!(tracker.active_count().await, 0);
406 }
407}