1use adk_core::{
2 AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3 InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4 RequestContext, RunConfig, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11pub struct MutableSession {
18 inner: Arc<dyn AdkSession>,
20 state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23 events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28 pub fn new(session: Arc<dyn AdkSession>) -> Self {
31 let initial_state = session.state().all();
33 let initial_events = session.events().all();
35
36 Self {
37 inner: session,
38 state: Arc::new(RwLock::new(initial_state)),
39 events: Arc::new(RwLock::new(initial_events)),
40 }
41 }
42
43 pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46 if delta.is_empty() {
47 return;
48 }
49
50 let Ok(mut state) = self.state.write() else {
51 tracing::error!("state RwLock poisoned in apply_state_delta — skipping delta");
52 return;
53 };
54 for (key, value) in delta {
55 if !key.starts_with("temp:") {
57 state.insert(key.clone(), value.clone());
58 }
59 }
60 }
61
62 pub fn append_event(&self, event: Event) {
65 let Ok(mut events) = self.events.write() else {
66 tracing::error!("events RwLock poisoned in append_event — event dropped");
67 return;
68 };
69 events.push(event);
70 }
71
72 pub fn events_snapshot(&self) -> Vec<Event> {
75 let Ok(events) = self.events.read() else {
76 tracing::error!("events RwLock poisoned in events_snapshot — returning empty");
77 return Vec::new();
78 };
79 events.clone()
80 }
81
82 pub fn events_len(&self) -> usize {
84 let Ok(events) = self.events.read() else {
85 tracing::error!("events RwLock poisoned in events_len — returning 0");
86 return 0;
87 };
88 events.len()
89 }
90
91 pub fn conversation_history_for_agent_impl(
101 &self,
102 agent_name: Option<&str>,
103 ) -> Vec<adk_core::Content> {
104 let Ok(events) = self.events.read() else {
105 tracing::error!("events RwLock poisoned in conversation_history — returning empty");
106 return Vec::new();
107 };
108 let mut history = Vec::new();
109
110 let mut compaction_boundary = None;
114 for event in events.iter().rev() {
115 if let Some(ref compaction) = event.actions.compaction {
116 history.push(compaction.compacted_content.clone());
117 compaction_boundary = Some(compaction.end_timestamp);
118 break;
119 }
120 }
121
122 for event in events.iter() {
123 if event.actions.compaction.is_some() {
125 continue;
126 }
127
128 if let Some(boundary) = compaction_boundary {
130 if event.timestamp <= boundary {
131 continue;
132 }
133 }
134
135 if let Some(name) = agent_name {
141 if event.author != "user" && event.author != name {
142 continue;
143 }
144 }
145
146 if let Some(content) = &event.llm_response.content {
147 let mut mapped_content = content.clone();
148 mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
149 ("user", _) => "user",
150 (_, "function" | "tool") => content.role.as_str(),
151 _ => "model",
152 }
153 .to_string();
154 history.push(mapped_content);
155 }
156 }
157
158 history
159 }
160}
161
162impl adk_core::Session for MutableSession {
163 fn id(&self) -> &str {
164 self.inner.id()
165 }
166
167 fn app_name(&self) -> &str {
168 self.inner.app_name()
169 }
170
171 fn user_id(&self) -> &str {
172 self.inner.user_id()
173 }
174
175 fn state(&self) -> &dyn adk_core::State {
176 self
177 }
178
179 fn conversation_history(&self) -> Vec<adk_core::Content> {
180 self.conversation_history_for_agent_impl(None)
181 }
182
183 fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
184 self.conversation_history_for_agent_impl(Some(agent_name))
185 }
186}
187
188impl adk_core::State for MutableSession {
189 fn get(&self, key: &str) -> Option<serde_json::Value> {
190 let Ok(state) = self.state.read() else {
191 tracing::error!("state RwLock poisoned in State::get — returning None");
192 return None;
193 };
194 state.get(key).cloned()
195 }
196
197 fn set(&mut self, key: String, value: serde_json::Value) {
198 if let Err(msg) = adk_core::validate_state_key(&key) {
199 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
200 return;
201 }
202 let Ok(mut state) = self.state.write() else {
203 tracing::error!("state RwLock poisoned in State::set — value dropped");
204 return;
205 };
206 state.insert(key, value);
207 }
208
209 fn all(&self) -> HashMap<String, serde_json::Value> {
210 let Ok(state) = self.state.read() else {
211 tracing::error!("state RwLock poisoned in State::all — returning empty");
212 return HashMap::new();
213 };
214 state.clone()
215 }
216}
217
218pub struct InvocationContext {
219 identity: ExecutionIdentity,
220 agent: Arc<dyn Agent>,
221 user_content: Content,
222 artifacts: Option<Arc<dyn Artifacts>>,
223 memory: Option<Arc<dyn Memory>>,
224 run_config: RunConfig,
225 ended: Arc<AtomicBool>,
226 session: Arc<MutableSession>,
230 request_context: Option<RequestContext>,
234}
235
236impl InvocationContext {
237 pub fn new_typed(
239 invocation_id: String,
240 agent: Arc<dyn Agent>,
241 user_id: UserId,
242 app_name: AppName,
243 session_id: SessionId,
244 user_content: Content,
245 session: Arc<dyn AdkSession>,
246 ) -> adk_core::Result<Self> {
247 let identity = ExecutionIdentity {
248 adk: AdkIdentity { app_name, user_id, session_id },
249 invocation_id: InvocationId::try_from(invocation_id)?,
250 branch: String::new(),
251 agent_name: agent.name().to_string(),
252 };
253 Ok(Self {
254 identity,
255 agent,
256 user_content,
257 artifacts: None,
258 memory: None,
259 run_config: RunConfig::default(),
260 ended: Arc::new(AtomicBool::new(false)),
261 session: Arc::new(MutableSession::new(session)),
262 request_context: None,
263 })
264 }
265
266 pub fn new(
267 invocation_id: String,
268 agent: Arc<dyn Agent>,
269 user_id: String,
270 app_name: String,
271 session_id: String,
272 user_content: Content,
273 session: Arc<dyn AdkSession>,
274 ) -> adk_core::Result<Self> {
275 Self::new_typed(
276 invocation_id,
277 agent,
278 UserId::try_from(user_id)?,
279 AppName::try_from(app_name)?,
280 SessionId::try_from(session_id)?,
281 user_content,
282 session,
283 )
284 }
285
286 pub fn with_mutable_session_typed(
289 invocation_id: String,
290 agent: Arc<dyn Agent>,
291 user_id: UserId,
292 app_name: AppName,
293 session_id: SessionId,
294 user_content: Content,
295 session: Arc<MutableSession>,
296 ) -> adk_core::Result<Self> {
297 let identity = ExecutionIdentity {
298 adk: AdkIdentity { app_name, user_id, session_id },
299 invocation_id: InvocationId::try_from(invocation_id)?,
300 branch: String::new(),
301 agent_name: agent.name().to_string(),
302 };
303 Ok(Self {
304 identity,
305 agent,
306 user_content,
307 artifacts: None,
308 memory: None,
309 run_config: RunConfig::default(),
310 ended: Arc::new(AtomicBool::new(false)),
311 session,
312 request_context: None,
313 })
314 }
315
316 pub fn with_mutable_session(
320 invocation_id: String,
321 agent: Arc<dyn Agent>,
322 user_id: String,
323 app_name: String,
324 session_id: String,
325 user_content: Content,
326 session: Arc<MutableSession>,
327 ) -> adk_core::Result<Self> {
328 Self::with_mutable_session_typed(
329 invocation_id,
330 agent,
331 UserId::try_from(user_id)?,
332 AppName::try_from(app_name)?,
333 SessionId::try_from(session_id)?,
334 user_content,
335 session,
336 )
337 }
338
339 pub fn with_branch(mut self, branch: String) -> Self {
340 self.identity.branch = branch;
341 self
342 }
343
344 pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
345 self.artifacts = Some(artifacts);
346 self
347 }
348
349 pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
350 self.memory = Some(memory);
351 self
352 }
353
354 pub fn with_run_config(mut self, config: RunConfig) -> Self {
355 self.run_config = config;
356 self
357 }
358
359 pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
367 self.request_context = Some(ctx);
368 self
369 }
370
371 pub fn mutable_session(&self) -> &Arc<MutableSession> {
374 &self.session
375 }
376}
377
378#[async_trait]
379impl ReadonlyContext for InvocationContext {
380 fn invocation_id(&self) -> &str {
381 self.identity.invocation_id.as_ref()
382 }
383
384 fn agent_name(&self) -> &str {
385 self.agent.name()
386 }
387
388 fn user_id(&self) -> &str {
389 self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
395 }
396
397 fn app_name(&self) -> &str {
398 self.identity.adk.app_name.as_ref()
399 }
400
401 fn session_id(&self) -> &str {
402 self.identity.adk.session_id.as_ref()
403 }
404
405 fn branch(&self) -> &str {
406 &self.identity.branch
407 }
408
409 fn user_content(&self) -> &Content {
410 &self.user_content
411 }
412}
413
414#[async_trait]
415impl CallbackContext for InvocationContext {
416 fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
417 self.artifacts.clone()
418 }
419}
420
421#[async_trait]
422impl InvocationContextTrait for InvocationContext {
423 fn agent(&self) -> Arc<dyn Agent> {
424 self.agent.clone()
425 }
426
427 fn memory(&self) -> Option<Arc<dyn Memory>> {
428 self.memory.clone()
429 }
430
431 fn session(&self) -> &dyn adk_core::Session {
432 self.session.as_ref()
433 }
434
435 fn run_config(&self) -> &RunConfig {
436 &self.run_config
437 }
438
439 fn end_invocation(&self) {
440 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
441 }
442
443 fn ended(&self) -> bool {
444 self.ended.load(std::sync::atomic::Ordering::SeqCst)
445 }
446
447 fn user_scopes(&self) -> Vec<String> {
448 self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
449 }
450
451 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
452 self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
453 rc.metadata
454 .iter()
455 .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
456 .collect()
457 })
458 }
459}