Skip to main content

adk_runner/
context.rs

1use adk_core::{
2    AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3    InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4    RequestContext, RunConfig, SecretService, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11/// MutableSession wraps a session with shared mutable state.
12///
13/// This mirrors ADK-Go's MutableSession pattern where state changes from
14/// events are immediately visible to all agents sharing the same context.
15/// This is critical for SequentialAgent/LoopAgent patterns where downstream
16/// agents need to read state set by upstream agents via output_key.
17pub struct MutableSession {
18    /// The original session snapshot (for metadata like id, app_name, user_id)
19    inner: Arc<dyn AdkSession>,
20    /// Shared mutable state - updated when events are processed
21    /// This is the key difference from the old SessionAdapter which used immutable snapshots
22    state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23    /// Accumulated events during this invocation (uses adk_core::Event which is re-exported by adk_session)
24    events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28    /// Create a new MutableSession from a session snapshot.
29    /// The state is copied from the session and becomes mutable.
30    pub fn new(session: Arc<dyn AdkSession>) -> Self {
31        // Clone the initial state from the session
32        let initial_state = session.state().all();
33        // Clone the initial events
34        let initial_events = session.events().all();
35
36        Self {
37            inner: session,
38            state: Arc::new(RwLock::new(initial_state)),
39            events: Arc::new(RwLock::new(initial_events)),
40        }
41    }
42
43    /// Apply state delta from an event to the mutable state.
44    /// This is called by the Runner when events are yielded.
45    pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46        if delta.is_empty() {
47            return;
48        }
49
50        let Ok(mut state) = self.state.write() else {
51            tracing::error!("state RwLock poisoned in apply_state_delta — skipping delta");
52            return;
53        };
54        for (key, value) in delta {
55            // Skip temp: prefixed keys (they shouldn't persist)
56            if !key.starts_with("temp:") {
57                state.insert(key.clone(), value.clone());
58            }
59        }
60    }
61
62    /// Append an event to the session's event list.
63    /// This keeps the in-memory view consistent.
64    pub fn append_event(&self, event: Event) {
65        let Ok(mut events) = self.events.write() else {
66            tracing::error!("events RwLock poisoned in append_event — event dropped");
67            return;
68        };
69        events.push(event);
70    }
71
72    /// Get a snapshot of all events in the session.
73    /// Used by the runner for compaction decisions.
74    pub fn events_snapshot(&self) -> Vec<Event> {
75        let Ok(events) = self.events.read() else {
76            tracing::error!("events RwLock poisoned in events_snapshot — returning empty");
77            return Vec::new();
78        };
79        events.clone()
80    }
81
82    /// Replace all events with a new list (used by intra-invocation compaction).
83    pub fn replace_events(&self, new_events: Vec<Event>) {
84        let Ok(mut events) = self.events.write() else {
85            tracing::error!("events RwLock poisoned in replace_events — events unchanged");
86            return;
87        };
88        *events = new_events;
89    }
90
91    /// Return the number of accumulated events without cloning the full list.
92    pub fn events_len(&self) -> usize {
93        let Ok(events) = self.events.read() else {
94            tracing::error!("events RwLock poisoned in events_len — returning 0");
95            return 0;
96        };
97        events.len()
98    }
99
100    /// Build conversation history, optionally filtered for a specific agent.
101    ///
102    /// When `agent_name` is `Some`, events authored by other agents (not "user",
103    /// not the named agent, and not function/tool responses) are excluded. This
104    /// prevents a transferred sub-agent from seeing the parent's tool calls
105    /// mapped as "model" role, which would cause the LLM to think work is
106    /// already done.
107    ///
108    /// When `agent_name` is `None`, all events are included (backward-compatible).
109    pub fn conversation_history_for_agent_impl(
110        &self,
111        agent_name: Option<&str>,
112    ) -> Vec<adk_core::Content> {
113        let Ok(events) = self.events.read() else {
114            tracing::error!("events RwLock poisoned in conversation_history — returning empty");
115            return Vec::new();
116        };
117        let mut history = Vec::new();
118
119        // Find the most recent compaction event — everything before its
120        // end_timestamp has been summarized and should be replaced by the
121        // compacted content.
122        let mut compaction_boundary = None;
123        for event in events.iter().rev() {
124            if let Some(ref compaction) = event.actions.compaction {
125                history.push(compaction.compacted_content.clone());
126                compaction_boundary = Some(compaction.end_timestamp);
127                break;
128            }
129        }
130
131        for event in events.iter() {
132            // Skip the compaction event itself
133            if event.actions.compaction.is_some() {
134                continue;
135            }
136
137            // Skip events that were already compacted
138            if let Some(boundary) = compaction_boundary
139                && event.timestamp <= boundary
140            {
141                continue;
142            }
143
144            // When filtering for a specific agent, skip events from other agents.
145            // Keep: user messages and the agent's own events.
146            // Skip: other agents' events entirely (model-role, function calls,
147            // and function/tool responses). This prevents the sub-agent from
148            // seeing orphaned function responses without their preceding calls.
149            if let Some(name) = agent_name
150                && event.author != "user"
151                && event.author != name
152            {
153                continue;
154            }
155
156            if let Some(content) = &event.llm_response.content {
157                let mut mapped_content = content.clone();
158                mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
159                    ("user", _) => "user",
160                    (_, "function" | "tool") => content.role.as_str(),
161                    _ => "model",
162                }
163                .to_string();
164                history.push(mapped_content);
165            }
166        }
167
168        history
169    }
170}
171
172impl adk_core::Session for MutableSession {
173    fn id(&self) -> &str {
174        self.inner.id()
175    }
176
177    fn app_name(&self) -> &str {
178        self.inner.app_name()
179    }
180
181    fn user_id(&self) -> &str {
182        self.inner.user_id()
183    }
184
185    fn state(&self) -> &dyn adk_core::State {
186        self
187    }
188
189    fn conversation_history(&self) -> Vec<adk_core::Content> {
190        self.conversation_history_for_agent_impl(None)
191    }
192
193    fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
194        self.conversation_history_for_agent_impl(Some(agent_name))
195    }
196}
197
198impl adk_core::State for MutableSession {
199    fn get(&self, key: &str) -> Option<serde_json::Value> {
200        let Ok(state) = self.state.read() else {
201            tracing::error!("state RwLock poisoned in State::get — returning None");
202            return None;
203        };
204        state.get(key).cloned()
205    }
206
207    fn set(&mut self, key: String, value: serde_json::Value) {
208        if let Err(msg) = adk_core::validate_state_key(&key) {
209            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
210            return;
211        }
212        let Ok(mut state) = self.state.write() else {
213            tracing::error!("state RwLock poisoned in State::set — value dropped");
214            return;
215        };
216        state.insert(key, value);
217    }
218
219    fn all(&self) -> HashMap<String, serde_json::Value> {
220        let Ok(state) = self.state.read() else {
221            tracing::error!("state RwLock poisoned in State::all — returning empty");
222            return HashMap::new();
223        };
224        state.clone()
225    }
226}
227
228/// Runtime context for a single agent invocation.
229///
230/// Holds the agent, session, identity, and optional services (artifacts, memory,
231/// secrets) needed during execution. Created by the [`Runner`](crate::Runner) for
232/// each `run()` call.
233pub struct InvocationContext {
234    identity: ExecutionIdentity,
235    agent: Arc<dyn Agent>,
236    user_content: Content,
237    artifacts: Option<Arc<dyn Artifacts>>,
238    memory: Option<Arc<dyn Memory>>,
239    run_config: RunConfig,
240    ended: Arc<AtomicBool>,
241    /// Mutable session that allows state to be updated during execution.
242    /// This is shared across all agents in a workflow, enabling state
243    /// propagation between sequential/parallel agents.
244    session: Arc<MutableSession>,
245    /// Optional request context from the server's auth middleware bridge.
246    /// When present, `user_id()` returns `request_context.user_id` and
247    /// `user_scopes()` returns `request_context.scopes`.
248    request_context: Option<RequestContext>,
249    /// Optional shared state for parallel agent coordination.
250    shared_state: Option<Arc<adk_core::SharedState>>,
251    /// Optional secret service for retrieving secrets at runtime.
252    /// When present, `get_secret()` delegates to this service.
253    secret_service: Option<Arc<dyn SecretService>>,
254}
255
256impl InvocationContext {
257    /// Create a new invocation context from validated typed identifiers.
258    pub fn new_typed(
259        invocation_id: String,
260        agent: Arc<dyn Agent>,
261        user_id: UserId,
262        app_name: AppName,
263        session_id: SessionId,
264        user_content: Content,
265        session: Arc<dyn AdkSession>,
266    ) -> adk_core::Result<Self> {
267        let identity = ExecutionIdentity {
268            adk: AdkIdentity { app_name, user_id, session_id },
269            invocation_id: InvocationId::try_from(invocation_id)?,
270            branch: String::new(),
271            agent_name: agent.name().to_string(),
272        };
273        Ok(Self {
274            identity,
275            agent,
276            user_content,
277            artifacts: None,
278            memory: None,
279            run_config: RunConfig::default(),
280            ended: Arc::new(AtomicBool::new(false)),
281            session: Arc::new(MutableSession::new(session)),
282            request_context: None,
283            shared_state: None,
284            secret_service: None,
285        })
286    }
287
288    /// Create a new invocation context from raw string identifiers.
289    ///
290    /// Validates and converts the string identifiers into typed wrappers.
291    /// Prefer [`new_typed`](Self::new_typed) when you already have validated types.
292    pub fn new(
293        invocation_id: String,
294        agent: Arc<dyn Agent>,
295        user_id: String,
296        app_name: String,
297        session_id: String,
298        user_content: Content,
299        session: Arc<dyn AdkSession>,
300    ) -> adk_core::Result<Self> {
301        Self::new_typed(
302            invocation_id,
303            agent,
304            UserId::try_from(user_id)?,
305            AppName::try_from(app_name)?,
306            SessionId::try_from(session_id)?,
307            user_content,
308            session,
309        )
310    }
311
312    /// Create an invocation context that reuses an existing mutable session and
313    /// validated typed identifiers.
314    pub fn with_mutable_session_typed(
315        invocation_id: String,
316        agent: Arc<dyn Agent>,
317        user_id: UserId,
318        app_name: AppName,
319        session_id: SessionId,
320        user_content: Content,
321        session: Arc<MutableSession>,
322    ) -> adk_core::Result<Self> {
323        let identity = ExecutionIdentity {
324            adk: AdkIdentity { app_name, user_id, session_id },
325            invocation_id: InvocationId::try_from(invocation_id)?,
326            branch: String::new(),
327            agent_name: agent.name().to_string(),
328        };
329        Ok(Self {
330            identity,
331            agent,
332            user_content,
333            artifacts: None,
334            memory: None,
335            run_config: RunConfig::default(),
336            ended: Arc::new(AtomicBool::new(false)),
337            session,
338            request_context: None,
339            shared_state: None,
340            secret_service: None,
341        })
342    }
343
344    /// Create an InvocationContext with an existing MutableSession.
345    /// This allows sharing the same mutable session across multiple contexts
346    /// (e.g., for agent transfers).
347    pub fn with_mutable_session(
348        invocation_id: String,
349        agent: Arc<dyn Agent>,
350        user_id: String,
351        app_name: String,
352        session_id: String,
353        user_content: Content,
354        session: Arc<MutableSession>,
355    ) -> adk_core::Result<Self> {
356        Self::with_mutable_session_typed(
357            invocation_id,
358            agent,
359            UserId::try_from(user_id)?,
360            AppName::try_from(app_name)?,
361            SessionId::try_from(session_id)?,
362            user_content,
363            session,
364        )
365    }
366
367    /// Set the event branch identifier for this context.
368    pub fn with_branch(mut self, branch: String) -> Self {
369        self.identity.branch = branch;
370        self
371    }
372
373    /// Attach an artifact storage service to this context.
374    pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
375        self.artifacts = Some(artifacts);
376        self
377    }
378
379    /// Attach a memory service for RAG/semantic search.
380    pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
381        self.memory = Some(memory);
382        self
383    }
384
385    /// Set the run configuration (streaming mode, history limits, etc.).
386    pub fn with_run_config(mut self, config: RunConfig) -> Self {
387        self.run_config = config;
388        self
389    }
390
391    /// Set the request context from the server's auth middleware bridge.
392    ///
393    /// When set, `user_id()` returns `request_context.user_id` (overriding
394    /// the session-scoped identity), and `user_scopes()` returns
395    /// `request_context.scopes`. This is the explicit authenticated user
396    /// override — `RequestContext` remains separate from `ExecutionIdentity`
397    /// and `AdkIdentity` (it does not carry session or invocation IDs).
398    pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
399        self.request_context = Some(ctx);
400        self
401    }
402
403    /// Set the shared state for parallel agent coordination.
404    pub fn with_shared_state(mut self, shared: Arc<adk_core::SharedState>) -> Self {
405        self.shared_state = Some(shared);
406        self
407    }
408
409    /// Set the secret service for runtime secret retrieval.
410    ///
411    /// When configured, tools can call `ctx.get_secret("name")` to retrieve
412    /// secrets from the configured provider (e.g., AWS Secrets Manager,
413    /// Azure Key Vault, GCP Secret Manager).
414    pub fn with_secret_service(mut self, service: Arc<dyn SecretService>) -> Self {
415        self.secret_service = Some(service);
416        self
417    }
418
419    /// Get a reference to the mutable session.
420    /// This allows the Runner to apply state deltas when events are processed.
421    pub fn mutable_session(&self) -> &Arc<MutableSession> {
422        &self.session
423    }
424}
425
426#[async_trait]
427impl ReadonlyContext for InvocationContext {
428    fn invocation_id(&self) -> &str {
429        self.identity.invocation_id.as_ref()
430    }
431
432    fn agent_name(&self) -> &str {
433        self.agent.name()
434    }
435
436    fn user_id(&self) -> &str {
437        // Explicit authenticated user override: when a RequestContext is
438        // present (set via with_request_context from the auth middleware
439        // bridge), the authenticated user_id takes precedence over the
440        // session-scoped identity. This keeps auth binding explicit and
441        // ensures the runtime reflects the verified caller identity.
442        self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
443    }
444
445    fn app_name(&self) -> &str {
446        self.identity.adk.app_name.as_ref()
447    }
448
449    fn session_id(&self) -> &str {
450        self.identity.adk.session_id.as_ref()
451    }
452
453    fn branch(&self) -> &str {
454        &self.identity.branch
455    }
456
457    fn user_content(&self) -> &Content {
458        &self.user_content
459    }
460}
461
462#[async_trait]
463impl CallbackContext for InvocationContext {
464    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
465        self.artifacts.clone()
466    }
467
468    fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
469        self.shared_state.clone()
470    }
471}
472
473#[async_trait]
474impl InvocationContextTrait for InvocationContext {
475    fn agent(&self) -> Arc<dyn Agent> {
476        self.agent.clone()
477    }
478
479    fn memory(&self) -> Option<Arc<dyn Memory>> {
480        self.memory.clone()
481    }
482
483    fn session(&self) -> &dyn adk_core::Session {
484        self.session.as_ref()
485    }
486
487    fn run_config(&self) -> &RunConfig {
488        &self.run_config
489    }
490
491    fn end_invocation(&self) {
492        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
493    }
494
495    fn ended(&self) -> bool {
496        self.ended.load(std::sync::atomic::Ordering::SeqCst)
497    }
498
499    fn user_scopes(&self) -> Vec<String> {
500        self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
501    }
502
503    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
504        self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
505            rc.metadata
506                .iter()
507                .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
508                .collect()
509        })
510    }
511
512    async fn get_secret(&self, name: &str) -> adk_core::Result<Option<String>> {
513        match &self.secret_service {
514            Some(service) => service.get_secret(name).await.map(Some),
515            None => Ok(None),
516        }
517    }
518}