Skip to main content

baml_agent_tui/
agent_task.rs

1use baml_agent::{
2    process_step, AgentMessage, LoopConfig, LoopDetector, LoopEvent, MessageRole, Session,
3    SgrAgentStream,
4};
5use std::sync::Arc;
6use tokio::sync::Mutex;
7
8/// Events emitted by the agent task back to the TUI.
9pub enum AgentTaskEvent {
10    /// Step started (1-based).
11    StepStart(usize),
12    /// Streaming text chunk from LLM.
13    StreamChunk(String),
14    /// LLM decision: situation + task (STAR).
15    Decision {
16        situation: String,
17        task: Vec<String>,
18    },
19    /// About to execute an action (human-readable label).
20    ActionStart(String),
21    /// Action executed, result output.
22    ActionDone(String),
23    /// A file was modified by a tool action.
24    FileModified(String),
25    /// Context trimmed.
26    Trimmed(usize),
27    /// Warning (loop detected, etc).
28    Warning(String),
29    /// Error message.
30    Error(String),
31    /// Task completed by LLM.
32    Completed,
33    /// Agent loop finished (always sent last).
34    Done,
35}
36
37/// Callback trait for handling agent events in the TUI.
38///
39/// Implement this to map agent events to your TUI's event system.
40/// Methods are sync — they typically just send to a channel.
41pub trait AgentEventHandler: Send + Sync + 'static {
42    /// Called for each agent task event. Return false to stop the loop.
43    fn on_event(&self, event: AgentTaskEvent) -> bool;
44}
45
46/// Channel-based event handler — maps AgentTaskEvent to your AppEvent type.
47pub struct ChannelHandler<T: Send + 'static> {
48    tx: tokio::sync::mpsc::Sender<T>,
49    mapper: Box<dyn Fn(AgentTaskEvent) -> T + Send + Sync>,
50}
51
52impl<T: Send + 'static> ChannelHandler<T> {
53    pub fn new(
54        tx: tokio::sync::mpsc::Sender<T>,
55        mapper: impl Fn(AgentTaskEvent) -> T + Send + Sync + 'static,
56    ) -> Self {
57        Self {
58            tx,
59            mapper: Box::new(mapper),
60        }
61    }
62}
63
64impl<T: Send + 'static> AgentEventHandler for ChannelHandler<T> {
65    fn on_event(&self, event: AgentTaskEvent) -> bool {
66        let mapped = (self.mapper)(event);
67        self.tx.try_send(mapped).is_ok()
68    }
69}
70
71/// TUI-specific agent extension.
72///
73/// Extends `SgrAgentStream` with optional methods for TUI display.
74/// Default implementations work out of the box — override for richer UI.
75pub trait TuiAgent: SgrAgentStream {
76    /// Human-readable action label for display. Defaults to `action_signature`.
77    fn action_label(action: &Self::Action) -> String {
78        Self::action_signature(action)
79    }
80
81    /// If an action modifies a file, return the path (for TUI refresh / git status).
82    fn file_modified(_action: &Self::Action) -> Option<String> {
83        None
84    }
85}
86
87/// Run the agent loop as a tokio task with TUI event integration.
88///
89/// Key design: session is **unlocked during LLM streaming** (so TUI can
90/// read messages for display) and **locked during action execution**
91/// (which modifies session).
92///
93/// # Arguments
94/// - `agent`: Shared agent ref. Takes `&self`, so `Arc` suffices (no Mutex).
95///   Use interior mutability (Mutex) inside agent for stateful tools (MCP, etc).
96/// - `session`: Shared session behind Mutex. TUI can lock to read messages.
97/// - `pending_notes`: Queue of user messages injected between steps.
98/// - `handler`: Receives `AgentTaskEvent`s — typically a `ChannelHandler`.
99/// - `config`: Loop config (max_steps, loop_abort_threshold).
100pub fn spawn_agent_loop<A, H>(
101    agent: Arc<A>,
102    session: Arc<Mutex<Session<A::Msg>>>,
103    pending_notes: Arc<Mutex<Vec<String>>>,
104    handler: H,
105    config: LoopConfig,
106) -> tokio::task::JoinHandle<()>
107where
108    A: TuiAgent + Send + Sync + 'static,
109    H: AgentEventHandler,
110{
111    tokio::spawn(async move {
112        let result = run_tui_loop(&*agent, &session, &pending_notes, &handler, &config).await;
113
114        if let Err(e) = result {
115            handler.on_event(AgentTaskEvent::Error(format!("Agent error: {}", e)));
116        }
117        handler.on_event(AgentTaskEvent::Done);
118    })
119}
120
121/// Inner loop logic. Separated for testability.
122async fn run_tui_loop<A, H>(
123    agent: &A,
124    session: &Mutex<Session<A::Msg>>,
125    pending_notes: &Mutex<Vec<String>>,
126    handler: &H,
127    config: &LoopConfig,
128) -> Result<usize, String>
129where
130    A: TuiAgent + Send + Sync,
131    H: AgentEventHandler,
132{
133    let mut detector = LoopDetector::new(config.loop_abort_threshold);
134
135    for step_num in 1..=config.max_steps {
136        // --- Inject pending user notes ---
137        {
138            let notes: Vec<String> = std::mem::take(&mut *pending_notes.lock().await);
139            if !notes.is_empty() {
140                let mut sess = session.lock().await;
141                for note in &notes {
142                    sess.push(
143                        <<A::Msg as AgentMessage>::Role>::user(),
144                        format!("User note while task is running:\n{}", note),
145                    );
146                }
147                handler.on_event(AgentTaskEvent::Warning(format!(
148                    "[NOTE] {} queued note(s) injected",
149                    notes.len()
150                )));
151            }
152        }
153
154        // --- Trim context ---
155        {
156            let mut sess = session.lock().await;
157            let trimmed = sess.trim();
158            if trimmed > 0 {
159                handler.on_event(AgentTaskEvent::Trimmed(trimmed));
160            }
161        }
162
163        if !handler.on_event(AgentTaskEvent::StepStart(step_num)) {
164            return Ok(step_num); // TUI requested stop
165        }
166
167        // --- Snapshot messages for LLM (session UNLOCKED during streaming) ---
168        let messages: Vec<A::Msg> = {
169            let sess = session.lock().await;
170            sess.messages().to_vec()
171        };
172
173        // --- Stream LLM decision ---
174        let decision = agent
175            .decide_stream(&messages, |token| {
176                handler.on_event(AgentTaskEvent::StreamChunk(token.to_string()));
177            })
178            .await
179            .map_err(|e| format!("{}", e))?;
180
181        // --- Process step (session LOCKED during execution) ---
182        let mut sess = session.lock().await;
183
184        // Map LoopEvent to AgentTaskEvent
185        let mut on_event = |event: LoopEvent<'_, A::Action>| {
186            match event {
187                LoopEvent::Decision { situation, task } => {
188                    handler.on_event(AgentTaskEvent::Decision {
189                        situation: situation.to_string(),
190                        task: task.to_vec(),
191                    });
192                }
193                LoopEvent::Completed => {
194                    handler.on_event(AgentTaskEvent::Completed);
195                }
196                LoopEvent::ActionStart(action) => {
197                    if let Some(path) = A::file_modified(action) {
198                        handler.on_event(AgentTaskEvent::FileModified(path));
199                    }
200                    handler.on_event(AgentTaskEvent::ActionStart(A::action_label(action)));
201                }
202                LoopEvent::ActionDone(result) => {
203                    handler.on_event(AgentTaskEvent::ActionDone(result.output.clone()));
204                }
205                LoopEvent::LoopWarning(n) => {
206                    handler.on_event(AgentTaskEvent::Warning(format!(
207                        "Loop detected — {} repeats",
208                        n
209                    )));
210                }
211                LoopEvent::LoopAbort(n) => {
212                    handler.on_event(AgentTaskEvent::Error(format!(
213                        "Agent stuck after {} identical actions — aborting",
214                        n
215                    )));
216                }
217                LoopEvent::Trimmed(n) => {
218                    handler.on_event(AgentTaskEvent::Trimmed(n));
219                }
220                LoopEvent::MaxStepsReached(n) => {
221                    handler.on_event(AgentTaskEvent::Warning(format!(
222                        "Max steps ({}) reached",
223                        n
224                    )));
225                }
226                LoopEvent::StepStart(_) => {}   // handled above
227                LoopEvent::StreamToken(_) => {} // handled above via decide_stream
228            }
229        };
230
231        if let Some(final_step) = process_step(
232            agent,
233            &mut sess,
234            decision,
235            step_num,
236            &mut detector,
237            &mut on_event,
238        )
239        .await
240        .map_err(|e| format!("{}", e))?
241        {
242            return Ok(final_step);
243        }
244    }
245
246    handler.on_event(AgentTaskEvent::Warning(format!(
247        "Max steps ({}) reached",
248        config.max_steps
249    )));
250    Ok(config.max_steps)
251}