1use adk_core::{
2 Agent, Artifacts, CallbackContext, Content, Event, InvocationContext as InvocationContextTrait,
3 Memory, ReadonlyContext, RunConfig,
4};
5use adk_session::Session as AdkSession;
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::{atomic::AtomicBool, Arc, RwLock};
9
10pub struct MutableSession {
17 inner: Arc<dyn AdkSession>,
19 state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
22 events: Arc<RwLock<Vec<Event>>>,
24}
25
26impl MutableSession {
27 pub fn new(session: Arc<dyn AdkSession>) -> Self {
30 let initial_state = session.state().all();
32 let initial_events = session.events().all();
34
35 Self {
36 inner: session,
37 state: Arc::new(RwLock::new(initial_state)),
38 events: Arc::new(RwLock::new(initial_events)),
39 }
40 }
41
42 pub fn apply_state_delta(&self, delta: &HashMap<String, serde_json::Value>) {
45 if delta.is_empty() {
46 return;
47 }
48
49 let mut state = self.state.write().unwrap();
50 for (key, value) in delta {
51 if !key.starts_with("temp:") {
53 state.insert(key.clone(), value.clone());
54 }
55 }
56 }
57
58 pub fn append_event(&self, event: Event) {
61 let mut events = self.events.write().unwrap();
62 events.push(event);
63 }
64}
65
66impl adk_core::Session for MutableSession {
67 fn id(&self) -> &str {
68 self.inner.id()
69 }
70
71 fn app_name(&self) -> &str {
72 self.inner.app_name()
73 }
74
75 fn user_id(&self) -> &str {
76 self.inner.user_id()
77 }
78
79 fn state(&self) -> &dyn adk_core::State {
80 unsafe { &*(self as *const Self as *const dyn adk_core::State) }
83 }
84
85 fn conversation_history(&self) -> Vec<adk_core::Content> {
86 let events = self.events.read().unwrap();
87 let mut history = Vec::new();
88
89 for event in events.iter() {
90 if let Some(content) = &event.llm_response.content {
91 let role = match event.author.as_str() {
92 "user" => "user".to_string(),
93 _ => "model".to_string(),
94 };
95
96 let mut mapped_content = content.clone();
97 mapped_content.role = role;
98 history.push(mapped_content);
99 }
100 }
101
102 history
103 }
104}
105
106impl adk_core::State for MutableSession {
107 fn get(&self, key: &str) -> Option<serde_json::Value> {
108 let state = self.state.read().unwrap();
109 state.get(key).cloned()
110 }
111
112 fn set(&mut self, key: String, value: serde_json::Value) {
113 let mut state = self.state.write().unwrap();
114 state.insert(key, value);
115 }
116
117 fn all(&self) -> HashMap<String, serde_json::Value> {
118 let state = self.state.read().unwrap();
119 state.clone()
120 }
121}
122
123pub struct InvocationContext {
124 invocation_id: String,
125 agent: Arc<dyn Agent>,
126 user_id: String,
127 app_name: String,
128 session_id: String,
129 branch: String,
130 user_content: Content,
131 artifacts: Option<Arc<dyn Artifacts>>,
132 memory: Option<Arc<dyn Memory>>,
133 run_config: RunConfig,
134 ended: Arc<AtomicBool>,
135 session: Arc<MutableSession>,
139}
140
141impl InvocationContext {
142 pub fn new(
143 invocation_id: String,
144 agent: Arc<dyn Agent>,
145 user_id: String,
146 app_name: String,
147 session_id: String,
148 user_content: Content,
149 session: Arc<dyn AdkSession>,
150 ) -> Self {
151 Self {
152 invocation_id,
153 agent,
154 user_id,
155 app_name,
156 session_id,
157 branch: String::new(),
158 user_content,
159 artifacts: None,
160 memory: None,
161 run_config: RunConfig::default(),
162 ended: Arc::new(AtomicBool::new(false)),
163 session: Arc::new(MutableSession::new(session)),
164 }
165 }
166
167 pub fn with_mutable_session(
171 invocation_id: String,
172 agent: Arc<dyn Agent>,
173 user_id: String,
174 app_name: String,
175 session_id: String,
176 user_content: Content,
177 session: Arc<MutableSession>,
178 ) -> Self {
179 Self {
180 invocation_id,
181 agent,
182 user_id,
183 app_name,
184 session_id,
185 branch: String::new(),
186 user_content,
187 artifacts: None,
188 memory: None,
189 run_config: RunConfig::default(),
190 ended: Arc::new(AtomicBool::new(false)),
191 session,
192 }
193 }
194
195 pub fn with_branch(mut self, branch: String) -> Self {
196 self.branch = branch;
197 self
198 }
199
200 pub fn with_artifacts(mut self, artifacts: Arc<dyn Artifacts>) -> Self {
201 self.artifacts = Some(artifacts);
202 self
203 }
204
205 pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
206 self.memory = Some(memory);
207 self
208 }
209
210 pub fn with_run_config(mut self, config: RunConfig) -> Self {
211 self.run_config = config;
212 self
213 }
214
215 pub fn mutable_session(&self) -> &Arc<MutableSession> {
218 &self.session
219 }
220}
221
222#[async_trait]
223impl ReadonlyContext for InvocationContext {
224 fn invocation_id(&self) -> &str {
225 &self.invocation_id
226 }
227
228 fn agent_name(&self) -> &str {
229 self.agent.name()
230 }
231
232 fn user_id(&self) -> &str {
233 &self.user_id
234 }
235
236 fn app_name(&self) -> &str {
237 &self.app_name
238 }
239
240 fn session_id(&self) -> &str {
241 &self.session_id
242 }
243
244 fn branch(&self) -> &str {
245 &self.branch
246 }
247
248 fn user_content(&self) -> &Content {
249 &self.user_content
250 }
251}
252
253#[async_trait]
254impl CallbackContext for InvocationContext {
255 fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
256 self.artifacts.clone()
257 }
258}
259
260#[async_trait]
261impl InvocationContextTrait for InvocationContext {
262 fn agent(&self) -> Arc<dyn Agent> {
263 self.agent.clone()
264 }
265
266 fn memory(&self) -> Option<Arc<dyn Memory>> {
267 self.memory.clone()
268 }
269
270 fn session(&self) -> &dyn adk_core::Session {
271 self.session.as_ref()
272 }
273
274 fn run_config(&self) -> &RunConfig {
275 &self.run_config
276 }
277
278 fn end_invocation(&self) {
279 self.ended.store(true, std::sync::atomic::Ordering::SeqCst);
280 }
281
282 fn ended(&self) -> bool {
283 self.ended.load(std::sync::atomic::Ordering::SeqCst)
284 }
285}