1use adk_core::{
2 Agent, Artifacts, CallbackContext, Content, Event, InvocationContext as InvocationContextTrait,
3 Memory, ReadonlyContext, RunConfig,
4};
5use adk_session::Session as AdkSession;
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock, atomic::AtomicBool};
9
10pub struct MutableSession {
17 inner: Arc<dyn AdkSession>,
19 state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
22 events: Arc<RwLock<Vec<Event>>>,
24}
25
26impl MutableSession {
27 pub fn new(session: Arc<dyn AdkSession>) -> Self {
30 let initial_state = session.state().all();
32 let initial_events = session.events().all();
34
35 Self {
36 inner: session,
37 state: Arc::new(RwLock::new(initial_state)),
38 events: Arc::new(RwLock::new(initial_events)),
39 }
40 }
41
42 pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
45 if delta.is_empty() {
46 return;
47 }
48
49 let mut state = self.state.write().unwrap();
50 for (key, value) in delta {
51 if !key.starts_with("temp:") {
53 state.insert(key.clone(), value.clone());
54 }
55 }
56 }
57
58 pub fn append_event(&self, event: Event) {
61 let mut events = self.events.write().unwrap();
62 events.push(event);
63 }
64
65 pub fn events_snapshot(&self) -> Vec<Event> {
68 let events = self.events.read().unwrap();
69 events.clone()
70 }
71}
72
73impl adk_core::Session for MutableSession {
74 fn id(&self) -> &str {
75 self.inner.id()
76 }
77
78 fn app_name(&self) -> &str {
79 self.inner.app_name()
80 }
81
82 fn user_id(&self) -> &str {
83 self.inner.user_id()
84 }
85
86 fn state(&self) -> &dyn adk_core::State {
87 unsafe { &*(self as *const Self as *const dyn adk_core::State) }
90 }
91
92 fn conversation_history(&self) -> Vec<adk_core::Content> {
93 let events = self.events.read().unwrap();
94 let mut history = Vec::new();
95
96 let mut compaction_boundary = None;
100 for event in events.iter().rev() {
101 if let Some(ref compaction) = event.actions.compaction {
102 history.push(compaction.compacted_content.clone());
104 compaction_boundary = Some(compaction.end_timestamp);
105 break;
106 }
107 }
108
109 for event in events.iter() {
110 if event.actions.compaction.is_some() {
112 continue;
113 }
114
115 if let Some(boundary) = compaction_boundary {
117 if event.timestamp <= boundary {
118 continue;
119 }
120 }
121
122 if let Some(content) = &event.llm_response.content {
123 let role = match event.author.as_str() {
124 "user" => "user".to_string(),
125 _ => "model".to_string(),
126 };
127
128 let mut mapped_content = content.clone();
129 mapped_content.role = role;
130 history.push(mapped_content);
131 }
132 }
133
134 history
135 }
136}
137
138impl adk_core::State for MutableSession {
139 fn get(&self, key: &str) -> Option<serde_json::Value> {
140 let state = self.state.read().unwrap();
141 state.get(key).cloned()
142 }
143
144 fn set(&mut self, key: String, value: serde_json::Value) {
145 let mut state = self.state.write().unwrap();
146 state.insert(key, value);
147 }
148
149 fn all(&self) -> HashMap<String, serde_json::Value> {
150 let state = self.state.read().unwrap();
151 state.clone()
152 }
153}
154
155pub struct InvocationContext {
156 invocation_id: String,
157 agent: Arc<dyn Agent>,
158 user_id: String,
159 app_name: String,
160 session_id: String,
161 branch: String,
162 user_content: Content,
163 artifacts: Option<Arc<dyn Artifacts>>,
164 memory: Option<Arc<dyn Memory>>,
165 run_config: RunConfig,
166 ended: Arc<AtomicBool>,
167 session: Arc<MutableSession>,
171}
172
173impl InvocationContext {
174 pub fn new(
175 invocation_id: String,
176 agent: Arc<dyn Agent>,
177 user_id: String,
178 app_name: String,
179 session_id: String,
180 user_content: Content,
181 session: Arc<dyn AdkSession>,
182 ) -> Self {
183 Self {
184 invocation_id,
185 agent,
186 user_id,
187 app_name,
188 session_id,
189 branch: String::new(),
190 user_content,
191 artifacts: None,
192 memory: None,
193 run_config: RunConfig::default(),
194 ended: Arc::new(AtomicBool::new(false)),
195 session: Arc::new(MutableSession::new(session)),
196 }
197 }
198
199 pub fn with_mutable_session(
203 invocation_id: String,
204 agent: Arc<dyn Agent>,
205 user_id: String,
206 app_name: String,
207 session_id: String,
208 user_content: Content,
209 session: Arc<MutableSession>,
210 ) -> Self {
211 Self {
212 invocation_id,
213 agent,
214 user_id,
215 app_name,
216 session_id,
217 branch: String::new(),
218 user_content,
219 artifacts: None,
220 memory: None,
221 run_config: RunConfig::default(),
222 ended: Arc::new(AtomicBool::new(false)),
223 session,
224 }
225 }
226
227 pub fn with_branch(mut self, branch: String) -> Self {
228 self.branch = branch;
229 self
230 }
231
232 pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
233 self.artifacts = Some(artifacts);
234 self
235 }
236
237 pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
238 self.memory = Some(memory);
239 self
240 }
241
242 pub fn with_run_config(mut self, config: RunConfig) -> Self {
243 self.run_config = config;
244 self
245 }
246
247 pub fn mutable_session(&self) -> &Arc<MutableSession> {
250 &self.session
251 }
252}
253
254#[async_trait]
255impl ReadonlyContext for InvocationContext {
256 fn invocation_id(&self) -> &str {
257 &self.invocation_id
258 }
259
260 fn agent_name(&self) -> &str {
261 self.agent.name()
262 }
263
264 fn user_id(&self) -> &str {
265 &self.user_id
266 }
267
268 fn app_name(&self) -> &str {
269 &self.app_name
270 }
271
272 fn session_id(&self) -> &str {
273 &self.session_id
274 }
275
276 fn branch(&self) -> &str {
277 &self.branch
278 }
279
280 fn user_content(&self) -> &Content {
281 &self.user_content
282 }
283}
284
285#[async_trait]
286impl CallbackContext for InvocationContext {
287 fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
288 self.artifacts.clone()
289 }
290}
291
292#[async_trait]
293impl InvocationContextTrait for InvocationContext {
294 fn agent(&self) -> Arc<dyn Agent> {
295 self.agent.clone()
296 }
297
298 fn memory(&self) -> Option<Arc<dyn Memory>> {
299 self.memory.clone()
300 }
301
302 fn session(&self) -> &dyn adk_core::Session {
303 self.session.as_ref()
304 }
305
306 fn run_config(&self) -> &RunConfig {
307 &self.run_config
308 }
309
310 fn end_invocation(&self) {
311 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
312 }
313
314 fn ended(&self) -> bool {
315 self.ended.load(std::sync::atomic::Ordering::SeqCst)
316 }
317}