1use crate::error::{CopilotError, Result};
9use crate::events::{SessionEvent, SessionEventData};
10use crate::types::{
11 ErrorOccurredHookInput, MessageOptions, PermissionRequest, PermissionRequestResult,
12 PostToolUseHookInput, PreToolUseHookInput, SessionEndHookInput, SessionHooks,
13 SessionStartHookInput, Tool, ToolResultObject, UserInputInvocation, UserInputRequest,
14 UserInputResponse, UserPromptSubmittedHookInput,
15};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicU64, Ordering};
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::sync::{broadcast, RwLock};
22
23pub type EventHandler = Arc<dyn Fn(&SessionEvent) + Send + Sync>;
29
30pub type PermissionHandler =
32 Arc<dyn Fn(&PermissionRequest) -> PermissionRequestResult + Send + Sync>;
33
34pub type ToolHandler = Arc<dyn Fn(&str, &Value) -> ToolResultObject + Send + Sync>;
36
37pub type UserInputHandler =
39 Arc<dyn Fn(&UserInputRequest, &UserInputInvocation) -> UserInputResponse + Send + Sync>;
40
41pub type InvokeFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value>> + Send>>;
43
44type InvokeFn = dyn Fn(&str, Option<Value>) -> InvokeFuture + Send + Sync;
45
46pub struct EventSubscription {
54 pub receiver: broadcast::Receiver<SessionEvent>,
55}
56
57impl EventSubscription {
58 pub async fn recv(&mut self) -> std::result::Result<SessionEvent, broadcast::error::RecvError> {
60 self.receiver.recv().await
61 }
62}
63
64#[derive(Clone)]
70pub struct RegisteredTool {
71 pub tool: Tool,
73 pub handler: Option<ToolHandler>,
75}
76
77struct SessionState {
83 tools: HashMap<String, RegisteredTool>,
85 permission_handler: Option<PermissionHandler>,
87 user_input_handler: Option<UserInputHandler>,
89 hooks: Option<SessionHooks>,
91 event_handlers: HashMap<u64, EventHandler>,
93 next_handler_id: AtomicU64,
95}
96
97pub struct Session {
131 session_id: String,
133 workspace_path: Option<String>,
135 event_tx: broadcast::Sender<SessionEvent>,
137 state: Arc<RwLock<SessionState>>,
139 invoke_fn: Arc<InvokeFn>,
141}
142
143impl Session {
144 pub fn new<F>(session_id: String, workspace_path: Option<String>, invoke_fn: F) -> Self
148 where
149 F: Fn(&str, Option<Value>) -> InvokeFuture + Send + Sync + 'static,
150 {
151 let (event_tx, _) = broadcast::channel(1024);
152
153 Self {
154 session_id,
155 workspace_path,
156 event_tx,
157 state: Arc::new(RwLock::new(SessionState {
158 tools: HashMap::new(),
159 permission_handler: None,
160 user_input_handler: None,
161 hooks: None,
162 event_handlers: HashMap::new(),
163 next_handler_id: AtomicU64::new(1),
164 })),
165 invoke_fn: Arc::new(invoke_fn),
166 }
167 }
168
169 pub fn session_id(&self) -> &str {
175 &self.session_id
176 }
177
178 pub fn workspace_path(&self) -> Option<&str> {
183 self.workspace_path.as_deref()
184 }
185
186 pub fn subscribe(&self) -> EventSubscription {
194 EventSubscription {
195 receiver: self.event_tx.subscribe(),
196 }
197 }
198
199 pub async fn on<F>(&self, handler: F) -> impl FnOnce()
204 where
205 F: Fn(&SessionEvent) + Send + Sync + 'static,
206 {
207 let mut state = self.state.write().await;
208 let id = state.next_handler_id.fetch_add(1, Ordering::SeqCst);
209 state.event_handlers.insert(id, Arc::new(handler));
210
211 let state_ref = Arc::clone(&self.state);
212 move || {
213 tokio::spawn(async move {
214 state_ref.write().await.event_handlers.remove(&id);
215 });
216 }
217 }
218
219 pub async fn off(&self, handler_id: u64) {
221 let mut state = self.state.write().await;
222 state.event_handlers.remove(&handler_id);
223 }
224
225 pub async fn dispatch_event(&self, event: SessionEvent) {
229 let _ = self.event_tx.send(event.clone());
231
232 let state = self.state.read().await;
234 for handler in state.event_handlers.values() {
235 handler(&event);
236 }
237 }
238
239 pub async fn send(&self, options: impl Into<MessageOptions>) -> Result<String> {
247 let options = options.into();
248 let params = serde_json::json!({
249 "sessionId": self.session_id,
250 "prompt": options.prompt,
251 "attachments": options.attachments,
252 "mode": options.mode,
253 });
254
255 let result = (self.invoke_fn)("session.send", Some(params)).await?;
256
257 result
258 .get("messageId")
259 .and_then(|v| v.as_str())
260 .map(|s| s.to_string())
261 .ok_or_else(|| CopilotError::Protocol("Missing messageId in response".into()))
262 }
263
264 pub async fn abort(&self) -> Result<()> {
266 let params = serde_json::json!({
267 "sessionId": self.session_id,
268 });
269
270 (self.invoke_fn)("session.abort", Some(params)).await?;
271 Ok(())
272 }
273
274 pub async fn get_messages(&self) -> Result<Vec<SessionEvent>> {
276 let params = serde_json::json!({
277 "sessionId": self.session_id,
278 });
279
280 let result = (self.invoke_fn)("session.getMessages", Some(params)).await?;
281
282 let events: Vec<SessionEvent> = result
283 .get("events")
284 .and_then(|v| v.as_array())
285 .map(|arr| {
286 arr.iter()
287 .filter_map(|v| SessionEvent::from_json(v).ok())
288 .collect()
289 })
290 .or_else(|| {
291 result
292 .get("messages")
293 .and_then(|v| v.as_array())
294 .map(|arr| {
295 arr.iter()
296 .filter_map(|v| SessionEvent::from_json(v).ok())
297 .collect()
298 })
299 })
300 .ok_or_else(|| {
301 CopilotError::Protocol("Missing events in getMessages response".into())
302 })?;
303
304 Ok(events)
305 }
306
307 pub async fn register_tool(&self, tool: Tool) {
313 self.register_tool_with_handler(tool, None).await;
314 }
315
316 pub async fn register_tool_with_handler(&self, tool: Tool, handler: Option<ToolHandler>) {
318 let mut state = self.state.write().await;
319 let name = tool.name.clone();
320 state.tools.insert(name, RegisteredTool { tool, handler });
321 }
322
323 pub async fn register_tools(&self, tools: Vec<Tool>) {
325 let mut state = self.state.write().await;
326 for tool in tools {
327 let name = tool.name.clone();
328 state.tools.insert(
329 name,
330 RegisteredTool {
331 tool,
332 handler: None,
333 },
334 );
335 }
336 }
337
338 pub async fn get_tool(&self, name: &str) -> Option<Tool> {
340 let state = self.state.read().await;
341 state.tools.get(name).map(|rt| rt.tool.clone())
342 }
343
344 pub async fn get_tools(&self) -> Vec<Tool> {
346 let state = self.state.read().await;
347 state.tools.values().map(|rt| rt.tool.clone()).collect()
348 }
349
350 pub async fn invoke_tool(&self, name: &str, arguments: &Value) -> Result<ToolResultObject> {
352 let state = self.state.read().await;
353 let registered = state
354 .tools
355 .get(name)
356 .ok_or_else(|| CopilotError::ToolNotFound(name.to_string()))?;
357
358 let handler = registered
359 .handler
360 .as_ref()
361 .ok_or_else(|| CopilotError::ToolError(format!("No handler for tool: {}", name)))?;
362
363 Ok(handler(name, arguments))
364 }
365
366 pub async fn register_permission_handler<F>(&self, handler: F)
372 where
373 F: Fn(&PermissionRequest) -> PermissionRequestResult + Send + Sync + 'static,
374 {
375 let mut state = self.state.write().await;
376 state.permission_handler = Some(Arc::new(handler));
377 }
378
379 pub async fn handle_permission_request(
384 &self,
385 request: &PermissionRequest,
386 ) -> PermissionRequestResult {
387 let state = self.state.read().await;
388
389 if let Some(handler) = &state.permission_handler {
390 handler(request)
391 } else {
392 PermissionRequestResult::denied()
394 }
395 }
396
397 pub async fn register_user_input_handler<F>(&self, handler: F)
403 where
404 F: Fn(&UserInputRequest, &UserInputInvocation) -> UserInputResponse + Send + Sync + 'static,
405 {
406 let mut state = self.state.write().await;
407 state.user_input_handler = Some(Arc::new(handler));
408 }
409
410 pub async fn handle_user_input_request(
412 &self,
413 request: &UserInputRequest,
414 ) -> Result<UserInputResponse> {
415 let state = self.state.read().await;
416 if let Some(handler) = &state.user_input_handler {
417 let invocation = UserInputInvocation {
418 session_id: self.session_id.clone(),
419 };
420 Ok(handler(request, &invocation))
421 } else {
422 Err(CopilotError::Protocol(
423 "No user input handler registered".into(),
424 ))
425 }
426 }
427
428 pub async fn has_user_input_handler(&self) -> bool {
430 let state = self.state.read().await;
431 state.user_input_handler.is_some()
432 }
433
434 pub async fn register_hooks(&self, hooks: SessionHooks) {
440 let mut state = self.state.write().await;
441 state.hooks = Some(hooks);
442 }
443
444 pub async fn has_hooks(&self) -> bool {
446 let state = self.state.read().await;
447 state.hooks.as_ref().is_some_and(|h| h.has_any())
448 }
449
450 pub async fn handle_hooks_invoke(&self, hook_type: &str, input: &Value) -> Result<Value> {
455 let state = self.state.read().await;
456 let hooks = match &state.hooks {
457 Some(h) => h,
458 None => return Ok(Value::Null),
459 };
460
461 match hook_type {
462 "preToolUse" => {
463 if let Some(handler) = &hooks.on_pre_tool_use {
464 let hook_input: PreToolUseHookInput = serde_json::from_value(input.clone())
465 .map_err(|e| {
466 CopilotError::Protocol(format!("Invalid preToolUse input: {}", e))
467 })?;
468 let output = handler(hook_input);
469 Ok(serde_json::to_value(output).unwrap_or(Value::Null))
470 } else {
471 Ok(Value::Null)
472 }
473 }
474 "postToolUse" => {
475 if let Some(handler) = &hooks.on_post_tool_use {
476 let hook_input: PostToolUseHookInput = serde_json::from_value(input.clone())
477 .map_err(|e| {
478 CopilotError::Protocol(format!("Invalid postToolUse input: {}", e))
479 })?;
480 let output = handler(hook_input);
481 Ok(serde_json::to_value(output).unwrap_or(Value::Null))
482 } else {
483 Ok(Value::Null)
484 }
485 }
486 "userPromptSubmitted" => {
487 if let Some(handler) = &hooks.on_user_prompt_submitted {
488 let hook_input: UserPromptSubmittedHookInput =
489 serde_json::from_value(input.clone()).map_err(|e| {
490 CopilotError::Protocol(format!(
491 "Invalid userPromptSubmitted input: {}",
492 e
493 ))
494 })?;
495 let output = handler(hook_input);
496 Ok(serde_json::to_value(output).unwrap_or(Value::Null))
497 } else {
498 Ok(Value::Null)
499 }
500 }
501 "sessionStart" => {
502 if let Some(handler) = &hooks.on_session_start {
503 let hook_input: SessionStartHookInput = serde_json::from_value(input.clone())
504 .map_err(|e| {
505 CopilotError::Protocol(format!("Invalid sessionStart input: {}", e))
506 })?;
507 let output = handler(hook_input);
508 Ok(serde_json::to_value(output).unwrap_or(Value::Null))
509 } else {
510 Ok(Value::Null)
511 }
512 }
513 "sessionEnd" => {
514 if let Some(handler) = &hooks.on_session_end {
515 let hook_input: SessionEndHookInput = serde_json::from_value(input.clone())
516 .map_err(|e| {
517 CopilotError::Protocol(format!("Invalid sessionEnd input: {}", e))
518 })?;
519 let output = handler(hook_input);
520 Ok(serde_json::to_value(output).unwrap_or(Value::Null))
521 } else {
522 Ok(Value::Null)
523 }
524 }
525 "errorOccurred" => {
526 if let Some(handler) = &hooks.on_error_occurred {
527 let hook_input: ErrorOccurredHookInput = serde_json::from_value(input.clone())
528 .map_err(|e| {
529 CopilotError::Protocol(format!("Invalid errorOccurred input: {}", e))
530 })?;
531 let output = handler(hook_input);
532 Ok(serde_json::to_value(output).unwrap_or(Value::Null))
533 } else {
534 Ok(Value::Null)
535 }
536 }
537 _ => Ok(Value::Null),
538 }
539 }
540
541 pub async fn destroy(&self) -> Result<()> {
547 let params = serde_json::json!({
548 "sessionId": self.session_id,
549 });
550
551 (self.invoke_fn)("session.destroy", Some(params)).await?;
552 Ok(())
553 }
554}
555
556impl Session {
561 const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
563
564 pub async fn wait_for_idle(&self, timeout: Option<Duration>) -> Result<Option<SessionEvent>> {
569 let timeout = timeout.unwrap_or(Self::DEFAULT_TIMEOUT);
570 let mut subscription = self.subscribe();
571 let mut last_assistant_message: Option<SessionEvent> = None;
572
573 let result = tokio::time::timeout(timeout, async {
574 loop {
575 match subscription.recv().await {
576 Ok(event) => match &event.data {
577 SessionEventData::AssistantMessage(_) => {
578 last_assistant_message = Some(event);
579 }
580 SessionEventData::AssistantMessageDelta(_) => {
581 }
583 SessionEventData::SessionIdle(_) => {
584 break;
585 }
586 SessionEventData::SessionError(err) => {
587 return Err(CopilotError::Protocol(format!(
588 "Session error: {}",
589 err.message
590 )));
591 }
592 _ => {}
593 },
594 Err(broadcast::error::RecvError::Closed) => {
595 return Err(CopilotError::ConnectionClosed);
596 }
597 Err(broadcast::error::RecvError::Lagged(_)) => {
598 }
600 }
601 }
602 Ok(())
603 })
604 .await;
605
606 match result {
607 Ok(Ok(())) => Ok(last_assistant_message),
608 Ok(Err(e)) => Err(e),
609 Err(_) => Err(CopilotError::Timeout(timeout)),
610 }
611 }
612
613 pub async fn send_and_wait(
619 &self,
620 options: impl Into<MessageOptions>,
621 timeout: Option<Duration>,
622 ) -> Result<Option<SessionEvent>> {
623 self.send(options).await?;
624 self.wait_for_idle(timeout).await
625 }
626
627 pub async fn send_and_collect(
632 &self,
633 options: impl Into<MessageOptions>,
634 timeout: Option<Duration>,
635 ) -> Result<String> {
636 let timeout = timeout.unwrap_or(Self::DEFAULT_TIMEOUT);
637 self.send(options).await?;
638
639 let mut subscription = self.subscribe();
640 let mut content = String::new();
641
642 let result = tokio::time::timeout(timeout, async {
643 loop {
644 match subscription.recv().await {
645 Ok(event) => match &event.data {
646 SessionEventData::AssistantMessage(msg) => {
647 content.push_str(&msg.content);
648 }
649 SessionEventData::AssistantMessageDelta(delta) => {
650 content.push_str(&delta.delta_content);
651 }
652 SessionEventData::SessionIdle(_) => {
653 break;
654 }
655 SessionEventData::SessionError(err) => {
656 return Err(CopilotError::Protocol(format!(
657 "Session error: {}",
658 err.message
659 )));
660 }
661 _ => {}
662 },
663 Err(broadcast::error::RecvError::Closed) => {
664 return Err(CopilotError::ConnectionClosed);
665 }
666 Err(broadcast::error::RecvError::Lagged(_)) => {}
667 }
668 }
669 Ok(())
670 })
671 .await;
672
673 match result {
674 Ok(Ok(())) => Ok(content),
675 Ok(Err(e)) => Err(e),
676 Err(_) => Err(CopilotError::Timeout(timeout)),
677 }
678 }
679}
680
681#[cfg(test)]
682mod tests {
683 use super::*;
684 use std::sync::atomic::AtomicUsize;
685
686 fn mock_invoke(_method: &str, _params: Option<Value>) -> InvokeFuture {
687 Box::pin(async { Ok(serde_json::json!({"messageId": "test-msg-123"})) })
688 }
689
690 fn mock_invoke_with_events(method: &str, _params: Option<Value>) -> InvokeFuture {
691 let method = method.to_string();
692 Box::pin(async move {
693 if method == "session.getMessages" {
694 return Ok(serde_json::json!({
695 "events": [{
696 "id": "evt-1",
697 "timestamp": "2024-01-01T00:00:00Z",
698 "type": "session.idle",
699 "data": {}
700 }]
701 }));
702 }
703 Ok(serde_json::json!({"messageId": "test-msg-123"}))
704 })
705 }
706
707 #[tokio::test]
708 async fn test_session_id() {
709 let session = Session::new("test-session-123".to_string(), None, mock_invoke);
710 assert_eq!(session.session_id(), "test-session-123");
711 }
712
713 #[tokio::test]
714 async fn test_workspace_path() {
715 let session = Session::new(
716 "test".to_string(),
717 Some("/tmp/workspace".to_string()),
718 mock_invoke,
719 );
720 assert_eq!(session.workspace_path(), Some("/tmp/workspace"));
721 }
722
723 #[tokio::test]
724 async fn test_register_tool() {
725 let session = Session::new("test".to_string(), None, mock_invoke);
726
727 let tool = Tool::new("my_tool").description("A test tool");
728
729 session.register_tool(tool.clone()).await;
730
731 let retrieved = session.get_tool("my_tool").await;
732 assert!(retrieved.is_some());
733 assert_eq!(retrieved.unwrap().name, "my_tool");
734 }
735
736 #[tokio::test]
737 async fn test_register_tool_with_handler() {
738 let session = Session::new("test".to_string(), None, mock_invoke);
739
740 let tool = Tool::new("echo").description("Echo tool");
741 let handler: ToolHandler = Arc::new(|_name, args| {
742 let text = args.get("text").and_then(|v| v.as_str()).unwrap_or("empty");
743 ToolResultObject::text(text)
744 });
745
746 session
747 .register_tool_with_handler(tool, Some(handler))
748 .await;
749
750 let result = session
751 .invoke_tool("echo", &serde_json::json!({"text": "hello"}))
752 .await
753 .unwrap();
754
755 assert_eq!(result.text_result_for_llm, "hello");
756 }
757
758 #[tokio::test]
759 async fn test_invoke_unknown_tool() {
760 let session = Session::new("test".to_string(), None, mock_invoke);
761
762 let result = session.invoke_tool("unknown", &serde_json::json!({})).await;
763
764 assert!(matches!(result, Err(CopilotError::ToolNotFound(_))));
765 }
766
767 #[tokio::test]
768 async fn test_event_subscription() {
769 let session = Session::new("test".to_string(), None, mock_invoke);
770
771 let mut sub1 = session.subscribe();
772 let mut sub2 = session.subscribe();
773
774 let event = SessionEvent::from_json(&serde_json::json!({
776 "id": "evt-1",
777 "timestamp": "2024-01-01T00:00:00Z",
778 "type": "session.idle",
779 "data": {}
780 }))
781 .unwrap();
782
783 session.dispatch_event(event).await;
784
785 let received1 = sub1.recv().await.unwrap();
787 let received2 = sub2.recv().await.unwrap();
788
789 assert_eq!(received1.id, "evt-1");
790 assert_eq!(received2.id, "evt-1");
791 }
792
793 #[tokio::test]
794 async fn test_callback_handler() {
795 let session = Session::new("test".to_string(), None, mock_invoke);
796 let call_count = Arc::new(AtomicUsize::new(0));
797
798 let count_clone = Arc::clone(&call_count);
799 let unsubscribe = session
800 .on(move |_event| {
801 count_clone.fetch_add(1, Ordering::SeqCst);
802 })
803 .await;
804
805 let event = SessionEvent::from_json(&serde_json::json!({
807 "id": "evt-callback-1",
808 "timestamp": "2024-01-01T00:00:00Z",
809 "type": "session.idle",
810 "data": {}
811 }))
812 .unwrap();
813
814 session.dispatch_event(event).await;
815
816 assert_eq!(call_count.load(Ordering::SeqCst), 1);
817
818 unsubscribe();
820 }
821
822 #[tokio::test]
823 async fn test_permission_handler() {
824 let session = Session::new("test".to_string(), None, mock_invoke);
825
826 let request = PermissionRequest {
828 kind: "tool_execution".to_string(),
829 tool_call_id: Some("call-123".to_string()),
830 extension_data: HashMap::new(),
831 };
832 let result = session.handle_permission_request(&request).await;
833 assert!(result.kind.contains("denied"));
834
835 session
837 .register_permission_handler(|_req| PermissionRequestResult::approved())
838 .await;
839
840 let result = session.handle_permission_request(&request).await;
841 assert_eq!(result.kind, "approved");
842 }
843
844 #[tokio::test]
845 async fn test_get_messages_with_events_field() {
846 let session = Session::new("test".to_string(), None, mock_invoke_with_events);
847 let messages = session.get_messages().await.unwrap();
848 assert_eq!(messages.len(), 1);
849 assert!(matches!(
850 messages[0].data,
851 crate::events::SessionEventData::SessionIdle(_)
852 ));
853 }
854
855 #[tokio::test]
856 async fn test_user_input_handler() {
857 let session = Session::new("test".to_string(), None, mock_invoke);
858
859 session
860 .register_user_input_handler(|req, _inv| {
861 assert_eq!(req.question, "What color?");
862 UserInputResponse {
863 answer: "blue".into(),
864 was_freeform: Some(true),
865 }
866 })
867 .await;
868
869 let request = UserInputRequest {
870 question: "What color?".into(),
871 choices: Some(vec!["red".into(), "blue".into()]),
872 allow_freeform: Some(true),
873 };
874
875 let response = session.handle_user_input_request(&request).await.unwrap();
876 assert_eq!(response.answer, "blue");
877 assert_eq!(response.was_freeform, Some(true));
878 }
879
880 #[tokio::test]
881 async fn test_user_input_no_handler_errors() {
882 let session = Session::new("test".to_string(), None, mock_invoke);
883
884 let request = UserInputRequest {
885 question: "?".into(),
886 choices: None,
887 allow_freeform: None,
888 };
889
890 let result = session.handle_user_input_request(&request).await;
891 assert!(result.is_err());
892 }
893
894 #[tokio::test]
895 async fn test_register_hooks() {
896 let session = Session::new("test".to_string(), None, mock_invoke);
897
898 assert!(!session.has_hooks().await);
899
900 let hooks = crate::types::SessionHooks {
901 on_pre_tool_use: Some(Arc::new(|input| {
902 assert_eq!(input.tool_name, "my_tool");
903 crate::types::PreToolUseHookOutput {
904 permission_decision: Some("allow".into()),
905 ..Default::default()
906 }
907 })),
908 ..Default::default()
909 };
910
911 session.register_hooks(hooks).await;
912 assert!(session.has_hooks().await);
913 }
914
915 #[tokio::test]
916 async fn test_hooks_invoke_pre_tool_use() {
917 let session = Session::new("test".to_string(), None, mock_invoke);
918
919 let hooks = crate::types::SessionHooks {
920 on_pre_tool_use: Some(Arc::new(|_input| crate::types::PreToolUseHookOutput {
921 permission_decision: Some("allow".into()),
922 additional_context: Some("extra context".into()),
923 ..Default::default()
924 })),
925 ..Default::default()
926 };
927
928 session.register_hooks(hooks).await;
929
930 let input = serde_json::json!({
931 "timestamp": 1234567890,
932 "cwd": "/tmp",
933 "toolName": "test_tool",
934 "toolArgs": {"key": "value"}
935 });
936
937 let result = session
938 .handle_hooks_invoke("preToolUse", &input)
939 .await
940 .unwrap();
941 assert_eq!(
942 result.get("permissionDecision").and_then(|v| v.as_str()),
943 Some("allow")
944 );
945 assert_eq!(
946 result.get("additionalContext").and_then(|v| v.as_str()),
947 Some("extra context")
948 );
949 }
950
951 #[tokio::test]
952 async fn test_hooks_invoke_no_handler_returns_null() {
953 let session = Session::new("test".to_string(), None, mock_invoke);
954
955 let result = session
957 .handle_hooks_invoke("preToolUse", &serde_json::json!({}))
958 .await
959 .unwrap();
960 assert!(result.is_null());
961
962 let hooks = crate::types::SessionHooks {
964 on_session_start: Some(Arc::new(|_input| {
965 crate::types::SessionStartHookOutput::default()
966 })),
967 ..Default::default()
968 };
969 session.register_hooks(hooks).await;
970
971 let input = serde_json::json!({
972 "timestamp": 1234567890,
973 "cwd": "/tmp",
974 "toolName": "test_tool",
975 "toolArgs": {}
976 });
977 let result = session
978 .handle_hooks_invoke("preToolUse", &input)
979 .await
980 .unwrap();
981 assert!(result.is_null());
982 }
983
984 #[tokio::test]
985 async fn test_hooks_invoke_unknown_type_returns_null() {
986 let session = Session::new("test".to_string(), None, mock_invoke);
987
988 let hooks = crate::types::SessionHooks {
989 on_pre_tool_use: Some(Arc::new(|_| crate::types::PreToolUseHookOutput::default())),
990 ..Default::default()
991 };
992 session.register_hooks(hooks).await;
993
994 let result = session
995 .handle_hooks_invoke("unknownHookType", &serde_json::json!({}))
996 .await
997 .unwrap();
998 assert!(result.is_null());
999 }
1000}