swink_agent/message_provider.rs
1//! Trait for polling steering and follow-up messages.
2//!
3//! [`MessageProvider`] replaces inline closures in [`AgentLoopConfig`](crate::loop_::AgentLoopConfig),
4//! giving callers a named, testable abstraction for injecting messages into the
5//! agent loop between turns.
6//!
7//! For push-based messaging, see [`ChannelMessageProvider`] and [`MessageSender`].
8
9use std::sync::Mutex;
10
11use crate::types::AgentMessage;
12
13/// Provides steering and follow-up messages to the agent loop.
14///
15/// Implementors are polled at well-defined points during loop execution:
16/// - [`poll_steering`](Self::poll_steering) is called after each tool execution batch.
17/// - [`poll_follow_up`](Self::poll_follow_up) is called when the agent would otherwise stop.
18pub trait MessageProvider: Send + Sync {
19 /// Return pending steering messages, if any.
20 ///
21 /// Called after tool execution completes. Returning a non-empty vec causes
22 /// a steering interrupt — pending tool calls may be cancelled and the new
23 /// messages are injected into the conversation.
24 fn poll_steering(&self) -> Vec<AgentMessage>;
25
26 /// Return pending follow-up messages, if any.
27 ///
28 /// Called when the model has finished a turn and no tool calls remain.
29 /// Returning a non-empty vec triggers another outer-loop iteration.
30 fn poll_follow_up(&self) -> Vec<AgentMessage>;
31
32 /// Non-draining check for pending steering messages.
33 ///
34 /// Used by tool-dispatch workers to detect steering interrupts early
35 /// without consuming queued messages — the authoritative drain happens
36 /// via [`poll_steering`](Self::poll_steering) in the interrupt collector.
37 ///
38 /// The default implementation returns `false`, so external providers
39 /// that only implement `poll_steering`/`poll_follow_up` will never
40 /// trigger a worker-initiated early interrupt. Built-in channel/queue
41 /// providers override this with a non-draining peek.
42 fn has_steering(&self) -> bool {
43 false
44 }
45}
46
47/// A [`MessageProvider`] built from two closures.
48///
49/// Created via [`from_fns`].
50pub struct FnMessageProvider<S, F>
51where
52 S: Fn() -> Vec<AgentMessage> + Send + Sync,
53 F: Fn() -> Vec<AgentMessage> + Send + Sync,
54{
55 steering: S,
56 follow_up: F,
57}
58
59impl<S, F> MessageProvider for FnMessageProvider<S, F>
60where
61 S: Fn() -> Vec<AgentMessage> + Send + Sync,
62 F: Fn() -> Vec<AgentMessage> + Send + Sync,
63{
64 fn poll_steering(&self) -> Vec<AgentMessage> {
65 (self.steering)()
66 }
67
68 fn poll_follow_up(&self) -> Vec<AgentMessage> {
69 (self.follow_up)()
70 }
71}
72
73/// Create a [`MessageProvider`] from two closures.
74///
75/// # Example
76///
77/// ```
78/// use swink_agent::from_fns;
79///
80/// let provider = from_fns(
81/// || vec![], // no steering messages
82/// || vec![], // no follow-up messages
83/// );
84/// ```
85pub const fn from_fns<S, F>(steering: S, follow_up: F) -> FnMessageProvider<S, F>
86where
87 S: Fn() -> Vec<AgentMessage> + Send + Sync,
88 F: Fn() -> Vec<AgentMessage> + Send + Sync,
89{
90 FnMessageProvider {
91 steering,
92 follow_up,
93 }
94}
95
96// ─── Channel-based MessageProvider ──────────────────────────────────────────
97
98/// A clonable handle for pushing messages into a [`ChannelMessageProvider`].
99///
100/// Obtained from [`message_channel`]. Messages sent through this handle are
101/// delivered as **follow-up** messages by default. Use [`send_steering`](Self::send_steering)
102/// to inject steering messages instead.
103#[derive(Clone)]
104pub struct MessageSender {
105 steering_tx: tokio::sync::mpsc::UnboundedSender<AgentMessage>,
106 follow_up_tx: tokio::sync::mpsc::UnboundedSender<AgentMessage>,
107}
108
109impl MessageSender {
110 /// Push a steering message to the agent.
111 ///
112 /// Steering messages are polled after each tool execution batch and can
113 /// interrupt in-progress tool calls.
114 ///
115 /// Returns `false` if the receiver has been dropped.
116 pub fn send_steering(&self, message: AgentMessage) -> bool {
117 self.steering_tx.send(message).is_ok()
118 }
119
120 /// Push a follow-up message to the agent.
121 ///
122 /// Follow-up messages are polled when the agent would otherwise stop,
123 /// triggering another outer-loop iteration.
124 ///
125 /// Returns `false` if the receiver has been dropped.
126 pub fn send_follow_up(&self, message: AgentMessage) -> bool {
127 self.follow_up_tx.send(message).is_ok()
128 }
129
130 /// Alias for [`send_follow_up`](Self::send_follow_up).
131 pub fn send(&self, message: AgentMessage) -> bool {
132 self.send_follow_up(message)
133 }
134}
135
136impl std::fmt::Debug for MessageSender {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 f.debug_struct("MessageSender").finish_non_exhaustive()
139 }
140}
141
142/// A [`MessageProvider`] backed by tokio unbounded mpsc channels.
143///
144/// Created via [`message_channel`]. External code pushes messages through the
145/// paired [`MessageSender`]; the provider drains them when the agent loop polls.
146pub struct ChannelMessageProvider {
147 steering_rx: Mutex<tokio::sync::mpsc::UnboundedReceiver<AgentMessage>>,
148 follow_up_rx: Mutex<tokio::sync::mpsc::UnboundedReceiver<AgentMessage>>,
149}
150
151impl ChannelMessageProvider {
152 /// Drain all currently buffered messages from a receiver.
153 fn drain_receiver(
154 rx: &Mutex<tokio::sync::mpsc::UnboundedReceiver<AgentMessage>>,
155 ) -> Vec<AgentMessage> {
156 let mut guard = rx.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
157 let mut messages = Vec::new();
158 while let Ok(msg) = guard.try_recv() {
159 messages.push(msg);
160 }
161 messages
162 }
163}
164
165impl MessageProvider for ChannelMessageProvider {
166 fn poll_steering(&self) -> Vec<AgentMessage> {
167 Self::drain_receiver(&self.steering_rx)
168 }
169
170 fn poll_follow_up(&self) -> Vec<AgentMessage> {
171 Self::drain_receiver(&self.follow_up_rx)
172 }
173
174 fn has_steering(&self) -> bool {
175 let guard = self
176 .steering_rx
177 .lock()
178 .unwrap_or_else(std::sync::PoisonError::into_inner);
179 !guard.is_empty()
180 }
181}
182
183/// A [`MessageProvider`] that combines two providers, draining both on each poll.
184///
185/// Messages from the primary provider are returned first, followed by those
186/// from the secondary provider.
187pub struct ComposedMessageProvider {
188 primary: std::sync::Arc<dyn MessageProvider>,
189 secondary: std::sync::Arc<dyn MessageProvider>,
190}
191
192impl ComposedMessageProvider {
193 /// Create a composed provider from two providers.
194 pub fn new(
195 primary: std::sync::Arc<dyn MessageProvider>,
196 secondary: std::sync::Arc<dyn MessageProvider>,
197 ) -> Self {
198 Self { primary, secondary }
199 }
200}
201
202impl MessageProvider for ComposedMessageProvider {
203 fn poll_steering(&self) -> Vec<AgentMessage> {
204 let mut msgs = self.primary.poll_steering();
205 msgs.extend(self.secondary.poll_steering());
206 msgs
207 }
208
209 fn poll_follow_up(&self) -> Vec<AgentMessage> {
210 let mut msgs = self.primary.poll_follow_up();
211 msgs.extend(self.secondary.poll_follow_up());
212 msgs
213 }
214
215 fn has_steering(&self) -> bool {
216 self.primary.has_steering() || self.secondary.has_steering()
217 }
218}
219
220/// Create a channel-backed [`MessageProvider`] and its paired [`MessageSender`].
221///
222/// The returned `ChannelMessageProvider` implements [`MessageProvider`] and can
223/// be passed to [`AgentLoopConfig`](crate::loop_::AgentLoopConfig) or used with
224/// [`AgentOptions::with_message_channel`](crate::AgentOptions::with_message_channel).
225/// The `MessageSender` is a clonable handle that external code uses to push
226/// messages into the agent.
227///
228/// # Example
229///
230/// ```
231/// use swink_agent::message_channel;
232///
233/// let (provider, sender) = message_channel();
234/// // sender.send(msg) pushes a follow-up message
235/// // sender.send_steering(msg) pushes a steering message
236/// ```
237pub fn message_channel() -> (ChannelMessageProvider, MessageSender) {
238 let (steering_tx, steering_rx) = tokio::sync::mpsc::unbounded_channel();
239 let (follow_up_tx, follow_up_rx) = tokio::sync::mpsc::unbounded_channel();
240
241 let provider = ChannelMessageProvider {
242 steering_rx: Mutex::new(steering_rx),
243 follow_up_rx: Mutex::new(follow_up_rx),
244 };
245
246 let sender = MessageSender {
247 steering_tx,
248 follow_up_tx,
249 };
250
251 (provider, sender)
252}