1use adk_core::{
2 AdkIdentity, Agent, AppName, Artifacts, CallbackContext, Content, Event, ExecutionIdentity,
3 InvocationContext as InvocationContextTrait, InvocationId, Memory, ReadonlyContext,
4 RequestContext, RunConfig, SessionId, UserId,
5};
6use adk_session::Session as AdkSession;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock, atomic::AtomicBool};
10
11pub struct MutableSession {
18 inner: Arc<dyn AdkSession>,
20 state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
23 events: Arc<RwLock<Vec<Event>>>,
25}
26
27impl MutableSession {
28 pub fn new(session: Arc<dyn AdkSession>) -> Self {
31 let initial_state = session.state().all();
33 let initial_events = session.events().all();
35
36 Self {
37 inner: session,
38 state: Arc::new(RwLock::new(initial_state)),
39 events: Arc::new(RwLock::new(initial_events)),
40 }
41 }
42
43 pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
46 if delta.is_empty() {
47 return;
48 }
49
50 let mut state = self.state.write().unwrap();
51 for (key, value) in delta {
52 if !key.starts_with("temp:") {
54 state.insert(key.clone(), value.clone());
55 }
56 }
57 }
58
59 pub fn append_event(&self, event: Event) {
62 let mut events = self.events.write().unwrap();
63 events.push(event);
64 }
65
66 pub fn events_snapshot(&self) -> Vec<Event> {
69 let events = self.events.read().unwrap();
70 events.clone()
71 }
72
73 pub fn conversation_history_for_agent_impl(
83 &self,
84 agent_name: Option<&str>,
85 ) -> Vec<adk_core::Content> {
86 let events = self.events.read().unwrap();
87 let mut history = Vec::new();
88
89 let mut compaction_boundary = None;
93 for event in events.iter().rev() {
94 if let Some(ref compaction) = event.actions.compaction {
95 history.push(compaction.compacted_content.clone());
96 compaction_boundary = Some(compaction.end_timestamp);
97 break;
98 }
99 }
100
101 for event in events.iter() {
102 if event.actions.compaction.is_some() {
104 continue;
105 }
106
107 if let Some(boundary) = compaction_boundary {
109 if event.timestamp <= boundary {
110 continue;
111 }
112 }
113
114 if let Some(name) = agent_name {
120 if event.author != "user" && event.author != name {
121 continue;
122 }
123 }
124
125 if let Some(content) = &event.llm_response.content {
126 let mut mapped_content = content.clone();
127 mapped_content.role = match (event.author.as_str(), content.role.as_str()) {
128 ("user", _) => "user",
129 (_, "function" | "tool") => content.role.as_str(),
130 _ => "model",
131 }
132 .to_string();
133 history.push(mapped_content);
134 }
135 }
136
137 history
138 }
139}
140
141impl adk_core::Session for MutableSession {
142 fn id(&self) -> &str {
143 self.inner.id()
144 }
145
146 fn app_name(&self) -> &str {
147 self.inner.app_name()
148 }
149
150 fn user_id(&self) -> &str {
151 self.inner.user_id()
152 }
153
154 fn state(&self) -> &dyn adk_core::State {
155 unsafe { &*(self as *const Self as *const dyn adk_core::State) }
158 }
159
160 fn conversation_history(&self) -> Vec<adk_core::Content> {
161 self.conversation_history_for_agent_impl(None)
162 }
163
164 fn conversation_history_for_agent(&self, agent_name: &str) -> Vec<adk_core::Content> {
165 self.conversation_history_for_agent_impl(Some(agent_name))
166 }
167}
168
169impl adk_core::State for MutableSession {
170 fn get(&self, key: &str) -> Option<serde_json::Value> {
171 let state = self.state.read().unwrap();
172 state.get(key).cloned()
173 }
174
175 fn set(&mut self, key: String, value: serde_json::Value) {
176 let mut state = self.state.write().unwrap();
177 state.insert(key, value);
178 }
179
180 fn all(&self) -> HashMap<String, serde_json::Value> {
181 let state = self.state.read().unwrap();
182 state.clone()
183 }
184}
185
186pub struct InvocationContext {
187 identity: ExecutionIdentity,
188 agent: Arc<dyn Agent>,
189 user_content: Content,
190 artifacts: Option<Arc<dyn Artifacts>>,
191 memory: Option<Arc<dyn Memory>>,
192 run_config: RunConfig,
193 ended: Arc<AtomicBool>,
194 session: Arc<MutableSession>,
198 request_context: Option<RequestContext>,
202}
203
204impl InvocationContext {
205 pub fn new(
206 invocation_id: String,
207 agent: Arc<dyn Agent>,
208 user_id: String,
209 app_name: String,
210 session_id: String,
211 user_content: Content,
212 session: Arc<dyn AdkSession>,
213 ) -> Self {
214 let identity = ExecutionIdentity {
215 adk: AdkIdentity {
216 app_name: AppName::new_unchecked(app_name),
217 user_id: UserId::new_unchecked(user_id),
218 session_id: SessionId::new_unchecked(session_id),
219 },
220 invocation_id: InvocationId::new_unchecked(invocation_id),
221 branch: String::new(),
222 agent_name: agent.name().to_string(),
223 };
224 Self {
225 identity,
226 agent,
227 user_content,
228 artifacts: None,
229 memory: None,
230 run_config: RunConfig::default(),
231 ended: Arc::new(AtomicBool::new(false)),
232 session: Arc::new(MutableSession::new(session)),
233 request_context: None,
234 }
235 }
236
237 pub fn with_mutable_session(
241 invocation_id: String,
242 agent: Arc<dyn Agent>,
243 user_id: String,
244 app_name: String,
245 session_id: String,
246 user_content: Content,
247 session: Arc<MutableSession>,
248 ) -> Self {
249 let identity = ExecutionIdentity {
250 adk: AdkIdentity {
251 app_name: AppName::new_unchecked(app_name),
252 user_id: UserId::new_unchecked(user_id),
253 session_id: SessionId::new_unchecked(session_id),
254 },
255 invocation_id: InvocationId::new_unchecked(invocation_id),
256 branch: String::new(),
257 agent_name: agent.name().to_string(),
258 };
259 Self {
260 identity,
261 agent,
262 user_content,
263 artifacts: None,
264 memory: None,
265 run_config: RunConfig::default(),
266 ended: Arc::new(AtomicBool::new(false)),
267 session,
268 request_context: None,
269 }
270 }
271
272 pub fn with_branch(mut self, branch: String) -> Self {
273 self.identity.branch = branch;
274 self
275 }
276
277 pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
278 self.artifacts = Some(artifacts);
279 self
280 }
281
282 pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
283 self.memory = Some(memory);
284 self
285 }
286
287 pub fn with_run_config(mut self, config: RunConfig) -> Self {
288 self.run_config = config;
289 self
290 }
291
292 pub fn with_request_context(mut self, ctx: RequestContext) -> Self {
300 self.request_context = Some(ctx);
301 self
302 }
303
304 pub fn mutable_session(&self) -> &Arc<MutableSession> {
307 &self.session
308 }
309}
310
311#[async_trait]
312impl ReadonlyContext for InvocationContext {
313 fn invocation_id(&self) -> &str {
314 self.identity.invocation_id.as_ref()
315 }
316
317 fn agent_name(&self) -> &str {
318 self.agent.name()
319 }
320
321 fn user_id(&self) -> &str {
322 self.request_context.as_ref().map_or(self.identity.adk.user_id.as_ref(), |rc| &rc.user_id)
328 }
329
330 fn app_name(&self) -> &str {
331 self.identity.adk.app_name.as_ref()
332 }
333
334 fn session_id(&self) -> &str {
335 self.identity.adk.session_id.as_ref()
336 }
337
338 fn branch(&self) -> &str {
339 &self.identity.branch
340 }
341
342 fn user_content(&self) -> &Content {
343 &self.user_content
344 }
345}
346
347#[async_trait]
348impl CallbackContext for InvocationContext {
349 fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
350 self.artifacts.clone()
351 }
352}
353
354#[async_trait]
355impl InvocationContextTrait for InvocationContext {
356 fn agent(&self) -> Arc<dyn Agent> {
357 self.agent.clone()
358 }
359
360 fn memory(&self) -> Option<Arc<dyn Memory>> {
361 self.memory.clone()
362 }
363
364 fn session(&self) -> &dyn adk_core::Session {
365 self.session.as_ref()
366 }
367
368 fn run_config(&self) -> &RunConfig {
369 &self.run_config
370 }
371
372 fn end_invocation(&self) {
373 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
374 }
375
376 fn ended(&self) -> bool {
377 self.ended.load(std::sync::atomic::Ordering::SeqCst)
378 }
379
380 fn user_scopes(&self) -> Vec<String> {
381 self.request_context.as_ref().map_or_else(Vec::new, |rc| rc.scopes.clone())
382 }
383
384 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
385 self.request_context.as_ref().map_or_else(HashMap::new, |rc| {
386 rc.metadata
387 .iter()
388 .map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
389 .collect()
390 })
391 }
392}