1use std::collections::VecDeque;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::{Arc, Mutex};
11
12use opi_ai::message::{InputContent, Message, UserMessage};
13use opi_ai::provider::Provider;
14use tokio_util::sync::CancellationToken;
15
16use crate::event::{AgentEvent, AgentEventSink};
17use crate::hooks::AgentHooks;
18use crate::loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
19use crate::message::AgentMessage;
20use crate::tool::{ExecutionMode, Tool, ToolError, ToolResult};
21
22struct SharedProvider(Arc<dyn Provider>);
25
26impl Provider for SharedProvider {
27 fn id(&self) -> &str {
28 self.0.id()
29 }
30 fn models(&self) -> &[opi_ai::provider::ModelInfo] {
31 self.0.models()
32 }
33 fn stream(&self, request: opi_ai::provider::Request) -> opi_ai::provider::EventStream {
34 self.0.stream(request)
35 }
36}
37
38struct SharedTool(Arc<dyn Tool>);
39
40impl Tool for SharedTool {
41 fn definition(&self) -> opi_ai::message::ToolDef {
42 self.0.definition()
43 }
44
45 fn execute(
46 &self,
47 call_id: &str,
48 arguments: serde_json::Value,
49 signal: CancellationToken,
50 on_update: Option<crate::tool::UpdateCallback>,
51 ) -> Pin<Box<dyn Future<Output = Result<ToolResult, ToolError>> + Send>> {
52 self.0.execute(call_id, arguments, signal, on_update)
53 }
54
55 fn execution_mode(&self) -> ExecutionMode {
56 self.0.execution_mode()
57 }
58}
59
60type EventSubscriber = Box<dyn Fn(&AgentEvent) + Send + Sync>;
63
64pub struct Agent {
67 provider: Arc<dyn Provider>,
68 tools: Vec<Arc<dyn Tool>>,
69 model: String,
70 system: Option<String>,
71 config: AgentLoopConfig,
72 hooks: Box<dyn AgentHooks>,
73 cancel: CancellationToken,
74 subscribers: Arc<Mutex<Vec<EventSubscriber>>>,
75 messages: Vec<AgentMessage>,
76 steering_queue: Arc<Mutex<VecDeque<String>>>,
77 follow_up_queue: Arc<Mutex<VecDeque<String>>>,
78}
79
80impl Agent {
81 pub fn new(
83 provider: Box<dyn Provider>,
84 tools: Vec<Box<dyn Tool>>,
85 model: String,
86 system: Option<String>,
87 config: AgentLoopConfig,
88 hooks: Box<dyn AgentHooks>,
89 ) -> Self {
90 Self {
91 provider: Arc::from(provider),
92 tools: tools.into_iter().map(Arc::from).collect(),
93 model,
94 system,
95 config,
96 hooks,
97 cancel: CancellationToken::new(),
98 subscribers: Arc::new(Mutex::new(Vec::new())),
99 messages: Vec::new(),
100 steering_queue: Arc::new(Mutex::new(VecDeque::new())),
101 follow_up_queue: Arc::new(Mutex::new(VecDeque::new())),
102 }
103 }
104
105 pub async fn prompt(
110 &mut self,
111 text: impl Into<String>,
112 ) -> Result<Vec<AgentMessage>, AgentError> {
113 self.maybe_reset_cancel();
114 let token = self.cancel.child_token();
115 self.messages
116 .push(AgentMessage::Llm(Message::User(UserMessage {
117 content: vec![InputContent::Text { text: text.into() }],
118 timestamp_ms: 0,
119 })));
120 self.run_with_token(token).await
121 }
122
123 pub async fn prompt_with_content(
126 &mut self,
127 content: Vec<InputContent>,
128 ) -> Result<Vec<AgentMessage>, AgentError> {
129 self.maybe_reset_cancel();
130 let token = self.cancel.child_token();
131 self.messages
132 .push(AgentMessage::Llm(Message::User(UserMessage {
133 content,
134 timestamp_ms: 0,
135 })));
136 self.run_with_token(token).await
137 }
138
139 pub async fn continue_(
143 &mut self,
144 text: impl Into<String>,
145 ) -> Result<Vec<AgentMessage>, AgentError> {
146 self.maybe_reset_cancel();
147
148 if self.messages.is_empty() {
149 return Err(AgentError::Hook("cannot continue: no messages".into()));
150 }
151
152 let token = self.cancel.child_token();
153 self.messages
154 .push(AgentMessage::Llm(Message::User(UserMessage {
155 content: vec![InputContent::Text { text: text.into() }],
156 timestamp_ms: 0,
157 })));
158 self.run_with_token(token).await
159 }
160
161 pub fn abort(&self) {
166 self.cancel.cancel();
167 }
168
169 pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
171 self.tools.push(Arc::from(tool));
172 }
173
174 pub fn model(&self) -> &str {
176 &self.model
177 }
178
179 pub fn set_model(&mut self, model: String) {
181 self.model = model;
182 }
183
184 pub fn provider(&self) -> &dyn Provider {
186 self.provider.as_ref()
187 }
188
189 pub fn set_initial_messages(&mut self, messages: Vec<AgentMessage>) {
194 self.messages = messages;
195 }
196
197 pub fn inject_message(&mut self, message: AgentMessage) {
202 self.messages.push(message);
203 }
204
205 pub fn replace_messages(&mut self, messages: Vec<AgentMessage>) {
210 self.messages = messages;
211 }
212
213 pub fn emit_event(&self, event: AgentEvent) {
218 let subs = self.subscribers.lock().unwrap();
219 for sub in subs.iter() {
220 sub(&event);
221 }
222 }
223
224 pub fn messages_snapshot(&self) -> Vec<AgentMessage> {
230 self.messages.clone()
231 }
232
233 pub fn subscribe(&mut self, callback: EventSubscriber) {
235 self.subscribers.lock().unwrap().push(callback);
236 }
237
238 pub fn cancel_token(&self) -> CancellationToken {
242 self.cancel.clone()
243 }
244
245 pub fn steer(&self, message: String) {
250 self.steering_queue.lock().unwrap().push_back(message);
251 }
252
253 pub fn follow_up(&self, message: String) {
258 self.follow_up_queue.lock().unwrap().push_back(message);
259 }
260
261 fn maybe_reset_cancel(&mut self) {
264 if self.cancel.is_cancelled() {
265 self.cancel = CancellationToken::new();
266 }
267 }
268
269 fn build_event_sink(&self) -> AgentEventSink {
270 let subscribers = self.subscribers.clone();
271 Box::new(move |event: AgentEvent| {
272 let subs = subscribers.lock().unwrap();
273 for sub in subs.iter() {
274 sub(&event);
275 }
276 })
277 }
278
279 async fn run_with_token(
280 &mut self,
281 cancel: CancellationToken,
282 ) -> Result<Vec<AgentMessage>, AgentError> {
283 let context = AgentLoopContext {
284 provider: Box::new(SharedProvider(self.provider.clone())),
285 tools: self
286 .tools
287 .iter()
288 .map(|t| Box::new(SharedTool(t.clone())) as Box<dyn Tool>)
289 .collect(),
290 messages: self.messages.clone(),
291 model: self.model.clone(),
292 system: self.system.clone(),
293 steering_queue: Some(self.steering_queue.clone()),
294 follow_up_queue: Some(self.follow_up_queue.clone()),
295 };
296
297 let sink = self.build_event_sink();
298 let result =
299 crate::agent_loop(context, self.config.clone(), &*self.hooks, sink, cancel).await?;
300
301 self.messages = result.clone();
302 Ok(result)
303 }
304}