copilot_sdk_supercharged/
session.rs1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use serde_json::Value;
15use tokio::sync::{mpsc, Mutex};
16
17use crate::jsonrpc::JsonRpcClient;
18use crate::types::*;
19use crate::CopilotError;
20
21pub type ToolHandler = Arc<
28 dyn Fn(
29 Value,
30 ToolInvocation,
31 ) -> Pin<Box<dyn Future<Output = Result<Value, CopilotError>> + Send>>
32 + Send
33 + Sync,
34>;
35
36pub type PermissionHandlerFn = Arc<
39 dyn Fn(
40 PermissionRequest,
41 String,
42 ) -> Pin<Box<dyn Future<Output = Result<PermissionRequestResult, CopilotError>> + Send>>
43 + Send
44 + Sync,
45>;
46
47pub type UserInputHandlerFn = Arc<
50 dyn Fn(
51 UserInputRequest,
52 String,
53 ) -> Pin<Box<dyn Future<Output = Result<UserInputResponse, CopilotError>> + Send>>
54 + Send
55 + Sync,
56>;
57
58pub type HooksHandlerFn = Arc<
61 dyn Fn(
62 String,
63 Value,
64 String,
65 ) -> Pin<Box<dyn Future<Output = Result<Option<Value>, CopilotError>> + Send>>
66 + Send
67 + Sync,
68>;
69
70pub type SessionEventHandlerFn = Arc<dyn Fn(SessionEvent) + Send + Sync>;
72
73pub type TypedSessionEventHandlerFn = Arc<dyn Fn(SessionEvent) + Send + Sync>;
75
76pub struct Subscription {
83 unsubscribe_fn: Option<Box<dyn FnOnce() + Send>>,
84}
85
86impl Subscription {
87 fn new(f: impl FnOnce() + Send + 'static) -> Self {
88 Self {
89 unsubscribe_fn: Some(Box::new(f)),
90 }
91 }
92
93 pub fn unsubscribe(mut self) {
95 if let Some(f) = self.unsubscribe_fn.take() {
96 f();
97 }
98 }
99}
100
101impl Drop for Subscription {
102 fn drop(&mut self) {
103 if let Some(f) = self.unsubscribe_fn.take() {
104 f();
105 }
106 }
107}
108
109pub struct CopilotSession {
146 session_id: String,
148 workspace_path: Option<String>,
150 rpc_client: Arc<JsonRpcClient>,
152 tool_handlers: Arc<Mutex<HashMap<String, ToolHandler>>>,
154 permission_handler: Arc<Mutex<Option<PermissionHandlerFn>>>,
156 user_input_handler: Arc<Mutex<Option<UserInputHandlerFn>>>,
158 hooks_handler: Arc<Mutex<Option<HooksHandlerFn>>>,
160 event_handlers: Arc<Mutex<Vec<(u64, SessionEventHandlerFn)>>>,
162 typed_event_handlers: Arc<Mutex<HashMap<String, Vec<(u64, TypedSessionEventHandlerFn)>>>>,
164 next_handler_id: Arc<Mutex<u64>>,
166}
167
168impl CopilotSession {
169 pub(crate) fn new(
171 session_id: String,
172 rpc_client: Arc<JsonRpcClient>,
173 workspace_path: Option<String>,
174 ) -> Self {
175 Self {
176 session_id,
177 workspace_path,
178 rpc_client,
179 tool_handlers: Arc::new(Mutex::new(HashMap::new())),
180 permission_handler: Arc::new(Mutex::new(None)),
181 user_input_handler: Arc::new(Mutex::new(None)),
182 hooks_handler: Arc::new(Mutex::new(None)),
183 event_handlers: Arc::new(Mutex::new(Vec::new())),
184 typed_event_handlers: Arc::new(Mutex::new(HashMap::new())),
185 next_handler_id: Arc::new(Mutex::new(0)),
186 }
187 }
188
189 pub fn session_id(&self) -> &str {
191 &self.session_id
192 }
193
194 pub fn workspace_path(&self) -> Option<&str> {
196 self.workspace_path.as_deref()
197 }
198
199 pub async fn send(&self, options: MessageOptions) -> Result<String, CopilotError> {
210 let params = serde_json::json!({
211 "sessionId": self.session_id,
212 "prompt": options.prompt,
213 "attachments": options.attachments,
214 "mode": options.mode,
215 "responseFormat": options.response_format,
216 "imageOptions": options.image_options,
217 });
218
219 let response = self.rpc_client.request("session.send", params, None).await?;
220 let message_id = response
221 .get("messageId")
222 .and_then(|v| v.as_str())
223 .unwrap_or("")
224 .to_string();
225 Ok(message_id)
226 }
227
228 pub async fn send_and_wait(
237 &self,
238 options: MessageOptions,
239 timeout: Option<u64>,
240 ) -> Result<Option<SessionEvent>, CopilotError> {
241 let effective_timeout = timeout.unwrap_or(60_000);
242
243 let (idle_tx, mut idle_rx) = mpsc::channel::<Result<(), CopilotError>>(1);
245 let last_assistant_message: Arc<Mutex<Option<SessionEvent>>> =
246 Arc::new(Mutex::new(None));
247
248 let last_msg_clone = Arc::clone(&last_assistant_message);
249 let idle_tx_clone = idle_tx.clone();
250
251 let sub = self
253 .on(move |event: SessionEvent| {
254 if event.is_assistant_message() {
255 let mut msg = last_msg_clone.blocking_lock();
256 *msg = Some(event);
257 } else if event.is_session_idle() {
258 let _ = idle_tx_clone.try_send(Ok(()));
259 } else if event.is_session_error() {
260 let error_msg = event
261 .error_message()
262 .unwrap_or("Unknown error")
263 .to_string();
264 let _ = idle_tx_clone.try_send(Err(CopilotError::SessionError(error_msg)));
265 }
266 })
267 .await;
268
269 self.send(options).await?;
271
272 let result = tokio::time::timeout(
274 std::time::Duration::from_millis(effective_timeout),
275 idle_rx.recv(),
276 )
277 .await;
278
279 sub.unsubscribe();
281
282 match result {
283 Ok(Some(Ok(()))) => {
284 let msg = last_assistant_message.lock().await;
285 Ok(msg.clone())
286 }
287 Ok(Some(Err(e))) => Err(e),
288 Ok(None) => Err(CopilotError::ConnectionClosed),
289 Err(_) => Err(CopilotError::Timeout(effective_timeout)),
290 }
291 }
292
293 pub async fn on<F>(&self, handler: F) -> Subscription
302 where
303 F: Fn(SessionEvent) + Send + Sync + 'static,
304 {
305 let handler_id = {
306 let mut id = self.next_handler_id.lock().await;
307 *id += 1;
308 *id
309 };
310
311 let handler_arc: SessionEventHandlerFn = Arc::new(handler);
312 {
313 let mut handlers = self.event_handlers.lock().await;
314 handlers.push((handler_id, handler_arc));
315 }
316
317 let event_handlers = Arc::clone(&self.event_handlers);
318 Subscription::new(move || {
319 let mut handlers = event_handlers.blocking_lock();
322 handlers.retain(|(id, _)| *id != handler_id);
323 })
324 }
325
326 pub async fn on_event<F>(&self, event_type: &str, handler: F) -> Subscription
332 where
333 F: Fn(SessionEvent) + Send + Sync + 'static,
334 {
335 let handler_id = {
336 let mut id = self.next_handler_id.lock().await;
337 *id += 1;
338 *id
339 };
340
341 let handler_arc: TypedSessionEventHandlerFn = Arc::new(handler);
342 let event_type_str = event_type.to_string();
343 {
344 let mut handlers = self.typed_event_handlers.lock().await;
345 handlers
346 .entry(event_type_str.clone())
347 .or_default()
348 .push((handler_id, handler_arc));
349 }
350
351 let typed_handlers = Arc::clone(&self.typed_event_handlers);
352 let et = event_type_str;
353 Subscription::new(move || {
354 let mut handlers = typed_handlers.blocking_lock();
355 if let Some(list) = handlers.get_mut(&et) {
356 list.retain(|(id, _)| *id != handler_id);
357 }
358 })
359 }
360
361 pub(crate) async fn dispatch_event(&self, event: SessionEvent) {
367 {
369 let handlers = self.typed_event_handlers.lock().await;
370 if let Some(list) = handlers.get(&event.event_type) {
371 for (_, handler) in list {
372 handler(event.clone());
373 }
374 }
375 }
376
377 {
379 let handlers = self.event_handlers.lock().await;
380 for (_, handler) in handlers.iter() {
381 handler(event.clone());
382 }
383 }
384 }
385
386 pub async fn register_tool(&self, name: &str, handler: ToolHandler) {
392 let mut handlers = self.tool_handlers.lock().await;
393 handlers.insert(name.to_string(), handler);
394 }
395
396 pub async fn register_tools(&self, tools: Vec<(String, ToolHandler)>) {
398 let mut handlers = self.tool_handlers.lock().await;
399 handlers.clear();
400 for (name, handler) in tools {
401 handlers.insert(name, handler);
402 }
403 }
404
405 pub(crate) async fn get_tool_handler(&self, name: &str) -> Option<ToolHandler> {
407 let handlers = self.tool_handlers.lock().await;
408 handlers.get(name).cloned()
409 }
410
411 pub async fn register_permission_handler(&self, handler: PermissionHandlerFn) {
417 let mut h = self.permission_handler.lock().await;
418 *h = Some(handler);
419 }
420
421 pub(crate) async fn handle_permission_request(
423 &self,
424 request: Value,
425 ) -> Result<PermissionRequestResult, CopilotError> {
426 let handler = self.permission_handler.lock().await;
427 if let Some(ref h) = *handler {
428 let perm_request: PermissionRequest = serde_json::from_value(request)
429 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
430 h(perm_request, self.session_id.clone()).await
431 } else {
432 Ok(PermissionRequestResult {
433 kind: PermissionResultKind::DeniedNoApprovalRuleAndCouldNotRequestFromUser,
434 rules: None,
435 })
436 }
437 }
438
439 pub async fn register_user_input_handler(&self, handler: UserInputHandlerFn) {
445 let mut h = self.user_input_handler.lock().await;
446 *h = Some(handler);
447 }
448
449 pub(crate) async fn handle_user_input_request(
451 &self,
452 request: Value,
453 ) -> Result<UserInputResponse, CopilotError> {
454 let handler = self.user_input_handler.lock().await;
455 if let Some(ref h) = *handler {
456 let input_request: UserInputRequest = serde_json::from_value(request)
457 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
458 h(input_request, self.session_id.clone()).await
459 } else {
460 Err(CopilotError::NoHandler(
461 "User input requested but no handler registered".to_string(),
462 ))
463 }
464 }
465
466 pub async fn register_hooks_handler(&self, handler: HooksHandlerFn) {
472 let mut h = self.hooks_handler.lock().await;
473 *h = Some(handler);
474 }
475
476 pub(crate) async fn handle_hooks_invoke(
478 &self,
479 hook_type: &str,
480 input: Value,
481 ) -> Result<Option<Value>, CopilotError> {
482 let handler = self.hooks_handler.lock().await;
483 if let Some(ref h) = *handler {
484 h(hook_type.to_string(), input, self.session_id.clone()).await
485 } else {
486 Ok(None)
487 }
488 }
489
490 pub async fn get_messages(&self) -> Result<Vec<SessionEvent>, CopilotError> {
496 let params = serde_json::json!({ "sessionId": self.session_id });
497 let response = self
498 .rpc_client
499 .request("session.getMessages", params, None)
500 .await?;
501 let events: Vec<SessionEvent> = serde_json::from_value(
502 response
503 .get("events")
504 .cloned()
505 .unwrap_or(Value::Array(vec![])),
506 )
507 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
508 Ok(events)
509 }
510
511 pub async fn get_metadata(&self) -> Result<Value, CopilotError> {
513 let params = serde_json::json!({ "sessionId": self.session_id });
514 let response = self
515 .rpc_client
516 .request("session.getMetadata", params, None)
517 .await?;
518 Ok(response)
519 }
520
521 pub async fn destroy(&self) -> Result<(), CopilotError> {
525 let params = serde_json::json!({ "sessionId": self.session_id });
526 self.rpc_client
527 .request("session.destroy", params, None)
528 .await?;
529
530 {
532 let mut handlers = self.event_handlers.lock().await;
533 handlers.clear();
534 }
535 {
536 let mut handlers = self.typed_event_handlers.lock().await;
537 handlers.clear();
538 }
539 {
540 let mut handlers = self.tool_handlers.lock().await;
541 handlers.clear();
542 }
543 {
544 let mut handler = self.permission_handler.lock().await;
545 *handler = None;
546 }
547 {
548 let mut handler = self.user_input_handler.lock().await;
549 *handler = None;
550 }
551
552 Ok(())
553 }
554
555 pub async fn abort(&self) -> Result<(), CopilotError> {
557 let params = serde_json::json!({ "sessionId": self.session_id });
558 self.rpc_client
559 .request("session.abort", params, None)
560 .await?;
561 Ok(())
562 }
563}