1use adk_core::{
2 AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3 InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4 RequestContext, RunConfig, SecretService, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11pub struct MutableSession {
18 inner: Arc<dyn AdkSession>,
20 state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23 events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28 pub fn new(session: Arc<dyn AdkSession>) -> Self {
31 let initial_state = session.state().all();
33 let initial_events = session.events().all();
35
36 Self {
37 inner: session,
38 state: Arc::new(RwLock::new(initial_state)),
39 events: Arc::new(RwLock::new(initial_events)),
40 }
41 }
42
43 pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46 if delta.is_empty() {
47 return;
48 }
49
50 let Ok(mut state) = self.state.write() else {
51 tracing::error!("state RwLock poisoned in apply_state_delta — skipping delta");
52 return;
53 };
54 for (key, value) in delta {
55 if !key.starts_with("temp:") {
57 state.insert(key.clone(), value.clone());
58 }
59 }
60 }
61
62 pub fn append_event(&self, event: Event) {
65 let Ok(mut events) = self.events.write() else {
66 tracing::error!("events RwLock poisoned in append_event — event dropped");
67 return;
68 };
69 events.push(event);
70 }
71
72 pub fn events_snapshot(&self) -> Vec<Event> {
75 let Ok(events) = self.events.read() else {
76 tracing::error!("events RwLock poisoned in events_snapshot — returning empty");
77 return Vec::new();
78 };
79 events.clone()
80 }
81
82 pub fn replace_events(&self, new_events: Vec<Event>) {
84 let Ok(mut events) = self.events.write() else {
85 tracing::error!("events RwLock poisoned in replace_events — events unchanged");
86 return;
87 };
88 *events = new_events;
89 }
90
91 pub fn events_len(&self) -> usize {
93 let Ok(events) = self.events.read() else {
94 tracing::error!("events RwLock poisoned in events_len — returning 0");
95 return 0;
96 };
97 events.len()
98 }
99
100 pub fn conversation_history_for_agent_impl(
110 &self,
111 agent_name: Option<&str>,
112 ) -> Vec<adk_core::Content> {
113 let Ok(events) = self.events.read() else {
114 tracing::error!("events RwLock poisoned in conversation_history — returning empty");
115 return Vec::new();
116 };
117 let mut history = Vec::new();
118
119 let mut compaction_boundary = None;
123 for event in events.iter().rev() {
124 if let Some(ref compaction) = event.actions.compaction {
125 history.push(compaction.compacted_content.clone());
126 compaction_boundary = Some(compaction.end_timestamp);
127 break;
128 }
129 }
130
131 for event in events.iter() {
132 if event.actions.compaction.is_some() {
134 continue;
135 }
136
137 if let Some(boundary) = compaction_boundary
139 && event.timestamp <= boundary
140 {
141 continue;
142 }
143
144 if let Some(name) = agent_name
150 && event.author != "user"
151 && event.author != name
152 {
153 continue;
154 }
155
156 if let Some(content) = &event.llm_response.content {
157 let mut mapped_content = content.clone();
158 mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
159 ("user", _) => "user",
160 (_, "function" | "tool") => content.role.as_str(),
161 _ => "model",
162 }
163 .to_string();
164 history.push(mapped_content);
165 }
166 }
167
168 history
169 }
170}
171
172impl adk_core::Session for MutableSession {
173 fn id(&self) -> &str {
174 self.inner.id()
175 }
176
177 fn app_name(&self) -> &str {
178 self.inner.app_name()
179 }
180
181 fn user_id(&self) -> &str {
182 self.inner.user_id()
183 }
184
185 fn state(&self) -> &dyn adk_core::State {
186 self
187 }
188
189 fn conversation_history(&self) -> Vec<adk_core::Content> {
190 self.conversation_history_for_agent_impl(None)
191 }
192
193 fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
194 self.conversation_history_for_agent_impl(Some(agent_name))
195 }
196}
197
198impl adk_core::State for MutableSession {
199 fn get(&self, key: &str) -> Option<serde_json::Value> {
200 let Ok(state) = self.state.read() else {
201 tracing::error!("state RwLock poisoned in State::get — returning None");
202 return None;
203 };
204 state.get(key).cloned()
205 }
206
207 fn set(&mut self, key: String, value: serde_json::Value) {
208 if let Err(msg) = adk_core::validate_state_key(&key) {
209 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
210 return;
211 }
212 let Ok(mut state) = self.state.write() else {
213 tracing::error!("state RwLock poisoned in State::set — value dropped");
214 return;
215 };
216 state.insert(key, value);
217 }
218
219 fn all(&self) -> HashMap<String, serde_json::Value> {
220 let Ok(state) = self.state.read() else {
221 tracing::error!("state RwLock poisoned in State::all — returning empty");
222 return HashMap::new();
223 };
224 state.clone()
225 }
226}
227
228pub struct InvocationContext {
234 identity: ExecutionIdentity,
235 agent: Arc<dyn Agent>,
236 user_content: Content,
237 artifacts: Option<Arc<dyn Artifacts>>,
238 memory: Option<Arc<dyn Memory>>,
239 run_config: RunConfig,
240 ended: Arc<AtomicBool>,
241 session: Arc<MutableSession>,
245 request_context: Option<RequestContext>,
249 shared_state: Option<Arc<adk_core::SharedState>>,
251 secret_service: Option<Arc<dyn SecretService>>,
254}
255
256impl InvocationContext {
257 pub fn new_typed(
259 invocation_id: String,
260 agent: Arc<dyn Agent>,
261 user_id: UserId,
262 app_name: AppName,
263 session_id: SessionId,
264 user_content: Content,
265 session: Arc<dyn AdkSession>,
266 ) -> adk_core::Result<Self> {
267 let identity = ExecutionIdentity {
268 adk: AdkIdentity { app_name, user_id, session_id },
269 invocation_id: InvocationId::try_from(invocation_id)?,
270 branch: String::new(),
271 agent_name: agent.name().to_string(),
272 };
273 Ok(Self {
274 identity,
275 agent,
276 user_content,
277 artifacts: None,
278 memory: None,
279 run_config: RunConfig::default(),
280 ended: Arc::new(AtomicBool::new(false)),
281 session: Arc::new(MutableSession::new(session)),
282 request_context: None,
283 shared_state: None,
284 secret_service: None,
285 })
286 }
287
288 pub fn new(
293 invocation_id: String,
294 agent: Arc<dyn Agent>,
295 user_id: String,
296 app_name: String,
297 session_id: String,
298 user_content: Content,
299 session: Arc<dyn AdkSession>,
300 ) -> adk_core::Result<Self> {
301 Self::new_typed(
302 invocation_id,
303 agent,
304 UserId::try_from(user_id)?,
305 AppName::try_from(app_name)?,
306 SessionId::try_from(session_id)?,
307 user_content,
308 session,
309 )
310 }
311
312 pub fn with_mutable_session_typed(
315 invocation_id: String,
316 agent: Arc<dyn Agent>,
317 user_id: UserId,
318 app_name: AppName,
319 session_id: SessionId,
320 user_content: Content,
321 session: Arc<MutableSession>,
322 ) -> adk_core::Result<Self> {
323 let identity = ExecutionIdentity {
324 adk: AdkIdentity { app_name, user_id, session_id },
325 invocation_id: InvocationId::try_from(invocation_id)?,
326 branch: String::new(),
327 agent_name: agent.name().to_string(),
328 };
329 Ok(Self {
330 identity,
331 agent,
332 user_content,
333 artifacts: None,
334 memory: None,
335 run_config: RunConfig::default(),
336 ended: Arc::new(AtomicBool::new(false)),
337 session,
338 request_context: None,
339 shared_state: None,
340 secret_service: None,
341 })
342 }
343
344 pub fn with_mutable_session(
348 invocation_id: String,
349 agent: Arc<dyn Agent>,
350 user_id: String,
351 app_name: String,
352 session_id: String,
353 user_content: Content,
354 session: Arc<MutableSession>,
355 ) -> adk_core::Result<Self> {
356 Self::with_mutable_session_typed(
357 invocation_id,
358 agent,
359 UserId::try_from(user_id)?,
360 AppName::try_from(app_name)?,
361 SessionId::try_from(session_id)?,
362 user_content,
363 session,
364 )
365 }
366
367 pub fn with_branch(mut self, branch: String) -> Self {
369 self.identity.branch = branch;
370 self
371 }
372
373 pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
375 self.artifacts = Some(artifacts);
376 self
377 }
378
379 pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
381 self.memory = Some(memory);
382 self
383 }
384
385 pub fn with_run_config(mut self, config: RunConfig) -> Self {
387 self.run_config = config;
388 self
389 }
390
391 pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
399 self.request_context = Some(ctx);
400 self
401 }
402
403 pub fn with_shared_state(mut self, shared: Arc<adk_core::SharedState>) -> Self {
405 self.shared_state = Some(shared);
406 self
407 }
408
409 pub fn with_secret_service(mut self, service: Arc<dyn SecretService>) -> Self {
415 self.secret_service = Some(service);
416 self
417 }
418
419 pub fn mutable_session(&self) -> &Arc<MutableSession> {
422 &self.session
423 }
424}
425
426#[async_trait]
427impl ReadonlyContext for InvocationContext {
428 fn invocation_id(&self) -> &str {
429 self.identity.invocation_id.as_ref()
430 }
431
432 fn agent_name(&self) -> &str {
433 self.agent.name()
434 }
435
436 fn user_id(&self) -> &str {
437 self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
443 }
444
445 fn app_name(&self) -> &str {
446 self.identity.adk.app_name.as_ref()
447 }
448
449 fn session_id(&self) -> &str {
450 self.identity.adk.session_id.as_ref()
451 }
452
453 fn branch(&self) -> &str {
454 &self.identity.branch
455 }
456
457 fn user_content(&self) -> &Content {
458 &self.user_content
459 }
460}
461
462#[async_trait]
463impl CallbackContext for InvocationContext {
464 fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
465 self.artifacts.clone()
466 }
467
468 fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
469 self.shared_state.clone()
470 }
471}
472
473#[async_trait]
474impl InvocationContextTrait for InvocationContext {
475 fn agent(&self) -> Arc<dyn Agent> {
476 self.agent.clone()
477 }
478
479 fn memory(&self) -> Option<Arc<dyn Memory>> {
480 self.memory.clone()
481 }
482
483 fn session(&self) -> &dyn adk_core::Session {
484 self.session.as_ref()
485 }
486
487 fn run_config(&self) -> &RunConfig {
488 &self.run_config
489 }
490
491 fn end_invocation(&self) {
492 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
493 }
494
495 fn ended(&self) -> bool {
496 self.ended.load(std::sync::atomic::Ordering::SeqCst)
497 }
498
499 fn user_scopes(&self) -> Vec<String> {
500 self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
501 }
502
503 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
504 self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
505 rc.metadata
506 .iter()
507 .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
508 .collect()
509 })
510 }
511
512 async fn get_secret(&self, name: &str) -> adk_core::Result<Option<String>> {
513 match &self.secret_service {
514 Some(service) => service.get_secret(name).await.map(Some),
515 None => Ok(None),
516 }
517 }
518}