Skip to main content

adk_runner/
context.rs

1use adk_core::{
2    AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3    InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4    RequestContext, RunConfig, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11/// MutableSession wraps a session with shared mutable state.
12///
13/// This mirrors ADK-Go's MutableSession pattern where state changes from
14/// events are immediately visible to all agents sharing the same context.
15/// This is critical for SequentialAgent/LoopAgent patterns where downstream
16/// agents need to read state set by upstream agents via output_key.
17pub struct MutableSession {
18    /// The original session snapshot (for metadata like id, app_name, user_id)
19    inner: Arc<dyn AdkSession>,
20    /// Shared mutable state - updated when events are processed
21    /// This is the key difference from the old SessionAdapter which used immutable snapshots
22    state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23    /// Accumulated events during this invocation (uses adk_core::Event which is re-exported by adk_session)
24    events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28    /// Create a new MutableSession from a session snapshot.
29    /// The state is copied from the session and becomes mutable.
30    pub fn new(session: Arc<dyn AdkSession>) -> Self {
31        // Clone the initial state from the session
32        let initial_state = session.state().all();
33        // Clone the initial events
34        let initial_events = session.events().all();
35
36        Self {
37            inner: session,
38            state: Arc::new(RwLock::new(initial_state)),
39            events: Arc::new(RwLock::new(initial_events)),
40        }
41    }
42
43    /// Apply state delta from an event to the mutable state.
44    /// This is called by the Runner when events are yielded.
45    pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46        if delta.is_empty() {
47            return;
48        }
49
50        let Ok(mut state) = self.state.write() else {
51            tracing::error!("state RwLock poisoned in apply_state_delta — skipping delta");
52            return;
53        };
54        for (key, value) in delta {
55            // Skip temp: prefixed keys (they shouldn't persist)
56            if !key.starts_with("temp:") {
57                state.insert(key.clone(), value.clone());
58            }
59        }
60    }
61
62    /// Append an event to the session's event list.
63    /// This keeps the in-memory view consistent.
64    pub fn append_event(&self, event: Event) {
65        let Ok(mut events) = self.events.write() else {
66            tracing::error!("events RwLock poisoned in append_event — event dropped");
67            return;
68        };
69        events.push(event);
70    }
71
72    /// Get a snapshot of all events in the session.
73    /// Used by the runner for compaction decisions.
74    pub fn events_snapshot(&self) -> Vec<Event> {
75        let Ok(events) = self.events.read() else {
76            tracing::error!("events RwLock poisoned in events_snapshot — returning empty");
77            return Vec::new();
78        };
79        events.clone()
80    }
81
82    /// Return the number of accumulated events without cloning the full list.
83    pub fn events_len(&self) -> usize {
84        let Ok(events) = self.events.read() else {
85            tracing::error!("events RwLock poisoned in events_len — returning 0");
86            return 0;
87        };
88        events.len()
89    }
90
91    /// Build conversation history, optionally filtered for a specific agent.
92    ///
93    /// When `agent_name` is `Some`, events authored by other agents (not "user",
94    /// not the named agent, and not function/tool responses) are excluded. This
95    /// prevents a transferred sub-agent from seeing the parent's tool calls
96    /// mapped as "model" role, which would cause the LLM to think work is
97    /// already done.
98    ///
99    /// When `agent_name` is `None`, all events are included (backward-compatible).
100    pub fn conversation_history_for_agent_impl(
101        &self,
102        agent_name: Option<&str>,
103    ) -> Vec<adk_core::Content> {
104        let Ok(events) = self.events.read() else {
105            tracing::error!("events RwLock poisoned in conversation_history — returning empty");
106            return Vec::new();
107        };
108        let mut history = Vec::new();
109
110        // Find the most recent compaction event — everything before its
111        // end_timestamp has been summarized and should be replaced by the
112        // compacted content.
113        let mut compaction_boundary = None;
114        for event in events.iter().rev() {
115            if let Some(ref compaction) = event.actions.compaction {
116                history.push(compaction.compacted_content.clone());
117                compaction_boundary = Some(compaction.end_timestamp);
118                break;
119            }
120        }
121
122        for event in events.iter() {
123            // Skip the compaction event itself
124            if event.actions.compaction.is_some() {
125                continue;
126            }
127
128            // Skip events that were already compacted
129            if let Some(boundary) = compaction_boundary {
130                if event.timestamp <= boundary {
131                    continue;
132                }
133            }
134
135            // When filtering for a specific agent, skip events from other agents.
136            // Keep: user messages and the agent's own events.
137            // Skip: other agents' events entirely (model-role, function calls,
138            // and function/tool responses). This prevents the sub-agent from
139            // seeing orphaned function responses without their preceding calls.
140            if let Some(name) = agent_name {
141                if event.author != "user" && event.author != name {
142                    continue;
143                }
144            }
145
146            if let Some(content) = &event.llm_response.content {
147                let mut mapped_content = content.clone();
148                mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
149                    ("user", _) => "user",
150                    (_, "function" | "tool") => content.role.as_str(),
151                    _ => "model",
152                }
153                .to_string();
154                history.push(mapped_content);
155            }
156        }
157
158        history
159    }
160}
161
162impl adk_core::Session for MutableSession {
163    fn id(&self) -> &str {
164        self.inner.id()
165    }
166
167    fn app_name(&self) -> &str {
168        self.inner.app_name()
169    }
170
171    fn user_id(&self) -> &str {
172        self.inner.user_id()
173    }
174
175    fn state(&self) -> &dyn adk_core::State {
176        self
177    }
178
179    fn conversation_history(&self) -> Vec<adk_core::Content> {
180        self.conversation_history_for_agent_impl(None)
181    }
182
183    fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
184        self.conversation_history_for_agent_impl(Some(agent_name))
185    }
186}
187
188impl adk_core::State for MutableSession {
189    fn get(&self, key: &str) -> Option<serde_json::Value> {
190        let Ok(state) = self.state.read() else {
191            tracing::error!("state RwLock poisoned in State::get — returning None");
192            return None;
193        };
194        state.get(key).cloned()
195    }
196
197    fn set(&mut self, key: String, value: serde_json::Value) {
198        if let Err(msg) = adk_core::validate_state_key(&key) {
199            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
200            return;
201        }
202        let Ok(mut state) = self.state.write() else {
203            tracing::error!("state RwLock poisoned in State::set — value dropped");
204            return;
205        };
206        state.insert(key, value);
207    }
208
209    fn all(&self) -> HashMap<String, serde_json::Value> {
210        let Ok(state) = self.state.read() else {
211            tracing::error!("state RwLock poisoned in State::all — returning empty");
212            return HashMap::new();
213        };
214        state.clone()
215    }
216}
217
218pub struct InvocationContext {
219    identity: ExecutionIdentity,
220    agent: Arc<dyn Agent>,
221    user_content: Content,
222    artifacts: Option<Arc<dyn Artifacts>>,
223    memory: Option<Arc<dyn Memory>>,
224    run_config: RunConfig,
225    ended: Arc<AtomicBool>,
226    /// Mutable session that allows state to be updated during execution.
227    /// This is shared across all agents in a workflow, enabling state
228    /// propagation between sequential/parallel agents.
229    session: Arc<MutableSession>,
230    /// Optional request context from the server's auth middleware bridge.
231    /// When present, `user_id()` returns `request_context.user_id` and
232    /// `user_scopes()` returns `request_context.scopes`.
233    request_context: Option<RequestContext>,
234    /// Optional shared state for parallel agent coordination.
235    shared_state: Option<Arc<adk_core::SharedState>>,
236}
237
238impl InvocationContext {
239    /// Create a new invocation context from validated typed identifiers.
240    pub fn new_typed(
241        invocation_id: String,
242        agent: Arc<dyn Agent>,
243        user_id: UserId,
244        app_name: AppName,
245        session_id: SessionId,
246        user_content: Content,
247        session: Arc<dyn AdkSession>,
248    ) -> adk_core::Result<Self> {
249        let identity = ExecutionIdentity {
250            adk: AdkIdentity { app_name, user_id, session_id },
251            invocation_id: InvocationId::try_from(invocation_id)?,
252            branch: String::new(),
253            agent_name: agent.name().to_string(),
254        };
255        Ok(Self {
256            identity,
257            agent,
258            user_content,
259            artifacts: None,
260            memory: None,
261            run_config: RunConfig::default(),
262            ended: Arc::new(AtomicBool::new(false)),
263            session: Arc::new(MutableSession::new(session)),
264            request_context: None,
265            shared_state: None,
266        })
267    }
268
269    pub fn new(
270        invocation_id: String,
271        agent: Arc<dyn Agent>,
272        user_id: String,
273        app_name: String,
274        session_id: String,
275        user_content: Content,
276        session: Arc<dyn AdkSession>,
277    ) -> adk_core::Result<Self> {
278        Self::new_typed(
279            invocation_id,
280            agent,
281            UserId::try_from(user_id)?,
282            AppName::try_from(app_name)?,
283            SessionId::try_from(session_id)?,
284            user_content,
285            session,
286        )
287    }
288
289    /// Create an invocation context that reuses an existing mutable session and
290    /// validated typed identifiers.
291    pub fn with_mutable_session_typed(
292        invocation_id: String,
293        agent: Arc<dyn Agent>,
294        user_id: UserId,
295        app_name: AppName,
296        session_id: SessionId,
297        user_content: Content,
298        session: Arc<MutableSession>,
299    ) -> adk_core::Result<Self> {
300        let identity = ExecutionIdentity {
301            adk: AdkIdentity { app_name, user_id, session_id },
302            invocation_id: InvocationId::try_from(invocation_id)?,
303            branch: String::new(),
304            agent_name: agent.name().to_string(),
305        };
306        Ok(Self {
307            identity,
308            agent,
309            user_content,
310            artifacts: None,
311            memory: None,
312            run_config: RunConfig::default(),
313            ended: Arc::new(AtomicBool::new(false)),
314            session,
315            request_context: None,
316            shared_state: None,
317        })
318    }
319
320    /// Create an InvocationContext with an existing MutableSession.
321    /// This allows sharing the same mutable session across multiple contexts
322    /// (e.g., for agent transfers).
323    pub fn with_mutable_session(
324        invocation_id: String,
325        agent: Arc<dyn Agent>,
326        user_id: String,
327        app_name: String,
328        session_id: String,
329        user_content: Content,
330        session: Arc<MutableSession>,
331    ) -> adk_core::Result<Self> {
332        Self::with_mutable_session_typed(
333            invocation_id,
334            agent,
335            UserId::try_from(user_id)?,
336            AppName::try_from(app_name)?,
337            SessionId::try_from(session_id)?,
338            user_content,
339            session,
340        )
341    }
342
343    pub fn with_branch(mut self, branch: String) -> Self {
344        self.identity.branch = branch;
345        self
346    }
347
348    pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
349        self.artifacts = Some(artifacts);
350        self
351    }
352
353    pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
354        self.memory = Some(memory);
355        self
356    }
357
358    pub fn with_run_config(mut self, config: RunConfig) -> Self {
359        self.run_config = config;
360        self
361    }
362
363    /// Set the request context from the server's auth middleware bridge.
364    ///
365    /// When set, `user_id()` returns `request_context.user_id` (overriding
366    /// the session-scoped identity), and `user_scopes()` returns
367    /// `request_context.scopes`. This is the explicit authenticated user
368    /// override — `RequestContext` remains separate from `ExecutionIdentity`
369    /// and `AdkIdentity` (it does not carry session or invocation IDs).
370    pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
371        self.request_context = Some(ctx);
372        self
373    }
374
375    /// Set the shared state for parallel agent coordination.
376    pub fn with_shared_state(mut self, shared: Arc<adk_core::SharedState>) -> Self {
377        self.shared_state = Some(shared);
378        self
379    }
380
381    /// Get a reference to the mutable session.
382    /// This allows the Runner to apply state deltas when events are processed.
383    pub fn mutable_session(&self) -> &Arc<MutableSession> {
384        &self.session
385    }
386}
387
388#[async_trait]
389impl ReadonlyContext for InvocationContext {
390    fn invocation_id(&self) -> &str {
391        self.identity.invocation_id.as_ref()
392    }
393
394    fn agent_name(&self) -> &str {
395        self.agent.name()
396    }
397
398    fn user_id(&self) -> &str {
399        // Explicit authenticated user override: when a RequestContext is
400        // present (set via with_request_context from the auth middleware
401        // bridge), the authenticated user_id takes precedence over the
402        // session-scoped identity. This keeps auth binding explicit and
403        // ensures the runtime reflects the verified caller identity.
404        self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
405    }
406
407    fn app_name(&self) -> &str {
408        self.identity.adk.app_name.as_ref()
409    }
410
411    fn session_id(&self) -> &str {
412        self.identity.adk.session_id.as_ref()
413    }
414
415    fn branch(&self) -> &str {
416        &self.identity.branch
417    }
418
419    fn user_content(&self) -> &Content {
420        &self.user_content
421    }
422}
423
424#[async_trait]
425impl CallbackContext for InvocationContext {
426    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
427        self.artifacts.clone()
428    }
429
430    fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
431        self.shared_state.clone()
432    }
433}
434
435#[async_trait]
436impl InvocationContextTrait for InvocationContext {
437    fn agent(&self) -> Arc<dyn Agent> {
438        self.agent.clone()
439    }
440
441    fn memory(&self) -> Option<Arc<dyn Memory>> {
442        self.memory.clone()
443    }
444
445    fn session(&self) -> &dyn adk_core::Session {
446        self.session.as_ref()
447    }
448
449    fn run_config(&self) -> &RunConfig {
450        &self.run_config
451    }
452
453    fn end_invocation(&self) {
454        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
455    }
456
457    fn ended(&self) -> bool {
458        self.ended.load(std::sync::atomic::Ordering::SeqCst)
459    }
460
461    fn user_scopes(&self) -> Vec<String> {
462        self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
463    }
464
465    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
466        self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
467            rc.metadata
468                .iter()
469                .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
470                .collect()
471        })
472    }
473}