1use std::collections::VecDeque;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::{Arc, Mutex};
11
12use opi_ai::message::{InputContent, Message, UserMessage};
13use opi_ai::provider::{Provider, ThinkingConfig};
14use tokio_util::sync::CancellationToken;
15
16use crate::event::{AgentEvent, AgentEventSink};
17use crate::hooks::AgentHooks;
18use crate::loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
19use crate::message::AgentMessage;
20use crate::tool::{ExecutionMode, Tool, ToolError, ToolResult};
21
22struct SharedProvider(Arc<dyn Provider>);
25
26impl Provider for SharedProvider {
27 fn id(&self) -> &str {
28 self.0.id()
29 }
30 fn models(&self) -> &[opi_ai::provider::ModelInfo] {
31 self.0.models()
32 }
33 fn stream(&self, request: opi_ai::provider::Request) -> opi_ai::provider::EventStream {
34 self.0.stream(request)
35 }
36}
37
38struct SharedTool(Arc<dyn Tool>);
39
40impl Tool for SharedTool {
41 fn definition(&self) -> opi_ai::message::ToolDef {
42 self.0.definition()
43 }
44
45 fn execute(
46 &self,
47 call_id: &str,
48 arguments: serde_json::Value,
49 signal: CancellationToken,
50 on_update: Option<crate::tool::UpdateCallback>,
51 ) -> Pin<Box<dyn Future<Output = Result<ToolResult, ToolError>> + Send>> {
52 self.0.execute(call_id, arguments, signal, on_update)
53 }
54
55 fn execution_mode(&self) -> ExecutionMode {
56 self.0.execution_mode()
57 }
58}
59
60type EventSubscriber = Box<dyn Fn(&AgentEvent) + Send + Sync>;
63
64#[derive(Clone)]
66pub struct AgentControl {
67 cancel: CancellationToken,
68 steering_queue: Arc<Mutex<VecDeque<String>>>,
69 follow_up_queue: Arc<Mutex<VecDeque<String>>>,
70}
71
72impl AgentControl {
73 pub fn abort(&self) {
75 self.cancel.cancel();
76 }
77
78 pub fn steer(&self, message: String) {
80 self.steering_queue.lock().unwrap().push_back(message);
81 }
82
83 pub fn follow_up(&self, message: String) {
85 self.follow_up_queue.lock().unwrap().push_back(message);
86 }
87}
88
89pub struct Agent {
92 provider: Arc<dyn Provider>,
93 tools: Vec<Arc<dyn Tool>>,
94 model: String,
95 system: Option<String>,
96 config: AgentLoopConfig,
97 hooks: Box<dyn AgentHooks>,
98 cancel: CancellationToken,
99 subscribers: Arc<Mutex<Vec<EventSubscriber>>>,
100 messages: Vec<AgentMessage>,
101 steering_queue: Arc<Mutex<VecDeque<String>>>,
102 follow_up_queue: Arc<Mutex<VecDeque<String>>>,
103}
104
105impl Agent {
106 pub fn new(
108 provider: Box<dyn Provider>,
109 tools: Vec<Box<dyn Tool>>,
110 model: String,
111 system: Option<String>,
112 config: AgentLoopConfig,
113 hooks: Box<dyn AgentHooks>,
114 ) -> Self {
115 Self {
116 provider: Arc::from(provider),
117 tools: tools.into_iter().map(Arc::from).collect(),
118 model,
119 system,
120 config,
121 hooks,
122 cancel: CancellationToken::new(),
123 subscribers: Arc::new(Mutex::new(Vec::new())),
124 messages: Vec::new(),
125 steering_queue: Arc::new(Mutex::new(VecDeque::new())),
126 follow_up_queue: Arc::new(Mutex::new(VecDeque::new())),
127 }
128 }
129
130 pub async fn prompt(
135 &mut self,
136 text: impl Into<String>,
137 ) -> Result<Vec<AgentMessage>, AgentError> {
138 self.maybe_reset_cancel();
139 let token = self.cancel.child_token();
140 self.messages
141 .push(AgentMessage::Llm(Message::User(UserMessage {
142 content: vec![InputContent::Text { text: text.into() }],
143 timestamp_ms: 0,
144 })));
145 self.run_with_token(token).await
146 }
147
148 pub async fn prompt_with_content(
151 &mut self,
152 content: Vec<InputContent>,
153 ) -> Result<Vec<AgentMessage>, AgentError> {
154 self.maybe_reset_cancel();
155 let token = self.cancel.child_token();
156 self.messages
157 .push(AgentMessage::Llm(Message::User(UserMessage {
158 content,
159 timestamp_ms: 0,
160 })));
161 self.run_with_token(token).await
162 }
163
164 pub async fn continue_(
168 &mut self,
169 text: impl Into<String>,
170 ) -> Result<Vec<AgentMessage>, AgentError> {
171 self.maybe_reset_cancel();
172
173 if self.messages.is_empty() {
174 return Err(AgentError::Hook("cannot continue: no messages".into()));
175 }
176
177 let token = self.cancel.child_token();
178 self.messages
179 .push(AgentMessage::Llm(Message::User(UserMessage {
180 content: vec![InputContent::Text { text: text.into() }],
181 timestamp_ms: 0,
182 })));
183 self.run_with_token(token).await
184 }
185
186 pub fn abort(&self) {
191 self.cancel.cancel();
192 }
193
194 pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
196 self.tools.push(Arc::from(tool));
197 }
198
199 pub fn model(&self) -> &str {
201 &self.model
202 }
203
204 pub fn set_model(&mut self, model: String) {
206 self.model = model;
207 }
208
209 pub fn provider(&self) -> &dyn Provider {
211 self.provider.as_ref()
212 }
213
214 pub fn thinking_config(&self) -> ThinkingConfig {
216 self.config.thinking.clone().unwrap_or_default()
217 }
218
219 pub fn set_thinking_config(&mut self, thinking: Option<ThinkingConfig>) {
221 self.config.thinking = thinking;
222 }
223
224 pub fn set_max_tokens(&mut self, max_tokens: Option<u64>) {
226 self.config.max_tokens = max_tokens;
227 }
228
229 pub fn set_initial_messages(&mut self, messages: Vec<AgentMessage>) {
234 self.messages = messages;
235 }
236
237 pub fn inject_message(&mut self, message: AgentMessage) {
242 self.messages.push(message);
243 }
244
245 pub fn replace_messages(&mut self, messages: Vec<AgentMessage>) {
250 self.messages = messages;
251 }
252
253 pub fn emit_event(&self, event: AgentEvent) {
258 let subs = self.subscribers.lock().unwrap();
259 for sub in subs.iter() {
260 sub(&event);
261 }
262 }
263
264 pub fn messages_snapshot(&self) -> Vec<AgentMessage> {
270 self.messages.clone()
271 }
272
273 pub fn subscribe(&mut self, callback: EventSubscriber) {
275 self.subscribers.lock().unwrap().push(callback);
276 }
277
278 pub fn cancel_token(&self) -> CancellationToken {
282 self.cancel.clone()
283 }
284
285 pub fn control_handle(&self) -> AgentControl {
287 AgentControl {
288 cancel: self.cancel.clone(),
289 steering_queue: self.steering_queue.clone(),
290 follow_up_queue: self.follow_up_queue.clone(),
291 }
292 }
293
294 pub fn steer(&self, message: String) {
299 self.steering_queue.lock().unwrap().push_back(message);
300 }
301
302 pub fn follow_up(&self, message: String) {
307 self.follow_up_queue.lock().unwrap().push_back(message);
308 }
309
310 fn maybe_reset_cancel(&mut self) {
313 if self.cancel.is_cancelled() {
314 self.cancel = CancellationToken::new();
315 }
316 }
317
318 pub fn reset_cancel_if_cancelled(&mut self) {
320 self.maybe_reset_cancel();
321 }
322
323 fn build_event_sink(&self) -> AgentEventSink {
324 let subscribers = self.subscribers.clone();
325 Box::new(move |event: AgentEvent| {
326 let subs = subscribers.lock().unwrap();
327 for sub in subs.iter() {
328 sub(&event);
329 }
330 })
331 }
332
333 async fn run_with_token(
334 &mut self,
335 cancel: CancellationToken,
336 ) -> Result<Vec<AgentMessage>, AgentError> {
337 let context = AgentLoopContext {
338 provider: Box::new(SharedProvider(self.provider.clone())),
339 tools: self
340 .tools
341 .iter()
342 .map(|t| Box::new(SharedTool(t.clone())) as Box<dyn Tool>)
343 .collect(),
344 messages: self.messages.clone(),
345 model: self.model.clone(),
346 system: self.system.clone(),
347 steering_queue: Some(self.steering_queue.clone()),
348 follow_up_queue: Some(self.follow_up_queue.clone()),
349 };
350
351 let sink = self.build_event_sink();
352 let result =
353 crate::agent_loop(context, self.config.clone(), &*self.hooks, sink, cancel).await?;
354
355 self.messages = result.clone();
356 Ok(result)
357 }
358}