Skip to main content

adk_runner/
context.rs

1use adk_core::{
2    AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3    InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4    RequestContext, RunConfig, SecretService, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11/// MutableSession wraps a session with shared mutable state.
12///
13/// This mirrors ADK-Go's MutableSession pattern where state changes from
14/// events are immediately visible to all agents sharing the same context.
15/// This is critical for SequentialAgent/LoopAgent patterns where downstream
16/// agents need to read state set by upstream agents via output_key.
17pub struct MutableSession {
18    /// The original session snapshot (for metadata like id, app_name, user_id)
19    inner: Arc<dyn AdkSession>,
20    /// Shared mutable state - updated when events are processed
21    /// This is the key difference from the old SessionAdapter which used immutable snapshots
22    state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23    /// Accumulated events during this invocation (uses adk_core::Event which is re-exported by adk_session)
24    events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28    /// Create a new MutableSession from a session snapshot.
29    /// The state is copied from the session and becomes mutable.
30    pub fn new(session: Arc<dyn AdkSession>) -> Self {
31        // Clone the initial state from the session
32        let initial_state = session.state().all();
33        // Clone the initial events
34        let initial_events = session.events().all();
35
36        Self {
37            inner: session,
38            state: Arc::new(RwLock::new(initial_state)),
39            events: Arc::new(RwLock::new(initial_events)),
40        }
41    }
42
43    /// Apply state delta from an event to the mutable state.
44    /// This is called by the Runner when events are yielded.
45    pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46        if delta.is_empty() {
47            return;
48        }
49
50        let Ok(mut state) = self.state.write() else {
51            tracing::error!("state RwLock poisoned in apply_state_delta — skipping delta");
52            return;
53        };
54        for (key, value) in delta {
55            // Skip temp: prefixed keys (they shouldn't persist)
56            if !key.starts_with("temp:") {
57                state.insert(key.clone(), value.clone());
58            }
59        }
60    }
61
62    /// Append an event to the session's event list.
63    /// This keeps the in-memory view consistent.
64    pub fn append_event(&self, event: Event) {
65        let Ok(mut events) = self.events.write() else {
66            tracing::error!("events RwLock poisoned in append_event — event dropped");
67            return;
68        };
69        events.push(event);
70    }
71
72    /// Get a snapshot of all events in the session.
73    /// Used by the runner for compaction decisions.
74    pub fn events_snapshot(&self) -> Vec<Event> {
75        let Ok(events) = self.events.read() else {
76            tracing::error!("events RwLock poisoned in events_snapshot — returning empty");
77            return Vec::new();
78        };
79        events.clone()
80    }
81
82    /// Replace all events with a new list (used by intra-invocation compaction).
83    pub fn replace_events(&self, new_events: Vec<Event>) {
84        let Ok(mut events) = self.events.write() else {
85            tracing::error!("events RwLock poisoned in replace_events — events unchanged");
86            return;
87        };
88        *events = new_events;
89    }
90
91    /// Return the number of accumulated events without cloning the full list.
92    pub fn events_len(&self) -> usize {
93        let Ok(events) = self.events.read() else {
94            tracing::error!("events RwLock poisoned in events_len — returning 0");
95            return 0;
96        };
97        events.len()
98    }
99
100    /// Build conversation history, optionally filtered for a specific agent.
101    ///
102    /// When `agent_name` is `Some`, events authored by other agents (not "user",
103    /// not the named agent, and not function/tool responses) are excluded. This
104    /// prevents a transferred sub-agent from seeing the parent's tool calls
105    /// mapped as "model" role, which would cause the LLM to think work is
106    /// already done.
107    ///
108    /// When `agent_name` is `None`, all events are included (backward-compatible).
109    pub fn conversation_history_for_agent_impl(
110        &self,
111        agent_name: Option<&str>,
112    ) -> Vec<adk_core::Content> {
113        let Ok(events) = self.events.read() else {
114            tracing::error!("events RwLock poisoned in conversation_history — returning empty");
115            return Vec::new();
116        };
117        let mut history = Vec::new();
118
119        // Find the most recent compaction event — everything before its
120        // end_timestamp has been summarized and should be replaced by the
121        // compacted content.
122        let mut compaction_boundary = None;
123        for event in events.iter().rev() {
124            if let Some(ref compaction) = event.actions.compaction {
125                history.push(compaction.compacted_content.clone());
126                compaction_boundary = Some(compaction.end_timestamp);
127                break;
128            }
129        }
130
131        for event in events.iter() {
132            // Skip the compaction event itself
133            if event.actions.compaction.is_some() {
134                continue;
135            }
136
137            // Skip events that were already compacted
138            if let Some(boundary) = compaction_boundary {
139                if event.timestamp <= boundary {
140                    continue;
141                }
142            }
143
144            // When filtering for a specific agent, skip events from other agents.
145            // Keep: user messages and the agent's own events.
146            // Skip: other agents' events entirely (model-role, function calls,
147            // and function/tool responses). This prevents the sub-agent from
148            // seeing orphaned function responses without their preceding calls.
149            if let Some(name) = agent_name {
150                if event.author != "user" && event.author != name {
151                    continue;
152                }
153            }
154
155            if let Some(content) = &event.llm_response.content {
156                let mut mapped_content = content.clone();
157                mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
158                    ("user", _) => "user",
159                    (_, "function" | "tool") => content.role.as_str(),
160                    _ => "model",
161                }
162                .to_string();
163                history.push(mapped_content);
164            }
165        }
166
167        history
168    }
169}
170
171impl adk_core::Session for MutableSession {
172    fn id(&self) -> &str {
173        self.inner.id()
174    }
175
176    fn app_name(&self) -> &str {
177        self.inner.app_name()
178    }
179
180    fn user_id(&self) -> &str {
181        self.inner.user_id()
182    }
183
184    fn state(&self) -> &dyn adk_core::State {
185        self
186    }
187
188    fn conversation_history(&self) -> Vec<adk_core::Content> {
189        self.conversation_history_for_agent_impl(None)
190    }
191
192    fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
193        self.conversation_history_for_agent_impl(Some(agent_name))
194    }
195}
196
197impl adk_core::State for MutableSession {
198    fn get(&self, key: &str) -> Option<serde_json::Value> {
199        let Ok(state) = self.state.read() else {
200            tracing::error!("state RwLock poisoned in State::get — returning None");
201            return None;
202        };
203        state.get(key).cloned()
204    }
205
206    fn set(&mut self, key: String, value: serde_json::Value) {
207        if let Err(msg) = adk_core::validate_state_key(&key) {
208            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
209            return;
210        }
211        let Ok(mut state) = self.state.write() else {
212            tracing::error!("state RwLock poisoned in State::set — value dropped");
213            return;
214        };
215        state.insert(key, value);
216    }
217
218    fn all(&self) -> HashMap<String, serde_json::Value> {
219        let Ok(state) = self.state.read() else {
220            tracing::error!("state RwLock poisoned in State::all — returning empty");
221            return HashMap::new();
222        };
223        state.clone()
224    }
225}
226
227pub struct InvocationContext {
228    identity: ExecutionIdentity,
229    agent: Arc<dyn Agent>,
230    user_content: Content,
231    artifacts: Option<Arc<dyn Artifacts>>,
232    memory: Option<Arc<dyn Memory>>,
233    run_config: RunConfig,
234    ended: Arc<AtomicBool>,
235    /// Mutable session that allows state to be updated during execution.
236    /// This is shared across all agents in a workflow, enabling state
237    /// propagation between sequential/parallel agents.
238    session: Arc<MutableSession>,
239    /// Optional request context from the server's auth middleware bridge.
240    /// When present, `user_id()` returns `request_context.user_id` and
241    /// `user_scopes()` returns `request_context.scopes`.
242    request_context: Option<RequestContext>,
243    /// Optional shared state for parallel agent coordination.
244    shared_state: Option<Arc<adk_core::SharedState>>,
245    /// Optional secret service for retrieving secrets at runtime.
246    /// When present, `get_secret()` delegates to this service.
247    secret_service: Option<Arc<dyn SecretService>>,
248}
249
250impl InvocationContext {
251    /// Create a new invocation context from validated typed identifiers.
252    pub fn new_typed(
253        invocation_id: String,
254        agent: Arc<dyn Agent>,
255        user_id: UserId,
256        app_name: AppName,
257        session_id: SessionId,
258        user_content: Content,
259        session: Arc<dyn AdkSession>,
260    ) -> adk_core::Result<Self> {
261        let identity = ExecutionIdentity {
262            adk: AdkIdentity { app_name, user_id, session_id },
263            invocation_id: InvocationId::try_from(invocation_id)?,
264            branch: String::new(),
265            agent_name: agent.name().to_string(),
266        };
267        Ok(Self {
268            identity,
269            agent,
270            user_content,
271            artifacts: None,
272            memory: None,
273            run_config: RunConfig::default(),
274            ended: Arc::new(AtomicBool::new(false)),
275            session: Arc::new(MutableSession::new(session)),
276            request_context: None,
277            shared_state: None,
278            secret_service: None,
279        })
280    }
281
282    pub fn new(
283        invocation_id: String,
284        agent: Arc<dyn Agent>,
285        user_id: String,
286        app_name: String,
287        session_id: String,
288        user_content: Content,
289        session: Arc<dyn AdkSession>,
290    ) -> adk_core::Result<Self> {
291        Self::new_typed(
292            invocation_id,
293            agent,
294            UserId::try_from(user_id)?,
295            AppName::try_from(app_name)?,
296            SessionId::try_from(session_id)?,
297            user_content,
298            session,
299        )
300    }
301
302    /// Create an invocation context that reuses an existing mutable session and
303    /// validated typed identifiers.
304    pub fn with_mutable_session_typed(
305        invocation_id: String,
306        agent: Arc<dyn Agent>,
307        user_id: UserId,
308        app_name: AppName,
309        session_id: SessionId,
310        user_content: Content,
311        session: Arc<MutableSession>,
312    ) -> adk_core::Result<Self> {
313        let identity = ExecutionIdentity {
314            adk: AdkIdentity { app_name, user_id, session_id },
315            invocation_id: InvocationId::try_from(invocation_id)?,
316            branch: String::new(),
317            agent_name: agent.name().to_string(),
318        };
319        Ok(Self {
320            identity,
321            agent,
322            user_content,
323            artifacts: None,
324            memory: None,
325            run_config: RunConfig::default(),
326            ended: Arc::new(AtomicBool::new(false)),
327            session,
328            request_context: None,
329            shared_state: None,
330            secret_service: None,
331        })
332    }
333
334    /// Create an InvocationContext with an existing MutableSession.
335    /// This allows sharing the same mutable session across multiple contexts
336    /// (e.g., for agent transfers).
337    pub fn with_mutable_session(
338        invocation_id: String,
339        agent: Arc<dyn Agent>,
340        user_id: String,
341        app_name: String,
342        session_id: String,
343        user_content: Content,
344        session: Arc<MutableSession>,
345    ) -> adk_core::Result<Self> {
346        Self::with_mutable_session_typed(
347            invocation_id,
348            agent,
349            UserId::try_from(user_id)?,
350            AppName::try_from(app_name)?,
351            SessionId::try_from(session_id)?,
352            user_content,
353            session,
354        )
355    }
356
357    pub fn with_branch(mut self, branch: String) -> Self {
358        self.identity.branch = branch;
359        self
360    }
361
362    pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
363        self.artifacts = Some(artifacts);
364        self
365    }
366
367    pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
368        self.memory = Some(memory);
369        self
370    }
371
372    pub fn with_run_config(mut self, config: RunConfig) -> Self {
373        self.run_config = config;
374        self
375    }
376
377    /// Set the request context from the server's auth middleware bridge.
378    ///
379    /// When set, `user_id()` returns `request_context.user_id` (overriding
380    /// the session-scoped identity), and `user_scopes()` returns
381    /// `request_context.scopes`. This is the explicit authenticated user
382    /// override — `RequestContext` remains separate from `ExecutionIdentity`
383    /// and `AdkIdentity` (it does not carry session or invocation IDs).
384    pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
385        self.request_context = Some(ctx);
386        self
387    }
388
389    /// Set the shared state for parallel agent coordination.
390    pub fn with_shared_state(mut self, shared: Arc<adk_core::SharedState>) -> Self {
391        self.shared_state = Some(shared);
392        self
393    }
394
395    /// Set the secret service for runtime secret retrieval.
396    ///
397    /// When configured, tools can call `ctx.get_secret("name")` to retrieve
398    /// secrets from the configured provider (e.g., AWS Secrets Manager,
399    /// Azure Key Vault, GCP Secret Manager).
400    pub fn with_secret_service(mut self, service: Arc<dyn SecretService>) -> Self {
401        self.secret_service = Some(service);
402        self
403    }
404
405    /// Get a reference to the mutable session.
406    /// This allows the Runner to apply state deltas when events are processed.
407    pub fn mutable_session(&self) -> &Arc<MutableSession> {
408        &self.session
409    }
410}
411
412#[async_trait]
413impl ReadonlyContext for InvocationContext {
414    fn invocation_id(&self) -> &str {
415        self.identity.invocation_id.as_ref()
416    }
417
418    fn agent_name(&self) -> &str {
419        self.agent.name()
420    }
421
422    fn user_id(&self) -> &str {
423        // Explicit authenticated user override: when a RequestContext is
424        // present (set via with_request_context from the auth middleware
425        // bridge), the authenticated user_id takes precedence over the
426        // session-scoped identity. This keeps auth binding explicit and
427        // ensures the runtime reflects the verified caller identity.
428        self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
429    }
430
431    fn app_name(&self) -> &str {
432        self.identity.adk.app_name.as_ref()
433    }
434
435    fn session_id(&self) -> &str {
436        self.identity.adk.session_id.as_ref()
437    }
438
439    fn branch(&self) -> &str {
440        &self.identity.branch
441    }
442
443    fn user_content(&self) -> &Content {
444        &self.user_content
445    }
446}
447
448#[async_trait]
449impl CallbackContext for InvocationContext {
450    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
451        self.artifacts.clone()
452    }
453
454    fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
455        self.shared_state.clone()
456    }
457}
458
459#[async_trait]
460impl InvocationContextTrait for InvocationContext {
461    fn agent(&self) -> Arc<dyn Agent> {
462        self.agent.clone()
463    }
464
465    fn memory(&self) -> Option<Arc<dyn Memory>> {
466        self.memory.clone()
467    }
468
469    fn session(&self) -> &dyn adk_core::Session {
470        self.session.as_ref()
471    }
472
473    fn run_config(&self) -> &RunConfig {
474        &self.run_config
475    }
476
477    fn end_invocation(&self) {
478        self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
479    }
480
481    fn ended(&self) -> bool {
482        self.ended.load(std::sync::atomic::Ordering::SeqCst)
483    }
484
485    fn user_scopes(&self) -> Vec<String> {
486        self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
487    }
488
489    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
490        self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
491            rc.metadata
492                .iter()
493                .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
494                .collect()
495        })
496    }
497
498    async fn get_secret(&self, name: &str) -> adk_core::Result<Option<String>> {
499        match &self.secret_service {
500            Some(service) => service.get_secret(name).await.map(Some),
501            None => Ok(None),
502        }
503    }
504}