Skip to main content

opi_agent/
agent.rs

1//! Stateful Agent wrapper around the agent loop (S8.2).
2//!
3//! Provides `prompt`, `continue_`, `abort`, `subscribe`, `steer`, and
4//! `follow_up` methods, managing conversation state, cancellation, event
5//! subscribers, and message queues.
6
7use std::collections::VecDeque;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::{Arc, Mutex};
11
12use opi_ai::message::{InputContent, Message, UserMessage};
13use opi_ai::provider::Provider;
14use tokio_util::sync::CancellationToken;
15
16use crate::event::{AgentEvent, AgentEventSink};
17use crate::hooks::AgentHooks;
18use crate::loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
19use crate::message::AgentMessage;
20use crate::tool::{ExecutionMode, Tool, ToolError, ToolResult};
21
22// -- Arc wrappers for Provider and Tool reuse across calls ------------------
23
24struct SharedProvider(Arc<dyn Provider>);
25
26impl Provider for SharedProvider {
27    fn id(&self) -> &str {
28        self.0.id()
29    }
30    fn models(&self) -> &[opi_ai::provider::ModelInfo] {
31        self.0.models()
32    }
33    fn stream(&self, request: opi_ai::provider::Request) -> opi_ai::provider::EventStream {
34        self.0.stream(request)
35    }
36}
37
38struct SharedTool(Arc<dyn Tool>);
39
40impl Tool for SharedTool {
41    fn definition(&self) -> opi_ai::message::ToolDef {
42        self.0.definition()
43    }
44
45    fn execute(
46        &self,
47        call_id: &str,
48        arguments: serde_json::Value,
49        signal: CancellationToken,
50        on_update: Option<crate::tool::UpdateCallback>,
51    ) -> Pin<Box<dyn Future<Output = Result<ToolResult, ToolError>> + Send>> {
52        self.0.execute(call_id, arguments, signal, on_update)
53    }
54
55    fn execution_mode(&self) -> ExecutionMode {
56        self.0.execution_mode()
57    }
58}
59
60// -- Agent -------------------------------------------------------------------
61
62type EventSubscriber = Box<dyn Fn(&AgentEvent) + Send + Sync>;
63
64/// Stateful wrapper around `agent_loop` with conversation state, cancellation,
65/// event subscription, and message queue management.
66pub struct Agent {
67    provider: Arc<dyn Provider>,
68    tools: Vec<Arc<dyn Tool>>,
69    model: String,
70    system: Option<String>,
71    config: AgentLoopConfig,
72    hooks: Box<dyn AgentHooks>,
73    cancel: CancellationToken,
74    subscribers: Arc<Mutex<Vec<EventSubscriber>>>,
75    messages: Vec<AgentMessage>,
76    steering_queue: Arc<Mutex<VecDeque<String>>>,
77    follow_up_queue: Arc<Mutex<VecDeque<String>>>,
78}
79
80impl Agent {
81    /// Create a new Agent with the given provider, tools, model, and hooks.
82    pub fn new(
83        provider: Box<dyn Provider>,
84        tools: Vec<Box<dyn Tool>>,
85        model: String,
86        system: Option<String>,
87        config: AgentLoopConfig,
88        hooks: Box<dyn AgentHooks>,
89    ) -> Self {
90        Self {
91            provider: Arc::from(provider),
92            tools: tools.into_iter().map(Arc::from).collect(),
93            model,
94            system,
95            config,
96            hooks,
97            cancel: CancellationToken::new(),
98            subscribers: Arc::new(Mutex::new(Vec::new())),
99            messages: Vec::new(),
100            steering_queue: Arc::new(Mutex::new(VecDeque::new())),
101            follow_up_queue: Arc::new(Mutex::new(VecDeque::new())),
102        }
103    }
104
105    /// Send a user message and run the agent loop.
106    ///
107    /// Resets the cancellation state if the agent was previously aborted,
108    /// allowing a fresh conversation turn.
109    pub async fn prompt(
110        &mut self,
111        text: impl Into<String>,
112    ) -> Result<Vec<AgentMessage>, AgentError> {
113        self.maybe_reset_cancel();
114        let token = self.cancel.child_token();
115        self.messages
116            .push(AgentMessage::Llm(Message::User(UserMessage {
117                content: vec![InputContent::Text { text: text.into() }],
118                timestamp_ms: 0,
119            })));
120        self.run_with_token(token).await
121    }
122
123    /// Continue the conversation with an additional user message.
124    ///
125    /// Requires the last context message to be a user message or tool result.
126    pub async fn continue_(
127        &mut self,
128        text: impl Into<String>,
129    ) -> Result<Vec<AgentMessage>, AgentError> {
130        self.maybe_reset_cancel();
131
132        if self.messages.is_empty() {
133            return Err(AgentError::Hook("cannot continue: no messages".into()));
134        }
135
136        let token = self.cancel.child_token();
137        self.messages
138            .push(AgentMessage::Llm(Message::User(UserMessage {
139                content: vec![InputContent::Text { text: text.into() }],
140                timestamp_ms: 0,
141            })));
142        self.run_with_token(token).await
143    }
144
145    /// Cancel the current operation.
146    ///
147    /// Equivalent to the first Ctrl+C. The running `prompt` or `continue_`
148    /// call will return `AgentError::Cancelled`.
149    pub fn abort(&self) {
150        self.cancel.cancel();
151    }
152
153    /// Add an additional tool to the agent's tool set.
154    pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
155        self.tools.push(Arc::from(tool));
156    }
157
158    /// Register an event subscriber that receives all `AgentEvent`s.
159    pub fn subscribe(&mut self, callback: EventSubscriber) {
160        self.subscribers.lock().unwrap().push(callback);
161    }
162
163    /// Return a clonable cancellation token for external cancellation.
164    ///
165    /// Cancelling this token cancels the currently running loop operation.
166    pub fn cancel_token(&self) -> CancellationToken {
167        self.cancel.clone()
168    }
169
170    /// Add a steering message to be delivered before the next provider request.
171    ///
172    /// Steering messages are high-priority and delivered after the current
173    /// turn's tool calls complete but before the next provider request.
174    pub fn steer(&self, message: String) {
175        self.steering_queue.lock().unwrap().push_back(message);
176    }
177
178    /// Add a follow-up message to be delivered when the agent would otherwise stop.
179    ///
180    /// Follow-up messages are only delivered when the agent has no tool calls
181    /// pending and no steering messages queued.
182    pub fn follow_up(&self, message: String) {
183        self.follow_up_queue.lock().unwrap().push_back(message);
184    }
185
186    // -- Internal helpers ---------------------------------------------------
187
188    fn maybe_reset_cancel(&mut self) {
189        if self.cancel.is_cancelled() {
190            self.cancel = CancellationToken::new();
191        }
192    }
193
194    fn build_event_sink(&self) -> AgentEventSink {
195        let subscribers = self.subscribers.clone();
196        Box::new(move |event: AgentEvent| {
197            let subs = subscribers.lock().unwrap();
198            for sub in subs.iter() {
199                sub(&event);
200            }
201        })
202    }
203
204    async fn run_with_token(
205        &mut self,
206        cancel: CancellationToken,
207    ) -> Result<Vec<AgentMessage>, AgentError> {
208        let context = AgentLoopContext {
209            provider: Box::new(SharedProvider(self.provider.clone())),
210            tools: self
211                .tools
212                .iter()
213                .map(|t| Box::new(SharedTool(t.clone())) as Box<dyn Tool>)
214                .collect(),
215            messages: self.messages.clone(),
216            model: self.model.clone(),
217            system: self.system.clone(),
218            steering_queue: Some(self.steering_queue.clone()),
219            follow_up_queue: Some(self.follow_up_queue.clone()),
220        };
221
222        let sink = self.build_event_sink();
223        let result =
224            crate::agent_loop(context, self.config.clone(), &*self.hooks, sink, cancel).await?;
225
226        self.messages = result.clone();
227        Ok(result)
228    }
229}