Skip to main content

adk_runner/
context.rs

1use adk_core::{
2    Agent, Artifacts, CallbackContext, Content, Event, InvocationContext as InvocationContextTrait,
3    Memory, ReadonlyContext, RunConfig,
4};
5use adk_session::Session as AdkSession;
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock, atomic::AtomicBool};
9
10/// MutableSession wraps a session with shared mutable state.
11///
12/// This mirrors ADK-Go's MutableSession pattern where state changes from
13/// events are immediately visible to all agents sharing the same context.
14/// This is critical for SequentialAgent/LoopAgent patterns where downstream
15/// agents need to read state set by upstream agents via output_key.
16pub struct MutableSession {
17    /// The original session snapshot (for metadata like id, app_name, user_id)
18    inner: Arc<dyn AdkSession>,
19    /// Shared mutable state - updated when events are processed
20    /// This is the key difference from the old SessionAdapter which used immutable snapshots
21    state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
22    /// Accumulated events during this invocation (uses adk_core::Event which is re-exported by adk_session)
23    events: Arc<RwLock<Vec<Event>>>,
24}
25
26impl MutableSession {
27    /// Create a new MutableSession from a session snapshot.
28    /// The state is copied from the session and becomes mutable.
29    pub fn new(session: Arc<dyn AdkSession>) -> Self {
30        // Clone the initial state from the session
31        let initial_state = session.state().all();
32        // Clone the initial events
33        let initial_events = session.events().all();
34
35        Self {
36            inner: session,
37            state: Arc::new(RwLock::new(initial_state)),
38            events: Arc::new(RwLock::new(initial_events)),
39        }
40    }
41
42    /// Apply state delta from an event to the mutable state.
43    /// This is called by the Runner when events are yielded.
44    pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
45        if delta.is_empty() {
46            return;
47        }
48
49        let mut state = self.state.write().unwrap();
50        for (key, value) in delta {
51            // Skip temp: prefixed keys (they shouldn't persist)
52            if !key.starts_with("temp:") {
53                state.insert(key.clone(), value.clone());
54            }
55        }
56    }
57
58    /// Append an event to the session's event list.
59    /// This keeps the in-memory view consistent.
60    pub fn append_event(&self, event: Event) {
61        let mut events = self.events.write().unwrap();
62        events.push(event);
63    }
64
65    /// Get a snapshot of all events in the session.
66    /// Used by the runner for compaction decisions.
67    pub fn events_snapshot(&self) -> Vec<Event> {
68        let events = self.events.read().unwrap();
69        events.clone()
70    }
71}
72
73impl adk_core::Session for MutableSession {
74    fn id(&self) -> &str {
75        self.inner.id()
76    }
77
78    fn app_name(&self) -> &str {
79        self.inner.app_name()
80    }
81
82    fn user_id(&self) -> &str {
83        self.inner.user_id()
84    }
85
86    fn state(&self) -> &dyn adk_core::State {
87        // SAFETY: We implement State for MutableSession, so this cast is valid.
88        // This pattern allows us to return a reference to self as a State trait object.
89        unsafe { &*(self as *const Self as *const dyn adk_core::State) }
90    }
91
92    fn conversation_history(&self) -> Vec<adk_core::Content> {
93        let events = self.events.read().unwrap();
94        let mut history = Vec::new();
95
96        // Find the most recent compaction event — everything before its
97        // end_timestamp has been summarized and should be replaced by the
98        // compacted content.
99        let mut compaction_boundary = None;
100        for event in events.iter().rev() {
101            if let Some(ref compaction) = event.actions.compaction {
102                // Insert the summary as the first history entry
103                history.push(compaction.compacted_content.clone());
104                compaction_boundary = Some(compaction.end_timestamp);
105                break;
106            }
107        }
108
109        for event in events.iter() {
110            // Skip the compaction event itself (author == "system" with compaction data)
111            if event.actions.compaction.is_some() {
112                continue;
113            }
114
115            // Skip events that were already compacted
116            if let Some(boundary) = compaction_boundary {
117                if event.timestamp <= boundary {
118                    continue;
119                }
120            }
121
122            if let Some(content) = &event.llm_response.content {
123                let mut mapped_content = content.clone();
124                mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
125                    ("user", _) => "user",
126                    (_, "function" | "tool") => content.role.as_str(),
127                    _ => "model",
128                }
129                .to_string();
130                history.push(mapped_content);
131            }
132        }
133
134        history
135    }
136}
137
138impl adk_core::State for MutableSession {
139    fn get(&self, key: &str) -> Option<serde_json::Value> {
140        let state = self.state.read().unwrap();
141        state.get(key).cloned()
142    }
143
144    fn set(&mut self, key: String, value: serde_json::Value) {
145        let mut state = self.state.write().unwrap();
146        state.insert(key, value);
147    }
148
149    fn all(&self) -> HashMap<String, serde_json::Value> {
150        let state = self.state.read().unwrap();
151        state.clone()
152    }
153}
154
155pub struct InvocationContext {
156    invocation_id: String,
157    agent: Arc<dyn Agent>,
158    user_id: String,
159    app_name: String,
160    session_id: String,
161    branch: String,
162    user_content: Content,
163    artifacts: Option<Arc<dyn Artifacts>>,
164    memory: Option<Arc<dyn Memory>>,
165    run_config: RunConfig,
166    ended: Arc<AtomicBool>,
167    /// Mutable session that allows state to be updated during execution.
168    /// This is shared across all agents in a workflow, enabling state
169    /// propagation between sequential/parallel agents.
170    session: Arc<MutableSession>,
171}
172
173impl InvocationContext {
174    pub fn new(
175        invocation_id: String,
176        agent: Arc<dyn Agent>,
177        user_id: String,
178        app_name: String,
179        session_id: String,
180        user_content: Content,
181        session: Arc<dyn AdkSession>,
182    ) -> Self {
183        Self {
184            invocation_id,
185            agent,
186            user_id,
187            app_name,
188            session_id,
189            branch: String::new(),
190            user_content,
191            artifacts: None,
192            memory: None,
193            run_config: RunConfig::default(),
194            ended: Arc::new(AtomicBool::new(false)),
195            session: Arc::new(MutableSession::new(session)),
196        }
197    }
198
199    /// Create an InvocationContext with an existing MutableSession.
200    /// This allows sharing the same mutable session across multiple contexts
201    /// (e.g., for agent transfers).
202    pub fn with_mutable_session(
203        invocation_id: String,
204        agent: Arc<dyn Agent>,
205        user_id: String,
206        app_name: String,
207        session_id: String,
208        user_content: Content,
209        session: Arc<MutableSession>,
210    ) -> Self {
211        Self {
212            invocation_id,
213            agent,
214            user_id,
215            app_name,
216            session_id,
217            branch: String::new(),
218            user_content,
219            artifacts: None,
220            memory: None,
221            run_config: RunConfig::default(),
222            ended: Arc::new(AtomicBool::new(false)),
223            session,
224        }
225    }
226
227    pub fn with_branch(mut self, branch: String) -> Self {
228        self.branch = branch;
229        self
230    }
231
232    pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
233        self.artifacts = Some(artifacts);
234        self
235    }
236
237    pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
238        self.memory = Some(memory);
239        self
240    }
241
242    pub fn with_run_config(mut self, config: RunConfig) -> Self {
243        self.run_config = config;
244        self
245    }
246
247    /// Get a reference to the mutable session.
248    /// This allows the Runner to apply state deltas when events are processed.
249    pub fn mutable_session(&self) -> &Arc<MutableSession> {
250        &self.session
251    }
252}
253
254#[async_trait]
255impl ReadonlyContext for InvocationContext {
256    fn invocation_id(&self) -> &str {
257        &self.invocation_id
258    }
259
260    fn agent_name(&self) -> &str {
261        self.agent.name()
262    }
263
264    fn user_id(&self) -> &str {
265        &self.user_id
266    }
267
268    fn app_name(&self) -> &str {
269        &self.app_name
270    }
271
272    fn session_id(&self) -> &str {
273        &self.session_id
274    }
275
276    fn branch(&self) -> &str {
277        &self.branch
278    }
279
280    fn user_content(&self) -> &Content {
281        &self.user_content
282    }
283}
284
285#[async_trait]
286impl CallbackContext for InvocationContext {
287    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
288        self.artifacts.clone()
289    }
290}
291
292#[async_trait]
293impl InvocationContextTrait for InvocationContext {
294    fn agent(&self) -> Arc<dyn Agent> {
295        self.agent.clone()
296    }
297
298    fn memory(&self) -> Option<Arc<dyn Memory>> {
299        self.memory.clone()
300    }
301
302    fn session(&self) -> &dyn adk_core::Session {
303        self.session.as_ref()
304    }
305
306    fn run_config(&self) -> &RunConfig {
307        &self.run_config
308    }
309
310    fn end_invocation(&self) {
311        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
312    }
313
314    fn ended(&self) -> bool {
315        self.ended.load(std::sync::atomic::Ordering::SeqCst)
316    }
317}