1use adk_core::{
2 AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3 InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4 RequestContext, RunConfig, SecretService, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11pub struct MutableSession {
18 inner: Arc<dyn AdkSession>,
20 state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23 events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28 pub fn new(session: Arc<dyn AdkSession>) -> Self {
31 let initial_state = session.state().all();
33 let initial_events = session.events().all();
35
36 Self {
37 inner: session,
38 state: Arc::new(RwLock::new(initial_state)),
39 events: Arc::new(RwLock::new(initial_events)),
40 }
41 }
42
43 pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46 if delta.is_empty() {
47 return;
48 }
49
50 let Ok(mut state) = self.state.write() else {
51 tracing::error!("state RwLock poisoned in apply_state_delta — skipping delta");
52 return;
53 };
54 for (key, value) in delta {
55 if !key.starts_with("temp:") {
57 state.insert(key.clone(), value.clone());
58 }
59 }
60 }
61
62 pub fn append_event(&self, event: Event) {
65 let Ok(mut events) = self.events.write() else {
66 tracing::error!("events RwLock poisoned in append_event — event dropped");
67 return;
68 };
69 events.push(event);
70 }
71
72 pub fn events_snapshot(&self) -> Vec<Event> {
75 let Ok(events) = self.events.read() else {
76 tracing::error!("events RwLock poisoned in events_snapshot — returning empty");
77 return Vec::new();
78 };
79 events.clone()
80 }
81
82 pub fn replace_events(&self, new_events: Vec<Event>) {
84 let Ok(mut events) = self.events.write() else {
85 tracing::error!("events RwLock poisoned in replace_events — events unchanged");
86 return;
87 };
88 *events = new_events;
89 }
90
91 pub fn events_len(&self) -> usize {
93 let Ok(events) = self.events.read() else {
94 tracing::error!("events RwLock poisoned in events_len — returning 0");
95 return 0;
96 };
97 events.len()
98 }
99
100 pub fn conversation_history_for_agent_impl(
110 &self,
111 agent_name: Option<&str>,
112 ) -> Vec<adk_core::Content> {
113 let Ok(events) = self.events.read() else {
114 tracing::error!("events RwLock poisoned in conversation_history — returning empty");
115 return Vec::new();
116 };
117 let mut history = Vec::new();
118
119 let mut compaction_boundary = None;
123 for event in events.iter().rev() {
124 if let Some(ref compaction) = event.actions.compaction {
125 history.push(compaction.compacted_content.clone());
126 compaction_boundary = Some(compaction.end_timestamp);
127 break;
128 }
129 }
130
131 for event in events.iter() {
132 if event.actions.compaction.is_some() {
134 continue;
135 }
136
137 if let Some(boundary) = compaction_boundary {
139 if event.timestamp <= boundary {
140 continue;
141 }
142 }
143
144 if let Some(name) = agent_name {
150 if event.author != "user" && event.author != name {
151 continue;
152 }
153 }
154
155 if let Some(content) = &event.llm_response.content {
156 let mut mapped_content = content.clone();
157 mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
158 ("user", _) => "user",
159 (_, "function" | "tool") => content.role.as_str(),
160 _ => "model",
161 }
162 .to_string();
163 history.push(mapped_content);
164 }
165 }
166
167 history
168 }
169}
170
171impl adk_core::Session for MutableSession {
172 fn id(&self) -> &str {
173 self.inner.id()
174 }
175
176 fn app_name(&self) -> &str {
177 self.inner.app_name()
178 }
179
180 fn user_id(&self) -> &str {
181 self.inner.user_id()
182 }
183
184 fn state(&self) -> &dyn adk_core::State {
185 self
186 }
187
188 fn conversation_history(&self) -> Vec<adk_core::Content> {
189 self.conversation_history_for_agent_impl(None)
190 }
191
192 fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
193 self.conversation_history_for_agent_impl(Some(agent_name))
194 }
195}
196
197impl adk_core::State for MutableSession {
198 fn get(&self, key: &str) -> Option<serde_json::Value> {
199 let Ok(state) = self.state.read() else {
200 tracing::error!("state RwLock poisoned in State::get — returning None");
201 return None;
202 };
203 state.get(key).cloned()
204 }
205
206 fn set(&mut self, key: String, value: serde_json::Value) {
207 if let Err(msg) = adk_core::validate_state_key(&key) {
208 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
209 return;
210 }
211 let Ok(mut state) = self.state.write() else {
212 tracing::error!("state RwLock poisoned in State::set — value dropped");
213 return;
214 };
215 state.insert(key, value);
216 }
217
218 fn all(&self) -> HashMap<String, serde_json::Value> {
219 let Ok(state) = self.state.read() else {
220 tracing::error!("state RwLock poisoned in State::all — returning empty");
221 return HashMap::new();
222 };
223 state.clone()
224 }
225}
226
227pub struct InvocationContext {
228 identity: ExecutionIdentity,
229 agent: Arc<dyn Agent>,
230 user_content: Content,
231 artifacts: Option<Arc<dyn Artifacts>>,
232 memory: Option<Arc<dyn Memory>>,
233 run_config: RunConfig,
234 ended: Arc<AtomicBool>,
235 session: Arc<MutableSession>,
239 request_context: Option<RequestContext>,
243 shared_state: Option<Arc<adk_core::SharedState>>,
245 secret_service: Option<Arc<dyn SecretService>>,
248}
249
250impl InvocationContext {
251 pub fn new_typed(
253 invocation_id: String,
254 agent: Arc<dyn Agent>,
255 user_id: UserId,
256 app_name: AppName,
257 session_id: SessionId,
258 user_content: Content,
259 session: Arc<dyn AdkSession>,
260 ) -> adk_core::Result<Self> {
261 let identity = ExecutionIdentity {
262 adk: AdkIdentity { app_name, user_id, session_id },
263 invocation_id: InvocationId::try_from(invocation_id)?,
264 branch: String::new(),
265 agent_name: agent.name().to_string(),
266 };
267 Ok(Self {
268 identity,
269 agent,
270 user_content,
271 artifacts: None,
272 memory: None,
273 run_config: RunConfig::default(),
274 ended: Arc::new(AtomicBool::new(false)),
275 session: Arc::new(MutableSession::new(session)),
276 request_context: None,
277 shared_state: None,
278 secret_service: None,
279 })
280 }
281
282 pub fn new(
283 invocation_id: String,
284 agent: Arc<dyn Agent>,
285 user_id: String,
286 app_name: String,
287 session_id: String,
288 user_content: Content,
289 session: Arc<dyn AdkSession>,
290 ) -> adk_core::Result<Self> {
291 Self::new_typed(
292 invocation_id,
293 agent,
294 UserId::try_from(user_id)?,
295 AppName::try_from(app_name)?,
296 SessionId::try_from(session_id)?,
297 user_content,
298 session,
299 )
300 }
301
302 pub fn with_mutable_session_typed(
305 invocation_id: String,
306 agent: Arc<dyn Agent>,
307 user_id: UserId,
308 app_name: AppName,
309 session_id: SessionId,
310 user_content: Content,
311 session: Arc<MutableSession>,
312 ) -> adk_core::Result<Self> {
313 let identity = ExecutionIdentity {
314 adk: AdkIdentity { app_name, user_id, session_id },
315 invocation_id: InvocationId::try_from(invocation_id)?,
316 branch: String::new(),
317 agent_name: agent.name().to_string(),
318 };
319 Ok(Self {
320 identity,
321 agent,
322 user_content,
323 artifacts: None,
324 memory: None,
325 run_config: RunConfig::default(),
326 ended: Arc::new(AtomicBool::new(false)),
327 session,
328 request_context: None,
329 shared_state: None,
330 secret_service: None,
331 })
332 }
333
334 pub fn with_mutable_session(
338 invocation_id: String,
339 agent: Arc<dyn Agent>,
340 user_id: String,
341 app_name: String,
342 session_id: String,
343 user_content: Content,
344 session: Arc<MutableSession>,
345 ) -> adk_core::Result<Self> {
346 Self::with_mutable_session_typed(
347 invocation_id,
348 agent,
349 UserId::try_from(user_id)?,
350 AppName::try_from(app_name)?,
351 SessionId::try_from(session_id)?,
352 user_content,
353 session,
354 )
355 }
356
357 pub fn with_branch(mut self, branch: String) -> Self {
358 self.identity.branch = branch;
359 self
360 }
361
362 pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
363 self.artifacts = Some(artifacts);
364 self
365 }
366
367 pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
368 self.memory = Some(memory);
369 self
370 }
371
372 pub fn with_run_config(mut self, config: RunConfig) -> Self {
373 self.run_config = config;
374 self
375 }
376
377 pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
385 self.request_context = Some(ctx);
386 self
387 }
388
389 pub fn with_shared_state(mut self, shared: Arc<adk_core::SharedState>) -> Self {
391 self.shared_state = Some(shared);
392 self
393 }
394
395 pub fn with_secret_service(mut self, service: Arc<dyn SecretService>) -> Self {
401 self.secret_service = Some(service);
402 self
403 }
404
405 pub fn mutable_session(&self) -> &Arc<MutableSession> {
408 &self.session
409 }
410}
411
412#[async_trait]
413impl ReadonlyContext for InvocationContext {
414 fn invocation_id(&self) -> &str {
415 self.identity.invocation_id.as_ref()
416 }
417
418 fn agent_name(&self) -> &str {
419 self.agent.name()
420 }
421
422 fn user_id(&self) -> &str {
423 self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
429 }
430
431 fn app_name(&self) -> &str {
432 self.identity.adk.app_name.as_ref()
433 }
434
435 fn session_id(&self) -> &str {
436 self.identity.adk.session_id.as_ref()
437 }
438
439 fn branch(&self) -> &str {
440 &self.identity.branch
441 }
442
443 fn user_content(&self) -> &Content {
444 &self.user_content
445 }
446}
447
448#[async_trait]
449impl CallbackContext for InvocationContext {
450 fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
451 self.artifacts.clone()
452 }
453
454 fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
455 self.shared_state.clone()
456 }
457}
458
459#[async_trait]
460impl InvocationContextTrait for InvocationContext {
461 fn agent(&self) -> Arc<dyn Agent> {
462 self.agent.clone()
463 }
464
465 fn memory(&self) -> Option<Arc<dyn Memory>> {
466 self.memory.clone()
467 }
468
469 fn session(&self) -> &dyn adk_core::Session {
470 self.session.as_ref()
471 }
472
473 fn run_config(&self) -> &RunConfig {
474 &self.run_config
475 }
476
477 fn end_invocation(&self) {
478 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
479 }
480
481 fn ended(&self) -> bool {
482 self.ended.load(std::sync::atomic::Ordering::SeqCst)
483 }
484
485 fn user_scopes(&self) -> Vec<String> {
486 self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
487 }
488
489 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
490 self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
491 rc.metadata
492 .iter()
493 .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
494 .collect()
495 })
496 }
497
498 async fn get_secret(&self, name: &str) -> adk_core::Result<Option<String>> {
499 match &self.secret_service {
500 Some(service) => service.get_secret(name).await.map(Some),
501 None => Ok(None),
502 }
503 }
504}