Skip to main content

adk_runner/
context.rs

1use adk_core::{
2    AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3    InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4    RequestContext, RunConfig, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11/// MutableSession wraps a session with shared mutable state.
12///
13/// This mirrors ADK-Go's MutableSession pattern where state changes from
14/// events are immediately visible to all agents sharing the same context.
15/// This is critical for SequentialAgent/LoopAgent patterns where downstream
16/// agents need to read state set by upstream agents via output_key.
17pub struct MutableSession {
18    /// The original session snapshot (for metadata like id, app_name, user_id)
19    inner: Arc<dyn AdkSession>,
20    /// Shared mutable state - updated when events are processed
21    /// This is the key difference from the old SessionAdapter which used immutable snapshots
22    state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23    /// Accumulated events during this invocation (uses adk_core::Event which is re-exported by adk_session)
24    events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28    /// Create a new MutableSession from a session snapshot.
29    /// The state is copied from the session and becomes mutable.
30    pub fn new(session: Arc<dyn AdkSession>) -> Self {
31        // Clone the initial state from the session
32        let initial_state = session.state().all();
33        // Clone the initial events
34        let initial_events = session.events().all();
35
36        Self {
37            inner: session,
38            state: Arc::new(RwLock::new(initial_state)),
39            events: Arc::new(RwLock::new(initial_events)),
40        }
41    }
42
43    /// Apply state delta from an event to the mutable state.
44    /// This is called by the Runner when events are yielded.
45    pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46        if delta.is_empty() {
47            return;
48        }
49
50        let Ok(mut state) = self.state.write() else {
51            tracing::error!("state RwLock poisoned in apply_state_delta — skipping delta");
52            return;
53        };
54        for (key, value) in delta {
55            // Skip temp: prefixed keys (they shouldn't persist)
56            if !key.starts_with("temp:") {
57                state.insert(key.clone(), value.clone());
58            }
59        }
60    }
61
62    /// Append an event to the session's event list.
63    /// This keeps the in-memory view consistent.
64    pub fn append_event(&self, event: Event) {
65        let Ok(mut events) = self.events.write() else {
66            tracing::error!("events RwLock poisoned in append_event — event dropped");
67            return;
68        };
69        events.push(event);
70    }
71
72    /// Get a snapshot of all events in the session.
73    /// Used by the runner for compaction decisions.
74    pub fn events_snapshot(&self) -> Vec<Event> {
75        let Ok(events) = self.events.read() else {
76            tracing::error!("events RwLock poisoned in events_snapshot — returning empty");
77            return Vec::new();
78        };
79        events.clone()
80    }
81
82    /// Return the number of accumulated events without cloning the full list.
83    pub fn events_len(&self) -> usize {
84        let Ok(events) = self.events.read() else {
85            tracing::error!("events RwLock poisoned in events_len — returning 0");
86            return 0;
87        };
88        events.len()
89    }
90
91    /// Build conversation history, optionally filtered for a specific agent.
92    ///
93    /// When `agent_name` is `Some`, events authored by other agents (not "user",
94    /// not the named agent, and not function/tool responses) are excluded. This
95    /// prevents a transferred sub-agent from seeing the parent's tool calls
96    /// mapped as "model" role, which would cause the LLM to think work is
97    /// already done.
98    ///
99    /// When `agent_name` is `None`, all events are included (backward-compatible).
100    pub fn conversation_history_for_agent_impl(
101        &self,
102        agent_name: Option<&str>,
103    ) -> Vec<adk_core::Content> {
104        let Ok(events) = self.events.read() else {
105            tracing::error!("events RwLock poisoned in conversation_history — returning empty");
106            return Vec::new();
107        };
108        let mut history = Vec::new();
109
110        // Find the most recent compaction event — everything before its
111        // end_timestamp has been summarized and should be replaced by the
112        // compacted content.
113        let mut compaction_boundary = None;
114        for event in events.iter().rev() {
115            if let Some(ref compaction) = event.actions.compaction {
116                history.push(compaction.compacted_content.clone());
117                compaction_boundary = Some(compaction.end_timestamp);
118                break;
119            }
120        }
121
122        for event in events.iter() {
123            // Skip the compaction event itself
124            if event.actions.compaction.is_some() {
125                continue;
126            }
127
128            // Skip events that were already compacted
129            if let Some(boundary) = compaction_boundary {
130                if event.timestamp <= boundary {
131                    continue;
132                }
133            }
134
135            // When filtering for a specific agent, skip events from other agents.
136            // Keep: user messages and the agent's own events.
137            // Skip: other agents' events entirely (model-role, function calls,
138            // and function/tool responses). This prevents the sub-agent from
139            // seeing orphaned function responses without their preceding calls.
140            if let Some(name) = agent_name {
141                if event.author != "user" && event.author != name {
142                    continue;
143                }
144            }
145
146            if let Some(content) = &event.llm_response.content {
147                let mut mapped_content = content.clone();
148                mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
149                    ("user", _) => "user",
150                    (_, "function" | "tool") => content.role.as_str(),
151                    _ => "model",
152                }
153                .to_string();
154                history.push(mapped_content);
155            }
156        }
157
158        history
159    }
160}
161
162impl adk_core::Session for MutableSession {
163    fn id(&self) -> &str {
164        self.inner.id()
165    }
166
167    fn app_name(&self) -> &str {
168        self.inner.app_name()
169    }
170
171    fn user_id(&self) -> &str {
172        self.inner.user_id()
173    }
174
175    fn state(&self) -> &dyn adk_core::State {
176        self
177    }
178
179    fn conversation_history(&self) -> Vec<adk_core::Content> {
180        self.conversation_history_for_agent_impl(None)
181    }
182
183    fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
184        self.conversation_history_for_agent_impl(Some(agent_name))
185    }
186}
187
188impl adk_core::State for MutableSession {
189    fn get(&self, key: &str) -> Option<serde_json::Value> {
190        let Ok(state) = self.state.read() else {
191            tracing::error!("state RwLock poisoned in State::get — returning None");
192            return None;
193        };
194        state.get(key).cloned()
195    }
196
197    fn set(&mut self, key: String, value: serde_json::Value) {
198        if let Err(msg) = adk_core::validate_state_key(&key) {
199            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
200            return;
201        }
202        let Ok(mut state) = self.state.write() else {
203            tracing::error!("state RwLock poisoned in State::set — value dropped");
204            return;
205        };
206        state.insert(key, value);
207    }
208
209    fn all(&self) -> HashMap<String, serde_json::Value> {
210        let Ok(state) = self.state.read() else {
211            tracing::error!("state RwLock poisoned in State::all — returning empty");
212            return HashMap::new();
213        };
214        state.clone()
215    }
216}
217
218pub struct InvocationContext {
219    identity: ExecutionIdentity,
220    agent: Arc<dyn Agent>,
221    user_content: Content,
222    artifacts: Option<Arc<dyn Artifacts>>,
223    memory: Option<Arc<dyn Memory>>,
224    run_config: RunConfig,
225    ended: Arc<AtomicBool>,
226    /// Mutable session that allows state to be updated during execution.
227    /// This is shared across all agents in a workflow, enabling state
228    /// propagation between sequential/parallel agents.
229    session: Arc<MutableSession>,
230    /// Optional request context from the server's auth middleware bridge.
231    /// When present, `user_id()` returns `request_context.user_id` and
232    /// `user_scopes()` returns `request_context.scopes`.
233    request_context: Option<RequestContext>,
234}
235
236impl InvocationContext {
237    /// Create a new invocation context from validated typed identifiers.
238    pub fn new_typed(
239        invocation_id: String,
240        agent: Arc<dyn Agent>,
241        user_id: UserId,
242        app_name: AppName,
243        session_id: SessionId,
244        user_content: Content,
245        session: Arc<dyn AdkSession>,
246    ) -> adk_core::Result<Self> {
247        let identity = ExecutionIdentity {
248            adk: AdkIdentity { app_name, user_id, session_id },
249            invocation_id: InvocationId::try_from(invocation_id)?,
250            branch: String::new(),
251            agent_name: agent.name().to_string(),
252        };
253        Ok(Self {
254            identity,
255            agent,
256            user_content,
257            artifacts: None,
258            memory: None,
259            run_config: RunConfig::default(),
260            ended: Arc::new(AtomicBool::new(false)),
261            session: Arc::new(MutableSession::new(session)),
262            request_context: None,
263        })
264    }
265
266    pub fn new(
267        invocation_id: String,
268        agent: Arc<dyn Agent>,
269        user_id: String,
270        app_name: String,
271        session_id: String,
272        user_content: Content,
273        session: Arc<dyn AdkSession>,
274    ) -> adk_core::Result<Self> {
275        Self::new_typed(
276            invocation_id,
277            agent,
278            UserId::try_from(user_id)?,
279            AppName::try_from(app_name)?,
280            SessionId::try_from(session_id)?,
281            user_content,
282            session,
283        )
284    }
285
286    /// Create an invocation context that reuses an existing mutable session and
287    /// validated typed identifiers.
288    pub fn with_mutable_session_typed(
289        invocation_id: String,
290        agent: Arc<dyn Agent>,
291        user_id: UserId,
292        app_name: AppName,
293        session_id: SessionId,
294        user_content: Content,
295        session: Arc<MutableSession>,
296    ) -> adk_core::Result<Self> {
297        let identity = ExecutionIdentity {
298            adk: AdkIdentity { app_name, user_id, session_id },
299            invocation_id: InvocationId::try_from(invocation_id)?,
300            branch: String::new(),
301            agent_name: agent.name().to_string(),
302        };
303        Ok(Self {
304            identity,
305            agent,
306            user_content,
307            artifacts: None,
308            memory: None,
309            run_config: RunConfig::default(),
310            ended: Arc::new(AtomicBool::new(false)),
311            session,
312            request_context: None,
313        })
314    }
315
316    /// Create an InvocationContext with an existing MutableSession.
317    /// This allows sharing the same mutable session across multiple contexts
318    /// (e.g., for agent transfers).
319    pub fn with_mutable_session(
320        invocation_id: String,
321        agent: Arc<dyn Agent>,
322        user_id: String,
323        app_name: String,
324        session_id: String,
325        user_content: Content,
326        session: Arc<MutableSession>,
327    ) -> adk_core::Result<Self> {
328        Self::with_mutable_session_typed(
329            invocation_id,
330            agent,
331            UserId::try_from(user_id)?,
332            AppName::try_from(app_name)?,
333            SessionId::try_from(session_id)?,
334            user_content,
335            session,
336        )
337    }
338
339    pub fn with_branch(mut self, branch: String) -> Self {
340        self.identity.branch = branch;
341        self
342    }
343
344    pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
345        self.artifacts = Some(artifacts);
346        self
347    }
348
349    pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
350        self.memory = Some(memory);
351        self
352    }
353
354    pub fn with_run_config(mut self, config: RunConfig) -> Self {
355        self.run_config = config;
356        self
357    }
358
359    /// Set the request context from the server's auth middleware bridge.
360    ///
361    /// When set, `user_id()` returns `request_context.user_id` (overriding
362    /// the session-scoped identity), and `user_scopes()` returns
363    /// `request_context.scopes`. This is the explicit authenticated user
364    /// override — `RequestContext` remains separate from `ExecutionIdentity`
365    /// and `AdkIdentity` (it does not carry session or invocation IDs).
366    pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
367        self.request_context = Some(ctx);
368        self
369    }
370
371    /// Get a reference to the mutable session.
372    /// This allows the Runner to apply state deltas when events are processed.
373    pub fn mutable_session(&self) -> &Arc<MutableSession> {
374        &self.session
375    }
376}
377
378#[async_trait]
379impl ReadonlyContext for InvocationContext {
380    fn invocation_id(&self) -> &str {
381        self.identity.invocation_id.as_ref()
382    }
383
384    fn agent_name(&self) -> &str {
385        self.agent.name()
386    }
387
388    fn user_id(&self) -> &str {
389        // Explicit authenticated user override: when a RequestContext is
390        // present (set via with_request_context from the auth middleware
391        // bridge), the authenticated user_id takes precedence over the
392        // session-scoped identity. This keeps auth binding explicit and
393        // ensures the runtime reflects the verified caller identity.
394        self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
395    }
396
397    fn app_name(&self) -> &str {
398        self.identity.adk.app_name.as_ref()
399    }
400
401    fn session_id(&self) -> &str {
402        self.identity.adk.session_id.as_ref()
403    }
404
405    fn branch(&self) -> &str {
406        &self.identity.branch
407    }
408
409    fn user_content(&self) -> &Content {
410        &self.user_content
411    }
412}
413
414#[async_trait]
415impl CallbackContext for InvocationContext {
416    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
417        self.artifacts.clone()
418    }
419}
420
421#[async_trait]
422impl InvocationContextTrait for InvocationContext {
423    fn agent(&self) -> Arc<dyn Agent> {
424        self.agent.clone()
425    }
426
427    fn memory(&self) -> Option<Arc<dyn Memory>> {
428        self.memory.clone()
429    }
430
431    fn session(&self) -> &dyn adk_core::Session {
432        self.session.as_ref()
433    }
434
435    fn run_config(&self) -> &RunConfig {
436        &self.run_config
437    }
438
439    fn end_invocation(&self) {
440        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
441    }
442
443    fn ended(&self) -> bool {
444        self.ended.load(std::sync::atomic::Ordering::SeqCst)
445    }
446
447    fn user_scopes(&self) -> Vec<String> {
448        self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
449    }
450
451    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
452        self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
453            rc.metadata
454                .iter()
455                .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
456                .collect()
457        })
458    }
459}