1use adk_core::{
2 AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3 InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4 RequestContext, RunConfig, SecretService, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11pub struct MutableSession {
18 inner: Arc<dyn AdkSession>,
20 state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23 events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28 pub fn new(session: Arc<dyn AdkSession>) -> Self {
31 let initial_state = session.state().all();
33 let initial_events = session.events().all();
35
36 Self {
37 inner: session,
38 state: Arc::new(RwLock::new(initial_state)),
39 events: Arc::new(RwLock::new(initial_events)),
40 }
41 }
42
43 pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46 if delta.is_empty() {
47 return;
48 }
49
50 let Ok(mut state) = self.state.write() else {
51 tracing::error!("state RwLock poisoned in apply_state_delta — skipping delta");
52 return;
53 };
54 for (key, value) in delta {
55 if !key.starts_with("temp:") {
57 state.insert(key.clone(), value.clone());
58 }
59 }
60 }
61
62 pub fn append_event(&self, event: Event) {
65 let Ok(mut events) = self.events.write() else {
66 tracing::error!("events RwLock poisoned in append_event — event dropped");
67 return;
68 };
69 events.push(event);
70 }
71
72 pub fn events_snapshot(&self) -> Vec<Event> {
75 let Ok(events) = self.events.read() else {
76 tracing::error!("events RwLock poisoned in events_snapshot — returning empty");
77 return Vec::new();
78 };
79 events.clone()
80 }
81
82 pub fn replace_events(&self, new_events: Vec<Event>) {
84 let Ok(mut events) = self.events.write() else {
85 tracing::error!("events RwLock poisoned in replace_events — events unchanged");
86 return;
87 };
88 *events = new_events;
89 }
90
91 pub fn events_len(&self) -> usize {
93 let Ok(events) = self.events.read() else {
94 tracing::error!("events RwLock poisoned in events_len — returning 0");
95 return 0;
96 };
97 events.len()
98 }
99
100 pub fn conversation_history_for_agent_impl(
110 &self,
111 agent_name: Option<&str>,
112 ) -> Vec<adk_core::Content> {
113 let Ok(events) = self.events.read() else {
114 tracing::error!("events RwLock poisoned in conversation_history — returning empty");
115 return Vec::new();
116 };
117 let mut history = Vec::new();
118
119 let mut compaction_boundary = None;
123 for event in events.iter().rev() {
124 if let Some(ref compaction) = event.actions.compaction {
125 history.push(compaction.compacted_content.clone());
126 compaction_boundary = Some(compaction.end_timestamp);
127 break;
128 }
129 }
130
131 for event in events.iter() {
132 if event.actions.compaction.is_some() {
134 continue;
135 }
136
137 if let Some(boundary) = compaction_boundary {
139 if event.timestamp <= boundary {
140 continue;
141 }
142 }
143
144 if let Some(name) = agent_name {
150 if event.author != "user" && event.author != name {
151 continue;
152 }
153 }
154
155 if let Some(content) = &event.llm_response.content {
156 let mut mapped_content = content.clone();
157 mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
158 ("user", _) => "user",
159 (_, "function" | "tool") => content.role.as_str(),
160 _ => "model",
161 }
162 .to_string();
163 history.push(mapped_content);
164 }
165 }
166
167 history
168 }
169}
170
171impl adk_core::Session for MutableSession {
172 fn id(&self) -> &str {
173 self.inner.id()
174 }
175
176 fn app_name(&self) -> &str {
177 self.inner.app_name()
178 }
179
180 fn user_id(&self) -> &str {
181 self.inner.user_id()
182 }
183
184 fn state(&self) -> &dyn adk_core::State {
185 self
186 }
187
188 fn conversation_history(&self) -> Vec<adk_core::Content> {
189 self.conversation_history_for_agent_impl(None)
190 }
191
192 fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
193 self.conversation_history_for_agent_impl(Some(agent_name))
194 }
195}
196
197impl adk_core::State for MutableSession {
198 fn get(&self, key: &str) -> Option<serde_json::Value> {
199 let Ok(state) = self.state.read() else {
200 tracing::error!("state RwLock poisoned in State::get — returning None");
201 return None;
202 };
203 state.get(key).cloned()
204 }
205
206 fn set(&mut self, key: String, value: serde_json::Value) {
207 if let Err(msg) = adk_core::validate_state_key(&key) {
208 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
209 return;
210 }
211 let Ok(mut state) = self.state.write() else {
212 tracing::error!("state RwLock poisoned in State::set — value dropped");
213 return;
214 };
215 state.insert(key, value);
216 }
217
218 fn all(&self) -> HashMap<String, serde_json::Value> {
219 let Ok(state) = self.state.read() else {
220 tracing::error!("state RwLock poisoned in State::all — returning empty");
221 return HashMap::new();
222 };
223 state.clone()
224 }
225}
226
227pub struct InvocationContext {
233 identity: ExecutionIdentity,
234 agent: Arc<dyn Agent>,
235 user_content: Content,
236 artifacts: Option<Arc<dyn Artifacts>>,
237 memory: Option<Arc<dyn Memory>>,
238 run_config: RunConfig,
239 ended: Arc<AtomicBool>,
240 session: Arc<MutableSession>,
244 request_context: Option<RequestContext>,
248 shared_state: Option<Arc<adk_core::SharedState>>,
250 secret_service: Option<Arc<dyn SecretService>>,
253}
254
255impl InvocationContext {
256 pub fn new_typed(
258 invocation_id: String,
259 agent: Arc<dyn Agent>,
260 user_id: UserId,
261 app_name: AppName,
262 session_id: SessionId,
263 user_content: Content,
264 session: Arc<dyn AdkSession>,
265 ) -> adk_core::Result<Self> {
266 let identity = ExecutionIdentity {
267 adk: AdkIdentity { app_name, user_id, session_id },
268 invocation_id: InvocationId::try_from(invocation_id)?,
269 branch: String::new(),
270 agent_name: agent.name().to_string(),
271 };
272 Ok(Self {
273 identity,
274 agent,
275 user_content,
276 artifacts: None,
277 memory: None,
278 run_config: RunConfig::default(),
279 ended: Arc::new(AtomicBool::new(false)),
280 session: Arc::new(MutableSession::new(session)),
281 request_context: None,
282 shared_state: None,
283 secret_service: None,
284 })
285 }
286
287 pub fn new(
292 invocation_id: String,
293 agent: Arc<dyn Agent>,
294 user_id: String,
295 app_name: String,
296 session_id: String,
297 user_content: Content,
298 session: Arc<dyn AdkSession>,
299 ) -> adk_core::Result<Self> {
300 Self::new_typed(
301 invocation_id,
302 agent,
303 UserId::try_from(user_id)?,
304 AppName::try_from(app_name)?,
305 SessionId::try_from(session_id)?,
306 user_content,
307 session,
308 )
309 }
310
311 pub fn with_mutable_session_typed(
314 invocation_id: String,
315 agent: Arc<dyn Agent>,
316 user_id: UserId,
317 app_name: AppName,
318 session_id: SessionId,
319 user_content: Content,
320 session: Arc<MutableSession>,
321 ) -> adk_core::Result<Self> {
322 let identity = ExecutionIdentity {
323 adk: AdkIdentity { app_name, user_id, session_id },
324 invocation_id: InvocationId::try_from(invocation_id)?,
325 branch: String::new(),
326 agent_name: agent.name().to_string(),
327 };
328 Ok(Self {
329 identity,
330 agent,
331 user_content,
332 artifacts: None,
333 memory: None,
334 run_config: RunConfig::default(),
335 ended: Arc::new(AtomicBool::new(false)),
336 session,
337 request_context: None,
338 shared_state: None,
339 secret_service: None,
340 })
341 }
342
343 pub fn with_mutable_session(
347 invocation_id: String,
348 agent: Arc<dyn Agent>,
349 user_id: String,
350 app_name: String,
351 session_id: String,
352 user_content: Content,
353 session: Arc<MutableSession>,
354 ) -> adk_core::Result<Self> {
355 Self::with_mutable_session_typed(
356 invocation_id,
357 agent,
358 UserId::try_from(user_id)?,
359 AppName::try_from(app_name)?,
360 SessionId::try_from(session_id)?,
361 user_content,
362 session,
363 )
364 }
365
366 pub fn with_branch(mut self, branch: String) -> Self {
368 self.identity.branch = branch;
369 self
370 }
371
372 pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
374 self.artifacts = Some(artifacts);
375 self
376 }
377
378 pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
380 self.memory = Some(memory);
381 self
382 }
383
384 pub fn with_run_config(mut self, config: RunConfig) -> Self {
386 self.run_config = config;
387 self
388 }
389
390 pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
398 self.request_context = Some(ctx);
399 self
400 }
401
402 pub fn with_shared_state(mut self, shared: Arc<adk_core::SharedState>) -> Self {
404 self.shared_state = Some(shared);
405 self
406 }
407
408 pub fn with_secret_service(mut self, service: Arc<dyn SecretService>) -> Self {
414 self.secret_service = Some(service);
415 self
416 }
417
418 pub fn mutable_session(&self) -> &Arc<MutableSession> {
421 &self.session
422 }
423}
424
425#[async_trait]
426impl ReadonlyContext for InvocationContext {
427 fn invocation_id(&self) -> &str {
428 self.identity.invocation_id.as_ref()
429 }
430
431 fn agent_name(&self) -> &str {
432 self.agent.name()
433 }
434
435 fn user_id(&self) -> &str {
436 self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
442 }
443
444 fn app_name(&self) -> &str {
445 self.identity.adk.app_name.as_ref()
446 }
447
448 fn session_id(&self) -> &str {
449 self.identity.adk.session_id.as_ref()
450 }
451
452 fn branch(&self) -> &str {
453 &self.identity.branch
454 }
455
456 fn user_content(&self) -> &Content {
457 &self.user_content
458 }
459}
460
461#[async_trait]
462impl CallbackContext for InvocationContext {
463 fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
464 self.artifacts.clone()
465 }
466
467 fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
468 self.shared_state.clone()
469 }
470}
471
472#[async_trait]
473impl InvocationContextTrait for InvocationContext {
474 fn agent(&self) -> Arc<dyn Agent> {
475 self.agent.clone()
476 }
477
478 fn memory(&self) -> Option<Arc<dyn Memory>> {
479 self.memory.clone()
480 }
481
482 fn session(&self) -> &dyn adk_core::Session {
483 self.session.as_ref()
484 }
485
486 fn run_config(&self) -> &RunConfig {
487 &self.run_config
488 }
489
490 fn end_invocation(&self) {
491 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
492 }
493
494 fn ended(&self) -> bool {
495 self.ended.load(std::sync::atomic::Ordering::SeqCst)
496 }
497
498 fn user_scopes(&self) -> Vec<String> {
499 self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
500 }
501
502 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
503 self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
504 rc.metadata
505 .iter()
506 .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
507 .collect()
508 })
509 }
510
511 async fn get_secret(&self, name: &str) -> adk_core::Result<Option<String>> {
512 match &self.secret_service {
513 Some(service) => service.get_secret(name).await.map(Some),
514 None => Ok(None),
515 }
516 }
517}