1use std::collections::VecDeque;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::{Arc, Mutex};
11
12use opi_ai::message::{InputContent, Message, UserMessage};
13use opi_ai::provider::Provider;
14use tokio_util::sync::CancellationToken;
15
16use crate::event::{AgentEvent, AgentEventSink};
17use crate::hooks::AgentHooks;
18use crate::loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
19use crate::message::AgentMessage;
20use crate::tool::{ExecutionMode, Tool, ToolError, ToolResult};
21
22struct SharedProvider(Arc<dyn Provider>);
25
26impl Provider for SharedProvider {
27 fn id(&self) -> &str {
28 self.0.id()
29 }
30 fn models(&self) -> &[opi_ai::provider::ModelInfo] {
31 self.0.models()
32 }
33 fn stream(&self, request: opi_ai::provider::Request) -> opi_ai::provider::EventStream {
34 self.0.stream(request)
35 }
36}
37
38struct SharedTool(Arc<dyn Tool>);
39
40impl Tool for SharedTool {
41 fn definition(&self) -> opi_ai::message::ToolDef {
42 self.0.definition()
43 }
44
45 fn execute(
46 &self,
47 call_id: &str,
48 arguments: serde_json::Value,
49 signal: CancellationToken,
50 on_update: Option<crate::tool::UpdateCallback>,
51 ) -> Pin<Box<dyn Future<Output = Result<ToolResult, ToolError>> + Send>> {
52 self.0.execute(call_id, arguments, signal, on_update)
53 }
54
55 fn execution_mode(&self) -> ExecutionMode {
56 self.0.execution_mode()
57 }
58}
59
60type EventSubscriber = Box<dyn Fn(&AgentEvent) + Send + Sync>;
63
64pub struct Agent {
67 provider: Arc<dyn Provider>,
68 tools: Vec<Arc<dyn Tool>>,
69 model: String,
70 system: Option<String>,
71 config: AgentLoopConfig,
72 hooks: Box<dyn AgentHooks>,
73 cancel: CancellationToken,
74 subscribers: Arc<Mutex<Vec<EventSubscriber>>>,
75 messages: Vec<AgentMessage>,
76 steering_queue: Arc<Mutex<VecDeque<String>>>,
77 follow_up_queue: Arc<Mutex<VecDeque<String>>>,
78}
79
80impl Agent {
81 pub fn new(
83 provider: Box<dyn Provider>,
84 tools: Vec<Box<dyn Tool>>,
85 model: String,
86 system: Option<String>,
87 config: AgentLoopConfig,
88 hooks: Box<dyn AgentHooks>,
89 ) -> Self {
90 Self {
91 provider: Arc::from(provider),
92 tools: tools.into_iter().map(Arc::from).collect(),
93 model,
94 system,
95 config,
96 hooks,
97 cancel: CancellationToken::new(),
98 subscribers: Arc::new(Mutex::new(Vec::new())),
99 messages: Vec::new(),
100 steering_queue: Arc::new(Mutex::new(VecDeque::new())),
101 follow_up_queue: Arc::new(Mutex::new(VecDeque::new())),
102 }
103 }
104
105 pub async fn prompt(
110 &mut self,
111 text: impl Into<String>,
112 ) -> Result<Vec<AgentMessage>, AgentError> {
113 self.maybe_reset_cancel();
114 let token = self.cancel.child_token();
115 self.messages
116 .push(AgentMessage::Llm(Message::User(UserMessage {
117 content: vec![InputContent::Text { text: text.into() }],
118 timestamp_ms: 0,
119 })));
120 self.run_with_token(token).await
121 }
122
123 pub async fn continue_(
127 &mut self,
128 text: impl Into<String>,
129 ) -> Result<Vec<AgentMessage>, AgentError> {
130 self.maybe_reset_cancel();
131
132 if self.messages.is_empty() {
133 return Err(AgentError::Hook("cannot continue: no messages".into()));
134 }
135
136 let token = self.cancel.child_token();
137 self.messages
138 .push(AgentMessage::Llm(Message::User(UserMessage {
139 content: vec![InputContent::Text { text: text.into() }],
140 timestamp_ms: 0,
141 })));
142 self.run_with_token(token).await
143 }
144
145 pub fn abort(&self) {
150 self.cancel.cancel();
151 }
152
153 pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
155 self.tools.push(Arc::from(tool));
156 }
157
158 pub fn subscribe(&mut self, callback: EventSubscriber) {
160 self.subscribers.lock().unwrap().push(callback);
161 }
162
163 pub fn cancel_token(&self) -> CancellationToken {
167 self.cancel.clone()
168 }
169
170 pub fn steer(&self, message: String) {
175 self.steering_queue.lock().unwrap().push_back(message);
176 }
177
178 pub fn follow_up(&self, message: String) {
183 self.follow_up_queue.lock().unwrap().push_back(message);
184 }
185
186 fn maybe_reset_cancel(&mut self) {
189 if self.cancel.is_cancelled() {
190 self.cancel = CancellationToken::new();
191 }
192 }
193
194 fn build_event_sink(&self) -> AgentEventSink {
195 let subscribers = self.subscribers.clone();
196 Box::new(move |event: AgentEvent| {
197 let subs = subscribers.lock().unwrap();
198 for sub in subs.iter() {
199 sub(&event);
200 }
201 })
202 }
203
204 async fn run_with_token(
205 &mut self,
206 cancel: CancellationToken,
207 ) -> Result<Vec<AgentMessage>, AgentError> {
208 let context = AgentLoopContext {
209 provider: Box::new(SharedProvider(self.provider.clone())),
210 tools: self
211 .tools
212 .iter()
213 .map(|t| Box::new(SharedTool(t.clone())) as Box<dyn Tool>)
214 .collect(),
215 messages: self.messages.clone(),
216 model: self.model.clone(),
217 system: self.system.clone(),
218 steering_queue: Some(self.steering_queue.clone()),
219 follow_up_queue: Some(self.follow_up_queue.clone()),
220 };
221
222 let sink = self.build_event_sink();
223 let result =
224 crate::agent_loop(context, self.config.clone(), &*self.hooks, sink, cancel).await?;
225
226 self.messages = result.clone();
227 Ok(result)
228 }
229}