Skip to main content

copilot_sdk_supercharged/
session.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2
3//! CopilotSession - represents a single conversation session with the Copilot CLI.
4//!
5//! Sessions are created via [`CopilotClient::create_session`] or resumed via
6//! [`CopilotClient::resume_session`]. They maintain conversation state, handle events,
7//! and manage tool execution.
8
9use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use serde_json::Value;
15use tokio::sync::{mpsc, Mutex};
16
17use crate::jsonrpc::JsonRpcClient;
18use crate::types::*;
19use crate::CopilotError;
20
21// ============================================================================
22// Handler Type Aliases
23// ============================================================================
24
25/// A tool handler is an async function that takes arguments and invocation info,
26/// and returns a result value.
27pub type ToolHandler = Arc<
28    dyn Fn(
29            Value,
30            ToolInvocation,
31        ) -> Pin<Box<dyn Future<Output = Result<Value, CopilotError>> + Send>>
32        + Send
33        + Sync,
34>;
35
36/// A permission handler is an async function that takes a permission request
37/// and returns a permission result.
38pub type PermissionHandlerFn = Arc<
39    dyn Fn(
40            PermissionRequest,
41            String,
42        ) -> Pin<Box<dyn Future<Output = Result<PermissionRequestResult, CopilotError>> + Send>>
43        + Send
44        + Sync,
45>;
46
47/// A user input handler is an async function that takes a user input request
48/// and returns a response.
49pub type UserInputHandlerFn = Arc<
50    dyn Fn(
51            UserInputRequest,
52            String,
53        ) -> Pin<Box<dyn Future<Output = Result<UserInputResponse, CopilotError>> + Send>>
54        + Send
55        + Sync,
56>;
57
58/// A hooks handler is an async function that takes hook type, input, and session ID,
59/// and returns optional output.
60pub type HooksHandlerFn = Arc<
61    dyn Fn(
62            String,
63            Value,
64            String,
65        ) -> Pin<Box<dyn Future<Output = Result<Option<Value>, CopilotError>> + Send>>
66        + Send
67        + Sync,
68>;
69
70/// A session event handler callback.
71pub type SessionEventHandlerFn = Arc<dyn Fn(SessionEvent) + Send + Sync>;
72
73/// A typed session event handler for a specific event type.
74pub type TypedSessionEventHandlerFn = Arc<dyn Fn(SessionEvent) + Send + Sync>;
75
76// ============================================================================
77// Unsubscribe Guard
78// ============================================================================
79
80/// An RAII guard that unsubscribes a handler when dropped.
81/// Can also be explicitly unsubscribed via `unsubscribe()`.
82pub struct Subscription {
83    unsubscribe_fn: Option<Box<dyn FnOnce() + Send>>,
84}
85
86impl Subscription {
87    fn new(f: impl FnOnce() + Send + 'static) -> Self {
88        Self {
89            unsubscribe_fn: Some(Box::new(f)),
90        }
91    }
92
93    /// Explicitly unsubscribes this handler.
94    pub fn unsubscribe(mut self) {
95        if let Some(f) = self.unsubscribe_fn.take() {
96            f();
97        }
98    }
99}
100
101impl Drop for Subscription {
102    fn drop(&mut self) {
103        if let Some(f) = self.unsubscribe_fn.take() {
104            f();
105        }
106    }
107}
108
109// ============================================================================
110// CopilotSession
111// ============================================================================
112
113/// Represents a single conversation session with the Copilot CLI.
114///
115/// A session maintains conversation state, handles events, and manages tool execution.
116///
117/// # Examples
118///
119/// ```rust,no_run
120/// # use copilot_sdk::*;
121/// # async fn example() -> Result<(), CopilotError> {
122/// let client = CopilotClient::new(CopilotClientOptions::default());
123/// let session = client.create_session(SessionConfig::default()).await?;
124///
125/// // Subscribe to events
126/// let sub = session.on(|event| {
127///     if event.is_assistant_message() {
128///         if let Some(content) = event.assistant_message_content() {
129///             println!("Assistant: {}", content);
130///         }
131///     }
132/// }).await;
133///
134/// // Send a message and wait for completion
135/// let response = session.send_and_wait(
136///     MessageOptions { prompt: "Hello!".into(), attachments: None, mode: None, response_format: None, image_options: None },
137///     None,
138/// ).await?;
139///
140/// // Clean up
141/// session.destroy().await?;
142/// # Ok(())
143/// # }
144/// ```
145pub struct CopilotSession {
146    /// The unique session ID.
147    session_id: String,
148    /// Path to the session workspace directory (when infinite sessions are enabled).
149    workspace_path: Option<String>,
150    /// Reference to the JSON-RPC client.
151    rpc_client: Arc<JsonRpcClient>,
152    /// Registered tool handlers, keyed by tool name.
153    tool_handlers: Arc<Mutex<HashMap<String, ToolHandler>>>,
154    /// Permission request handler.
155    permission_handler: Arc<Mutex<Option<PermissionHandlerFn>>>,
156    /// User input request handler.
157    user_input_handler: Arc<Mutex<Option<UserInputHandlerFn>>>,
158    /// Hooks handler.
159    hooks_handler: Arc<Mutex<Option<HooksHandlerFn>>>,
160    /// Wildcard event handlers.
161    event_handlers: Arc<Mutex<Vec<(u64, SessionEventHandlerFn)>>>,
162    /// Typed event handlers, keyed by event type string.
163    typed_event_handlers: Arc<Mutex<HashMap<String, Vec<(u64, TypedSessionEventHandlerFn)>>>>,
164    /// Counter for handler IDs (for unsubscribe).
165    next_handler_id: Arc<Mutex<u64>>,
166}
167
168impl CopilotSession {
169    /// Creates a new CopilotSession. This is called internally by CopilotClient.
170    pub(crate) fn new(
171        session_id: String,
172        rpc_client: Arc<JsonRpcClient>,
173        workspace_path: Option<String>,
174    ) -> Self {
175        Self {
176            session_id,
177            workspace_path,
178            rpc_client,
179            tool_handlers: Arc::new(Mutex::new(HashMap::new())),
180            permission_handler: Arc::new(Mutex::new(None)),
181            user_input_handler: Arc::new(Mutex::new(None)),
182            hooks_handler: Arc::new(Mutex::new(None)),
183            event_handlers: Arc::new(Mutex::new(Vec::new())),
184            typed_event_handlers: Arc::new(Mutex::new(HashMap::new())),
185            next_handler_id: Arc::new(Mutex::new(0)),
186        }
187    }
188
189    /// Returns the session ID.
190    pub fn session_id(&self) -> &str {
191        &self.session_id
192    }
193
194    /// Returns the workspace path (when infinite sessions are enabled).
195    pub fn workspace_path(&self) -> Option<&str> {
196        self.workspace_path.as_deref()
197    }
198
199    // ========================================================================
200    // Message Sending
201    // ========================================================================
202
203    /// Sends a message to this session.
204    ///
205    /// The message is processed asynchronously. Subscribe to events via `on()`
206    /// to receive streaming responses and other session events.
207    ///
208    /// Returns the message ID.
209    pub async fn send(&self, options: MessageOptions) -> Result<String, CopilotError> {
210        let params = serde_json::json!({
211            "sessionId": self.session_id,
212            "prompt": options.prompt,
213            "attachments": options.attachments,
214            "mode": options.mode,
215            "responseFormat": options.response_format,
216            "imageOptions": options.image_options,
217        });
218
219        let response = self.rpc_client.request("session.send", params, None).await?;
220        let message_id = response
221            .get("messageId")
222            .and_then(|v| v.as_str())
223            .unwrap_or("")
224            .to_string();
225        Ok(message_id)
226    }
227
228    /// Sends a message and waits until the session becomes idle.
229    ///
230    /// This combines `send()` with waiting for the `session.idle` event.
231    /// Returns the last `assistant.message` event received, or `None`.
232    ///
233    /// # Arguments
234    /// * `options` - The message options
235    /// * `timeout` - Optional timeout in milliseconds (defaults to 60000)
236    pub async fn send_and_wait(
237        &self,
238        options: MessageOptions,
239        timeout: Option<u64>,
240    ) -> Result<Option<SessionEvent>, CopilotError> {
241        let effective_timeout = timeout.unwrap_or(60_000);
242
243        // Channel to signal idle or error
244        let (idle_tx, mut idle_rx) = mpsc::channel::<Result<(), CopilotError>>(1);
245        let last_assistant_message: Arc<Mutex<Option<SessionEvent>>> =
246            Arc::new(Mutex::new(None));
247
248        let last_msg_clone = Arc::clone(&last_assistant_message);
249        let idle_tx_clone = idle_tx.clone();
250
251        // Register event handler BEFORE calling send to avoid race condition
252        let sub = self
253            .on(move |event: SessionEvent| {
254                if event.is_assistant_message() {
255                    let mut msg = last_msg_clone.blocking_lock();
256                    *msg = Some(event);
257                } else if event.is_session_idle() {
258                    let _ = idle_tx_clone.try_send(Ok(()));
259                } else if event.is_session_error() {
260                    let error_msg = event
261                        .error_message()
262                        .unwrap_or("Unknown error")
263                        .to_string();
264                    let _ = idle_tx_clone.try_send(Err(CopilotError::SessionError(error_msg)));
265                }
266            })
267            .await;
268
269        // Send the message
270        self.send(options).await?;
271
272        // Wait for idle or timeout
273        let result = tokio::time::timeout(
274            std::time::Duration::from_millis(effective_timeout),
275            idle_rx.recv(),
276        )
277        .await;
278
279        // Unsubscribe
280        sub.unsubscribe();
281
282        match result {
283            Ok(Some(Ok(()))) => {
284                let msg = last_assistant_message.lock().await;
285                Ok(msg.clone())
286            }
287            Ok(Some(Err(e))) => Err(e),
288            Ok(None) => Err(CopilotError::ConnectionClosed),
289            Err(_) => Err(CopilotError::Timeout(effective_timeout)),
290        }
291    }
292
293    // ========================================================================
294    // Event Subscription
295    // ========================================================================
296
297    /// Subscribes to all events from this session.
298    ///
299    /// Returns a `Subscription` that unsubscribes when dropped or when
300    /// `unsubscribe()` is called.
301    pub async fn on<F>(&self, handler: F) -> Subscription
302    where
303        F: Fn(SessionEvent) + Send + Sync + 'static,
304    {
305        let handler_id = {
306            let mut id = self.next_handler_id.lock().await;
307            *id += 1;
308            *id
309        };
310
311        let handler_arc: SessionEventHandlerFn = Arc::new(handler);
312        {
313            let mut handlers = self.event_handlers.lock().await;
314            handlers.push((handler_id, handler_arc));
315        }
316
317        let event_handlers = Arc::clone(&self.event_handlers);
318        Subscription::new(move || {
319            // We need to use blocking_lock since unsubscribe may be called from Drop
320            // in a non-async context
321            let mut handlers = event_handlers.blocking_lock();
322            handlers.retain(|(id, _)| *id != handler_id);
323        })
324    }
325
326    /// Subscribes to a specific event type from this session.
327    ///
328    /// # Arguments
329    /// * `event_type` - The event type string (e.g., "assistant.message", "session.idle")
330    /// * `handler` - The callback function
331    pub async fn on_event<F>(&self, event_type: &str, handler: F) -> Subscription
332    where
333        F: Fn(SessionEvent) + Send + Sync + 'static,
334    {
335        let handler_id = {
336            let mut id = self.next_handler_id.lock().await;
337            *id += 1;
338            *id
339        };
340
341        let handler_arc: TypedSessionEventHandlerFn = Arc::new(handler);
342        let event_type_str = event_type.to_string();
343        {
344            let mut handlers = self.typed_event_handlers.lock().await;
345            handlers
346                .entry(event_type_str.clone())
347                .or_default()
348                .push((handler_id, handler_arc));
349        }
350
351        let typed_handlers = Arc::clone(&self.typed_event_handlers);
352        let et = event_type_str;
353        Subscription::new(move || {
354            let mut handlers = typed_handlers.blocking_lock();
355            if let Some(list) = handlers.get_mut(&et) {
356                list.retain(|(id, _)| *id != handler_id);
357            }
358        })
359    }
360
361    // ========================================================================
362    // Event Dispatch (internal, called by CopilotClient)
363    // ========================================================================
364
365    /// Dispatches a session event to all registered handlers.
366    pub(crate) async fn dispatch_event(&self, event: SessionEvent) {
367        // Dispatch to typed handlers
368        {
369            let handlers = self.typed_event_handlers.lock().await;
370            if let Some(list) = handlers.get(&event.event_type) {
371                for (_, handler) in list {
372                    handler(event.clone());
373                }
374            }
375        }
376
377        // Dispatch to wildcard handlers
378        {
379            let handlers = self.event_handlers.lock().await;
380            for (_, handler) in handlers.iter() {
381                handler(event.clone());
382            }
383        }
384    }
385
386    // ========================================================================
387    // Tool Registration
388    // ========================================================================
389
390    /// Registers a tool handler.
391    pub async fn register_tool(&self, name: &str, handler: ToolHandler) {
392        let mut handlers = self.tool_handlers.lock().await;
393        handlers.insert(name.to_string(), handler);
394    }
395
396    /// Registers multiple tool handlers.
397    pub async fn register_tools(&self, tools: Vec<(String, ToolHandler)>) {
398        let mut handlers = self.tool_handlers.lock().await;
399        handlers.clear();
400        for (name, handler) in tools {
401            handlers.insert(name, handler);
402        }
403    }
404
405    /// Gets a registered tool handler by name.
406    pub(crate) async fn get_tool_handler(&self, name: &str) -> Option<ToolHandler> {
407        let handlers = self.tool_handlers.lock().await;
408        handlers.get(name).cloned()
409    }
410
411    // ========================================================================
412    // Permission Handler
413    // ========================================================================
414
415    /// Registers a permission request handler.
416    pub async fn register_permission_handler(&self, handler: PermissionHandlerFn) {
417        let mut h = self.permission_handler.lock().await;
418        *h = Some(handler);
419    }
420
421    /// Handles an incoming permission request from the server.
422    pub(crate) async fn handle_permission_request(
423        &self,
424        request: Value,
425    ) -> Result<PermissionRequestResult, CopilotError> {
426        let handler = self.permission_handler.lock().await;
427        if let Some(ref h) = *handler {
428            let perm_request: PermissionRequest = serde_json::from_value(request)
429                .map_err(|e| CopilotError::Serialization(e.to_string()))?;
430            h(perm_request, self.session_id.clone()).await
431        } else {
432            Ok(PermissionRequestResult {
433                kind: PermissionResultKind::DeniedNoApprovalRuleAndCouldNotRequestFromUser,
434                rules: None,
435            })
436        }
437    }
438
439    // ========================================================================
440    // User Input Handler
441    // ========================================================================
442
443    /// Registers a user input request handler.
444    pub async fn register_user_input_handler(&self, handler: UserInputHandlerFn) {
445        let mut h = self.user_input_handler.lock().await;
446        *h = Some(handler);
447    }
448
449    /// Handles an incoming user input request from the server.
450    pub(crate) async fn handle_user_input_request(
451        &self,
452        request: Value,
453    ) -> Result<UserInputResponse, CopilotError> {
454        let handler = self.user_input_handler.lock().await;
455        if let Some(ref h) = *handler {
456            let input_request: UserInputRequest = serde_json::from_value(request)
457                .map_err(|e| CopilotError::Serialization(e.to_string()))?;
458            h(input_request, self.session_id.clone()).await
459        } else {
460            Err(CopilotError::NoHandler(
461                "User input requested but no handler registered".to_string(),
462            ))
463        }
464    }
465
466    // ========================================================================
467    // Hooks Handler
468    // ========================================================================
469
470    /// Registers a hooks handler for all hook types.
471    pub async fn register_hooks_handler(&self, handler: HooksHandlerFn) {
472        let mut h = self.hooks_handler.lock().await;
473        *h = Some(handler);
474    }
475
476    /// Handles an incoming hooks invocation from the server.
477    pub(crate) async fn handle_hooks_invoke(
478        &self,
479        hook_type: &str,
480        input: Value,
481    ) -> Result<Option<Value>, CopilotError> {
482        let handler = self.hooks_handler.lock().await;
483        if let Some(ref h) = *handler {
484            h(hook_type.to_string(), input, self.session_id.clone()).await
485        } else {
486            Ok(None)
487        }
488    }
489
490    // ========================================================================
491    // Session Operations
492    // ========================================================================
493
494    /// Retrieves all events and messages from this session's history.
495    pub async fn get_messages(&self) -> Result<Vec<SessionEvent>, CopilotError> {
496        let params = serde_json::json!({ "sessionId": self.session_id });
497        let response = self
498            .rpc_client
499            .request("session.getMessages", params, None)
500            .await?;
501        let events: Vec<SessionEvent> = serde_json::from_value(
502            response
503                .get("events")
504                .cloned()
505                .unwrap_or(Value::Array(vec![])),
506        )
507        .map_err(|e| CopilotError::Serialization(e.to_string()))?;
508        Ok(events)
509    }
510
511    /// Retrieves metadata for this session.
512    pub async fn get_metadata(&self) -> Result<Value, CopilotError> {
513        let params = serde_json::json!({ "sessionId": self.session_id });
514        let response = self
515            .rpc_client
516            .request("session.getMetadata", params, None)
517            .await?;
518        Ok(response)
519    }
520
521    /// Destroys this session and releases all associated resources.
522    ///
523    /// After calling this method, the session can no longer be used.
524    pub async fn destroy(&self) -> Result<(), CopilotError> {
525        let params = serde_json::json!({ "sessionId": self.session_id });
526        self.rpc_client
527            .request("session.destroy", params, None)
528            .await?;
529
530        // Clear all handlers
531        {
532            let mut handlers = self.event_handlers.lock().await;
533            handlers.clear();
534        }
535        {
536            let mut handlers = self.typed_event_handlers.lock().await;
537            handlers.clear();
538        }
539        {
540            let mut handlers = self.tool_handlers.lock().await;
541            handlers.clear();
542        }
543        {
544            let mut handler = self.permission_handler.lock().await;
545            *handler = None;
546        }
547        {
548            let mut handler = self.user_input_handler.lock().await;
549            *handler = None;
550        }
551
552        Ok(())
553    }
554
555    /// Aborts the currently processing message in this session.
556    pub async fn abort(&self) -> Result<(), CopilotError> {
557        let params = serde_json::json!({ "sessionId": self.session_id });
558        self.rpc_client
559            .request("session.abort", params, None)
560            .await?;
561        Ok(())
562    }
563}