Skip to main content

adk_runner/
context.rs

1use adk_core::{
2    AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3    InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4    RequestContext, RunConfig, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11/// MutableSession wraps a session with shared mutable state.
12///
13/// This mirrors ADK-Go's MutableSession pattern where state changes from
14/// events are immediately visible to all agents sharing the same context.
15/// This is critical for SequentialAgent/LoopAgent patterns where downstream
16/// agents need to read state set by upstream agents via output_key.
17pub struct MutableSession {
18    /// The original session snapshot (for metadata like id, app_name, user_id)
19    inner: Arc<dyn AdkSession>,
20    /// Shared mutable state - updated when events are processed
21    /// This is the key difference from the old SessionAdapter which used immutable snapshots
22    state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23    /// Accumulated events during this invocation (uses adk_core::Event which is re-exported by adk_session)
24    events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28    /// Create a new MutableSession from a session snapshot.
29    /// The state is copied from the session and becomes mutable.
30    pub fn new(session: Arc<dyn AdkSession>) -> Self {
31        // Clone the initial state from the session
32        let initial_state = session.state().all();
33        // Clone the initial events
34        let initial_events = session.events().all();
35
36        Self {
37            inner: session,
38            state: Arc::new(RwLock::new(initial_state)),
39            events: Arc::new(RwLock::new(initial_events)),
40        }
41    }
42
43    /// Apply state delta from an event to the mutable state.
44    /// This is called by the Runner when events are yielded.
45    pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46        if delta.is_empty() {
47            return;
48        }
49
50        let mut state = self.state.write().unwrap();
51        for (key, value) in delta {
52            // Skip temp: prefixed keys (they shouldn't persist)
53            if !key.starts_with("temp:") {
54                state.insert(key.clone(), value.clone());
55            }
56        }
57    }
58
59    /// Append an event to the session's event list.
60    /// This keeps the in-memory view consistent.
61    pub fn append_event(&self, event: Event) {
62        let mut events = self.events.write().unwrap();
63        events.push(event);
64    }
65
66    /// Get a snapshot of all events in the session.
67    /// Used by the runner for compaction decisions.
68    pub fn events_snapshot(&self) -> Vec<Event> {
69        let events = self.events.read().unwrap();
70        events.clone()
71    }
72
73    /// Build conversation history, optionally filtered for a specific agent.
74    ///
75    /// When `agent_name` is `Some`, events authored by other agents (not "user",
76    /// not the named agent, and not function/tool responses) are excluded. This
77    /// prevents a transferred sub-agent from seeing the parent's tool calls
78    /// mapped as "model" role, which would cause the LLM to think work is
79    /// already done.
80    ///
81    /// When `agent_name` is `None`, all events are included (backward-compatible).
82    pub fn conversation_history_for_agent_impl(
83        &self,
84        agent_name: Option<&str>,
85    ) -> Vec<adk_core::Content> {
86        let events = self.events.read().unwrap();
87        let mut history = Vec::new();
88
89        // Find the most recent compaction event — everything before its
90        // end_timestamp has been summarized and should be replaced by the
91        // compacted content.
92        let mut compaction_boundary = None;
93        for event in events.iter().rev() {
94            if let Some(ref compaction) = event.actions.compaction {
95                history.push(compaction.compacted_content.clone());
96                compaction_boundary = Some(compaction.end_timestamp);
97                break;
98            }
99        }
100
101        for event in events.iter() {
102            // Skip the compaction event itself
103            if event.actions.compaction.is_some() {
104                continue;
105            }
106
107            // Skip events that were already compacted
108            if let Some(boundary) = compaction_boundary {
109                if event.timestamp <= boundary {
110                    continue;
111                }
112            }
113
114            // When filtering for a specific agent, skip events from other agents.
115            // Keep: user messages and the agent's own events.
116            // Skip: other agents' events entirely (model-role, function calls,
117            // and function/tool responses). This prevents the sub-agent from
118            // seeing orphaned function responses without their preceding calls.
119            if let Some(name) = agent_name {
120                if event.author != "user" && event.author != name {
121                    continue;
122                }
123            }
124
125            if let Some(content) = &event.llm_response.content {
126                let mut mapped_content = content.clone();
127                mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
128                    ("user", _) => "user",
129                    (_, "function" | "tool") => content.role.as_str(),
130                    _ => "model",
131                }
132                .to_string();
133                history.push(mapped_content);
134            }
135        }
136
137        history
138    }
139}
140
141impl adk_core::Session for MutableSession {
142    fn id(&self) -> &str {
143        self.inner.id()
144    }
145
146    fn app_name(&self) -> &str {
147        self.inner.app_name()
148    }
149
150    fn user_id(&self) -> &str {
151        self.inner.user_id()
152    }
153
154    fn state(&self) -> &dyn adk_core::State {
155        // SAFETY: We implement State for MutableSession, so this cast is valid.
156        // This pattern allows us to return a reference to self as a State trait object.
157        unsafe { &*(self as *const Self as *const dyn adk_core::State) }
158    }
159
160    fn conversation_history(&self) -> Vec<adk_core::Content> {
161        self.conversation_history_for_agent_impl(None)
162    }
163
164    fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
165        self.conversation_history_for_agent_impl(Some(agent_name))
166    }
167}
168
169impl adk_core::State for MutableSession {
170    fn get(&self, key: &str) -> Option<serde_json::Value> {
171        let state = self.state.read().unwrap();
172        state.get(key).cloned()
173    }
174
175    fn set(&mut self, key: String, value: serde_json::Value) {
176        let mut state = self.state.write().unwrap();
177        state.insert(key, value);
178    }
179
180    fn all(&self) -> HashMap<String, serde_json::Value> {
181        let state = self.state.read().unwrap();
182        state.clone()
183    }
184}
185
186pub struct InvocationContext {
187    identity: ExecutionIdentity,
188    agent: Arc<dyn Agent>,
189    user_content: Content,
190    artifacts: Option<Arc<dyn Artifacts>>,
191    memory: Option<Arc<dyn Memory>>,
192    run_config: RunConfig,
193    ended: Arc<AtomicBool>,
194    /// Mutable session that allows state to be updated during execution.
195    /// This is shared across all agents in a workflow, enabling state
196    /// propagation between sequential/parallel agents.
197    session: Arc<MutableSession>,
198    /// Optional request context from the server's auth middleware bridge.
199    /// When present, `user_id()` returns `request_context.user_id` and
200    /// `user_scopes()` returns `request_context.scopes`.
201    request_context: Option<RequestContext>,
202}
203
204impl InvocationContext {
205    pub fn new(
206        invocation_id: String,
207        agent: Arc<dyn Agent>,
208        user_id: String,
209        app_name: String,
210        session_id: String,
211        user_content: Content,
212        session: Arc<dyn AdkSession>,
213    ) -> Self {
214        let identity = ExecutionIdentity {
215            adk: AdkIdentity {
216                app_name: AppName::new_unchecked(app_name),
217                user_id: UserId::new_unchecked(user_id),
218                session_id: SessionId::new_unchecked(session_id),
219            },
220            invocation_id: InvocationId::new_unchecked(invocation_id),
221            branch: String::new(),
222            agent_name: agent.name().to_string(),
223        };
224        Self {
225            identity,
226            agent,
227            user_content,
228            artifacts: None,
229            memory: None,
230            run_config: RunConfig::default(),
231            ended: Arc::new(AtomicBool::new(false)),
232            session: Arc::new(MutableSession::new(session)),
233            request_context: None,
234        }
235    }
236
237    /// Create an InvocationContext with an existing MutableSession.
238    /// This allows sharing the same mutable session across multiple contexts
239    /// (e.g., for agent transfers).
240    pub fn with_mutable_session(
241        invocation_id: String,
242        agent: Arc<dyn Agent>,
243        user_id: String,
244        app_name: String,
245        session_id: String,
246        user_content: Content,
247        session: Arc<MutableSession>,
248    ) -> Self {
249        let identity = ExecutionIdentity {
250            adk: AdkIdentity {
251                app_name: AppName::new_unchecked(app_name),
252                user_id: UserId::new_unchecked(user_id),
253                session_id: SessionId::new_unchecked(session_id),
254            },
255            invocation_id: InvocationId::new_unchecked(invocation_id),
256            branch: String::new(),
257            agent_name: agent.name().to_string(),
258        };
259        Self {
260            identity,
261            agent,
262            user_content,
263            artifacts: None,
264            memory: None,
265            run_config: RunConfig::default(),
266            ended: Arc::new(AtomicBool::new(false)),
267            session,
268            request_context: None,
269        }
270    }
271
272    pub fn with_branch(mut self, branch: String) -> Self {
273        self.identity.branch = branch;
274        self
275    }
276
277    pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
278        self.artifacts = Some(artifacts);
279        self
280    }
281
282    pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
283        self.memory = Some(memory);
284        self
285    }
286
287    pub fn with_run_config(mut self, config: RunConfig) -> Self {
288        self.run_config = config;
289        self
290    }
291
292    /// Set the request context from the server's auth middleware bridge.
293    ///
294    /// When set, `user_id()` returns `request_context.user_id` (overriding
295    /// the session-scoped identity), and `user_scopes()` returns
296    /// `request_context.scopes`. This is the explicit authenticated user
297    /// override — `RequestContext` remains separate from `ExecutionIdentity`
298    /// and `AdkIdentity` (it does not carry session or invocation IDs).
299    pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
300        self.request_context = Some(ctx);
301        self
302    }
303
304    /// Get a reference to the mutable session.
305    /// This allows the Runner to apply state deltas when events are processed.
306    pub fn mutable_session(&self) -> &Arc<MutableSession> {
307        &self.session
308    }
309}
310
311#[async_trait]
312impl ReadonlyContext for InvocationContext {
313    fn invocation_id(&self) -> &str {
314        self.identity.invocation_id.as_ref()
315    }
316
317    fn agent_name(&self) -> &str {
318        self.agent.name()
319    }
320
321    fn user_id(&self) -> &str {
322        // Explicit authenticated user override: when a RequestContext is
323        // present (set via with_request_context from the auth middleware
324        // bridge), the authenticated user_id takes precedence over the
325        // session-scoped identity. This keeps auth binding explicit and
326        // ensures the runtime reflects the verified caller identity.
327        self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
328    }
329
330    fn app_name(&self) -> &str {
331        self.identity.adk.app_name.as_ref()
332    }
333
334    fn session_id(&self) -> &str {
335        self.identity.adk.session_id.as_ref()
336    }
337
338    fn branch(&self) -> &str {
339        &self.identity.branch
340    }
341
342    fn user_content(&self) -> &Content {
343        &self.user_content
344    }
345}
346
347#[async_trait]
348impl CallbackContext for InvocationContext {
349    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
350        self.artifacts.clone()
351    }
352}
353
354#[async_trait]
355impl InvocationContextTrait for InvocationContext {
356    fn agent(&self) -> Arc<dyn Agent> {
357        self.agent.clone()
358    }
359
360    fn memory(&self) -> Option<Arc<dyn Memory>> {
361        self.memory.clone()
362    }
363
364    fn session(&self) -> &dyn adk_core::Session {
365        self.session.as_ref()
366    }
367
368    fn run_config(&self) -> &RunConfig {
369        &self.run_config
370    }
371
372    fn end_invocation(&self) {
373        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
374    }
375
376    fn ended(&self) -> bool {
377        self.ended.load(std::sync::atomic::Ordering::SeqCst)
378    }
379
380    fn user_scopes(&self) -> Vec<String> {
381        self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
382    }
383
384    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
385        self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
386            rc.metadata
387                .iter()
388                .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
389                .collect()
390        })
391    }
392}