1use adk_core::{
2 AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3 InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4 RequestContext, RunConfig, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11pub struct MutableSession {
18 inner: Arc<dyn AdkSession>,
20 state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23 events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28 pub fn new(session: Arc<dyn AdkSession>) -> Self {
31 let initial_state = session.state().all();
33 let initial_events = session.events().all();
35
36 Self {
37 inner: session,
38 state: Arc::new(RwLock::new(initial_state)),
39 events: Arc::new(RwLock::new(initial_events)),
40 }
41 }
42
43 pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46 if delta.is_empty() {
47 return;
48 }
49
50 let Ok(mut state) = self.state.write() else {
51 tracing::error!("state RwLock poisoned in apply_state_delta — skipping delta");
52 return;
53 };
54 for (key, value) in delta {
55 if !key.starts_with("temp:") {
57 state.insert(key.clone(), value.clone());
58 }
59 }
60 }
61
62 pub fn append_event(&self, event: Event) {
65 let Ok(mut events) = self.events.write() else {
66 tracing::error!("events RwLock poisoned in append_event — event dropped");
67 return;
68 };
69 events.push(event);
70 }
71
72 pub fn events_snapshot(&self) -> Vec<Event> {
75 let Ok(events) = self.events.read() else {
76 tracing::error!("events RwLock poisoned in events_snapshot — returning empty");
77 return Vec::new();
78 };
79 events.clone()
80 }
81
82 pub fn events_len(&self) -> usize {
84 let Ok(events) = self.events.read() else {
85 tracing::error!("events RwLock poisoned in events_len — returning 0");
86 return 0;
87 };
88 events.len()
89 }
90
91 pub fn conversation_history_for_agent_impl(
101 &self,
102 agent_name: Option<&str>,
103 ) -> Vec<adk_core::Content> {
104 let Ok(events) = self.events.read() else {
105 tracing::error!("events RwLock poisoned in conversation_history — returning empty");
106 return Vec::new();
107 };
108 let mut history = Vec::new();
109
110 let mut compaction_boundary = None;
114 for event in events.iter().rev() {
115 if let Some(ref compaction) = event.actions.compaction {
116 history.push(compaction.compacted_content.clone());
117 compaction_boundary = Some(compaction.end_timestamp);
118 break;
119 }
120 }
121
122 for event in events.iter() {
123 if event.actions.compaction.is_some() {
125 continue;
126 }
127
128 if let Some(boundary) = compaction_boundary {
130 if event.timestamp <= boundary {
131 continue;
132 }
133 }
134
135 if let Some(name) = agent_name {
141 if event.author != "user" && event.author != name {
142 continue;
143 }
144 }
145
146 if let Some(content) = &event.llm_response.content {
147 let mut mapped_content = content.clone();
148 mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
149 ("user", _) => "user",
150 (_, "function" | "tool") => content.role.as_str(),
151 _ => "model",
152 }
153 .to_string();
154 history.push(mapped_content);
155 }
156 }
157
158 history
159 }
160}
161
162impl adk_core::Session for MutableSession {
163 fn id(&self) -> &str {
164 self.inner.id()
165 }
166
167 fn app_name(&self) -> &str {
168 self.inner.app_name()
169 }
170
171 fn user_id(&self) -> &str {
172 self.inner.user_id()
173 }
174
175 fn state(&self) -> &dyn adk_core::State {
176 self
177 }
178
179 fn conversation_history(&self) -> Vec<adk_core::Content> {
180 self.conversation_history_for_agent_impl(None)
181 }
182
183 fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
184 self.conversation_history_for_agent_impl(Some(agent_name))
185 }
186}
187
188impl adk_core::State for MutableSession {
189 fn get(&self, key: &str) -> Option<serde_json::Value> {
190 let Ok(state) = self.state.read() else {
191 tracing::error!("state RwLock poisoned in State::get — returning None");
192 return None;
193 };
194 state.get(key).cloned()
195 }
196
197 fn set(&mut self, key: String, value: serde_json::Value) {
198 if let Err(msg) = adk_core::validate_state_key(&key) {
199 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
200 return;
201 }
202 let Ok(mut state) = self.state.write() else {
203 tracing::error!("state RwLock poisoned in State::set — value dropped");
204 return;
205 };
206 state.insert(key, value);
207 }
208
209 fn all(&self) -> HashMap<String, serde_json::Value> {
210 let Ok(state) = self.state.read() else {
211 tracing::error!("state RwLock poisoned in State::all — returning empty");
212 return HashMap::new();
213 };
214 state.clone()
215 }
216}
217
218pub struct InvocationContext {
219 identity: ExecutionIdentity,
220 agent: Arc<dyn Agent>,
221 user_content: Content,
222 artifacts: Option<Arc<dyn Artifacts>>,
223 memory: Option<Arc<dyn Memory>>,
224 run_config: RunConfig,
225 ended: Arc<AtomicBool>,
226 session: Arc<MutableSession>,
230 request_context: Option<RequestContext>,
234 shared_state: Option<Arc<adk_core::SharedState>>,
236}
237
238impl InvocationContext {
239 pub fn new_typed(
241 invocation_id: String,
242 agent: Arc<dyn Agent>,
243 user_id: UserId,
244 app_name: AppName,
245 session_id: SessionId,
246 user_content: Content,
247 session: Arc<dyn AdkSession>,
248 ) -> adk_core::Result<Self> {
249 let identity = ExecutionIdentity {
250 adk: AdkIdentity { app_name, user_id, session_id },
251 invocation_id: InvocationId::try_from(invocation_id)?,
252 branch: String::new(),
253 agent_name: agent.name().to_string(),
254 };
255 Ok(Self {
256 identity,
257 agent,
258 user_content,
259 artifacts: None,
260 memory: None,
261 run_config: RunConfig::default(),
262 ended: Arc::new(AtomicBool::new(false)),
263 session: Arc::new(MutableSession::new(session)),
264 request_context: None,
265 shared_state: None,
266 })
267 }
268
269 pub fn new(
270 invocation_id: String,
271 agent: Arc<dyn Agent>,
272 user_id: String,
273 app_name: String,
274 session_id: String,
275 user_content: Content,
276 session: Arc<dyn AdkSession>,
277 ) -> adk_core::Result<Self> {
278 Self::new_typed(
279 invocation_id,
280 agent,
281 UserId::try_from(user_id)?,
282 AppName::try_from(app_name)?,
283 SessionId::try_from(session_id)?,
284 user_content,
285 session,
286 )
287 }
288
289 pub fn with_mutable_session_typed(
292 invocation_id: String,
293 agent: Arc<dyn Agent>,
294 user_id: UserId,
295 app_name: AppName,
296 session_id: SessionId,
297 user_content: Content,
298 session: Arc<MutableSession>,
299 ) -> adk_core::Result<Self> {
300 let identity = ExecutionIdentity {
301 adk: AdkIdentity { app_name, user_id, session_id },
302 invocation_id: InvocationId::try_from(invocation_id)?,
303 branch: String::new(),
304 agent_name: agent.name().to_string(),
305 };
306 Ok(Self {
307 identity,
308 agent,
309 user_content,
310 artifacts: None,
311 memory: None,
312 run_config: RunConfig::default(),
313 ended: Arc::new(AtomicBool::new(false)),
314 session,
315 request_context: None,
316 shared_state: None,
317 })
318 }
319
320 pub fn with_mutable_session(
324 invocation_id: String,
325 agent: Arc<dyn Agent>,
326 user_id: String,
327 app_name: String,
328 session_id: String,
329 user_content: Content,
330 session: Arc<MutableSession>,
331 ) -> adk_core::Result<Self> {
332 Self::with_mutable_session_typed(
333 invocation_id,
334 agent,
335 UserId::try_from(user_id)?,
336 AppName::try_from(app_name)?,
337 SessionId::try_from(session_id)?,
338 user_content,
339 session,
340 )
341 }
342
343 pub fn with_branch(mut self, branch: String) -> Self {
344 self.identity.branch = branch;
345 self
346 }
347
348 pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
349 self.artifacts = Some(artifacts);
350 self
351 }
352
353 pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
354 self.memory = Some(memory);
355 self
356 }
357
358 pub fn with_run_config(mut self, config: RunConfig) -> Self {
359 self.run_config = config;
360 self
361 }
362
363 pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
371 self.request_context = Some(ctx);
372 self
373 }
374
375 pub fn with_shared_state(mut self, shared: Arc<adk_core::SharedState>) -> Self {
377 self.shared_state = Some(shared);
378 self
379 }
380
381 pub fn mutable_session(&self) -> &Arc<MutableSession> {
384 &self.session
385 }
386}
387
388#[async_trait]
389impl ReadonlyContext for InvocationContext {
390 fn invocation_id(&self) -> &str {
391 self.identity.invocation_id.as_ref()
392 }
393
394 fn agent_name(&self) -> &str {
395 self.agent.name()
396 }
397
398 fn user_id(&self) -> &str {
399 self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
405 }
406
407 fn app_name(&self) -> &str {
408 self.identity.adk.app_name.as_ref()
409 }
410
411 fn session_id(&self) -> &str {
412 self.identity.adk.session_id.as_ref()
413 }
414
415 fn branch(&self) -> &str {
416 &self.identity.branch
417 }
418
419 fn user_content(&self) -> &Content {
420 &self.user_content
421 }
422}
423
424#[async_trait]
425impl CallbackContext for InvocationContext {
426 fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
427 self.artifacts.clone()
428 }
429
430 fn shared_state(&self) -> Option<Arc<adk_core::SharedState>> {
431 self.shared_state.clone()
432 }
433}
434
435#[async_trait]
436impl InvocationContextTrait for InvocationContext {
437 fn agent(&self) -> Arc<dyn Agent> {
438 self.agent.clone()
439 }
440
441 fn memory(&self) -> Option<Arc<dyn Memory>> {
442 self.memory.clone()
443 }
444
445 fn session(&self) -> &dyn adk_core::Session {
446 self.session.as_ref()
447 }
448
449 fn run_config(&self) -> &RunConfig {
450 &self.run_config
451 }
452
453 fn end_invocation(&self) {
454 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
455 }
456
457 fn ended(&self) -> bool {
458 self.ended.load(std::sync::atomic::Ordering::SeqCst)
459 }
460
461 fn user_scopes(&self) -> Vec<String> {
462 self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
463 }
464
465 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
466 self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
467 rc.metadata
468 .iter()
469 .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
470 .collect()
471 })
472 }
473}