swink_agent/agent/
control.rs1use std::sync::Arc;
2use std::sync::atomic::Ordering;
3
4use tokio::sync::Notify;
5use tracing::info;
6
7use super::Agent;
8
9async fn wait_for_idle_future<F>(
10 notify: Arc<Notify>,
11 active: Arc<std::sync::atomic::AtomicBool>,
12 after_register: F,
13) where
14 F: Fn() + Send + Sync + 'static,
15{
16 loop {
17 let notified = notify.notified();
18 after_register();
19 if !active.load(Ordering::Acquire) {
20 return;
21 }
22 notified.await;
23 }
24}
25
26impl Agent {
27 pub(super) fn clear_transient_runtime_state(&mut self) {
28 self.state.is_running = false;
29 self.state.stream_message = None;
30 self.state.pending_tool_calls.clear();
31 self.state.error = None;
32 self.abort_controller = None;
33 self.in_flight_llm_messages = None;
34 self.in_flight_messages = None;
35 self.pending_message_snapshot.clear();
36 self.loop_context_snapshot.clear();
37 }
38
39 pub fn abort(&mut self) {
41 if let Some(ref token) = self.abort_controller {
42 info!("aborting agent loop");
43 token.cancel();
44 }
45 }
46
47 pub fn reset(&mut self) {
53 if let Some(ref token) = self.abort_controller {
56 token.cancel();
57 }
58
59 self.loop_generation.fetch_add(1, Ordering::AcqRel);
62
63 self.state.messages.clear();
64 self.loop_active.store(false, Ordering::Release);
65 self.clear_transient_runtime_state();
66 self.clear_queues();
67 self.idle_notify.notify_waiters();
68 }
69
70 pub fn wait_for_idle(&self) -> impl Future<Output = ()> + Send + '_ {
75 wait_for_idle_future(
76 Arc::clone(&self.idle_notify),
77 Arc::clone(&self.loop_active),
78 || {},
79 )
80 }
81}
82
83#[cfg(all(test, feature = "testkit"))]
84mod tests {
85 use std::sync::Arc;
86 use std::sync::atomic::{AtomicBool, Ordering};
87 use std::task::Poll;
88
89 use futures::pin_mut;
90 use tokio::sync::Notify;
91
92 use crate::agent_options::AgentOptions;
93 use crate::stream::StreamFn;
94 use crate::testing::{
95 MockStreamFn, default_convert, default_model, text_only_events, user_msg,
96 };
97
98 use super::{Agent, wait_for_idle_future};
99
100 #[tokio::test]
101 async fn wait_for_idle_returns_when_idle_transition_happens_after_registration() {
102 let notify = Arc::new(Notify::new());
103 let active = Arc::new(AtomicBool::new(true));
104 let active_for_hook = Arc::clone(&active);
105 let notify_for_hook = Arc::clone(¬ify);
106
107 let wait_for_idle = wait_for_idle_future(notify, active, move || {
108 active_for_hook.store(false, Ordering::Release);
109 notify_for_hook.notify_waiters();
110 });
111 pin_mut!(wait_for_idle);
112
113 assert!(matches!(
114 futures::poll!(wait_for_idle.as_mut()),
115 Poll::Ready(())
116 ));
117 }
118
119 #[tokio::test]
120 async fn wait_for_idle_stays_pending_until_idle_notification() {
121 let notify = Arc::new(Notify::new());
122 let active = Arc::new(AtomicBool::new(true));
123 let active_for_assert = Arc::clone(&active);
124
125 let wait_for_idle = wait_for_idle_future(Arc::clone(¬ify), Arc::clone(&active), || {});
126 pin_mut!(wait_for_idle);
127
128 assert!(matches!(
129 futures::poll!(wait_for_idle.as_mut()),
130 Poll::Pending
131 ));
132 assert!(active_for_assert.load(Ordering::Acquire));
133
134 active.store(false, Ordering::Release);
135 notify.notify_waiters();
136
137 assert!(matches!(
138 futures::poll!(wait_for_idle.as_mut()),
139 Poll::Ready(())
140 ));
141 }
142
143 #[tokio::test]
144 async fn reset_notifies_pending_wait_for_idle_waiters() {
145 let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("done")]));
146 let mut agent = Agent::new(AgentOptions::new(
147 "sys",
148 default_model(),
149 stream_fn as Arc<dyn StreamFn>,
150 default_convert,
151 ));
152
153 let _stream = agent
154 .prompt_stream(vec![user_msg("hi")])
155 .expect("prompt_stream should start a loop");
156
157 let wait_for_idle = wait_for_idle_future(
158 Arc::clone(&agent.idle_notify),
159 Arc::clone(&agent.loop_active),
160 || {},
161 );
162 pin_mut!(wait_for_idle);
163
164 assert!(matches!(
165 futures::poll!(wait_for_idle.as_mut()),
166 Poll::Pending
167 ));
168
169 agent.reset();
170
171 assert!(matches!(
172 futures::poll!(wait_for_idle.as_mut()),
173 Poll::Ready(())
174 ));
175 }
176}