Skip to main content

opi_agent/
agent.rs

1//! Stateful Agent wrapper around the agent loop (S8.2).
2//!
3//! Provides `prompt`, `continue_`, `abort`, `subscribe`, `steer`, and
4//! `follow_up` methods, managing conversation state, cancellation, event
5//! subscribers, and message queues.
6
7use std::collections::VecDeque;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::{Arc, Mutex};
11
12use opi_ai::message::{InputContent, Message, UserMessage};
13use opi_ai::provider::Provider;
14use tokio_util::sync::CancellationToken;
15
16use crate::event::{AgentEvent, AgentEventSink};
17use crate::hooks::AgentHooks;
18use crate::loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
19use crate::message::AgentMessage;
20use crate::tool::{ExecutionMode, Tool, ToolError, ToolResult};
21
22// -- Arc wrappers for Provider and Tool reuse across calls ------------------
23
24struct SharedProvider(Arc<dyn Provider>);
25
26impl Provider for SharedProvider {
27    fn id(&self) -> &str {
28        self.0.id()
29    }
30    fn models(&self) -> &[opi_ai::provider::ModelInfo] {
31        self.0.models()
32    }
33    fn stream(&self, request: opi_ai::provider::Request) -> opi_ai::provider::EventStream {
34        self.0.stream(request)
35    }
36}
37
38struct SharedTool(Arc<dyn Tool>);
39
40impl Tool for SharedTool {
41    fn definition(&self) -> opi_ai::message::ToolDef {
42        self.0.definition()
43    }
44
45    fn execute(
46        &self,
47        call_id: &str,
48        arguments: serde_json::Value,
49        signal: CancellationToken,
50        on_update: Option<crate::tool::UpdateCallback>,
51    ) -> Pin<Box<dyn Future<Output = Result<ToolResult, ToolError>> + Send>> {
52        self.0.execute(call_id, arguments, signal, on_update)
53    }
54
55    fn execution_mode(&self) -> ExecutionMode {
56        self.0.execution_mode()
57    }
58}
59
60// -- Agent -------------------------------------------------------------------
61
62type EventSubscriber = Box<dyn Fn(&AgentEvent) + Send + Sync>;
63
64/// Stateful wrapper around `agent_loop` with conversation state, cancellation,
65/// event subscription, and message queue management.
66pub struct Agent {
67    provider: Arc<dyn Provider>,
68    tools: Vec<Arc<dyn Tool>>,
69    model: String,
70    system: Option<String>,
71    config: AgentLoopConfig,
72    hooks: Box<dyn AgentHooks>,
73    cancel: CancellationToken,
74    subscribers: Arc<Mutex<Vec<EventSubscriber>>>,
75    messages: Vec<AgentMessage>,
76    steering_queue: Arc<Mutex<VecDeque<String>>>,
77    follow_up_queue: Arc<Mutex<VecDeque<String>>>,
78}
79
80impl Agent {
81    /// Create a new Agent with the given provider, tools, model, and hooks.
82    pub fn new(
83        provider: Box<dyn Provider>,
84        tools: Vec<Box<dyn Tool>>,
85        model: String,
86        system: Option<String>,
87        config: AgentLoopConfig,
88        hooks: Box<dyn AgentHooks>,
89    ) -> Self {
90        Self {
91            provider: Arc::from(provider),
92            tools: tools.into_iter().map(Arc::from).collect(),
93            model,
94            system,
95            config,
96            hooks,
97            cancel: CancellationToken::new(),
98            subscribers: Arc::new(Mutex::new(Vec::new())),
99            messages: Vec::new(),
100            steering_queue: Arc::new(Mutex::new(VecDeque::new())),
101            follow_up_queue: Arc::new(Mutex::new(VecDeque::new())),
102        }
103    }
104
105    /// Send a user message and run the agent loop.
106    ///
107    /// Resets the cancellation state if the agent was previously aborted,
108    /// allowing a fresh conversation turn.
109    pub async fn prompt(
110        &mut self,
111        text: impl Into<String>,
112    ) -> Result<Vec<AgentMessage>, AgentError> {
113        self.maybe_reset_cancel();
114        let token = self.cancel.child_token();
115        self.messages
116            .push(AgentMessage::Llm(Message::User(UserMessage {
117                content: vec![InputContent::Text { text: text.into() }],
118                timestamp_ms: 0,
119            })));
120        self.run_with_token(token).await
121    }
122
123    /// Send a user message with arbitrary content (text + images) and run the
124    /// agent loop.
125    pub async fn prompt_with_content(
126        &mut self,
127        content: Vec<InputContent>,
128    ) -> Result<Vec<AgentMessage>, AgentError> {
129        self.maybe_reset_cancel();
130        let token = self.cancel.child_token();
131        self.messages
132            .push(AgentMessage::Llm(Message::User(UserMessage {
133                content,
134                timestamp_ms: 0,
135            })));
136        self.run_with_token(token).await
137    }
138
139    /// Continue the conversation with an additional user message.
140    ///
141    /// Requires the last context message to be a user message or tool result.
142    pub async fn continue_(
143        &mut self,
144        text: impl Into<String>,
145    ) -> Result<Vec<AgentMessage>, AgentError> {
146        self.maybe_reset_cancel();
147
148        if self.messages.is_empty() {
149            return Err(AgentError::Hook("cannot continue: no messages".into()));
150        }
151
152        let token = self.cancel.child_token();
153        self.messages
154            .push(AgentMessage::Llm(Message::User(UserMessage {
155                content: vec![InputContent::Text { text: text.into() }],
156                timestamp_ms: 0,
157            })));
158        self.run_with_token(token).await
159    }
160
161    /// Cancel the current operation.
162    ///
163    /// Equivalent to the first Ctrl+C. The running `prompt` or `continue_`
164    /// call will return `AgentError::Cancelled`.
165    pub fn abort(&self) {
166        self.cancel.cancel();
167    }
168
169    /// Add an additional tool to the agent's tool set.
170    pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
171        self.tools.push(Arc::from(tool));
172    }
173
174    /// Return the active model spec.
175    pub fn model(&self) -> &str {
176        &self.model
177    }
178
179    /// Change the model used by subsequent provider requests.
180    pub fn set_model(&mut self, model: String) {
181        self.model = model;
182    }
183
184    /// Return the underlying provider metadata.
185    pub fn provider(&self) -> &dyn Provider {
186        self.provider.as_ref()
187    }
188
189    /// Set the initial conversation messages (for session resume).
190    ///
191    /// Must be called before `prompt` or `continue_`. Replaces any
192    /// existing messages in the agent's internal buffer.
193    pub fn set_initial_messages(&mut self, messages: Vec<AgentMessage>) {
194        self.messages = messages;
195    }
196
197    /// Inject a single message into the conversation buffer.
198    ///
199    /// Used after compaction to insert a `CompactionSummary` so subsequent
200    /// provider calls include the summary in their context window.
201    pub fn inject_message(&mut self, message: AgentMessage) {
202        self.messages.push(message);
203    }
204
205    /// Replace the entire conversation buffer.
206    ///
207    /// Used after compaction to install `[summary, ...kept]` so subsequent
208    /// provider requests no longer carry the compacted messages.
209    pub fn replace_messages(&mut self, messages: Vec<AgentMessage>) {
210        self.messages = messages;
211    }
212
213    /// Emit an `AgentEvent` to all subscribers outside of the agent loop.
214    ///
215    /// Used by callers (e.g. harness) to surface lifecycle events that occur
216    /// between loop invocations, such as compaction start/end.
217    pub fn emit_event(&self, event: AgentEvent) {
218        let subs = self.subscribers.lock().unwrap();
219        for sub in subs.iter() {
220            sub(&event);
221        }
222    }
223
224    /// Snapshot the current conversation buffer.
225    ///
226    /// The harness uses this after a turn (and any subsequent compaction) to
227    /// compute the next `turn_offset` and return the post-compaction message
228    /// list to callers.
229    pub fn messages_snapshot(&self) -> Vec<AgentMessage> {
230        self.messages.clone()
231    }
232
233    /// Register an event subscriber that receives all `AgentEvent`s.
234    pub fn subscribe(&mut self, callback: EventSubscriber) {
235        self.subscribers.lock().unwrap().push(callback);
236    }
237
238    /// Return a clonable cancellation token for external cancellation.
239    ///
240    /// Cancelling this token cancels the currently running loop operation.
241    pub fn cancel_token(&self) -> CancellationToken {
242        self.cancel.clone()
243    }
244
245    /// Add a steering message to be delivered before the next provider request.
246    ///
247    /// Steering messages are high-priority and delivered after the current
248    /// turn's tool calls complete but before the next provider request.
249    pub fn steer(&self, message: String) {
250        self.steering_queue.lock().unwrap().push_back(message);
251    }
252
253    /// Add a follow-up message to be delivered when the agent would otherwise stop.
254    ///
255    /// Follow-up messages are only delivered when the agent has no tool calls
256    /// pending and no steering messages queued.
257    pub fn follow_up(&self, message: String) {
258        self.follow_up_queue.lock().unwrap().push_back(message);
259    }
260
261    // -- Internal helpers ---------------------------------------------------
262
263    fn maybe_reset_cancel(&mut self) {
264        if self.cancel.is_cancelled() {
265            self.cancel = CancellationToken::new();
266        }
267    }
268
269    fn build_event_sink(&self) -> AgentEventSink {
270        let subscribers = self.subscribers.clone();
271        Box::new(move |event: AgentEvent| {
272            let subs = subscribers.lock().unwrap();
273            for sub in subs.iter() {
274                sub(&event);
275            }
276        })
277    }
278
279    async fn run_with_token(
280        &mut self,
281        cancel: CancellationToken,
282    ) -> Result<Vec<AgentMessage>, AgentError> {
283        let context = AgentLoopContext {
284            provider: Box::new(SharedProvider(self.provider.clone())),
285            tools: self
286                .tools
287                .iter()
288                .map(|t| Box::new(SharedTool(t.clone())) as Box<dyn Tool>)
289                .collect(),
290            messages: self.messages.clone(),
291            model: self.model.clone(),
292            system: self.system.clone(),
293            steering_queue: Some(self.steering_queue.clone()),
294            follow_up_queue: Some(self.follow_up_queue.clone()),
295        };
296
297        let sink = self.build_event_sink();
298        let result =
299            crate::agent_loop(context, self.config.clone(), &*self.hooks, sink, cancel).await?;
300
301        self.messages = result.clone();
302        Ok(result)
303    }
304}