copilot_sdk_supercharged/
session.rs1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use serde_json::Value;
15use tokio::sync::{mpsc, Mutex};
16
17use crate::jsonrpc::JsonRpcClient;
18use crate::types::*;
19use crate::CopilotError;
20
21pub type ToolHandler = Arc<
28 dyn Fn(
29 Value,
30 ToolInvocation,
31 ) -> Pin<Box<dyn Future<Output = Result<Value, CopilotError>> + Send>>
32 + Send
33 + Sync,
34>;
35
36pub type PermissionHandlerFn = Arc<
39 dyn Fn(
40 PermissionRequest,
41 String,
42 ) -> Pin<Box<dyn Future<Output = Result<PermissionRequestResult, CopilotError>> + Send>>
43 + Send
44 + Sync,
45>;
46
47pub type UserInputHandlerFn = Arc<
50 dyn Fn(
51 UserInputRequest,
52 String,
53 ) -> Pin<Box<dyn Future<Output = Result<UserInputResponse, CopilotError>> + Send>>
54 + Send
55 + Sync,
56>;
57
58pub type HooksHandlerFn = Arc<
61 dyn Fn(
62 String,
63 Value,
64 String,
65 ) -> Pin<Box<dyn Future<Output = Result<Option<Value>, CopilotError>> + Send>>
66 + Send
67 + Sync,
68>;
69
70pub type SessionEventHandlerFn = Arc<dyn Fn(SessionEvent) + Send + Sync>;
72
73pub type TypedSessionEventHandlerFn = Arc<dyn Fn(SessionEvent) + Send + Sync>;
75
76pub struct Subscription {
83 unsubscribe_fn: Option<Box<dyn FnOnce() + Send>>,
84}
85
86impl Subscription {
87 fn new(f: impl FnOnce() + Send + 'static) -> Self {
88 Self {
89 unsubscribe_fn: Some(Box::new(f)),
90 }
91 }
92
93 pub fn unsubscribe(mut self) {
95 if let Some(f) = self.unsubscribe_fn.take() {
96 f();
97 }
98 }
99}
100
101impl Drop for Subscription {
102 fn drop(&mut self) {
103 if let Some(f) = self.unsubscribe_fn.take() {
104 f();
105 }
106 }
107}
108
109pub struct CopilotSession {
146 session_id: String,
148 workspace_path: Option<String>,
150 rpc_client: Arc<JsonRpcClient>,
152 tool_handlers: Arc<Mutex<HashMap<String, ToolHandler>>>,
154 permission_handler: Arc<Mutex<Option<PermissionHandlerFn>>>,
156 user_input_handler: Arc<Mutex<Option<UserInputHandlerFn>>>,
158 hooks_handler: Arc<Mutex<Option<HooksHandlerFn>>>,
160 event_handlers: Arc<Mutex<Vec<(u64, SessionEventHandlerFn)>>>,
162 typed_event_handlers: Arc<Mutex<HashMap<String, Vec<(u64, TypedSessionEventHandlerFn)>>>>,
164 next_handler_id: Arc<Mutex<u64>>,
166}
167
168impl CopilotSession {
169 pub(crate) fn new(
171 session_id: String,
172 rpc_client: Arc<JsonRpcClient>,
173 workspace_path: Option<String>,
174 ) -> Self {
175 Self {
176 session_id,
177 workspace_path,
178 rpc_client,
179 tool_handlers: Arc::new(Mutex::new(HashMap::new())),
180 permission_handler: Arc::new(Mutex::new(None)),
181 user_input_handler: Arc::new(Mutex::new(None)),
182 hooks_handler: Arc::new(Mutex::new(None)),
183 event_handlers: Arc::new(Mutex::new(Vec::new())),
184 typed_event_handlers: Arc::new(Mutex::new(HashMap::new())),
185 next_handler_id: Arc::new(Mutex::new(0)),
186 }
187 }
188
189 pub fn session_id(&self) -> &str {
191 &self.session_id
192 }
193
194 pub fn workspace_path(&self) -> Option<&str> {
196 self.workspace_path.as_deref()
197 }
198
199 pub async fn send(&self, options: MessageOptions) -> Result<String, CopilotError> {
210 let params = serde_json::json!({
211 "sessionId": self.session_id,
212 "prompt": options.prompt,
213 "attachments": options.attachments,
214 "mode": options.mode,
215 });
216
217 let response = self.rpc_client.request("session.send", params, None).await?;
218 let message_id = response
219 .get("messageId")
220 .and_then(|v| v.as_str())
221 .unwrap_or("")
222 .to_string();
223 Ok(message_id)
224 }
225
226 pub async fn send_and_wait(
235 &self,
236 options: MessageOptions,
237 timeout: Option<u64>,
238 ) -> Result<Option<SessionEvent>, CopilotError> {
239 let effective_timeout = timeout.unwrap_or(60_000);
240
241 let (idle_tx, mut idle_rx) = mpsc::channel::<Result<(), CopilotError>>(1);
243 let last_assistant_message: Arc<Mutex<Option<SessionEvent>>> =
244 Arc::new(Mutex::new(None));
245
246 let last_msg_clone = Arc::clone(&last_assistant_message);
247 let idle_tx_clone = idle_tx.clone();
248
249 let sub = self
251 .on(move |event: SessionEvent| {
252 if event.is_assistant_message() {
253 let mut msg = last_msg_clone.blocking_lock();
254 *msg = Some(event);
255 } else if event.is_session_idle() {
256 let _ = idle_tx_clone.try_send(Ok(()));
257 } else if event.is_session_error() {
258 let error_msg = event
259 .error_message()
260 .unwrap_or("Unknown error")
261 .to_string();
262 let _ = idle_tx_clone.try_send(Err(CopilotError::SessionError(error_msg)));
263 }
264 })
265 .await;
266
267 self.send(options).await?;
269
270 let result = tokio::time::timeout(
272 std::time::Duration::from_millis(effective_timeout),
273 idle_rx.recv(),
274 )
275 .await;
276
277 sub.unsubscribe();
279
280 match result {
281 Ok(Some(Ok(()))) => {
282 let msg = last_assistant_message.lock().await;
283 Ok(msg.clone())
284 }
285 Ok(Some(Err(e))) => Err(e),
286 Ok(None) => Err(CopilotError::ConnectionClosed),
287 Err(_) => Err(CopilotError::Timeout(effective_timeout)),
288 }
289 }
290
291 pub async fn on<F>(&self, handler: F) -> Subscription
300 where
301 F: Fn(SessionEvent) + Send + Sync + 'static,
302 {
303 let handler_id = {
304 let mut id = self.next_handler_id.lock().await;
305 *id += 1;
306 *id
307 };
308
309 let handler_arc: SessionEventHandlerFn = Arc::new(handler);
310 {
311 let mut handlers = self.event_handlers.lock().await;
312 handlers.push((handler_id, handler_arc));
313 }
314
315 let event_handlers = Arc::clone(&self.event_handlers);
316 Subscription::new(move || {
317 let mut handlers = event_handlers.blocking_lock();
320 handlers.retain(|(id, _)| *id != handler_id);
321 })
322 }
323
324 pub async fn on_event<F>(&self, event_type: &str, handler: F) -> Subscription
330 where
331 F: Fn(SessionEvent) + Send + Sync + 'static,
332 {
333 let handler_id = {
334 let mut id = self.next_handler_id.lock().await;
335 *id += 1;
336 *id
337 };
338
339 let handler_arc: TypedSessionEventHandlerFn = Arc::new(handler);
340 let event_type_str = event_type.to_string();
341 {
342 let mut handlers = self.typed_event_handlers.lock().await;
343 handlers
344 .entry(event_type_str.clone())
345 .or_default()
346 .push((handler_id, handler_arc));
347 }
348
349 let typed_handlers = Arc::clone(&self.typed_event_handlers);
350 let et = event_type_str;
351 Subscription::new(move || {
352 let mut handlers = typed_handlers.blocking_lock();
353 if let Some(list) = handlers.get_mut(&et) {
354 list.retain(|(id, _)| *id != handler_id);
355 }
356 })
357 }
358
359 pub(crate) async fn dispatch_event(&self, event: SessionEvent) {
365 {
367 let handlers = self.typed_event_handlers.lock().await;
368 if let Some(list) = handlers.get(&event.event_type) {
369 for (_, handler) in list {
370 handler(event.clone());
371 }
372 }
373 }
374
375 {
377 let handlers = self.event_handlers.lock().await;
378 for (_, handler) in handlers.iter() {
379 handler(event.clone());
380 }
381 }
382 }
383
384 pub async fn register_tool(&self, name: &str, handler: ToolHandler) {
390 let mut handlers = self.tool_handlers.lock().await;
391 handlers.insert(name.to_string(), handler);
392 }
393
394 pub async fn register_tools(&self, tools: Vec<(String, ToolHandler)>) {
396 let mut handlers = self.tool_handlers.lock().await;
397 handlers.clear();
398 for (name, handler) in tools {
399 handlers.insert(name, handler);
400 }
401 }
402
403 pub(crate) async fn get_tool_handler(&self, name: &str) -> Option<ToolHandler> {
405 let handlers = self.tool_handlers.lock().await;
406 handlers.get(name).cloned()
407 }
408
409 pub async fn register_permission_handler(&self, handler: PermissionHandlerFn) {
415 let mut h = self.permission_handler.lock().await;
416 *h = Some(handler);
417 }
418
419 pub(crate) async fn handle_permission_request(
421 &self,
422 request: Value,
423 ) -> Result<PermissionRequestResult, CopilotError> {
424 let handler = self.permission_handler.lock().await;
425 if let Some(ref h) = *handler {
426 let perm_request: PermissionRequest = serde_json::from_value(request)
427 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
428 h(perm_request, self.session_id.clone()).await
429 } else {
430 Ok(PermissionRequestResult {
431 kind: PermissionResultKind::DeniedNoApprovalRuleAndCouldNotRequestFromUser,
432 rules: None,
433 })
434 }
435 }
436
437 pub async fn register_user_input_handler(&self, handler: UserInputHandlerFn) {
443 let mut h = self.user_input_handler.lock().await;
444 *h = Some(handler);
445 }
446
447 pub(crate) async fn handle_user_input_request(
449 &self,
450 request: Value,
451 ) -> Result<UserInputResponse, CopilotError> {
452 let handler = self.user_input_handler.lock().await;
453 if let Some(ref h) = *handler {
454 let input_request: UserInputRequest = serde_json::from_value(request)
455 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
456 h(input_request, self.session_id.clone()).await
457 } else {
458 Err(CopilotError::NoHandler(
459 "User input requested but no handler registered".to_string(),
460 ))
461 }
462 }
463
464 pub async fn register_hooks_handler(&self, handler: HooksHandlerFn) {
470 let mut h = self.hooks_handler.lock().await;
471 *h = Some(handler);
472 }
473
474 pub(crate) async fn handle_hooks_invoke(
476 &self,
477 hook_type: &str,
478 input: Value,
479 ) -> Result<Option<Value>, CopilotError> {
480 let handler = self.hooks_handler.lock().await;
481 if let Some(ref h) = *handler {
482 h(hook_type.to_string(), input, self.session_id.clone()).await
483 } else {
484 Ok(None)
485 }
486 }
487
488 pub async fn get_messages(&self) -> Result<Vec<SessionEvent>, CopilotError> {
494 let params = serde_json::json!({ "sessionId": self.session_id });
495 let response = self
496 .rpc_client
497 .request("session.getMessages", params, None)
498 .await?;
499 let events: Vec<SessionEvent> = serde_json::from_value(
500 response
501 .get("events")
502 .cloned()
503 .unwrap_or(Value::Array(vec![])),
504 )
505 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
506 Ok(events)
507 }
508
509 pub async fn destroy(&self) -> Result<(), CopilotError> {
513 let params = serde_json::json!({ "sessionId": self.session_id });
514 self.rpc_client
515 .request("session.destroy", params, None)
516 .await?;
517
518 {
520 let mut handlers = self.event_handlers.lock().await;
521 handlers.clear();
522 }
523 {
524 let mut handlers = self.typed_event_handlers.lock().await;
525 handlers.clear();
526 }
527 {
528 let mut handlers = self.tool_handlers.lock().await;
529 handlers.clear();
530 }
531 {
532 let mut handler = self.permission_handler.lock().await;
533 *handler = None;
534 }
535 {
536 let mut handler = self.user_input_handler.lock().await;
537 *handler = None;
538 }
539
540 Ok(())
541 }
542
543 pub async fn abort(&self) -> Result<(), CopilotError> {
545 let params = serde_json::json!({ "sessionId": self.session_id });
546 self.rpc_client
547 .request("session.abort", params, None)
548 .await?;
549 Ok(())
550 }
551}