Skip to main content

copilot_sdk_supercharged/
session.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2
3//! CopilotSession - represents a single conversation session with the Copilot CLI.
4//!
5//! Sessions are created via [`CopilotClient::create_session`] or resumed via
6//! [`CopilotClient::resume_session`]. They maintain conversation state, handle events,
7//! and manage tool execution.
8
9use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use serde_json::Value;
15use tokio::sync::{mpsc, Mutex};
16
17use crate::jsonrpc::JsonRpcClient;
18use crate::types::*;
19use crate::CopilotError;
20
21// ============================================================================
22// Handler Type Aliases
23// ============================================================================
24
25/// A tool handler is an async function that takes arguments and invocation info,
26/// and returns a result value.
27pub type ToolHandler = Arc<
28    dyn Fn(
29            Value,
30            ToolInvocation,
31        ) -> Pin<Box<dyn Future<Output = Result<Value, CopilotError>> + Send>>
32        + Send
33        + Sync,
34>;
35
36/// A permission handler is an async function that takes a permission request
37/// and returns a permission result.
38pub type PermissionHandlerFn = Arc<
39    dyn Fn(
40            PermissionRequest,
41            String,
42        ) -> Pin<Box<dyn Future<Output = Result<PermissionRequestResult, CopilotError>> + Send>>
43        + Send
44        + Sync,
45>;
46
47/// A user input handler is an async function that takes a user input request
48/// and returns a response.
49pub type UserInputHandlerFn = Arc<
50    dyn Fn(
51            UserInputRequest,
52            String,
53        ) -> Pin<Box<dyn Future<Output = Result<UserInputResponse, CopilotError>> + Send>>
54        + Send
55        + Sync,
56>;
57
58/// A hooks handler is an async function that takes hook type, input, and session ID,
59/// and returns optional output.
60pub type HooksHandlerFn = Arc<
61    dyn Fn(
62            String,
63            Value,
64            String,
65        ) -> Pin<Box<dyn Future<Output = Result<Option<Value>, CopilotError>> + Send>>
66        + Send
67        + Sync,
68>;
69
70/// A session event handler callback.
71pub type SessionEventHandlerFn = Arc<dyn Fn(SessionEvent) + Send + Sync>;
72
73/// A typed session event handler for a specific event type.
74pub type TypedSessionEventHandlerFn = Arc<dyn Fn(SessionEvent) + Send + Sync>;
75
76// ============================================================================
77// Unsubscribe Guard
78// ============================================================================
79
80/// An RAII guard that unsubscribes a handler when dropped.
81/// Can also be explicitly unsubscribed via `unsubscribe()`.
82pub struct Subscription {
83    unsubscribe_fn: Option<Box<dyn FnOnce() + Send>>,
84}
85
86impl Subscription {
87    fn new(f: impl FnOnce() + Send + 'static) -> Self {
88        Self {
89            unsubscribe_fn: Some(Box::new(f)),
90        }
91    }
92
93    /// Explicitly unsubscribes this handler.
94    pub fn unsubscribe(mut self) {
95        if let Some(f) = self.unsubscribe_fn.take() {
96            f();
97        }
98    }
99}
100
101impl Drop for Subscription {
102    fn drop(&mut self) {
103        if let Some(f) = self.unsubscribe_fn.take() {
104            f();
105        }
106    }
107}
108
109// ============================================================================
110// CopilotSession
111// ============================================================================
112
113/// Represents a single conversation session with the Copilot CLI.
114///
115/// A session maintains conversation state, handles events, and manages tool execution.
116///
117/// # Examples
118///
119/// ```rust,no_run
120/// # use copilot_sdk::*;
121/// # async fn example() -> Result<(), CopilotError> {
122/// let client = CopilotClient::new(CopilotClientOptions::default());
123/// let session = client.create_session(SessionConfig::default()).await?;
124///
125/// // Subscribe to events
126/// let sub = session.on(|event| {
127///     if event.is_assistant_message() {
128///         if let Some(content) = event.assistant_message_content() {
129///             println!("Assistant: {}", content);
130///         }
131///     }
132/// }).await;
133///
134/// // Send a message and wait for completion
135/// let response = session.send_and_wait(
136///     MessageOptions { prompt: "Hello!".into(), attachments: None, mode: None },
137///     None,
138/// ).await?;
139///
140/// // Clean up
141/// session.destroy().await?;
142/// # Ok(())
143/// # }
144/// ```
145pub struct CopilotSession {
146    /// The unique session ID.
147    session_id: String,
148    /// Path to the session workspace directory (when infinite sessions are enabled).
149    workspace_path: Option<String>,
150    /// Reference to the JSON-RPC client.
151    rpc_client: Arc<JsonRpcClient>,
152    /// Registered tool handlers, keyed by tool name.
153    tool_handlers: Arc<Mutex<HashMap<String, ToolHandler>>>,
154    /// Permission request handler.
155    permission_handler: Arc<Mutex<Option<PermissionHandlerFn>>>,
156    /// User input request handler.
157    user_input_handler: Arc<Mutex<Option<UserInputHandlerFn>>>,
158    /// Hooks handler.
159    hooks_handler: Arc<Mutex<Option<HooksHandlerFn>>>,
160    /// Wildcard event handlers.
161    event_handlers: Arc<Mutex<Vec<(u64, SessionEventHandlerFn)>>>,
162    /// Typed event handlers, keyed by event type string.
163    typed_event_handlers: Arc<Mutex<HashMap<String, Vec<(u64, TypedSessionEventHandlerFn)>>>>,
164    /// Counter for handler IDs (for unsubscribe).
165    next_handler_id: Arc<Mutex<u64>>,
166}
167
168impl CopilotSession {
169    /// Creates a new CopilotSession. This is called internally by CopilotClient.
170    pub(crate) fn new(
171        session_id: String,
172        rpc_client: Arc<JsonRpcClient>,
173        workspace_path: Option<String>,
174    ) -> Self {
175        Self {
176            session_id,
177            workspace_path,
178            rpc_client,
179            tool_handlers: Arc::new(Mutex::new(HashMap::new())),
180            permission_handler: Arc::new(Mutex::new(None)),
181            user_input_handler: Arc::new(Mutex::new(None)),
182            hooks_handler: Arc::new(Mutex::new(None)),
183            event_handlers: Arc::new(Mutex::new(Vec::new())),
184            typed_event_handlers: Arc::new(Mutex::new(HashMap::new())),
185            next_handler_id: Arc::new(Mutex::new(0)),
186        }
187    }
188
189    /// Returns the session ID.
190    pub fn session_id(&self) -> &str {
191        &self.session_id
192    }
193
194    /// Returns the workspace path (when infinite sessions are enabled).
195    pub fn workspace_path(&self) -> Option<&str> {
196        self.workspace_path.as_deref()
197    }
198
199    // ========================================================================
200    // Message Sending
201    // ========================================================================
202
203    /// Sends a message to this session.
204    ///
205    /// The message is processed asynchronously. Subscribe to events via `on()`
206    /// to receive streaming responses and other session events.
207    ///
208    /// Returns the message ID.
209    pub async fn send(&self, options: MessageOptions) -> Result<String, CopilotError> {
210        let params = serde_json::json!({
211            "sessionId": self.session_id,
212            "prompt": options.prompt,
213            "attachments": options.attachments,
214            "mode": options.mode,
215        });
216
217        let response = self.rpc_client.request("session.send", params, None).await?;
218        let message_id = response
219            .get("messageId")
220            .and_then(|v| v.as_str())
221            .unwrap_or("")
222            .to_string();
223        Ok(message_id)
224    }
225
226    /// Sends a message and waits until the session becomes idle.
227    ///
228    /// This combines `send()` with waiting for the `session.idle` event.
229    /// Returns the last `assistant.message` event received, or `None`.
230    ///
231    /// # Arguments
232    /// * `options` - The message options
233    /// * `timeout` - Optional timeout in milliseconds (defaults to 60000)
234    pub async fn send_and_wait(
235        &self,
236        options: MessageOptions,
237        timeout: Option<u64>,
238    ) -> Result<Option<SessionEvent>, CopilotError> {
239        let effective_timeout = timeout.unwrap_or(60_000);
240
241        // Channel to signal idle or error
242        let (idle_tx, mut idle_rx) = mpsc::channel::<Result<(), CopilotError>>(1);
243        let last_assistant_message: Arc<Mutex<Option<SessionEvent>>> =
244            Arc::new(Mutex::new(None));
245
246        let last_msg_clone = Arc::clone(&last_assistant_message);
247        let idle_tx_clone = idle_tx.clone();
248
249        // Register event handler BEFORE calling send to avoid race condition
250        let sub = self
251            .on(move |event: SessionEvent| {
252                if event.is_assistant_message() {
253                    let mut msg = last_msg_clone.blocking_lock();
254                    *msg = Some(event);
255                } else if event.is_session_idle() {
256                    let _ = idle_tx_clone.try_send(Ok(()));
257                } else if event.is_session_error() {
258                    let error_msg = event
259                        .error_message()
260                        .unwrap_or("Unknown error")
261                        .to_string();
262                    let _ = idle_tx_clone.try_send(Err(CopilotError::SessionError(error_msg)));
263                }
264            })
265            .await;
266
267        // Send the message
268        self.send(options).await?;
269
270        // Wait for idle or timeout
271        let result = tokio::time::timeout(
272            std::time::Duration::from_millis(effective_timeout),
273            idle_rx.recv(),
274        )
275        .await;
276
277        // Unsubscribe
278        sub.unsubscribe();
279
280        match result {
281            Ok(Some(Ok(()))) => {
282                let msg = last_assistant_message.lock().await;
283                Ok(msg.clone())
284            }
285            Ok(Some(Err(e))) => Err(e),
286            Ok(None) => Err(CopilotError::ConnectionClosed),
287            Err(_) => Err(CopilotError::Timeout(effective_timeout)),
288        }
289    }
290
291    // ========================================================================
292    // Event Subscription
293    // ========================================================================
294
295    /// Subscribes to all events from this session.
296    ///
297    /// Returns a `Subscription` that unsubscribes when dropped or when
298    /// `unsubscribe()` is called.
299    pub async fn on<F>(&self, handler: F) -> Subscription
300    where
301        F: Fn(SessionEvent) + Send + Sync + 'static,
302    {
303        let handler_id = {
304            let mut id = self.next_handler_id.lock().await;
305            *id += 1;
306            *id
307        };
308
309        let handler_arc: SessionEventHandlerFn = Arc::new(handler);
310        {
311            let mut handlers = self.event_handlers.lock().await;
312            handlers.push((handler_id, handler_arc));
313        }
314
315        let event_handlers = Arc::clone(&self.event_handlers);
316        Subscription::new(move || {
317            // We need to use blocking_lock since unsubscribe may be called from Drop
318            // in a non-async context
319            let mut handlers = event_handlers.blocking_lock();
320            handlers.retain(|(id, _)| *id != handler_id);
321        })
322    }
323
324    /// Subscribes to a specific event type from this session.
325    ///
326    /// # Arguments
327    /// * `event_type` - The event type string (e.g., "assistant.message", "session.idle")
328    /// * `handler` - The callback function
329    pub async fn on_event<F>(&self, event_type: &str, handler: F) -> Subscription
330    where
331        F: Fn(SessionEvent) + Send + Sync + 'static,
332    {
333        let handler_id = {
334            let mut id = self.next_handler_id.lock().await;
335            *id += 1;
336            *id
337        };
338
339        let handler_arc: TypedSessionEventHandlerFn = Arc::new(handler);
340        let event_type_str = event_type.to_string();
341        {
342            let mut handlers = self.typed_event_handlers.lock().await;
343            handlers
344                .entry(event_type_str.clone())
345                .or_default()
346                .push((handler_id, handler_arc));
347        }
348
349        let typed_handlers = Arc::clone(&self.typed_event_handlers);
350        let et = event_type_str;
351        Subscription::new(move || {
352            let mut handlers = typed_handlers.blocking_lock();
353            if let Some(list) = handlers.get_mut(&et) {
354                list.retain(|(id, _)| *id != handler_id);
355            }
356        })
357    }
358
359    // ========================================================================
360    // Event Dispatch (internal, called by CopilotClient)
361    // ========================================================================
362
363    /// Dispatches a session event to all registered handlers.
364    pub(crate) async fn dispatch_event(&self, event: SessionEvent) {
365        // Dispatch to typed handlers
366        {
367            let handlers = self.typed_event_handlers.lock().await;
368            if let Some(list) = handlers.get(&event.event_type) {
369                for (_, handler) in list {
370                    handler(event.clone());
371                }
372            }
373        }
374
375        // Dispatch to wildcard handlers
376        {
377            let handlers = self.event_handlers.lock().await;
378            for (_, handler) in handlers.iter() {
379                handler(event.clone());
380            }
381        }
382    }
383
384    // ========================================================================
385    // Tool Registration
386    // ========================================================================
387
388    /// Registers a tool handler.
389    pub async fn register_tool(&self, name: &str, handler: ToolHandler) {
390        let mut handlers = self.tool_handlers.lock().await;
391        handlers.insert(name.to_string(), handler);
392    }
393
394    /// Registers multiple tool handlers.
395    pub async fn register_tools(&self, tools: Vec<(String, ToolHandler)>) {
396        let mut handlers = self.tool_handlers.lock().await;
397        handlers.clear();
398        for (name, handler) in tools {
399            handlers.insert(name, handler);
400        }
401    }
402
403    /// Gets a registered tool handler by name.
404    pub(crate) async fn get_tool_handler(&self, name: &str) -> Option<ToolHandler> {
405        let handlers = self.tool_handlers.lock().await;
406        handlers.get(name).cloned()
407    }
408
409    // ========================================================================
410    // Permission Handler
411    // ========================================================================
412
413    /// Registers a permission request handler.
414    pub async fn register_permission_handler(&self, handler: PermissionHandlerFn) {
415        let mut h = self.permission_handler.lock().await;
416        *h = Some(handler);
417    }
418
419    /// Handles an incoming permission request from the server.
420    pub(crate) async fn handle_permission_request(
421        &self,
422        request: Value,
423    ) -> Result<PermissionRequestResult, CopilotError> {
424        let handler = self.permission_handler.lock().await;
425        if let Some(ref h) = *handler {
426            let perm_request: PermissionRequest = serde_json::from_value(request)
427                .map_err(|e| CopilotError::Serialization(e.to_string()))?;
428            h(perm_request, self.session_id.clone()).await
429        } else {
430            Ok(PermissionRequestResult {
431                kind: PermissionResultKind::DeniedNoApprovalRuleAndCouldNotRequestFromUser,
432                rules: None,
433            })
434        }
435    }
436
437    // ========================================================================
438    // User Input Handler
439    // ========================================================================
440
441    /// Registers a user input request handler.
442    pub async fn register_user_input_handler(&self, handler: UserInputHandlerFn) {
443        let mut h = self.user_input_handler.lock().await;
444        *h = Some(handler);
445    }
446
447    /// Handles an incoming user input request from the server.
448    pub(crate) async fn handle_user_input_request(
449        &self,
450        request: Value,
451    ) -> Result<UserInputResponse, CopilotError> {
452        let handler = self.user_input_handler.lock().await;
453        if let Some(ref h) = *handler {
454            let input_request: UserInputRequest = serde_json::from_value(request)
455                .map_err(|e| CopilotError::Serialization(e.to_string()))?;
456            h(input_request, self.session_id.clone()).await
457        } else {
458            Err(CopilotError::NoHandler(
459                "User input requested but no handler registered".to_string(),
460            ))
461        }
462    }
463
464    // ========================================================================
465    // Hooks Handler
466    // ========================================================================
467
468    /// Registers a hooks handler for all hook types.
469    pub async fn register_hooks_handler(&self, handler: HooksHandlerFn) {
470        let mut h = self.hooks_handler.lock().await;
471        *h = Some(handler);
472    }
473
474    /// Handles an incoming hooks invocation from the server.
475    pub(crate) async fn handle_hooks_invoke(
476        &self,
477        hook_type: &str,
478        input: Value,
479    ) -> Result<Option<Value>, CopilotError> {
480        let handler = self.hooks_handler.lock().await;
481        if let Some(ref h) = *handler {
482            h(hook_type.to_string(), input, self.session_id.clone()).await
483        } else {
484            Ok(None)
485        }
486    }
487
488    // ========================================================================
489    // Session Operations
490    // ========================================================================
491
492    /// Retrieves all events and messages from this session's history.
493    pub async fn get_messages(&self) -> Result<Vec<SessionEvent>, CopilotError> {
494        let params = serde_json::json!({ "sessionId": self.session_id });
495        let response = self
496            .rpc_client
497            .request("session.getMessages", params, None)
498            .await?;
499        let events: Vec<SessionEvent> = serde_json::from_value(
500            response
501                .get("events")
502                .cloned()
503                .unwrap_or(Value::Array(vec![])),
504        )
505        .map_err(|e| CopilotError::Serialization(e.to_string()))?;
506        Ok(events)
507    }
508
509    /// Destroys this session and releases all associated resources.
510    ///
511    /// After calling this method, the session can no longer be used.
512    pub async fn destroy(&self) -> Result<(), CopilotError> {
513        let params = serde_json::json!({ "sessionId": self.session_id });
514        self.rpc_client
515            .request("session.destroy", params, None)
516            .await?;
517
518        // Clear all handlers
519        {
520            let mut handlers = self.event_handlers.lock().await;
521            handlers.clear();
522        }
523        {
524            let mut handlers = self.typed_event_handlers.lock().await;
525            handlers.clear();
526        }
527        {
528            let mut handlers = self.tool_handlers.lock().await;
529            handlers.clear();
530        }
531        {
532            let mut handler = self.permission_handler.lock().await;
533            *handler = None;
534        }
535        {
536            let mut handler = self.user_input_handler.lock().await;
537            *handler = None;
538        }
539
540        Ok(())
541    }
542
543    /// Aborts the currently processing message in this session.
544    pub async fn abort(&self) -> Result<(), CopilotError> {
545        let params = serde_json::json!({ "sessionId": self.session_id });
546        self.rpc_client
547            .request("session.abort", params, None)
548            .await?;
549        Ok(())
550    }
551}