Skip to main content

agentzero_channels/
drafts.rs

1//! Draft lifecycle orchestration: send_draft → update_draft → finalize_draft.
2//!
3//! When `stream_mode` is enabled, the agent sends incremental "draft" messages
4//! that are progressively updated as the response streams in. The `DraftTracker`
5//! manages the state for each in-flight draft and throttles update calls to
6//! respect `draft_update_interval_ms`.
7
8use crate::Channel;
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::Mutex;
13
14/// Identifies a draft by (recipient, channel_name).
15#[derive(Debug, Clone, Hash, PartialEq, Eq)]
16pub struct DraftKey {
17    pub recipient: String,
18    pub channel_name: String,
19}
20
21/// State of an in-flight draft message.
22#[derive(Debug)]
23struct DraftState {
24    message_id: String,
25    last_update: Instant,
26    latest_text: String,
27}
28
29/// Tracks in-flight drafts across channels and throttles updates.
30pub struct DraftTracker {
31    drafts: Arc<Mutex<HashMap<DraftKey, DraftState>>>,
32    update_interval: Duration,
33}
34
35impl DraftTracker {
36    pub fn new(update_interval_ms: u64) -> Self {
37        Self {
38            drafts: Arc::new(Mutex::new(HashMap::new())),
39            update_interval: Duration::from_millis(update_interval_ms),
40        }
41    }
42
43    /// Start a new draft. Calls `channel.send_draft()` and tracks the returned message ID.
44    ///
45    /// Returns `Ok(message_id)` if the channel supports drafts and created one,
46    /// or `Ok(None)` if the channel does not support drafts.
47    pub async fn start(
48        &self,
49        key: DraftKey,
50        initial_text: &str,
51        channel: &Arc<dyn Channel>,
52    ) -> anyhow::Result<Option<String>> {
53        if !channel.supports_draft_updates() {
54            return Ok(None);
55        }
56
57        let msg = crate::SendMessage::new(initial_text, &key.recipient);
58        let message_id = channel.send_draft(&msg).await?;
59
60        if let Some(id) = &message_id {
61            let state = DraftState {
62                message_id: id.clone(),
63                last_update: Instant::now(),
64                latest_text: initial_text.to_string(),
65            };
66            self.drafts.lock().await.insert(key, state);
67        }
68
69        Ok(message_id)
70    }
71
72    /// Update an in-flight draft with new text.
73    ///
74    /// Respects the throttle interval — if called too frequently, the text is
75    /// buffered and only the latest version is sent when the interval elapses.
76    /// Returns `true` if an update was actually sent to the channel.
77    pub async fn update(
78        &self,
79        key: &DraftKey,
80        text: &str,
81        channel: &Arc<dyn Channel>,
82    ) -> anyhow::Result<bool> {
83        let mut drafts = self.drafts.lock().await;
84        let Some(state) = drafts.get_mut(key) else {
85            return Ok(false);
86        };
87
88        state.latest_text = text.to_string();
89
90        if state.last_update.elapsed() < self.update_interval {
91            return Ok(false);
92        }
93
94        channel
95            .update_draft(&key.recipient, &state.message_id, text)
96            .await?;
97        state.last_update = Instant::now();
98
99        Ok(true)
100    }
101
102    /// Finalize a draft — sends the final text and removes tracking.
103    pub async fn finalize(
104        &self,
105        key: &DraftKey,
106        final_text: &str,
107        channel: &Arc<dyn Channel>,
108    ) -> anyhow::Result<()> {
109        let state = self.drafts.lock().await.remove(key);
110        if let Some(state) = state {
111            channel
112                .finalize_draft(&key.recipient, &state.message_id, final_text)
113                .await?;
114        }
115        Ok(())
116    }
117
118    /// Cancel a draft — removes tracking and notifies the channel.
119    pub async fn cancel(&self, key: &DraftKey, channel: &Arc<dyn Channel>) -> anyhow::Result<()> {
120        let state = self.drafts.lock().await.remove(key);
121        if let Some(state) = state {
122            channel
123                .cancel_draft(&key.recipient, &state.message_id)
124                .await?;
125        }
126        Ok(())
127    }
128
129    /// Check if a draft is currently tracked for the given key.
130    pub async fn has_draft(&self, key: &DraftKey) -> bool {
131        self.drafts.lock().await.contains_key(key)
132    }
133
134    /// Get the number of active drafts.
135    pub async fn active_count(&self) -> usize {
136        self.drafts.lock().await.len()
137    }
138
139    /// Flush a pending update — sends the latest buffered text regardless of throttle.
140    pub async fn flush(&self, key: &DraftKey, channel: &Arc<dyn Channel>) -> anyhow::Result<bool> {
141        let mut drafts = self.drafts.lock().await;
142        let Some(state) = drafts.get_mut(key) else {
143            return Ok(false);
144        };
145
146        let text = state.latest_text.clone();
147        channel
148            .update_draft(&key.recipient, &state.message_id, &text)
149            .await?;
150        state.last_update = Instant::now();
151
152        Ok(true)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use async_trait::async_trait;
160    use std::sync::atomic::{AtomicU32, Ordering};
161
162    /// A mock channel that tracks draft method calls.
163    struct MockDraftChannel {
164        draft_sends: AtomicU32,
165        draft_updates: AtomicU32,
166        draft_finalizes: AtomicU32,
167        draft_cancels: AtomicU32,
168    }
169
170    impl MockDraftChannel {
171        fn new() -> Self {
172            Self {
173                draft_sends: AtomicU32::new(0),
174                draft_updates: AtomicU32::new(0),
175                draft_finalizes: AtomicU32::new(0),
176                draft_cancels: AtomicU32::new(0),
177            }
178        }
179    }
180
181    #[async_trait]
182    impl Channel for MockDraftChannel {
183        fn name(&self) -> &str {
184            "mock-draft"
185        }
186
187        async fn send(&self, _message: &crate::SendMessage) -> anyhow::Result<()> {
188            Ok(())
189        }
190
191        async fn listen(
192            &self,
193            _tx: tokio::sync::mpsc::Sender<crate::ChannelMessage>,
194        ) -> anyhow::Result<()> {
195            Ok(())
196        }
197
198        fn supports_draft_updates(&self) -> bool {
199            true
200        }
201
202        async fn send_draft(
203            &self,
204            _message: &crate::SendMessage,
205        ) -> anyhow::Result<Option<String>> {
206            self.draft_sends.fetch_add(1, Ordering::SeqCst);
207            Ok(Some(format!(
208                "draft-{}",
209                self.draft_sends.load(Ordering::SeqCst)
210            )))
211        }
212
213        async fn update_draft(
214            &self,
215            _recipient: &str,
216            _message_id: &str,
217            _text: &str,
218        ) -> anyhow::Result<Option<String>> {
219            self.draft_updates.fetch_add(1, Ordering::SeqCst);
220            Ok(None)
221        }
222
223        async fn finalize_draft(
224            &self,
225            _recipient: &str,
226            _message_id: &str,
227            _text: &str,
228        ) -> anyhow::Result<()> {
229            self.draft_finalizes.fetch_add(1, Ordering::SeqCst);
230            Ok(())
231        }
232
233        async fn cancel_draft(&self, _recipient: &str, _message_id: &str) -> anyhow::Result<()> {
234            self.draft_cancels.fetch_add(1, Ordering::SeqCst);
235            Ok(())
236        }
237    }
238
239    /// A channel that does not support drafts.
240    struct NoDraftChannel;
241
242    #[async_trait]
243    impl Channel for NoDraftChannel {
244        fn name(&self) -> &str {
245            "no-draft"
246        }
247        async fn send(&self, _message: &crate::SendMessage) -> anyhow::Result<()> {
248            Ok(())
249        }
250        async fn listen(
251            &self,
252            _tx: tokio::sync::mpsc::Sender<crate::ChannelMessage>,
253        ) -> anyhow::Result<()> {
254            Ok(())
255        }
256    }
257
258    fn test_key() -> DraftKey {
259        DraftKey {
260            recipient: "user-1".into(),
261            channel_name: "mock-draft".into(),
262        }
263    }
264
265    #[tokio::test]
266    async fn start_creates_draft_and_tracks_it() {
267        let ch: Arc<dyn Channel> = Arc::new(MockDraftChannel::new());
268        let tracker = DraftTracker::new(500);
269        let key = test_key();
270
271        let id = tracker.start(key.clone(), "hello", &ch).await.unwrap();
272        assert!(id.is_some());
273        assert!(tracker.has_draft(&key).await);
274        assert_eq!(tracker.active_count().await, 1);
275    }
276
277    #[tokio::test]
278    async fn start_returns_none_for_non_draft_channel() {
279        let ch: Arc<dyn Channel> = Arc::new(NoDraftChannel);
280        let tracker = DraftTracker::new(500);
281        let key = test_key();
282
283        let id = tracker.start(key.clone(), "hello", &ch).await.unwrap();
284        assert!(id.is_none());
285        assert!(!tracker.has_draft(&key).await);
286    }
287
288    #[tokio::test]
289    async fn update_throttles_rapid_calls() {
290        let mock = Arc::new(MockDraftChannel::new());
291        let ch: Arc<dyn Channel> = mock.clone();
292        let tracker = DraftTracker::new(1000); // 1 second throttle
293        let key = test_key();
294
295        tracker.start(key.clone(), "initial", &ch).await.unwrap();
296
297        // Immediate update should be throttled (interval not elapsed)
298        let sent = tracker.update(&key, "update-1", &ch).await.unwrap();
299        assert!(!sent, "first rapid update should be throttled");
300
301        // Updates counter should still be 0
302        assert_eq!(mock.draft_updates.load(Ordering::SeqCst), 0);
303    }
304
305    #[tokio::test]
306    async fn update_sends_after_interval() {
307        let mock = Arc::new(MockDraftChannel::new());
308        let ch: Arc<dyn Channel> = mock.clone();
309        let tracker = DraftTracker::new(50); // 50ms throttle
310        let key = test_key();
311
312        tracker.start(key.clone(), "initial", &ch).await.unwrap();
313
314        // Wait for throttle interval
315        tokio::time::sleep(Duration::from_millis(60)).await;
316
317        let sent = tracker.update(&key, "update-1", &ch).await.unwrap();
318        assert!(sent, "update after interval should succeed");
319        assert_eq!(mock.draft_updates.load(Ordering::SeqCst), 1);
320    }
321
322    #[tokio::test]
323    async fn finalize_sends_final_and_removes_tracking() {
324        let mock = Arc::new(MockDraftChannel::new());
325        let ch: Arc<dyn Channel> = mock.clone();
326        let tracker = DraftTracker::new(500);
327        let key = test_key();
328
329        tracker.start(key.clone(), "initial", &ch).await.unwrap();
330        assert!(tracker.has_draft(&key).await);
331
332        tracker.finalize(&key, "final text", &ch).await.unwrap();
333        assert!(!tracker.has_draft(&key).await);
334        assert_eq!(mock.draft_finalizes.load(Ordering::SeqCst), 1);
335    }
336
337    #[tokio::test]
338    async fn cancel_removes_tracking_and_notifies() {
339        let mock = Arc::new(MockDraftChannel::new());
340        let ch: Arc<dyn Channel> = mock.clone();
341        let tracker = DraftTracker::new(500);
342        let key = test_key();
343
344        tracker.start(key.clone(), "initial", &ch).await.unwrap();
345        tracker.cancel(&key, &ch).await.unwrap();
346
347        assert!(!tracker.has_draft(&key).await);
348        assert_eq!(mock.draft_cancels.load(Ordering::SeqCst), 1);
349    }
350
351    #[tokio::test]
352    async fn finalize_without_start_is_noop() {
353        let ch: Arc<dyn Channel> = Arc::new(MockDraftChannel::new());
354        let tracker = DraftTracker::new(500);
355        let key = test_key();
356
357        // Should not error
358        tracker.finalize(&key, "text", &ch).await.unwrap();
359    }
360
361    #[tokio::test]
362    async fn flush_sends_regardless_of_throttle() {
363        let mock = Arc::new(MockDraftChannel::new());
364        let ch: Arc<dyn Channel> = mock.clone();
365        let tracker = DraftTracker::new(10_000); // very long throttle
366        let key = test_key();
367
368        tracker.start(key.clone(), "initial", &ch).await.unwrap();
369
370        // Buffer an update (throttled)
371        tracker.update(&key, "latest", &ch).await.unwrap();
372
373        // Flush should send immediately
374        let flushed = tracker.flush(&key, &ch).await.unwrap();
375        assert!(flushed);
376        assert_eq!(mock.draft_updates.load(Ordering::SeqCst), 1);
377    }
378
379    #[tokio::test]
380    async fn full_lifecycle_send_update_finalize() {
381        let mock = Arc::new(MockDraftChannel::new());
382        let ch: Arc<dyn Channel> = mock.clone();
383        let tracker = DraftTracker::new(10); // short throttle for testing
384        let key = test_key();
385
386        // Start
387        let id = tracker
388            .start(key.clone(), "thinking...", &ch)
389            .await
390            .unwrap();
391        assert!(id.is_some());
392        assert_eq!(mock.draft_sends.load(Ordering::SeqCst), 1);
393
394        // Update (after throttle)
395        tokio::time::sleep(Duration::from_millis(20)).await;
396        tracker.update(&key, "thinking... more", &ch).await.unwrap();
397        assert_eq!(mock.draft_updates.load(Ordering::SeqCst), 1);
398
399        // Finalize
400        tracker
401            .finalize(&key, "Here is the answer.", &ch)
402            .await
403            .unwrap();
404        assert_eq!(mock.draft_finalizes.load(Ordering::SeqCst), 1);
405        assert_eq!(tracker.active_count().await, 0);
406    }
407}