Skip to main content

oxi_agent/advisor/
agent_advisor.rs

1//! Concrete [`AdvisorAgent`] adapter over [`crate::Agent`].
2//!
3//! The SDK exposes the [`AdvisorAgent`] trait (so tests/SDK consumers can
4//! hand-roll a fake), but every real consumer that drives an `oxi_agent::Agent`
5//! repeats the same 5-method mapping. This adapter removes that boilerplate —
6//! the host wiring and SDK consumers both use [`AgentAdvisor::new`].
7//!
8//! Method mapping (verified against `Agent`):
9//! - `prompt` → `Agent::continue_with` (preserves advisor context across turns;
10//!   `run_with_channel` auto-resets the cancel flag, so a prior `abort` does
11//!   not strand the next prompt)
12//! - `abort` → `Agent::cancel`
13//! - `reset` → `Agent::reset`
14//! - `rollback_to` → `Agent::update_state(|s| s.messages.truncate(count))`
15//! - `message_count` → `Agent::state().messages.len()`
16
17use std::sync::Arc;
18
19use async_trait::async_trait;
20
21use crate::Agent;
22use crate::advisor::runtime::AdvisorAgent;
23
24/// Hook fired after each successful advisor prompt, with the advisor agent.
25/// The host uses it to persist advisor turns to `<session>/__advisor.jsonl`.
26pub type AdvisorPromptHook = Arc<dyn Fn(&Agent) + Send + Sync>;
27
28/// An [`AdvisorAgent`] backed by a concrete [`Agent`].
29pub struct AgentAdvisor {
30    agent: Arc<Agent>,
31    on_prompted: Option<AdvisorPromptHook>,
32}
33
34impl AgentAdvisor {
35    /// Wrap an `Agent` as an advisor-driving [`AdvisorAgent`].
36    #[must_use]
37    pub fn new(agent: Arc<Agent>) -> Self {
38        Self {
39            agent,
40            on_prompted: None,
41        }
42    }
43
44    /// Wrap an `Agent` with a hook fired after each successful advisor prompt.
45    /// Used by the host to persist advisor turns to `<session>/__advisor.jsonl`
46    /// for stats attribution / observability.
47    #[must_use]
48    pub fn with_post_prompt_hook(agent: Arc<Agent>, hook: AdvisorPromptHook) -> Self {
49        Self {
50            agent,
51            on_prompted: Some(hook),
52        }
53    }
54
55    /// Access the underlying agent (for transcript-recorder / event wiring).
56    #[must_use]
57    pub fn agent(&self) -> &Agent {
58        &self.agent
59    }
60
61    /// Clone the underlying agent handle (cheap — `Arc` clone).
62    #[must_use]
63    pub fn into_agent(self) -> Arc<Agent> {
64        self.agent
65    }
66}
67
68#[async_trait]
69impl AdvisorAgent for AgentAdvisor {
70    async fn prompt(&self, input: String) -> Result<(), String> {
71        // continue_with preserves the advisor's conversation state across
72        // turns (appends rather than resetting). The Response + events are
73        // discarded — the host wires a separate event subscription for the
74        // transcript recorder if it wants advisor-turn observability.
75        self.agent
76            .continue_with(input)
77            .await
78            .map(|_| {
79                if let Some(hook) = &self.on_prompted {
80                    hook(&self.agent);
81                }
82            })
83            .map_err(|e| e.to_string())
84    }
85
86    fn abort(&self, _reason: &str) {
87        // Sets the cancel flag; the next prompt() -> run_with_channel resets
88        // it, so abort does not strand subsequent turns.
89        self.agent.cancel();
90    }
91
92    fn reset(&self) {
93        self.agent.reset();
94    }
95
96    async fn rollback_to(&self, count: usize) {
97        self.agent.update_state(|s| s.messages.truncate(count));
98    }
99
100    fn message_count(&self) -> usize {
101        self.agent.state().messages.len()
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    #![allow(clippy::unwrap_used)]
108    use super::*;
109    use crate::config::AgentConfig;
110    use oxi_ai::{Message, Provider};
111
112    /// A provider that never actually streams — we only exercise the
113    /// non-network methods (reset/rollback_to/message_count) here.
114    struct NopProvider;
115    impl Provider for NopProvider {
116        fn stream<'a>(
117            &'a self,
118            _model: &'a oxi_ai::Model,
119            _context: &'a oxi_ai::Context,
120            _options: Option<oxi_ai::StreamOptions>,
121        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = oxi_ai::StreamResult> + Send + 'a>>
122        {
123            // We never call prompt() in these tests; this provider only
124            // satisfies Agent construction. Returns an empty stream.
125            let s: std::pin::Pin<Box<dyn futures::Stream<Item = oxi_ai::ProviderEvent> + Send>> =
126                Box::pin(futures::stream::empty::<oxi_ai::ProviderEvent>());
127            Box::pin(async move { Ok(s) })
128        }
129        fn name(&self) -> &str {
130            "nop"
131        }
132    }
133
134    #[tokio::test]
135    async fn message_count_tracks_state() {
136        let provider: Arc<dyn Provider> = Arc::new(NopProvider);
137        let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
138        let advisor = AgentAdvisor::new(Arc::clone(&agent));
139
140        assert_eq!(advisor.message_count(), 0);
141        // Mutate state directly and confirm message_count reflects it.
142        agent.update_state(|s| {
143            s.messages.push(Message::user("hello"));
144            s.messages.push(Message::user("world"));
145        });
146        assert_eq!(advisor.message_count(), 2);
147    }
148
149    #[tokio::test]
150    async fn rollback_to_truncates_messages() {
151        let provider: Arc<dyn Provider> = Arc::new(NopProvider);
152        let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
153        let advisor = AgentAdvisor::new(Arc::clone(&agent));
154        agent.update_state(|s| {
155            s.messages.push(Message::user("a"));
156            s.messages.push(Message::user("b"));
157            s.messages.push(Message::user("c"));
158            s.messages.push(Message::user("d"));
159        });
160        advisor.rollback_to(2).await;
161        assert_eq!(advisor.message_count(), 2);
162        assert_eq!(agent.state().messages[0].text_content().unwrap(), "a");
163    }
164
165    #[tokio::test]
166    async fn reset_clears_state() {
167        let provider: Arc<dyn Provider> = Arc::new(NopProvider);
168        let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
169        let advisor = AgentAdvisor::new(Arc::clone(&agent));
170        agent.update_state(|s| {
171            s.messages.push(Message::user("a"));
172        });
173        assert_eq!(advisor.message_count(), 1);
174        advisor.reset();
175        assert_eq!(advisor.message_count(), 0);
176    }
177
178    #[test]
179    fn agent_accessor_and_into_agent_round_trip() {
180        let provider: Arc<dyn Provider> = Arc::new(NopProvider);
181        let agent = Arc::new(Agent::new_empty(provider, AgentConfig::default()));
182        let cloned = Arc::clone(&agent);
183        let advisor = AgentAdvisor::new(cloned);
184        // accessor + into_agent both hand back the same underlying Arc.
185        assert!(std::ptr::eq(advisor.agent(), Arc::as_ref(&agent)));
186        assert!(Arc::ptr_eq(&advisor.into_agent(), &agent));
187    }
188}