adk_runner/
context.rs

1use adk_core::{
2    Agent, Artifacts, CallbackContext, Content, Event, InvocationContext as InvocationContextTrait,
3    Memory, ReadonlyContext, RunConfig,
4};
5use adk_session::Session as AdkSession;
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::{atomic::AtomicBool, Arc, RwLock};
9
10/// MutableSession wraps a session with shared mutable state.
11///
12/// This mirrors ADK-Go's MutableSession pattern where state changes from
13/// events are immediately visible to all agents sharing the same context.
14/// This is critical for SequentialAgent/LoopAgent patterns where downstream
15/// agents need to read state set by upstream agents via output_key.
16pub struct MutableSession {
17    /// The original session snapshot (for metadata like id, app_name, user_id)
18    inner: Arc<dyn AdkSession>,
19    /// Shared mutable state - updated when events are processed
20    /// This is the key difference from the old SessionAdapter which used immutable snapshots
21    state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
22    /// Accumulated events during this invocation (uses adk_core::Event which is re-exported by adk_session)
23    events: Arc<RwLock<Vec<Event>>>,
24}
25
26impl MutableSession {
27    /// Create a new MutableSession from a session snapshot.
28    /// The state is copied from the session and becomes mutable.
29    pub fn new(session: Arc<dyn AdkSession>) -> Self {
30        // Clone the initial state from the session
31        let initial_state = session.state().all();
32        // Clone the initial events
33        let initial_events = session.events().all();
34
35        Self {
36            inner: session,
37            state: Arc::new(RwLock::new(initial_state)),
38            events: Arc::new(RwLock::new(initial_events)),
39        }
40    }
41
42    /// Apply state delta from an event to the mutable state.
43    /// This is called by the Runner when events are yielded.
44    pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
45        if delta.is_empty() {
46            return;
47        }
48
49        let mut state = self.state.write().unwrap();
50        for (key, value) in delta {
51            // Skip temp: prefixed keys (they shouldn't persist)
52            if !key.starts_with("temp:") {
53                state.insert(key.clone(), value.clone());
54            }
55        }
56    }
57
58    /// Append an event to the session's event list.
59    /// This keeps the in-memory view consistent.
60    pub fn append_event(&self, event: Event) {
61        let mut events = self.events.write().unwrap();
62        events.push(event);
63    }
64}
65
66impl adk_core::Session for MutableSession {
67    fn id(&self) -> &str {
68        self.inner.id()
69    }
70
71    fn app_name(&self) -> &str {
72        self.inner.app_name()
73    }
74
75    fn user_id(&self) -> &str {
76        self.inner.user_id()
77    }
78
79    fn state(&self) -> &dyn adk_core::State {
80        // SAFETY: We implement State for MutableSession, so this cast is valid.
81        // This pattern allows us to return a reference to self as a State trait object.
82        unsafe { &*(self as *const Self as *const dyn adk_core::State) }
83    }
84
85    fn conversation_history(&self) -> Vec<adk_core::Content> {
86        let events = self.events.read().unwrap();
87        let mut history = Vec::new();
88
89        for event in events.iter() {
90            if let Some(content) = &event.llm_response.content {
91                let role = match event.author.as_str() {
92                    "user" => "user".to_string(),
93                    _ => "model".to_string(),
94                };
95
96                let mut mapped_content = content.clone();
97                mapped_content.role = role;
98                history.push(mapped_content);
99            }
100        }
101
102        history
103    }
104}
105
106impl adk_core::State for MutableSession {
107    fn get(&self, key: &str) -> Option<serde_json::Value> {
108        let state = self.state.read().unwrap();
109        state.get(key).cloned()
110    }
111
112    fn set(&mut self, key: String, value: serde_json::Value) {
113        let mut state = self.state.write().unwrap();
114        state.insert(key, value);
115    }
116
117    fn all(&self) -> HashMap<String, serde_json::Value> {
118        let state = self.state.read().unwrap();
119        state.clone()
120    }
121}
122
123pub struct InvocationContext {
124    invocation_id: String,
125    agent: Arc<dyn Agent>,
126    user_id: String,
127    app_name: String,
128    session_id: String,
129    branch: String,
130    user_content: Content,
131    artifacts: Option<Arc<dyn Artifacts>>,
132    memory: Option<Arc<dyn Memory>>,
133    run_config: RunConfig,
134    ended: Arc<AtomicBool>,
135    /// Mutable session that allows state to be updated during execution.
136    /// This is shared across all agents in a workflow, enabling state
137    /// propagation between sequential/parallel agents.
138    session: Arc<MutableSession>,
139}
140
141impl InvocationContext {
142    pub fn new(
143        invocation_id: String,
144        agent: Arc<dyn Agent>,
145        user_id: String,
146        app_name: String,
147        session_id: String,
148        user_content: Content,
149        session: Arc<dyn AdkSession>,
150    ) -> Self {
151        Self {
152            invocation_id,
153            agent,
154            user_id,
155            app_name,
156            session_id,
157            branch: String::new(),
158            user_content,
159            artifacts: None,
160            memory: None,
161            run_config: RunConfig::default(),
162            ended: Arc::new(AtomicBool::new(false)),
163            session: Arc::new(MutableSession::new(session)),
164        }
165    }
166
167    /// Create an InvocationContext with an existing MutableSession.
168    /// This allows sharing the same mutable session across multiple contexts
169    /// (e.g., for agent transfers).
170    pub fn with_mutable_session(
171        invocation_id: String,
172        agent: Arc<dyn Agent>,
173        user_id: String,
174        app_name: String,
175        session_id: String,
176        user_content: Content,
177        session: Arc<MutableSession>,
178    ) -> Self {
179        Self {
180            invocation_id,
181            agent,
182            user_id,
183            app_name,
184            session_id,
185            branch: String::new(),
186            user_content,
187            artifacts: None,
188            memory: None,
189            run_config: RunConfig::default(),
190            ended: Arc::new(AtomicBool::new(false)),
191            session,
192        }
193    }
194
195    pub fn with_branch(mut self, branch: String) -> Self {
196        self.branch = branch;
197        self
198    }
199
200    pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
201        self.artifacts = Some(artifacts);
202        self
203    }
204
205    pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
206        self.memory = Some(memory);
207        self
208    }
209
210    pub fn with_run_config(mut self, config: RunConfig) -> Self {
211        self.run_config = config;
212        self
213    }
214
215    /// Get a reference to the mutable session.
216    /// This allows the Runner to apply state deltas when events are processed.
217    pub fn mutable_session(&self) -> &Arc<MutableSession> {
218        &self.session
219    }
220}
221
222#[async_trait]
223impl ReadonlyContext for InvocationContext {
224    fn invocation_id(&self) -> &str {
225        &self.invocation_id
226    }
227
228    fn agent_name(&self) -> &str {
229        self.agent.name()
230    }
231
232    fn user_id(&self) -> &str {
233        &self.user_id
234    }
235
236    fn app_name(&self) -> &str {
237        &self.app_name
238    }
239
240    fn session_id(&self) -> &str {
241        &self.session_id
242    }
243
244    fn branch(&self) -> &str {
245        &self.branch
246    }
247
248    fn user_content(&self) -> &Content {
249        &self.user_content
250    }
251}
252
253#[async_trait]
254impl CallbackContext for InvocationContext {
255    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
256        self.artifacts.clone()
257    }
258}
259
260#[async_trait]
261impl InvocationContextTrait for InvocationContext {
262    fn agent(&self) -> Arc<dyn Agent> {
263        self.agent.clone()
264    }
265
266    fn memory(&self) -> Option<Arc<dyn Memory>> {
267        self.memory.clone()
268    }
269
270    fn session(&self) -> &dyn adk_core::Session {
271        self.session.as_ref()
272    }
273
274    fn run_config(&self) -> &RunConfig {
275        &self.run_config
276    }
277
278    fn end_invocation(&self) {
279        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
280    }
281
282    fn ended(&self) -> bool {
283        self.ended.load(std::sync::atomic::Ordering::SeqCst)
284    }
285}