Skip to main content

copilot_sdk/
session.rs

1// Copyright (c) 2026 Elias Bachaalany
2// SPDX-License-Identifier: MIT
3
4//! Session management for the Copilot SDK.
5//!
6//! A session represents a conversation with the Copilot CLI.
7
8use crate::error::{CopilotError, Result};
9use crate::events::{SessionEvent, SessionEventData};
10use crate::types::{
11    ErrorOccurredHookInput, MessageOptions, PermissionRequest, PermissionRequestResult,
12    PostToolUseHookInput, PreToolUseHookInput, SessionEndHookInput, SessionHooks,
13    SessionStartHookInput, Tool, ToolResultObject, UserInputInvocation, UserInputRequest,
14    UserInputResponse, UserPromptSubmittedHookInput,
15};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicU64, Ordering};
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::{broadcast, RwLock};
22
23// =============================================================================
24// Event Handler Types
25// =============================================================================
26
27/// Handler for session events.
28pub type EventHandler = Arc<dyn Fn(&SessionEvent) + Send + Sync>;
29
30/// Handler for permission requests.
31pub type PermissionHandler =
32    Arc<dyn Fn(&PermissionRequest) -> PermissionRequestResult + Send + Sync>;
33
34/// Handler for tool invocations.
35pub type ToolHandler = Arc<dyn Fn(&str, &Value) -> ToolResultObject + Send + Sync>;
36
37/// Handler for user input requests.
38pub type UserInputHandler =
39    Arc<dyn Fn(&UserInputRequest, &UserInputInvocation) -> UserInputResponse + Send + Sync>;
40
41/// Type alias for the invoke future.
42pub type InvokeFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value>> + Send>>;
43
44type InvokeFn = dyn Fn(&str, Option<Value>) -> InvokeFuture + Send + Sync;
45
46// =============================================================================
47// Event Subscription
48// =============================================================================
49
50/// A subscription to session events.
51///
52/// Events are delivered via the broadcast channel receiver.
53pub struct EventSubscription {
54    pub receiver: broadcast::Receiver<SessionEvent>,
55}
56
57impl EventSubscription {
58    /// Receive the next event.
59    pub async fn recv(&mut self) -> std::result::Result<SessionEvent, broadcast::error::RecvError> {
60        self.receiver.recv().await
61    }
62}
63
64// =============================================================================
65// Registered Tool
66// =============================================================================
67
68/// A tool registered with the session, including its handler.
69#[derive(Clone)]
70pub struct RegisteredTool {
71    /// Tool definition.
72    pub tool: Tool,
73    /// Handler for tool invocations.
74    pub handler: Option<ToolHandler>,
75}
76
77// =============================================================================
78// Session
79// =============================================================================
80
81/// Shared session state.
82struct SessionState {
83    /// Registered tools.
84    tools: HashMap<String, RegisteredTool>,
85    /// Permission handler.
86    permission_handler: Option<PermissionHandler>,
87    /// User input handler.
88    user_input_handler: Option<UserInputHandler>,
89    /// Session hooks.
90    hooks: Option<SessionHooks>,
91    /// Callback-based event handlers.
92    event_handlers: HashMap<u64, EventHandler>,
93    /// Next handler ID.
94    next_handler_id: AtomicU64,
95}
96
97/// A Copilot conversation session.
98///
99/// Sessions maintain conversation state, handle events, and manage tool execution.
100///
101/// # Example
102///
103/// ```no_run
104/// use copilot_sdk::{Client, SessionConfig, SessionEventData};
105///
106/// #[tokio::main]
107/// async fn main() -> copilot_sdk::Result<()> {
108/// let client = Client::builder().build()?;
109/// client.start().await?;
110/// let session = client.create_session(SessionConfig::default()).await?;
111///
112/// // Subscribe to events
113/// let mut events = session.subscribe();
114///
115/// // Send a message
116/// session.send("Hello!").await?;
117///
118/// // Process events
119/// while let Ok(event) = events.recv().await {
120///     match &event.data {
121///         SessionEventData::AssistantMessage(msg) => println!("{}", msg.content),
122///         SessionEventData::SessionIdle(_) => break,
123///         _ => {}
124///     }
125/// }
126/// client.stop().await;
127/// # Ok(())
128/// # }
129/// ```
130pub struct Session {
131    /// Session ID.
132    session_id: String,
133    /// Workspace path for infinite sessions.
134    workspace_path: Option<String>,
135    /// Event broadcaster.
136    event_tx: broadcast::Sender<SessionEvent>,
137    /// Session state.
138    state: Arc<RwLock<SessionState>>,
139    /// JSON-RPC invoke function (injected by Client).
140    invoke_fn: Arc<InvokeFn>,
141}
142
143impl Session {
144    /// Create a new session.
145    ///
146    /// This is typically called by the Client when creating a session.
147    pub fn new<F>(session_id: String, workspace_path: Option<String>, invoke_fn: F) -> Self
148    where
149        F: Fn(&str, Option<Value>) -> InvokeFuture + Send + Sync + 'static,
150    {
151        let (event_tx, _) = broadcast::channel(1024);
152
153        Self {
154            session_id,
155            workspace_path,
156            event_tx,
157            state: Arc::new(RwLock::new(SessionState {
158                tools: HashMap::new(),
159                permission_handler: None,
160                user_input_handler: None,
161                hooks: None,
162                event_handlers: HashMap::new(),
163                next_handler_id: AtomicU64::new(1),
164            })),
165            invoke_fn: Arc::new(invoke_fn),
166        }
167    }
168
169    // =========================================================================
170    // Session Properties
171    // =========================================================================
172
173    /// Get the session ID.
174    pub fn session_id(&self) -> &str {
175        &self.session_id
176    }
177
178    /// Get the workspace path for infinite sessions.
179    ///
180    /// Contains checkpoints/, plan.md, and files/ subdirectories.
181    /// Returns None if infinite sessions are disabled.
182    pub fn workspace_path(&self) -> Option<&str> {
183        self.workspace_path.as_deref()
184    }
185
186    // =========================================================================
187    // Event Handling
188    // =========================================================================
189
190    /// Subscribe to session events.
191    ///
192    /// Returns a receiver that will receive all session events.
193    pub fn subscribe(&self) -> EventSubscription {
194        EventSubscription {
195            receiver: self.event_tx.subscribe(),
196        }
197    }
198
199    /// Register a callback-based event handler.
200    ///
201    /// Returns an unsubscribe closure. Call it to remove the handler.
202    /// Alternatively, use [`off`] with the internal handler ID.
203    pub async fn on<F>(&self, handler: F) -> impl FnOnce()
204    where
205        F: Fn(&SessionEvent) + Send + Sync + 'static,
206    {
207        let mut state = self.state.write().await;
208        let id = state.next_handler_id.fetch_add(1, Ordering::SeqCst);
209        state.event_handlers.insert(id, Arc::new(handler));
210
211        let state_ref = Arc::clone(&self.state);
212        move || {
213            tokio::spawn(async move {
214                state_ref.write().await.event_handlers.remove(&id);
215            });
216        }
217    }
218
219    /// Unsubscribe a callback-based event handler.
220    pub async fn off(&self, handler_id: u64) {
221        let mut state = self.state.write().await;
222        state.event_handlers.remove(&handler_id);
223    }
224
225    /// Dispatch an event to all subscribers.
226    ///
227    /// This is called by the Client when events are received.
228    pub async fn dispatch_event(&self, event: SessionEvent) {
229        // Send to broadcast channel
230        let _ = self.event_tx.send(event.clone());
231
232        // Call registered handlers
233        let state = self.state.read().await;
234        for handler in state.event_handlers.values() {
235            handler(&event);
236        }
237    }
238
239    // =========================================================================
240    // Messaging
241    // =========================================================================
242
243    /// Send a message to the session.
244    ///
245    /// Returns the message ID.
246    pub async fn send(&self, options: impl Into<MessageOptions>) -> Result<String> {
247        let options = options.into();
248        let params = serde_json::json!({
249            "sessionId": self.session_id,
250            "prompt": options.prompt,
251            "attachments": options.attachments,
252            "mode": options.mode,
253        });
254
255        let result = (self.invoke_fn)("session.send", Some(params)).await?;
256
257        result
258            .get("messageId")
259            .and_then(|v| v.as_str())
260            .map(|s| s.to_string())
261            .ok_or_else(|| CopilotError::Protocol("Missing messageId in response".into()))
262    }
263
264    /// Abort the current message processing.
265    pub async fn abort(&self) -> Result<()> {
266        let params = serde_json::json!({
267            "sessionId": self.session_id,
268        });
269
270        (self.invoke_fn)("session.abort", Some(params)).await?;
271        Ok(())
272    }
273
274    /// Get all messages in the session.
275    pub async fn get_messages(&self) -> Result<Vec<SessionEvent>> {
276        let params = serde_json::json!({
277            "sessionId": self.session_id,
278        });
279
280        let result = (self.invoke_fn)("session.getMessages", Some(params)).await?;
281
282        let events: Vec<SessionEvent> = result
283            .get("events")
284            .and_then(|v| v.as_array())
285            .map(|arr| {
286                arr.iter()
287                    .filter_map(|v| SessionEvent::from_json(v).ok())
288                    .collect()
289            })
290            .or_else(|| {
291                result
292                    .get("messages")
293                    .and_then(|v| v.as_array())
294                    .map(|arr| {
295                        arr.iter()
296                            .filter_map(|v| SessionEvent::from_json(v).ok())
297                            .collect()
298                    })
299            })
300            .ok_or_else(|| {
301                CopilotError::Protocol("Missing events in getMessages response".into())
302            })?;
303
304        Ok(events)
305    }
306
307    // =========================================================================
308    // Tool Management
309    // =========================================================================
310
311    /// Register a tool with this session.
312    pub async fn register_tool(&self, tool: Tool) {
313        self.register_tool_with_handler(tool, None).await;
314    }
315
316    /// Register a tool with a handler.
317    pub async fn register_tool_with_handler(&self, tool: Tool, handler: Option<ToolHandler>) {
318        let mut state = self.state.write().await;
319        let name = tool.name.clone();
320        state.tools.insert(name, RegisteredTool { tool, handler });
321    }
322
323    /// Register multiple tools.
324    pub async fn register_tools(&self, tools: Vec<Tool>) {
325        let mut state = self.state.write().await;
326        for tool in tools {
327            let name = tool.name.clone();
328            state.tools.insert(
329                name,
330                RegisteredTool {
331                    tool,
332                    handler: None,
333                },
334            );
335        }
336    }
337
338    /// Get a registered tool by name.
339    pub async fn get_tool(&self, name: &str) -> Option<Tool> {
340        let state = self.state.read().await;
341        state.tools.get(name).map(|rt| rt.tool.clone())
342    }
343
344    /// Get all registered tools.
345    pub async fn get_tools(&self) -> Vec<Tool> {
346        let state = self.state.read().await;
347        state.tools.values().map(|rt| rt.tool.clone()).collect()
348    }
349
350    /// Invoke a tool handler.
351    pub async fn invoke_tool(&self, name: &str, arguments: &Value) -> Result<ToolResultObject> {
352        let state = self.state.read().await;
353        let registered = state
354            .tools
355            .get(name)
356            .ok_or_else(|| CopilotError::ToolNotFound(name.to_string()))?;
357
358        let handler = registered
359            .handler
360            .as_ref()
361            .ok_or_else(|| CopilotError::ToolError(format!("No handler for tool: {}", name)))?;
362
363        Ok(handler(name, arguments))
364    }
365
366    // =========================================================================
367    // Permission Handling
368    // =========================================================================
369
370    /// Register a permission handler.
371    pub async fn register_permission_handler<F>(&self, handler: F)
372    where
373        F: Fn(&PermissionRequest) -> PermissionRequestResult + Send + Sync + 'static,
374    {
375        let mut state = self.state.write().await;
376        state.permission_handler = Some(Arc::new(handler));
377    }
378
379    /// Handle a permission request.
380    ///
381    /// Delegates to the registered permission handler, or denies by default
382    /// if no handler is set.
383    pub async fn handle_permission_request(
384        &self,
385        request: &PermissionRequest,
386    ) -> PermissionRequestResult {
387        let state = self.state.read().await;
388
389        if let Some(handler) = &state.permission_handler {
390            handler(request)
391        } else {
392            // Default: deny all permissions
393            PermissionRequestResult::denied()
394        }
395    }
396
397    // =========================================================================
398    // User Input Handling
399    // =========================================================================
400
401    /// Register a handler for user input requests from the server.
402    pub async fn register_user_input_handler<F>(&self, handler: F)
403    where
404        F: Fn(&UserInputRequest, &UserInputInvocation) -> UserInputResponse + Send + Sync + 'static,
405    {
406        let mut state = self.state.write().await;
407        state.user_input_handler = Some(Arc::new(handler));
408    }
409
410    /// Handle a user input request from the server.
411    pub async fn handle_user_input_request(
412        &self,
413        request: &UserInputRequest,
414    ) -> Result<UserInputResponse> {
415        let state = self.state.read().await;
416        if let Some(handler) = &state.user_input_handler {
417            let invocation = UserInputInvocation {
418                session_id: self.session_id.clone(),
419            };
420            Ok(handler(request, &invocation))
421        } else {
422            Err(CopilotError::Protocol(
423                "No user input handler registered".into(),
424            ))
425        }
426    }
427
428    /// Check if a user input handler is registered.
429    pub async fn has_user_input_handler(&self) -> bool {
430        let state = self.state.read().await;
431        state.user_input_handler.is_some()
432    }
433
434    // =========================================================================
435    // Hooks
436    // =========================================================================
437
438    /// Register session hooks.
439    pub async fn register_hooks(&self, hooks: SessionHooks) {
440        let mut state = self.state.write().await;
441        state.hooks = Some(hooks);
442    }
443
444    /// Check if any hooks are registered.
445    pub async fn has_hooks(&self) -> bool {
446        let state = self.state.read().await;
447        state.hooks.as_ref().is_some_and(|h| h.has_any())
448    }
449
450    /// Handle a `hooks.invoke` callback from the server.
451    ///
452    /// Dispatches to the appropriate hook handler based on `hook_type` and returns
453    /// the serialized output JSON.
454    pub async fn handle_hooks_invoke(&self, hook_type: &str, input: &Value) -> Result<Value> {
455        let state = self.state.read().await;
456        let hooks = match &state.hooks {
457            Some(h) => h,
458            None => return Ok(Value::Null),
459        };
460
461        match hook_type {
462            "preToolUse" => {
463                if let Some(handler) = &hooks.on_pre_tool_use {
464                    let hook_input: PreToolUseHookInput = serde_json::from_value(input.clone())
465                        .map_err(|e| {
466                            CopilotError::Protocol(format!("Invalid preToolUse input: {}", e))
467                        })?;
468                    let output = handler(hook_input);
469                    Ok(serde_json::to_value(output).unwrap_or(Value::Null))
470                } else {
471                    Ok(Value::Null)
472                }
473            }
474            "postToolUse" => {
475                if let Some(handler) = &hooks.on_post_tool_use {
476                    let hook_input: PostToolUseHookInput = serde_json::from_value(input.clone())
477                        .map_err(|e| {
478                            CopilotError::Protocol(format!("Invalid postToolUse input: {}", e))
479                        })?;
480                    let output = handler(hook_input);
481                    Ok(serde_json::to_value(output).unwrap_or(Value::Null))
482                } else {
483                    Ok(Value::Null)
484                }
485            }
486            "userPromptSubmitted" => {
487                if let Some(handler) = &hooks.on_user_prompt_submitted {
488                    let hook_input: UserPromptSubmittedHookInput =
489                        serde_json::from_value(input.clone()).map_err(|e| {
490                            CopilotError::Protocol(format!(
491                                "Invalid userPromptSubmitted input: {}",
492                                e
493                            ))
494                        })?;
495                    let output = handler(hook_input);
496                    Ok(serde_json::to_value(output).unwrap_or(Value::Null))
497                } else {
498                    Ok(Value::Null)
499                }
500            }
501            "sessionStart" => {
502                if let Some(handler) = &hooks.on_session_start {
503                    let hook_input: SessionStartHookInput = serde_json::from_value(input.clone())
504                        .map_err(|e| {
505                        CopilotError::Protocol(format!("Invalid sessionStart input: {}", e))
506                    })?;
507                    let output = handler(hook_input);
508                    Ok(serde_json::to_value(output).unwrap_or(Value::Null))
509                } else {
510                    Ok(Value::Null)
511                }
512            }
513            "sessionEnd" => {
514                if let Some(handler) = &hooks.on_session_end {
515                    let hook_input: SessionEndHookInput = serde_json::from_value(input.clone())
516                        .map_err(|e| {
517                            CopilotError::Protocol(format!("Invalid sessionEnd input: {}", e))
518                        })?;
519                    let output = handler(hook_input);
520                    Ok(serde_json::to_value(output).unwrap_or(Value::Null))
521                } else {
522                    Ok(Value::Null)
523                }
524            }
525            "errorOccurred" => {
526                if let Some(handler) = &hooks.on_error_occurred {
527                    let hook_input: ErrorOccurredHookInput = serde_json::from_value(input.clone())
528                        .map_err(|e| {
529                            CopilotError::Protocol(format!("Invalid errorOccurred input: {}", e))
530                        })?;
531                    let output = handler(hook_input);
532                    Ok(serde_json::to_value(output).unwrap_or(Value::Null))
533                } else {
534                    Ok(Value::Null)
535                }
536            }
537            _ => Ok(Value::Null),
538        }
539    }
540
541    // =========================================================================
542    // Lifecycle
543    // =========================================================================
544
545    /// Destroy the session.
546    pub async fn destroy(&self) -> Result<()> {
547        let params = serde_json::json!({
548            "sessionId": self.session_id,
549        });
550
551        (self.invoke_fn)("session.destroy", Some(params)).await?;
552        Ok(())
553    }
554}
555
556// =============================================================================
557// Convenience methods for waiting on events
558// =============================================================================
559
560impl Session {
561    /// Default timeout for waiting on session events (60 seconds).
562    const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
563
564    /// Wait for the session to become idle.
565    ///
566    /// Returns the last assistant message event, or None if no message was received.
567    /// Uses the specified timeout, or 60 seconds if None.
568    pub async fn wait_for_idle(&self, timeout: Option<Duration>) -> Result<Option<SessionEvent>> {
569        let timeout = timeout.unwrap_or(Self::DEFAULT_TIMEOUT);
570        let mut subscription = self.subscribe();
571        let mut last_assistant_message: Option<SessionEvent> = None;
572
573        let result = tokio::time::timeout(timeout, async {
574            loop {
575                match subscription.recv().await {
576                    Ok(event) => match &event.data {
577                        SessionEventData::AssistantMessage(_) => {
578                            last_assistant_message = Some(event);
579                        }
580                        SessionEventData::AssistantMessageDelta(_) => {
581                            // Deltas are intermediate; we track the full message
582                        }
583                        SessionEventData::SessionIdle(_) => {
584                            break;
585                        }
586                        SessionEventData::SessionError(err) => {
587                            return Err(CopilotError::Protocol(format!(
588                                "Session error: {}",
589                                err.message
590                            )));
591                        }
592                        _ => {}
593                    },
594                    Err(broadcast::error::RecvError::Closed) => {
595                        return Err(CopilotError::ConnectionClosed);
596                    }
597                    Err(broadcast::error::RecvError::Lagged(_)) => {
598                        // Continue - we missed some events but can recover
599                    }
600                }
601            }
602            Ok(())
603        })
604        .await;
605
606        match result {
607            Ok(Ok(())) => Ok(last_assistant_message),
608            Ok(Err(e)) => Err(e),
609            Err(_) => Err(CopilotError::Timeout(timeout)),
610        }
611    }
612
613    /// Send a message and wait for the complete response.
614    ///
615    /// Returns the last `AssistantMessage` event, or `None` if session
616    /// became idle without producing an assistant message.
617    /// Uses the specified timeout, or 60 seconds if None.
618    pub async fn send_and_wait(
619        &self,
620        options: impl Into<MessageOptions>,
621        timeout: Option<Duration>,
622    ) -> Result<Option<SessionEvent>> {
623        self.send(options).await?;
624        self.wait_for_idle(timeout).await
625    }
626
627    /// Send a message and wait for the response content as a string.
628    ///
629    /// Convenience method that collects all assistant message/delta content.
630    /// Uses the specified timeout, or 60 seconds if None.
631    pub async fn send_and_collect(
632        &self,
633        options: impl Into<MessageOptions>,
634        timeout: Option<Duration>,
635    ) -> Result<String> {
636        let timeout = timeout.unwrap_or(Self::DEFAULT_TIMEOUT);
637        self.send(options).await?;
638
639        let mut subscription = self.subscribe();
640        let mut content = String::new();
641
642        let result = tokio::time::timeout(timeout, async {
643            loop {
644                match subscription.recv().await {
645                    Ok(event) => match &event.data {
646                        SessionEventData::AssistantMessage(msg) => {
647                            content.push_str(&msg.content);
648                        }
649                        SessionEventData::AssistantMessageDelta(delta) => {
650                            content.push_str(&delta.delta_content);
651                        }
652                        SessionEventData::SessionIdle(_) => {
653                            break;
654                        }
655                        SessionEventData::SessionError(err) => {
656                            return Err(CopilotError::Protocol(format!(
657                                "Session error: {}",
658                                err.message
659                            )));
660                        }
661                        _ => {}
662                    },
663                    Err(broadcast::error::RecvError::Closed) => {
664                        return Err(CopilotError::ConnectionClosed);
665                    }
666                    Err(broadcast::error::RecvError::Lagged(_)) => {}
667                }
668            }
669            Ok(())
670        })
671        .await;
672
673        match result {
674            Ok(Ok(())) => Ok(content),
675            Ok(Err(e)) => Err(e),
676            Err(_) => Err(CopilotError::Timeout(timeout)),
677        }
678    }
679}
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684    use std::sync::atomic::AtomicUsize;
685
686    fn mock_invoke(_method: &str, _params: Option<Value>) -> InvokeFuture {
687        Box::pin(async { Ok(serde_json::json!({"messageId": "test-msg-123"})) })
688    }
689
690    fn mock_invoke_with_events(method: &str, _params: Option<Value>) -> InvokeFuture {
691        let method = method.to_string();
692        Box::pin(async move {
693            if method == "session.getMessages" {
694                return Ok(serde_json::json!({
695                    "events": [{
696                        "id": "evt-1",
697                        "timestamp": "2024-01-01T00:00:00Z",
698                        "type": "session.idle",
699                        "data": {}
700                    }]
701                }));
702            }
703            Ok(serde_json::json!({"messageId": "test-msg-123"}))
704        })
705    }
706
707    #[tokio::test]
708    async fn test_session_id() {
709        let session = Session::new("test-session-123".to_string(), None, mock_invoke);
710        assert_eq!(session.session_id(), "test-session-123");
711    }
712
713    #[tokio::test]
714    async fn test_workspace_path() {
715        let session = Session::new(
716            "test".to_string(),
717            Some("/tmp/workspace".to_string()),
718            mock_invoke,
719        );
720        assert_eq!(session.workspace_path(), Some("/tmp/workspace"));
721    }
722
723    #[tokio::test]
724    async fn test_register_tool() {
725        let session = Session::new("test".to_string(), None, mock_invoke);
726
727        let tool = Tool::new("my_tool").description("A test tool");
728
729        session.register_tool(tool.clone()).await;
730
731        let retrieved = session.get_tool("my_tool").await;
732        assert!(retrieved.is_some());
733        assert_eq!(retrieved.unwrap().name, "my_tool");
734    }
735
736    #[tokio::test]
737    async fn test_register_tool_with_handler() {
738        let session = Session::new("test".to_string(), None, mock_invoke);
739
740        let tool = Tool::new("echo").description("Echo tool");
741        let handler: ToolHandler = Arc::new(|_name, args| {
742            let text = args.get("text").and_then(|v| v.as_str()).unwrap_or("empty");
743            ToolResultObject::text(text)
744        });
745
746        session
747            .register_tool_with_handler(tool, Some(handler))
748            .await;
749
750        let result = session
751            .invoke_tool("echo", &serde_json::json!({"text": "hello"}))
752            .await
753            .unwrap();
754
755        assert_eq!(result.text_result_for_llm, "hello");
756    }
757
758    #[tokio::test]
759    async fn test_invoke_unknown_tool() {
760        let session = Session::new("test".to_string(), None, mock_invoke);
761
762        let result = session.invoke_tool("unknown", &serde_json::json!({})).await;
763
764        assert!(matches!(result, Err(CopilotError::ToolNotFound(_))));
765    }
766
767    #[tokio::test]
768    async fn test_event_subscription() {
769        let session = Session::new("test".to_string(), None, mock_invoke);
770
771        let mut sub1 = session.subscribe();
772        let mut sub2 = session.subscribe();
773
774        // Dispatch an event
775        let event = SessionEvent::from_json(&serde_json::json!({
776            "id": "evt-1",
777            "timestamp": "2024-01-01T00:00:00Z",
778            "type": "session.idle",
779            "data": {}
780        }))
781        .unwrap();
782
783        session.dispatch_event(event).await;
784
785        // Both subscribers should receive it
786        let received1 = sub1.recv().await.unwrap();
787        let received2 = sub2.recv().await.unwrap();
788
789        assert_eq!(received1.id, "evt-1");
790        assert_eq!(received2.id, "evt-1");
791    }
792
793    #[tokio::test]
794    async fn test_callback_handler() {
795        let session = Session::new("test".to_string(), None, mock_invoke);
796        let call_count = Arc::new(AtomicUsize::new(0));
797
798        let count_clone = Arc::clone(&call_count);
799        let unsubscribe = session
800            .on(move |_event| {
801                count_clone.fetch_add(1, Ordering::SeqCst);
802            })
803            .await;
804
805        // Dispatch events
806        let event = SessionEvent::from_json(&serde_json::json!({
807            "id": "evt-callback-1",
808            "timestamp": "2024-01-01T00:00:00Z",
809            "type": "session.idle",
810            "data": {}
811        }))
812        .unwrap();
813
814        session.dispatch_event(event).await;
815
816        assert_eq!(call_count.load(Ordering::SeqCst), 1);
817
818        // Unsubscribe
819        unsubscribe();
820    }
821
822    #[tokio::test]
823    async fn test_permission_handler() {
824        let session = Session::new("test".to_string(), None, mock_invoke);
825
826        // Default handler denies
827        let request = PermissionRequest {
828            kind: "tool_execution".to_string(),
829            tool_call_id: Some("call-123".to_string()),
830            extension_data: HashMap::new(),
831        };
832        let result = session.handle_permission_request(&request).await;
833        assert!(result.kind.contains("denied"));
834
835        // Register custom handler that approves
836        session
837            .register_permission_handler(|_req| PermissionRequestResult::approved())
838            .await;
839
840        let result = session.handle_permission_request(&request).await;
841        assert_eq!(result.kind, "approved");
842    }
843
844    #[tokio::test]
845    async fn test_get_messages_with_events_field() {
846        let session = Session::new("test".to_string(), None, mock_invoke_with_events);
847        let messages = session.get_messages().await.unwrap();
848        assert_eq!(messages.len(), 1);
849        assert!(matches!(
850            messages[0].data,
851            crate::events::SessionEventData::SessionIdle(_)
852        ));
853    }
854
855    #[tokio::test]
856    async fn test_user_input_handler() {
857        let session = Session::new("test".to_string(), None, mock_invoke);
858
859        session
860            .register_user_input_handler(|req, _inv| {
861                assert_eq!(req.question, "What color?");
862                UserInputResponse {
863                    answer: "blue".into(),
864                    was_freeform: Some(true),
865                }
866            })
867            .await;
868
869        let request = UserInputRequest {
870            question: "What color?".into(),
871            choices: Some(vec!["red".into(), "blue".into()]),
872            allow_freeform: Some(true),
873        };
874
875        let response = session.handle_user_input_request(&request).await.unwrap();
876        assert_eq!(response.answer, "blue");
877        assert_eq!(response.was_freeform, Some(true));
878    }
879
880    #[tokio::test]
881    async fn test_user_input_no_handler_errors() {
882        let session = Session::new("test".to_string(), None, mock_invoke);
883
884        let request = UserInputRequest {
885            question: "?".into(),
886            choices: None,
887            allow_freeform: None,
888        };
889
890        let result = session.handle_user_input_request(&request).await;
891        assert!(result.is_err());
892    }
893
894    #[tokio::test]
895    async fn test_register_hooks() {
896        let session = Session::new("test".to_string(), None, mock_invoke);
897
898        assert!(!session.has_hooks().await);
899
900        let hooks = crate::types::SessionHooks {
901            on_pre_tool_use: Some(Arc::new(|input| {
902                assert_eq!(input.tool_name, "my_tool");
903                crate::types::PreToolUseHookOutput {
904                    permission_decision: Some("allow".into()),
905                    ..Default::default()
906                }
907            })),
908            ..Default::default()
909        };
910
911        session.register_hooks(hooks).await;
912        assert!(session.has_hooks().await);
913    }
914
915    #[tokio::test]
916    async fn test_hooks_invoke_pre_tool_use() {
917        let session = Session::new("test".to_string(), None, mock_invoke);
918
919        let hooks = crate::types::SessionHooks {
920            on_pre_tool_use: Some(Arc::new(|_input| crate::types::PreToolUseHookOutput {
921                permission_decision: Some("allow".into()),
922                additional_context: Some("extra context".into()),
923                ..Default::default()
924            })),
925            ..Default::default()
926        };
927
928        session.register_hooks(hooks).await;
929
930        let input = serde_json::json!({
931            "timestamp": 1234567890,
932            "cwd": "/tmp",
933            "toolName": "test_tool",
934            "toolArgs": {"key": "value"}
935        });
936
937        let result = session
938            .handle_hooks_invoke("preToolUse", &input)
939            .await
940            .unwrap();
941        assert_eq!(
942            result.get("permissionDecision").and_then(|v| v.as_str()),
943            Some("allow")
944        );
945        assert_eq!(
946            result.get("additionalContext").and_then(|v| v.as_str()),
947            Some("extra context")
948        );
949    }
950
951    #[tokio::test]
952    async fn test_hooks_invoke_no_handler_returns_null() {
953        let session = Session::new("test".to_string(), None, mock_invoke);
954
955        // No hooks registered at all
956        let result = session
957            .handle_hooks_invoke("preToolUse", &serde_json::json!({}))
958            .await
959            .unwrap();
960        assert!(result.is_null());
961
962        // Hooks registered but not for this type
963        let hooks = crate::types::SessionHooks {
964            on_session_start: Some(Arc::new(|_input| {
965                crate::types::SessionStartHookOutput::default()
966            })),
967            ..Default::default()
968        };
969        session.register_hooks(hooks).await;
970
971        let input = serde_json::json!({
972            "timestamp": 1234567890,
973            "cwd": "/tmp",
974            "toolName": "test_tool",
975            "toolArgs": {}
976        });
977        let result = session
978            .handle_hooks_invoke("preToolUse", &input)
979            .await
980            .unwrap();
981        assert!(result.is_null());
982    }
983
984    #[tokio::test]
985    async fn test_hooks_invoke_unknown_type_returns_null() {
986        let session = Session::new("test".to_string(), None, mock_invoke);
987
988        let hooks = crate::types::SessionHooks {
989            on_pre_tool_use: Some(Arc::new(|_| crate::types::PreToolUseHookOutput::default())),
990            ..Default::default()
991        };
992        session.register_hooks(hooks).await;
993
994        let result = session
995            .handle_hooks_invoke("unknownHookType", &serde_json::json!({}))
996            .await
997            .unwrap();
998        assert!(result.is_null());
999    }
1000}