use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use opi_ai::message::{InputContent, Message, UserMessage};
use opi_ai::provider::{Provider, ThinkingConfig};
use tokio_util::sync::CancellationToken;
use crate::event::{AgentEvent, AgentEventSink};
use crate::hooks::AgentHooks;
use crate::loop_types::{AgentError, AgentLoopConfig, AgentLoopContext};
use crate::message::AgentMessage;
use crate::tool::{ExecutionMode, Tool, ToolError, ToolResult};
struct SharedProvider(Arc<dyn Provider>);
impl Provider for SharedProvider {
fn id(&self) -> &str {
self.0.id()
}
fn models(&self) -> &[opi_ai::provider::ModelInfo] {
self.0.models()
}
fn stream(&self, request: opi_ai::provider::Request) -> opi_ai::provider::EventStream {
self.0.stream(request)
}
}
struct SharedTool(Arc<dyn Tool>);
impl Tool for SharedTool {
fn definition(&self) -> opi_ai::message::ToolDef {
self.0.definition()
}
fn execute(
&self,
call_id: &str,
arguments: serde_json::Value,
signal: CancellationToken,
on_update: Option<crate::tool::UpdateCallback>,
) -> Pin<Box<dyn Future<Output = Result<ToolResult, ToolError>> + Send>> {
self.0.execute(call_id, arguments, signal, on_update)
}
fn execution_mode(&self) -> ExecutionMode {
self.0.execution_mode()
}
}
type EventSubscriber = Box<dyn Fn(&AgentEvent) + Send + Sync>;
#[derive(Clone)]
pub struct AgentControl {
cancel: CancellationToken,
steering_queue: Arc<Mutex<VecDeque<String>>>,
follow_up_queue: Arc<Mutex<VecDeque<String>>>,
}
impl AgentControl {
pub fn abort(&self) {
self.cancel.cancel();
}
pub fn steer(&self, message: String) {
self.steering_queue.lock().unwrap().push_back(message);
}
pub fn follow_up(&self, message: String) {
self.follow_up_queue.lock().unwrap().push_back(message);
}
}
pub struct Agent {
provider: Arc<dyn Provider>,
tools: Vec<Arc<dyn Tool>>,
model: String,
system: Option<String>,
config: AgentLoopConfig,
hooks: Box<dyn AgentHooks>,
cancel: CancellationToken,
subscribers: Arc<Mutex<Vec<EventSubscriber>>>,
messages: Vec<AgentMessage>,
steering_queue: Arc<Mutex<VecDeque<String>>>,
follow_up_queue: Arc<Mutex<VecDeque<String>>>,
}
impl Agent {
pub fn new(
provider: Box<dyn Provider>,
tools: Vec<Box<dyn Tool>>,
model: String,
system: Option<String>,
config: AgentLoopConfig,
hooks: Box<dyn AgentHooks>,
) -> Self {
Self {
provider: Arc::from(provider),
tools: tools.into_iter().map(Arc::from).collect(),
model,
system,
config,
hooks,
cancel: CancellationToken::new(),
subscribers: Arc::new(Mutex::new(Vec::new())),
messages: Vec::new(),
steering_queue: Arc::new(Mutex::new(VecDeque::new())),
follow_up_queue: Arc::new(Mutex::new(VecDeque::new())),
}
}
pub async fn prompt(
&mut self,
text: impl Into<String>,
) -> Result<Vec<AgentMessage>, AgentError> {
self.maybe_reset_cancel();
let token = self.cancel.child_token();
self.messages
.push(AgentMessage::Llm(Message::User(UserMessage {
content: vec![InputContent::Text { text: text.into() }],
timestamp_ms: 0,
})));
self.run_with_token(token).await
}
pub async fn prompt_with_content(
&mut self,
content: Vec<InputContent>,
) -> Result<Vec<AgentMessage>, AgentError> {
self.maybe_reset_cancel();
let token = self.cancel.child_token();
self.messages
.push(AgentMessage::Llm(Message::User(UserMessage {
content,
timestamp_ms: 0,
})));
self.run_with_token(token).await
}
pub async fn continue_(
&mut self,
text: impl Into<String>,
) -> Result<Vec<AgentMessage>, AgentError> {
self.maybe_reset_cancel();
if self.messages.is_empty() {
return Err(AgentError::Hook("cannot continue: no messages".into()));
}
let token = self.cancel.child_token();
self.messages
.push(AgentMessage::Llm(Message::User(UserMessage {
content: vec![InputContent::Text { text: text.into() }],
timestamp_ms: 0,
})));
self.run_with_token(token).await
}
pub fn abort(&self) {
self.cancel.cancel();
}
pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
self.tools.push(Arc::from(tool));
}
pub fn model(&self) -> &str {
&self.model
}
pub fn set_model(&mut self, model: String) {
self.model = model;
}
pub fn provider(&self) -> &dyn Provider {
self.provider.as_ref()
}
pub fn thinking_config(&self) -> ThinkingConfig {
self.config.thinking.clone().unwrap_or_default()
}
pub fn set_thinking_config(&mut self, thinking: Option<ThinkingConfig>) {
self.config.thinking = thinking;
}
pub fn set_max_tokens(&mut self, max_tokens: Option<u64>) {
self.config.max_tokens = max_tokens;
}
pub fn set_initial_messages(&mut self, messages: Vec<AgentMessage>) {
self.messages = messages;
}
pub fn inject_message(&mut self, message: AgentMessage) {
self.messages.push(message);
}
pub fn replace_messages(&mut self, messages: Vec<AgentMessage>) {
self.messages = messages;
}
pub fn emit_event(&self, event: AgentEvent) {
let subs = self.subscribers.lock().unwrap();
for sub in subs.iter() {
sub(&event);
}
}
pub fn messages_snapshot(&self) -> Vec<AgentMessage> {
self.messages.clone()
}
pub fn subscribe(&mut self, callback: EventSubscriber) {
self.subscribers.lock().unwrap().push(callback);
}
pub fn cancel_token(&self) -> CancellationToken {
self.cancel.clone()
}
pub fn control_handle(&self) -> AgentControl {
AgentControl {
cancel: self.cancel.clone(),
steering_queue: self.steering_queue.clone(),
follow_up_queue: self.follow_up_queue.clone(),
}
}
pub fn steer(&self, message: String) {
self.steering_queue.lock().unwrap().push_back(message);
}
pub fn follow_up(&self, message: String) {
self.follow_up_queue.lock().unwrap().push_back(message);
}
fn maybe_reset_cancel(&mut self) {
if self.cancel.is_cancelled() {
self.cancel = CancellationToken::new();
}
}
pub fn reset_cancel_if_cancelled(&mut self) {
self.maybe_reset_cancel();
}
fn build_event_sink(&self) -> AgentEventSink {
let subscribers = self.subscribers.clone();
Box::new(move |event: AgentEvent| {
let subs = subscribers.lock().unwrap();
for sub in subs.iter() {
sub(&event);
}
})
}
async fn run_with_token(
&mut self,
cancel: CancellationToken,
) -> Result<Vec<AgentMessage>, AgentError> {
let context = AgentLoopContext {
provider: Box::new(SharedProvider(self.provider.clone())),
tools: self
.tools
.iter()
.map(|t| Box::new(SharedTool(t.clone())) as Box<dyn Tool>)
.collect(),
messages: self.messages.clone(),
model: self.model.clone(),
system: self.system.clone(),
steering_queue: Some(self.steering_queue.clone()),
follow_up_queue: Some(self.follow_up_queue.clone()),
};
let sink = self.build_event_sink();
let result =
crate::agent_loop(context, self.config.clone(), &*self.hooks, sink, cancel).await?;
self.messages = result.clone();
Ok(result)
}
}