oxi_agent/advisor/
agent_advisor.rs1use std::sync::Arc;
18
19use async_trait::async_trait;
20
21use crate::Agent;
22use crate::advisor::runtime::AdvisorAgent;
23
24pub type AdvisorPromptHook = Arc<dyn Fn(&Agent) + Send + Sync>;
27
28pub struct AgentAdvisor {
30 agent: Arc<Agent>,
31 on_prompted: Option<AdvisorPromptHook>,
32}
33
34impl AgentAdvisor {
35 #[must_use]
37 pub fn new(agent: Arc<Agent>) -> Self {
38 Self {
39 agent,
40 on_prompted: None,
41 }
42 }
43
44 #[must_use]
48 pub fn with_post_prompt_hook(agent: Arc<Agent>, hook: AdvisorPromptHook) -> Self {
49 Self {
50 agent,
51 on_prompted: Some(hook),
52 }
53 }
54
55 #[must_use]
57 pub fn agent(&self) -> &Agent {
58 &self.agent
59 }
60
61 #[must_use]
63 pub fn into_agent(self) -> Arc<Agent> {
64 self.agent
65 }
66}
67
68#[async_trait]
69impl AdvisorAgent for AgentAdvisor {
70 async fn prompt(&self, input: String) -> Result<(), String> {
71 self.agent
76 .continue_with(input)
77 .await
78 .map(|_| {
79 if let Some(hook) = &self.on_prompted {
80 hook(&self.agent);
81 }
82 })
83 .map_err(|e| e.to_string())
84 }
85
86 fn abort(&self, _reason: &str) {
87 self.agent.cancel();
90 }
91
92 fn reset(&self) {
93 self.agent.reset();
94 }
95
96 async fn rollback_to(&self, count: usize) {
97 self.agent.update_state(|s| s.messages.truncate(count));
98 }
99
100 fn message_count(&self) -> usize {
101 self.agent.state().messages.len()
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 #![allow(clippy::unwrap_used)]
108 use super::*;
109 use crate::config::AgentConfig;
110 use oxi_ai::{Message, Provider};
111
112 struct NopProvider;
115 impl Provider for NopProvider {
116 fn stream<'a>(
117 &'a self,
118 _model: &'a oxi_ai::Model,
119 _context: &'a oxi_ai::Context,
120 _options: Option<oxi_ai::StreamOptions>,
121 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = oxi_ai::StreamResult> + Send + 'a>>
122 {
123 let s: std::pin::Pin<Box<dyn futures::Stream<Item = oxi_ai::ProviderEvent> + Send>> =
126 Box::pin(futures::stream::empty::<oxi_ai::ProviderEvent>());
127 Box::pin(async move { Ok(s) })
128 }
129 fn name(&self) -> &str {
130 "nop"
131 }
132 }
133
134 #[tokio::test]
135 async fn message_count_tracks_state() {
136 let provider: Arc<dyn Provider> = Arc::new(NopProvider);
137 let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
138 let advisor = AgentAdvisor::new(Arc::clone(&agent));
139
140 assert_eq!(advisor.message_count(), 0);
141 agent.update_state(|s| {
143 s.messages.push(Message::user("hello"));
144 s.messages.push(Message::user("world"));
145 });
146 assert_eq!(advisor.message_count(), 2);
147 }
148
149 #[tokio::test]
150 async fn rollback_to_truncates_messages() {
151 let provider: Arc<dyn Provider> = Arc::new(NopProvider);
152 let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
153 let advisor = AgentAdvisor::new(Arc::clone(&agent));
154 agent.update_state(|s| {
155 s.messages.push(Message::user("a"));
156 s.messages.push(Message::user("b"));
157 s.messages.push(Message::user("c"));
158 s.messages.push(Message::user("d"));
159 });
160 advisor.rollback_to(2).await;
161 assert_eq!(advisor.message_count(), 2);
162 assert_eq!(agent.state().messages[0].text_content().unwrap(), "a");
163 }
164
165 #[tokio::test]
166 async fn reset_clears_state() {
167 let provider: Arc<dyn Provider> = Arc::new(NopProvider);
168 let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
169 let advisor = AgentAdvisor::new(Arc::clone(&agent));
170 agent.update_state(|s| {
171 s.messages.push(Message::user("a"));
172 });
173 assert_eq!(advisor.message_count(), 1);
174 advisor.reset();
175 assert_eq!(advisor.message_count(), 0);
176 }
177
178 #[test]
179 fn agent_accessor_and_into_agent_round_trip() {
180 let provider: Arc<dyn Provider> = Arc::new(NopProvider);
181 let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
182 let cloned = Arc::clone(&agent);
183 let advisor = AgentAdvisor::new(cloned);
184 assert!(std::ptr::eq(advisor.agent(), Arc::as_ref(&agent)));
186 assert!(Arc::ptr_eq(&advisor.into_agent(), &agent));
187 }
188}