baml_agent_tui/
agent_task.rs1use baml_agent::{
2 process_step, AgentMessage, LoopConfig, LoopDetector, LoopEvent, MessageRole, Session,
3 SgrAgentStream,
4};
5use std::sync::Arc;
6use tokio::sync::Mutex;
7
8pub enum AgentTaskEvent {
10 StepStart(usize),
12 StreamChunk(String),
14 Decision {
16 situation: String,
17 task: Vec<String>,
18 },
19 ActionStart(String),
21 ActionDone(String),
23 FileModified(String),
25 Trimmed(usize),
27 Warning(String),
29 Error(String),
31 Completed,
33 Done,
35}
36
37pub trait AgentEventHandler: Send + Sync + 'static {
42 fn on_event(&self, event: AgentTaskEvent) -> bool;
44}
45
46pub struct ChannelHandler<T: Send + 'static> {
48 tx: tokio::sync::mpsc::Sender<T>,
49 mapper: Box<dyn Fn(AgentTaskEvent) -> T + Send + Sync>,
50}
51
52impl<T: Send + 'static> ChannelHandler<T> {
53 pub fn new(
54 tx: tokio::sync::mpsc::Sender<T>,
55 mapper: impl Fn(AgentTaskEvent) -> T + Send + Sync + 'static,
56 ) -> Self {
57 Self {
58 tx,
59 mapper: Box::new(mapper),
60 }
61 }
62}
63
64impl<T: Send + 'static> AgentEventHandler for ChannelHandler<T> {
65 fn on_event(&self, event: AgentTaskEvent) -> bool {
66 let mapped = (self.mapper)(event);
67 self.tx.try_send(mapped).is_ok()
68 }
69}
70
71pub trait TuiAgent: SgrAgentStream {
76 fn action_label(action: &Self::Action) -> String {
78 Self::action_signature(action)
79 }
80
81 fn file_modified(_action: &Self::Action) -> Option<String> {
83 None
84 }
85}
86
87pub fn spawn_agent_loop<A, H>(
101 agent: Arc<A>,
102 session: Arc<Mutex<Session<A::Msg>>>,
103 pending_notes: Arc<Mutex<Vec<String>>>,
104 handler: H,
105 config: LoopConfig,
106) -> tokio::task::JoinHandle<()>
107where
108 A: TuiAgent + Send + Sync + 'static,
109 H: AgentEventHandler,
110{
111 tokio::spawn(async move {
112 let result = run_tui_loop(&*agent, &session, &pending_notes, &handler, &config).await;
113
114 if let Err(e) = result {
115 handler.on_event(AgentTaskEvent::Error(format!("Agent error: {}", e)));
116 }
117 handler.on_event(AgentTaskEvent::Done);
118 })
119}
120
121async fn run_tui_loop<A, H>(
123 agent: &A,
124 session: &Mutex<Session<A::Msg>>,
125 pending_notes: &Mutex<Vec<String>>,
126 handler: &H,
127 config: &LoopConfig,
128) -> Result<usize, String>
129where
130 A: TuiAgent + Send + Sync,
131 H: AgentEventHandler,
132{
133 let mut detector = LoopDetector::new(config.loop_abort_threshold);
134
135 for step_num in 1..=config.max_steps {
136 {
138 let notes: Vec<String> = std::mem::take(&mut *pending_notes.lock().await);
139 if !notes.is_empty() {
140 let mut sess = session.lock().await;
141 for note in ¬es {
142 sess.push(
143 <<A::Msg as AgentMessage>::Role>::user(),
144 format!("User note while task is running:\n{}", note),
145 );
146 }
147 handler.on_event(AgentTaskEvent::Warning(format!(
148 "[NOTE] {} queued note(s) injected",
149 notes.len()
150 )));
151 }
152 }
153
154 {
156 let mut sess = session.lock().await;
157 let trimmed = sess.trim();
158 if trimmed > 0 {
159 handler.on_event(AgentTaskEvent::Trimmed(trimmed));
160 }
161 }
162
163 if !handler.on_event(AgentTaskEvent::StepStart(step_num)) {
164 return Ok(step_num); }
166
167 let messages: Vec<A::Msg> = {
169 let sess = session.lock().await;
170 sess.messages().to_vec()
171 };
172
173 let decision = agent
175 .decide_stream(&messages, |token| {
176 handler.on_event(AgentTaskEvent::StreamChunk(token.to_string()));
177 })
178 .await
179 .map_err(|e| format!("{}", e))?;
180
181 let mut sess = session.lock().await;
183
184 let mut on_event = |event: LoopEvent<'_, A::Action>| {
186 match event {
187 LoopEvent::Decision { situation, task } => {
188 handler.on_event(AgentTaskEvent::Decision {
189 situation: situation.to_string(),
190 task: task.to_vec(),
191 });
192 }
193 LoopEvent::Completed => {
194 handler.on_event(AgentTaskEvent::Completed);
195 }
196 LoopEvent::ActionStart(action) => {
197 if let Some(path) = A::file_modified(action) {
198 handler.on_event(AgentTaskEvent::FileModified(path));
199 }
200 handler.on_event(AgentTaskEvent::ActionStart(A::action_label(action)));
201 }
202 LoopEvent::ActionDone(result) => {
203 handler.on_event(AgentTaskEvent::ActionDone(result.output.clone()));
204 }
205 LoopEvent::LoopWarning(n) => {
206 handler.on_event(AgentTaskEvent::Warning(format!(
207 "Loop detected — {} repeats",
208 n
209 )));
210 }
211 LoopEvent::LoopAbort(n) => {
212 handler.on_event(AgentTaskEvent::Error(format!(
213 "Agent stuck after {} identical actions — aborting",
214 n
215 )));
216 }
217 LoopEvent::Trimmed(n) => {
218 handler.on_event(AgentTaskEvent::Trimmed(n));
219 }
220 LoopEvent::MaxStepsReached(n) => {
221 handler.on_event(AgentTaskEvent::Warning(format!(
222 "Max steps ({}) reached",
223 n
224 )));
225 }
226 LoopEvent::StepStart(_) => {} LoopEvent::StreamToken(_) => {} }
229 };
230
231 if let Some(final_step) = process_step(
232 agent,
233 &mut sess,
234 decision,
235 step_num,
236 &mut detector,
237 &mut on_event,
238 )
239 .await
240 .map_err(|e| format!("{}", e))?
241 {
242 return Ok(final_step);
243 }
244 }
245
246 handler.on_event(AgentTaskEvent::Warning(format!(
247 "Max steps ({}) reached",
248 config.max_steps
249 )));
250 Ok(config.max_steps)
251}