Skip to main content

adk_runner/
context.rs

1use adk_core::{
2    AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3    InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4    RequestContext, RunConfig, SecretService, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11/// MutableSession wraps a session with shared mutable state.
12///
13/// This mirrors ADK-Go's MutableSession pattern where state changes from
14/// events are immediately visible to all agents sharing the same context.
15/// This is critical for SequentialAgent/LoopAgent patterns where downstream
16/// agents need to read state set by upstream agents via output_key.
17pub struct MutableSession {
18    /// The original session snapshot (for metadata like id, app_name, user_id)
19    inner: Arc<dyn AdkSession>,
20    /// Shared mutable state - updated when events are processed
21    /// This is the key difference from the old SessionAdapter which used immutable snapshots
22    state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23    /// Accumulated events during this invocation (uses adk_core::Event which is re-exported by adk_session)
24    events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28    /// Create a new MutableSession from a session snapshot.
29    /// The state is copied from the session and becomes mutable.
30    pub fn new(session: Arc<dyn AdkSession>) -> Self {
31        // Clone the initial state from the session
32        let initial_state = session.state().all();
33        // Clone the initial events
34        let initial_events = session.events().all();
35
36        Self {
37            inner: session,
38            state: Arc::new(RwLock::new(initial_state)),
39            events: Arc::new(RwLock::new(initial_events)),
40        }
41    }
42
43    /// Apply state delta from an event to the mutable state.
44    /// This is called by the Runner when events are yielded.
45    pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46        if delta.is_empty() {
47            return;
48        }
49
50        let Ok(mut state) = self.state.write() else {
51            tracing::error!("state RwLock poisoned in apply_state_delta — skipping delta");
52            return;
53        };
54        for (key, value) in delta {
55            // Skip temp: prefixed keys (they shouldn't persist)
56            if !key.starts_with("temp:") {
57                state.insert(key.clone(), value.clone());
58            }
59        }
60    }
61
62    /// Append an event to the session's event list.
63    /// This keeps the in-memory view consistent.
64    pub fn append_event(&self, event: Event) {
65        let Ok(mut events) = self.events.write() else {
66            tracing::error!("events RwLock poisoned in append_event — event dropped");
67            return;
68        };
69        events.push(event);
70    }
71
72    /// Get a snapshot of all events in the session.
73    /// Used by the runner for compaction decisions.
74    pub fn events_snapshot(&self) -> Vec<Event> {
75        let Ok(events) = self.events.read() else {
76            tracing::error!("events RwLock poisoned in events_snapshot — returning empty");
77            return Vec::new();
78        };
79        events.clone()
80    }
81
82    /// Replace all events with a new list (used by intra-invocation compaction).
83    pub fn replace_events(&self, new_events: Vec<Event>) {
84        let Ok(mut events) = self.events.write() else {
85            tracing::error!("events RwLock poisoned in replace_events — events unchanged");
86            return;
87        };
88        *events = new_events;
89    }
90
91    /// Return the number of accumulated events without cloning the full list.
92    pub fn events_len(&self) -> usize {
93        let Ok(events) = self.events.read() else {
94            tracing::error!("events RwLock poisoned in events_len — returning 0");
95            return 0;
96        };
97        events.len()
98    }
99
100    /// Build conversation history, optionally filtered for a specific agent.
101    ///
102    /// When `agent_name` is `Some`, events authored by other agents (not "user",
103    /// not the named agent, and not function/tool responses) are excluded. This
104    /// prevents a transferred sub-agent from seeing the parent's tool calls
105    /// mapped as "model" role, which would cause the LLM to think work is
106    /// already done.
107    ///
108    /// When `agent_name` is `None`, all events are included (backward-compatible).
109    pub fn conversation_history_for_agent_impl(
110        &self,
111        agent_name: Option<&str>,
112    ) -> Vec<adk_core::Content> {
113        let Ok(events) = self.events.read() else {
114            tracing::error!("events RwLock poisoned in conversation_history — returning empty");
115            return Vec::new();
116        };
117        let mut history = Vec::new();
118
119        // Find the most recent compaction event — everything before its
120        // end_timestamp has been summarized and should be replaced by the
121        // compacted content.
122        let mut compaction_boundary = None;
123        for event in events.iter().rev() {
124            if let Some(ref compaction) = event.actions.compaction {
125                history.push(compaction.compacted_content.clone());
126                compaction_boundary = Some(compaction.end_timestamp);
127                break;
128            }
129        }
130
131        for event in events.iter() {
132            // Skip the compaction event itself
133            if event.actions.compaction.is_some() {
134                continue;
135            }
136
137            // Skip events that were already compacted
138            if let Some(boundary) = compaction_boundary {
139                if event.timestamp <= boundary {
140                    continue;
141                }
142            }
143
144            // When filtering for a specific agent, skip events from other agents.
145            // Keep: user messages and the agent's own events.
146            // Skip: other agents' events entirely (model-role, function calls,
147            // and function/tool responses). This prevents the sub-agent from
148            // seeing orphaned function responses without their preceding calls.
149            if let Some(name) = agent_name {
150                if event.author != "user" && event.author != name {
151                    continue;
152                }
153            }
154
155            if let Some(content) = &event.llm_response.content {
156                let mut mapped_content = content.clone();
157                mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
158                    ("user", _) => "user",
159                    (_, "function" | "tool") => content.role.as_str(),
160                    _ => "model",
161                }
162                .to_string();
163                history.push(mapped_content);
164            }
165        }
166
167        history
168    }
169}
170
171impl adk_core::Session for MutableSession {
172    fn id(&self) -> &str {
173        self.inner.id()
174    }
175
176    fn app_name(&self) -> &str {
177        self.inner.app_name()
178    }
179
180    fn user_id(&self) -> &str {
181        self.inner.user_id()
182    }
183
184    fn state(&self) -> &dyn adk_core::State {
185        self
186    }
187
188    fn conversation_history(&self) -> Vec<adk_core::Content> {
189        self.conversation_history_for_agent_impl(None)
190    }
191
192    fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
193        self.conversation_history_for_agent_impl(Some(agent_name))
194    }
195}
196
197impl adk_core::State for MutableSession {
198    fn get(&self, key: &str) -> Option<serde_json::Value> {
199        let Ok(state) = self.state.read() else {
200            tracing::error!("state RwLock poisoned in State::get — returning None");
201            return None;
202        };
203        state.get(key).cloned()
204    }
205
206    fn set(&mut self, key: String, value: serde_json::Value) {
207        if let Err(msg) = adk_core::validate_state_key(&key) {
208            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
209            return;
210        }
211        let Ok(mut state) = self.state.write() else {
212            tracing::error!("state RwLock poisoned in State::set — value dropped");
213            return;
214        };
215        state.insert(key, value);
216    }
217
218    fn all(&self) -> HashMap<String, serde_json::Value> {
219        let Ok(state) = self.state.read() else {
220            tracing::error!("state RwLock poisoned in State::all — returning empty");
221            return HashMap::new();
222        };
223        state.clone()
224    }
225}
226
227/// Runtime context for a single agent invocation.
228///
229/// Holds the agent, session, identity, and optional services (artifacts, memory,
230/// secrets) needed during execution. Created by the [`Runner`](crate::Runner) for
231/// each `run()` call.
232pub struct InvocationContext {
233    identity: ExecutionIdentity,
234    agent: Arc<dyn Agent>,
235    user_content: Content,
236    artifacts: Option<Arc<dyn Artifacts>>,
237    memory: Option<Arc<dyn Memory>>,
238    run_config: RunConfig,
239    ended: Arc<AtomicBool>,
240    /// Mutable session that allows state to be updated during execution.
241    /// This is shared across all agents in a workflow, enabling state
242    /// propagation between sequential/parallel agents.
243    session: Arc<MutableSession>,
244    /// Optional request context from the server's auth middleware bridge.
245    /// When present, `user_id()` returns `request_context.user_id` and
246    /// `user_scopes()` returns `request_context.scopes`.
247    request_context: Option<RequestContext>,
248    /// Optional shared state for parallel agent coordination.
249    shared_state: Option<Arc<adk_core::SharedState>>,
250    /// Optional secret service for retrieving secrets at runtime.
251    /// When present, `get_secret()` delegates to this service.
252    secret_service: Option<Arc<dyn SecretService>>,
253}
254
255impl InvocationContext {
256    /// Create a new invocation context from validated typed identifiers.
257    pub fn new_typed(
258        invocation_id: String,
259        agent: Arc<dyn Agent>,
260        user_id: UserId,
261        app_name: AppName,
262        session_id: SessionId,
263        user_content: Content,
264        session: Arc<dyn AdkSession>,
265    ) -> adk_core::Result<Self> {
266        let identity = ExecutionIdentity {
267            adk: AdkIdentity { app_name, user_id, session_id },
268            invocation_id: InvocationId::try_from(invocation_id)?,
269            branch: String::new(),
270            agent_name: agent.name().to_string(),
271        };
272        Ok(Self {
273            identity,
274            agent,
275            user_content,
276            artifacts: None,
277            memory: None,
278            run_config: RunConfig::default(),
279            ended: Arc::new(AtomicBool::new(false)),
280            session: Arc::new(MutableSession::new(session)),
281            request_context: None,
282            shared_state: None,
283            secret_service: None,
284        })
285    }
286
287    /// Create a new invocation context from raw string identifiers.
288    ///
289    /// Validates and converts the string identifiers into typed wrappers.
290    /// Prefer [`new_typed`](Self::new_typed) when you already have validated types.
291    pub fn new(
292        invocation_id: String,
293        agent: Arc<dyn Agent>,
294        user_id: String,
295        app_name: String,
296        session_id: String,
297        user_content: Content,
298        session: Arc<dyn AdkSession>,
299    ) -> adk_core::Result<Self> {
300        Self::new_typed(
301            invocation_id,
302            agent,
303            UserId::try_from(user_id)?,
304            AppName::try_from(app_name)?,
305            SessionId::try_from(session_id)?,
306            user_content,
307            session,
308        )
309    }
310
311    /// Create an invocation context that reuses an existing mutable session and
312    /// validated typed identifiers.
313    pub fn with_mutable_session_typed(
314        invocation_id: String,
315        agent: Arc<dyn Agent>,
316        user_id: UserId,
317        app_name: AppName,
318        session_id: SessionId,
319        user_content: Content,
320        session: Arc<MutableSession>,
321    ) -> adk_core::Result<Self> {
322        let identity = ExecutionIdentity {
323            adk: AdkIdentity { app_name, user_id, session_id },
324            invocation_id: InvocationId::try_from(invocation_id)?,
325            branch: String::new(),
326            agent_name: agent.name().to_string(),
327        };
328        Ok(Self {
329            identity,
330            agent,
331            user_content,
332            artifacts: None,
333            memory: None,
334            run_config: RunConfig::default(),
335            ended: Arc::new(AtomicBool::new(false)),
336            session,
337            request_context: None,
338            shared_state: None,
339            secret_service: None,
340        })
341    }
342
343    /// Create an InvocationContext with an existing MutableSession.
344    /// This allows sharing the same mutable session across multiple contexts
345    /// (e.g., for agent transfers).
346    pub fn with_mutable_session(
347        invocation_id: String,
348        agent: Arc<dyn Agent>,
349        user_id: String,
350        app_name: String,
351        session_id: String,
352        user_content: Content,
353        session: Arc<MutableSession>,
354    ) -> adk_core::Result<Self> {
355        Self::with_mutable_session_typed(
356            invocation_id,
357            agent,
358            UserId::try_from(user_id)?,
359            AppName::try_from(app_name)?,
360            SessionId::try_from(session_id)?,
361            user_content,
362            session,
363        )
364    }
365
366    /// Set the event branch identifier for this context.
367    pub fn with_branch(mut self, branch: String) -> Self {
368        self.identity.branch = branch;
369        self
370    }
371
372    /// Attach an artifact storage service to this context.
373    pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
374        self.artifacts = Some(artifacts);
375        self
376    }
377
378    /// Attach a memory service for RAG/semantic search.
379    pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
380        self.memory = Some(memory);
381        self
382    }
383
384    /// Set the run configuration (streaming mode, history limits, etc.).
385    pub fn with_run_config(mut self, config: RunConfig) -> Self {
386        self.run_config = config;
387        self
388    }
389
390    /// Set the request context from the server's auth middleware bridge.
391    ///
392    /// When set, `user_id()` returns `request_context.user_id` (overriding
393    /// the session-scoped identity), and `user_scopes()` returns
394    /// `request_context.scopes`. This is the explicit authenticated user
395    /// override — `RequestContext` remains separate from `ExecutionIdentity`
396    /// and `AdkIdentity` (it does not carry session or invocation IDs).
397    pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
398        self.request_context = Some(ctx);
399        self
400    }
401
402    /// Set the shared state for parallel agent coordination.
403    pub fn with_shared_state(mut self, shared: Arc<adk_core::SharedState>) -> Self {
404        self.shared_state = Some(shared);
405        self
406    }
407
408    /// Set the secret service for runtime secret retrieval.
409    ///
410    /// When configured, tools can call `ctx.get_secret("name")` to retrieve
411    /// secrets from the configured provider (e.g., AWS Secrets Manager,
412    /// Azure Key Vault, GCP Secret Manager).
413    pub fn with_secret_service(mut self, service: Arc<dyn SecretService>) -> Self {
414        self.secret_service = Some(service);
415        self
416    }
417
418    /// Get a reference to the mutable session.
419    /// This allows the Runner to apply state deltas when events are processed.
420    pub fn mutable_session(&self) -> &Arc<MutableSession> {
421        &self.session
422    }
423}
424
425#[async_trait]
426impl ReadonlyContext for InvocationContext {
427    fn invocation_id(&self) -> &str {
428        self.identity.invocation_id.as_ref()
429    }
430
431    fn agent_name(&self) -> &str {
432        self.agent.name()
433    }
434
435    fn user_id(&self) -> &str {
436        // Explicit authenticated user override: when a RequestContext is
437        // present (set via with_request_context from the auth middleware
438        // bridge), the authenticated user_id takes precedence over the
439        // session-scoped identity. This keeps auth binding explicit and
440        // ensures the runtime reflects the verified caller identity.
441        self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
442    }
443
444    fn app_name(&self) -> &str {
445        self.identity.adk.app_name.as_ref()
446    }
447
448    fn session_id(&self) -> &str {
449        self.identity.adk.session_id.as_ref()
450    }
451
452    fn branch(&self) -> &str {
453        &self.identity.branch
454    }
455
456    fn user_content(&self) -> &Content {
457        &self.user_content
458    }
459}
460
461#[async_trait]
462impl CallbackContext for InvocationContext {
463    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
464        self.artifacts.clone()
465    }
466
467    fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
468        self.shared_state.clone()
469    }
470}
471
472#[async_trait]
473impl InvocationContextTrait for InvocationContext {
474    fn agent(&self) -> Arc<dyn Agent> {
475        self.agent.clone()
476    }
477
478    fn memory(&self) -> Option<Arc<dyn Memory>> {
479        self.memory.clone()
480    }
481
482    fn session(&self) -> &dyn adk_core::Session {
483        self.session.as_ref()
484    }
485
486    fn run_config(&self) -> &RunConfig {
487        &self.run_config
488    }
489
490    fn end_invocation(&self) {
491        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
492    }
493
494    fn ended(&self) -> bool {
495        self.ended.load(std::sync::atomic::Ordering::SeqCst)
496    }
497
498    fn user_scopes(&self) -> Vec<String> {
499        self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
500    }
501
502    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
503        self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
504            rc.metadata
505                .iter()
506                .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
507                .collect()
508        })
509    }
510
511    async fn get_secret(&self, name: &str) -> adk_core::Result<Option<String>> {
512        match &self.secret_service {
513            Some(service) => service.get_secret(name).await.map(Some),
514            None => Ok(None),
515        }
516    }
517}