use std::sync::Arc;
use async_trait::async_trait;
use crate::Agent;
use crate::advisor::runtime::AdvisorAgent;
pub type AdvisorPromptHook = Arc<dyn Fn(&Agent) + Send + Sync>;
pub struct AgentAdvisor {
agent: Arc<Agent>,
on_prompted: Option<AdvisorPromptHook>,
}
impl AgentAdvisor {
#[must_use]
pub fn new(agent: Arc<Agent>) -> Self {
Self {
agent,
on_prompted: None,
}
}
#[must_use]
pub fn with_post_prompt_hook(agent: Arc<Agent>, hook: AdvisorPromptHook) -> Self {
Self {
agent,
on_prompted: Some(hook),
}
}
#[must_use]
pub fn agent(&self) -> &Agent {
&self.agent
}
#[must_use]
pub fn into_agent(self) -> Arc<Agent> {
self.agent
}
}
#[async_trait]
impl AdvisorAgent for AgentAdvisor {
async fn prompt(&self, input: String) -> Result<(), String> {
self.agent
.continue_with(input)
.await
.map(|_| {
if let Some(hook) = &self.on_prompted {
hook(&self.agent);
}
})
.map_err(|e| e.to_string())
}
fn abort(&self, _reason: &str) {
self.agent.cancel();
}
fn reset(&self) {
self.agent.reset();
}
async fn rollback_to(&self, count: usize) {
self.agent.update_state(|s| s.messages.truncate(count));
}
fn message_count(&self) -> usize {
self.agent.state().messages.len()
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use crate::config::AgentConfig;
use oxi_ai::{Message, Provider};
struct NopProvider;
impl Provider for NopProvider {
fn stream<'a>(
&'a self,
_model: &'a oxi_ai::Model,
_context: &'a oxi_ai::Context,
_options: Option<oxi_ai::StreamOptions>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = oxi_ai::StreamResult> + Send + 'a>>
{
let s: std::pin::Pin<Box<dyn futures::Stream<Item = oxi_ai::ProviderEvent> + Send>> =
Box::pin(futures::stream::empty::<oxi_ai::ProviderEvent>());
Box::pin(async move { Ok(s) })
}
fn name(&self) -> &str {
"nop"
}
}
#[tokio::test]
async fn message_count_tracks_state() {
let provider: Arc<dyn Provider> = Arc::new(NopProvider);
let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
let advisor = AgentAdvisor::new(Arc::clone(&agent));
assert_eq!(advisor.message_count(), 0);
agent.update_state(|s| {
s.messages.push(Message::user("hello"));
s.messages.push(Message::user("world"));
});
assert_eq!(advisor.message_count(), 2);
}
#[tokio::test]
async fn rollback_to_truncates_messages() {
let provider: Arc<dyn Provider> = Arc::new(NopProvider);
let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
let advisor = AgentAdvisor::new(Arc::clone(&agent));
agent.update_state(|s| {
s.messages.push(Message::user("a"));
s.messages.push(Message::user("b"));
s.messages.push(Message::user("c"));
s.messages.push(Message::user("d"));
});
advisor.rollback_to(2).await;
assert_eq!(advisor.message_count(), 2);
assert_eq!(agent.state().messages[0].text_content().unwrap(), "a");
}
#[tokio::test]
async fn reset_clears_state() {
let provider: Arc<dyn Provider> = Arc::new(NopProvider);
let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
let advisor = AgentAdvisor::new(Arc::clone(&agent));
agent.update_state(|s| {
s.messages.push(Message::user("a"));
});
assert_eq!(advisor.message_count(), 1);
advisor.reset();
assert_eq!(advisor.message_count(), 0);
}
#[test]
fn agent_accessor_and_into_agent_round_trip() {
let provider: Arc<dyn Provider> = Arc::new(NopProvider);
let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
let cloned = Arc::clone(&agent);
let advisor = AgentAdvisor::new(cloned);
assert!(std::ptr::eq(advisor.agent(), Arc::as_ref(&agent)));
assert!(Arc::ptr_eq(&advisor.into_agent(), &agent));
}
}