1use std::collections::HashMap;
9use std::sync::Arc;
10
11use serde_json::Value;
12use tokio::net::TcpStream;
13use tokio::process::{Child, Command};
14use tokio::sync::Mutex;
15
16use crate::jsonrpc::{JsonRpcClient, NotificationHandler, RequestHandler};
17use crate::sdk_protocol_version::get_sdk_protocol_version;
18use crate::session::CopilotSession;
19use crate::types::*;
20use crate::CopilotError;
21
22pub struct CopilotClient {
66 options: CopilotClientOptions,
67 state: Arc<Mutex<ConnectionState>>,
68 rpc_client: Arc<Mutex<Option<Arc<JsonRpcClient>>>>,
69 cli_process: Arc<Mutex<Option<Child>>>,
70 sessions: Arc<Mutex<HashMap<String, Arc<CopilotSession>>>>,
71 is_external_server: bool,
72 models_cache: Arc<Mutex<Option<Vec<ModelInfo>>>>,
73 lifecycle_handlers: Arc<Mutex<Vec<(u64, Arc<dyn Fn(SessionLifecycleEvent) + Send + Sync>)>>>,
74 next_lifecycle_handler_id: Arc<Mutex<u64>>,
75}
76
77impl CopilotClient {
78 pub fn new(options: CopilotClientOptions) -> Self {
83 let is_external = options.cli_url.is_some();
84
85 Self {
86 options,
87 state: Arc::new(Mutex::new(ConnectionState::Disconnected)),
88 rpc_client: Arc::new(Mutex::new(None)),
89 cli_process: Arc::new(Mutex::new(None)),
90 sessions: Arc::new(Mutex::new(HashMap::new())),
91 is_external_server: is_external,
92 models_cache: Arc::new(Mutex::new(None)),
93 lifecycle_handlers: Arc::new(Mutex::new(Vec::new())),
94 next_lifecycle_handler_id: Arc::new(Mutex::new(0)),
95 }
96 }
97
98 pub async fn get_state(&self) -> ConnectionState {
100 *self.state.lock().await
101 }
102
103 pub async fn start(&self) -> Result<(), CopilotError> {
114 {
115 let state = self.state.lock().await;
116 if *state == ConnectionState::Connected {
117 return Ok(());
118 }
119 }
120
121 {
122 let mut state = self.state.lock().await;
123 *state = ConnectionState::Connecting;
124 }
125
126 let result = self.do_start().await;
127 match &result {
128 Ok(()) => {
129 let mut state = self.state.lock().await;
130 *state = ConnectionState::Connected;
131 }
132 Err(_) => {
133 let mut state = self.state.lock().await;
134 *state = ConnectionState::Error;
135 }
136 }
137 result
138 }
139
140 async fn do_start(&self) -> Result<(), CopilotError> {
141 if self.is_external_server {
142 self.connect_to_external_server().await?;
144 } else if self.options.use_stdio {
145 self.start_cli_stdio().await?;
147 } else {
148 self.start_cli_tcp().await?;
150 }
151
152 self.verify_protocol_version().await?;
154
155 Ok(())
156 }
157
158 pub async fn stop(&self) -> Result<Vec<CopilotError>, CopilotError> {
162 let mut errors = Vec::new();
163
164 let session_ids: Vec<String> = {
166 let sessions = self.sessions.lock().await;
167 sessions.keys().cloned().collect()
168 };
169
170 for session_id in session_ids {
171 let session = {
172 let sessions = self.sessions.lock().await;
173 sessions.get(&session_id).cloned()
174 };
175 if let Some(session) = session {
176 for attempt in 1..=3 {
177 match session.destroy().await {
178 Ok(()) => break,
179 Err(e) => {
180 if attempt == 3 {
181 errors.push(CopilotError::SessionError(format!(
182 "Failed to destroy session {} after 3 attempts: {}",
183 session_id, e
184 )));
185 } else {
186 let delay = 100 * (1u64 << (attempt - 1));
187 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
188 }
189 }
190 }
191 }
192 }
193 }
194
195 {
196 let mut sessions = self.sessions.lock().await;
197 sessions.clear();
198 }
199
200 {
202 let mut rpc = self.rpc_client.lock().await;
203 if let Some(client) = rpc.take() {
204 drop(client);
208 }
209 }
210
211 {
213 let mut cache = self.models_cache.lock().await;
214 *cache = None;
215 }
216
217 if !self.is_external_server {
219 let mut proc = self.cli_process.lock().await;
220 if let Some(ref mut child) = *proc {
221 let _ = child.kill().await;
222 }
223 *proc = None;
224 }
225
226 {
227 let mut state = self.state.lock().await;
228 *state = ConnectionState::Disconnected;
229 }
230
231 Ok(errors)
232 }
233
234 pub async fn force_stop(&self) {
236 {
238 let mut sessions = self.sessions.lock().await;
239 sessions.clear();
240 }
241
242 {
244 let mut rpc = self.rpc_client.lock().await;
245 *rpc = None;
246 }
247
248 {
250 let mut cache = self.models_cache.lock().await;
251 *cache = None;
252 }
253
254 if !self.is_external_server {
256 let mut proc = self.cli_process.lock().await;
257 if let Some(ref mut child) = *proc {
258 let _ = child.kill().await;
259 }
260 *proc = None;
261 }
262
263 {
264 let mut state = self.state.lock().await;
265 *state = ConnectionState::Disconnected;
266 }
267 }
268
269 pub async fn create_session(
275 &self,
276 config: SessionConfig,
277 ) -> Result<Arc<CopilotSession>, CopilotError> {
278 self.ensure_connected().await?;
279
280 let rpc = self.get_rpc_client().await?;
281
282 let params = serde_json::to_value(&config)
283 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
284
285 let response = rpc.request("session.create", params, None).await?;
286 let session_id = response
287 .get("sessionId")
288 .and_then(|v| v.as_str())
289 .ok_or_else(|| CopilotError::Protocol("Missing sessionId in response".to_string()))?
290 .to_string();
291 let workspace_path = response
292 .get("workspacePath")
293 .and_then(|v| v.as_str())
294 .map(|s| s.to_string());
295
296 let session = Arc::new(CopilotSession::new(
297 session_id.clone(),
298 rpc.clone(),
299 workspace_path,
300 ));
301
302 {
303 let mut sessions = self.sessions.lock().await;
304 sessions.insert(session_id, Arc::clone(&session));
305 }
306
307 Ok(session)
308 }
309
310 pub async fn resume_session(
312 &self,
313 config: ResumeSessionConfig,
314 ) -> Result<Arc<CopilotSession>, CopilotError> {
315 self.ensure_connected().await?;
316
317 let rpc = self.get_rpc_client().await?;
318
319 let params = serde_json::to_value(&config)
320 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
321
322 let response = rpc.request("session.resume", params, None).await?;
323 let session_id = response
324 .get("sessionId")
325 .and_then(|v| v.as_str())
326 .ok_or_else(|| CopilotError::Protocol("Missing sessionId in response".to_string()))?
327 .to_string();
328 let workspace_path = response
329 .get("workspacePath")
330 .and_then(|v| v.as_str())
331 .map(|s| s.to_string());
332
333 let session = Arc::new(CopilotSession::new(
334 session_id.clone(),
335 rpc.clone(),
336 workspace_path,
337 ));
338
339 {
340 let mut sessions = self.sessions.lock().await;
341 sessions.insert(session_id, Arc::clone(&session));
342 }
343
344 Ok(session)
345 }
346
347 pub async fn get_last_session_id(&self) -> Result<Option<String>, CopilotError> {
349 let rpc = self.get_rpc_client().await?;
350 let response = rpc
351 .request("session.getLastId", serde_json::json!({}), None)
352 .await?;
353 Ok(response
354 .get("sessionId")
355 .and_then(|v| v.as_str())
356 .map(|s| s.to_string()))
357 }
358
359 pub async fn delete_session(&self, session_id: &str) -> Result<(), CopilotError> {
361 let rpc = self.get_rpc_client().await?;
362 let response = rpc
363 .request(
364 "session.delete",
365 serde_json::json!({ "sessionId": session_id }),
366 None,
367 )
368 .await?;
369
370 let success = response.get("success").and_then(|v| v.as_bool()).unwrap_or(false);
371 if !success {
372 let error = response
373 .get("error")
374 .and_then(|v| v.as_str())
375 .unwrap_or("Unknown error");
376 return Err(CopilotError::SessionError(format!(
377 "Failed to delete session {}: {}",
378 session_id, error
379 )));
380 }
381
382 {
383 let mut sessions = self.sessions.lock().await;
384 sessions.remove(session_id);
385 }
386
387 Ok(())
388 }
389
390 pub async fn list_sessions(&self) -> Result<Vec<SessionMetadata>, CopilotError> {
392 let rpc = self.get_rpc_client().await?;
393 let response = rpc
394 .request("session.list", serde_json::json!({}), None)
395 .await?;
396 let sessions: Vec<SessionMetadata> = serde_json::from_value(
397 response
398 .get("sessions")
399 .cloned()
400 .unwrap_or(Value::Array(vec![])),
401 )
402 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
403 Ok(sessions)
404 }
405
406 pub async fn ping(&self, message: Option<&str>) -> Result<PingResponse, CopilotError> {
412 let rpc = self.get_rpc_client().await?;
413 let params = serde_json::json!({ "message": message });
414 let response = rpc.request("ping", params, None).await?;
415 serde_json::from_value(response).map_err(|e| CopilotError::Serialization(e.to_string()))
416 }
417
418 pub async fn get_status(&self) -> Result<GetStatusResponse, CopilotError> {
420 let rpc = self.get_rpc_client().await?;
421 let response = rpc
422 .request("status.get", serde_json::json!({}), None)
423 .await?;
424 serde_json::from_value(response).map_err(|e| CopilotError::Serialization(e.to_string()))
425 }
426
427 pub async fn get_auth_status(&self) -> Result<GetAuthStatusResponse, CopilotError> {
429 let rpc = self.get_rpc_client().await?;
430 let response = rpc
431 .request("auth.getStatus", serde_json::json!({}), None)
432 .await?;
433 serde_json::from_value(response).map_err(|e| CopilotError::Serialization(e.to_string()))
434 }
435
436 pub async fn list_models(&self) -> Result<Vec<ModelInfo>, CopilotError> {
440 {
442 let cache = self.models_cache.lock().await;
443 if let Some(ref models) = *cache {
444 return Ok(models.clone());
445 }
446 }
447
448 let rpc = self.get_rpc_client().await?;
449 let response = rpc
450 .request("models.list", serde_json::json!({}), None)
451 .await?;
452 let models_response: HashMap<String, Vec<ModelInfo>> =
453 serde_json::from_value(response)
454 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
455 let models = models_response.get("models").cloned().unwrap_or_default();
456
457 {
459 let mut cache = self.models_cache.lock().await;
460 *cache = Some(models.clone());
461 }
462
463 Ok(models)
464 }
465
466 pub async fn get_foreground_session_id(&self) -> Result<Option<String>, CopilotError> {
472 let rpc = self.get_rpc_client().await?;
473 let response = rpc
474 .request("session.getForeground", serde_json::json!({}), None)
475 .await?;
476 Ok(response
477 .get("sessionId")
478 .and_then(|v| v.as_str())
479 .map(|s| s.to_string()))
480 }
481
482 pub async fn set_foreground_session_id(
484 &self,
485 session_id: &str,
486 ) -> Result<(), CopilotError> {
487 let rpc = self.get_rpc_client().await?;
488 let response = rpc
489 .request(
490 "session.setForeground",
491 serde_json::json!({ "sessionId": session_id }),
492 None,
493 )
494 .await?;
495 let success = response.get("success").and_then(|v| v.as_bool()).unwrap_or(false);
496 if !success {
497 let error = response
498 .get("error")
499 .and_then(|v| v.as_str())
500 .unwrap_or("Unknown error");
501 return Err(CopilotError::SessionError(error.to_string()));
502 }
503 Ok(())
504 }
505
506 pub async fn on_lifecycle<F>(&self, handler: F) -> u64
514 where
515 F: Fn(SessionLifecycleEvent) + Send + Sync + 'static,
516 {
517 let handler_id = {
518 let mut id = self.next_lifecycle_handler_id.lock().await;
519 *id += 1;
520 *id
521 };
522
523 let mut handlers = self.lifecycle_handlers.lock().await;
524 handlers.push((handler_id, Arc::new(handler)));
525 handler_id
526 }
527
528 pub async fn off_lifecycle(&self, handler_id: u64) {
530 let mut handlers = self.lifecycle_handlers.lock().await;
531 handlers.retain(|(id, _)| *id != handler_id);
532 }
533
534 async fn ensure_connected(&self) -> Result<(), CopilotError> {
539 let state = self.state.lock().await;
540 if *state == ConnectionState::Connected {
541 return Ok(());
542 }
543 drop(state);
544
545 if self.options.auto_start {
546 self.start().await
547 } else {
548 Err(CopilotError::NotConnected)
549 }
550 }
551
552 async fn get_rpc_client(&self) -> Result<Arc<JsonRpcClient>, CopilotError> {
553 let rpc = self.rpc_client.lock().await;
554 rpc.clone().ok_or(CopilotError::NotConnected)
555 }
556
557 async fn start_cli_stdio(&self) -> Result<(), CopilotError> {
558 let cli_path = self
559 .options
560 .cli_path
561 .as_deref()
562 .ok_or_else(|| CopilotError::Configuration("cli_path is required".to_string()))?;
563
564 let mut args = self.options.cli_args.clone();
565 args.extend_from_slice(&[
566 "--headless".to_string(),
567 "--no-auto-update".to_string(),
568 "--log-level".to_string(),
569 self.options.log_level.clone(),
570 "--stdio".to_string(),
571 ]);
572
573 if self.options.github_token.is_some() {
575 args.push("--auth-token-env".to_string());
576 args.push("COPILOT_SDK_AUTH_TOKEN".to_string());
577 }
578 let use_logged_in = self
579 .options
580 .use_logged_in_user
581 .unwrap_or(self.options.github_token.is_none());
582 if !use_logged_in {
583 args.push("--no-auto-login".to_string());
584 }
585
586 let mut cmd = Command::new(cli_path);
587 cmd.args(&args)
588 .stdin(std::process::Stdio::piped())
589 .stdout(std::process::Stdio::piped())
590 .stderr(std::process::Stdio::piped());
591
592 if let Some(ref cwd) = self.options.cwd {
593 cmd.current_dir(cwd);
594 }
595
596 if let Some(ref env) = self.options.env {
598 cmd.envs(env.iter());
599 }
600 if let Some(ref token) = self.options.github_token {
601 cmd.env("COPILOT_SDK_AUTH_TOKEN", token);
602 }
603
604 let mut child = cmd.spawn().map_err(|e| {
605 CopilotError::ProcessSpawn(format!("Failed to spawn CLI process: {}", e))
606 })?;
607
608 let stdin = child.stdin.take().ok_or_else(|| {
609 CopilotError::ProcessSpawn("Failed to capture stdin".to_string())
610 })?;
611 let stdout = child.stdout.take().ok_or_else(|| {
612 CopilotError::ProcessSpawn("Failed to capture stdout".to_string())
613 })?;
614
615 let stderr = child.stderr.take();
617 if let Some(stderr) = stderr {
618 tokio::spawn(async move {
619 use tokio::io::AsyncBufReadExt;
620 let reader = tokio::io::BufReader::new(stderr);
621 let mut lines = reader.lines();
622 while let Ok(Some(line)) = lines.next_line().await {
623 if !line.trim().is_empty() {
624 eprintln!("[CLI subprocess] {}", line);
625 }
626 }
627 });
628 }
629
630 let rpc_client = Arc::new(JsonRpcClient::new(stdout, stdin));
631
632 self.attach_connection_handlers(&rpc_client).await;
634
635 {
636 let mut rpc = self.rpc_client.lock().await;
637 *rpc = Some(rpc_client);
638 }
639 {
640 let mut proc = self.cli_process.lock().await;
641 *proc = Some(child);
642 }
643
644 Ok(())
645 }
646
647 async fn start_cli_tcp(&self) -> Result<(), CopilotError> {
648 let cli_path = self
649 .options
650 .cli_path
651 .as_deref()
652 .ok_or_else(|| CopilotError::Configuration("cli_path is required".to_string()))?;
653
654 let mut args = self.options.cli_args.clone();
655 args.extend_from_slice(&[
656 "--headless".to_string(),
657 "--no-auto-update".to_string(),
658 "--log-level".to_string(),
659 self.options.log_level.clone(),
660 ]);
661
662 if self.options.port > 0 {
663 args.push("--port".to_string());
664 args.push(self.options.port.to_string());
665 }
666
667 let mut cmd = Command::new(cli_path);
668 cmd.args(&args)
669 .stdin(std::process::Stdio::null())
670 .stdout(std::process::Stdio::piped())
671 .stderr(std::process::Stdio::piped());
672
673 if let Some(ref cwd) = self.options.cwd {
674 cmd.current_dir(cwd);
675 }
676 if let Some(ref env) = self.options.env {
677 cmd.envs(env.iter());
678 }
679
680 let mut child = cmd.spawn().map_err(|e| {
681 CopilotError::ProcessSpawn(format!("Failed to spawn CLI process: {}", e))
682 })?;
683
684 let stdout = child.stdout.take().ok_or_else(|| {
686 CopilotError::ProcessSpawn("Failed to capture stdout".to_string())
687 })?;
688
689 let port = {
690 use tokio::io::AsyncBufReadExt;
691 let reader = tokio::io::BufReader::new(stdout);
692 let mut lines = reader.lines();
693 let mut found_port = None;
694
695 let timeout = tokio::time::timeout(std::time::Duration::from_secs(10), async {
696 while let Ok(Some(line)) = lines.next_line().await {
697 if let Some(idx) = line.to_lowercase().find("listening on port ") {
698 let port_str = &line[idx + "listening on port ".len()..];
699 if let Ok(p) = port_str.trim().parse::<u16>() {
700 found_port = Some(p);
701 break;
702 }
703 }
704 }
705 found_port
706 })
707 .await;
708
709 match timeout {
710 Ok(Some(p)) => p,
711 _ => {
712 let _ = child.kill().await;
713 return Err(CopilotError::Timeout(10000));
714 }
715 }
716 };
717
718 let stream = TcpStream::connect(format!("localhost:{}", port))
720 .await
721 .map_err(|e| CopilotError::Connection(format!("Failed to connect via TCP: {}", e)))?;
722
723 let (reader, writer) = stream.into_split();
724 let rpc_client = Arc::new(JsonRpcClient::new(reader, writer));
725
726 self.attach_connection_handlers(&rpc_client).await;
727
728 {
729 let mut rpc = self.rpc_client.lock().await;
730 *rpc = Some(rpc_client);
731 }
732 {
733 let mut proc = self.cli_process.lock().await;
734 *proc = Some(child);
735 }
736
737 Ok(())
738 }
739
740 async fn connect_to_external_server(&self) -> Result<(), CopilotError> {
741 let url = self
742 .options
743 .cli_url
744 .as_deref()
745 .ok_or_else(|| CopilotError::Configuration("cli_url is required".to_string()))?;
746
747 let (host, port) = Self::parse_cli_url(url)?;
748
749 let stream = TcpStream::connect(format!("{}:{}", host, port))
750 .await
751 .map_err(|e| {
752 CopilotError::Connection(format!("Failed to connect to {}: {}", url, e))
753 })?;
754
755 let (reader, writer) = stream.into_split();
756 let rpc_client = Arc::new(JsonRpcClient::new(reader, writer));
757
758 self.attach_connection_handlers(&rpc_client).await;
759
760 {
761 let mut rpc = self.rpc_client.lock().await;
762 *rpc = Some(rpc_client);
763 }
764
765 Ok(())
766 }
767
768 fn parse_cli_url(url: &str) -> Result<(String, u16), CopilotError> {
769 let clean = url
770 .trim_start_matches("http://")
771 .trim_start_matches("https://");
772
773 if let Ok(port) = clean.parse::<u16>() {
775 return Ok(("localhost".to_string(), port));
776 }
777
778 let parts: Vec<&str> = clean.split(':').collect();
780 if parts.len() != 2 {
781 return Err(CopilotError::Configuration(format!(
782 "Invalid cli_url format: {}. Expected host:port, http://host:port, or port",
783 url
784 )));
785 }
786
787 let host = if parts[0].is_empty() {
788 "localhost".to_string()
789 } else {
790 parts[0].to_string()
791 };
792 let port: u16 = parts[1].parse().map_err(|_| {
793 CopilotError::Configuration(format!("Invalid port in cli_url: {}", url))
794 })?;
795
796 Ok((host, port))
797 }
798
799 async fn verify_protocol_version(&self) -> Result<(), CopilotError> {
800 let expected_version = get_sdk_protocol_version();
801 let ping_response = self.ping(None).await?;
802
803 match ping_response.protocol_version {
804 None => Err(CopilotError::ProtocolMismatch {
805 expected: expected_version,
806 actual: None,
807 }),
808 Some(server_version) if server_version != expected_version => {
809 Err(CopilotError::ProtocolMismatch {
810 expected: expected_version,
811 actual: Some(server_version),
812 })
813 }
814 _ => Ok(()),
815 }
816 }
817
818 async fn attach_connection_handlers(&self, rpc_client: &Arc<JsonRpcClient>) {
823 let sessions = Arc::clone(&self.sessions);
825 let lifecycle_handlers = Arc::clone(&self.lifecycle_handlers);
826 let notification_handler: NotificationHandler =
827 Arc::new(move |method: String, params: Value| {
828 let sessions = Arc::clone(&sessions);
829 let lifecycle_handlers = Arc::clone(&lifecycle_handlers);
830
831 match method.as_str() {
832 "session.event" => {
833 let session_id = params
834 .get("sessionId")
835 .and_then(|v| v.as_str())
836 .map(|s| s.to_string());
837 let event = params.get("event").cloned();
838
839 if let (Some(session_id), Some(event_value)) = (session_id, event) {
840 if let Ok(event) =
841 serde_json::from_value::<SessionEvent>(event_value)
842 {
843 tokio::spawn(async move {
844 let sessions = sessions.lock().await;
845 if let Some(session) = sessions.get(&session_id) {
846 session.dispatch_event(event).await;
847 }
848 });
849 }
850 }
851 }
852 "session.lifecycle" => {
853 if let Ok(event) =
854 serde_json::from_value::<SessionLifecycleEvent>(params)
855 {
856 tokio::spawn(async move {
857 let handlers = lifecycle_handlers.lock().await;
858 for (_, handler) in handlers.iter() {
859 handler(event.clone());
860 }
861 });
862 }
863 }
864 _ => {}
865 }
866 });
867
868 rpc_client
869 .set_notification_handler(notification_handler)
870 .await;
871
872 let sessions_for_tools = Arc::clone(&self.sessions);
874 let tool_handler: RequestHandler = Arc::new(move |params: Value| {
875 let sessions = Arc::clone(&sessions_for_tools);
876 Box::pin(async move {
877 let session_id = params
878 .get("sessionId")
879 .and_then(|v| v.as_str())
880 .ok_or_else(|| CopilotError::Protocol("Missing sessionId".to_string()))?;
881 let tool_call_id = params
882 .get("toolCallId")
883 .and_then(|v| v.as_str())
884 .ok_or_else(|| CopilotError::Protocol("Missing toolCallId".to_string()))?;
885 let tool_name = params
886 .get("toolName")
887 .and_then(|v| v.as_str())
888 .ok_or_else(|| CopilotError::Protocol("Missing toolName".to_string()))?;
889 let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
890
891 let session = {
892 let sessions = sessions.lock().await;
893 sessions.get(session_id).cloned()
894 };
895
896 let session = session.ok_or_else(|| {
897 CopilotError::SessionError(format!("Unknown session {}", session_id))
898 })?;
899
900 let handler = session.get_tool_handler(tool_name).await;
901
902 let result = if let Some(handler) = handler {
903 let invocation = ToolInvocation {
904 session_id: session_id.to_string(),
905 tool_call_id: tool_call_id.to_string(),
906 tool_name: tool_name.to_string(),
907 arguments: arguments.clone(),
908 };
909
910 match handler(arguments, invocation).await {
911 Ok(value) => normalize_tool_result(value),
912 Err(e) => ToolResultObject {
913 text_result_for_llm:
914 "Invoking this tool produced an error. Detailed information is not available."
915 .to_string(),
916 binary_results_for_llm: None,
917 result_type: ToolResultType::Failure,
918 error: Some(e.to_string()),
919 session_log: None,
920 tool_telemetry: Some(HashMap::new()),
921 },
922 }
923 } else {
924 ToolResultObject {
925 text_result_for_llm: format!(
926 "Tool '{}' is not supported by this client instance.",
927 tool_name
928 ),
929 binary_results_for_llm: None,
930 result_type: ToolResultType::Failure,
931 error: Some(format!("tool '{}' not supported", tool_name)),
932 session_log: None,
933 tool_telemetry: Some(HashMap::new()),
934 }
935 };
936
937 let response =
938 serde_json::to_value(ToolCallResponsePayload { result })
939 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
940 Ok(response)
941 })
942 });
943 rpc_client.set_request_handler("tool.call", tool_handler).await;
944
945 let sessions_for_perm = Arc::clone(&self.sessions);
947 let permission_handler: RequestHandler = Arc::new(move |params: Value| {
948 let sessions = Arc::clone(&sessions_for_perm);
949 Box::pin(async move {
950 let session_id = params
951 .get("sessionId")
952 .and_then(|v| v.as_str())
953 .ok_or_else(|| CopilotError::Protocol("Missing sessionId".to_string()))?;
954 let perm_request = params
955 .get("permissionRequest")
956 .cloned()
957 .unwrap_or(Value::Null);
958
959 let session = {
960 let sessions = sessions.lock().await;
961 sessions.get(session_id).cloned()
962 };
963
964 let session = session.ok_or_else(|| {
965 CopilotError::SessionError(format!("Session not found: {}", session_id))
966 })?;
967
968 let result = match session.handle_permission_request(perm_request).await {
969 Ok(result) => result,
970 Err(_) => PermissionRequestResult {
971 kind: PermissionResultKind::DeniedNoApprovalRuleAndCouldNotRequestFromUser,
972 rules: None,
973 },
974 };
975
976 let response = serde_json::json!({ "result": result });
977 Ok(response)
978 })
979 });
980 rpc_client
981 .set_request_handler("permission.request", permission_handler)
982 .await;
983
984 let sessions_for_input = Arc::clone(&self.sessions);
986 let user_input_handler: RequestHandler = Arc::new(move |params: Value| {
987 let sessions = Arc::clone(&sessions_for_input);
988 Box::pin(async move {
989 let session_id = params
990 .get("sessionId")
991 .and_then(|v| v.as_str())
992 .ok_or_else(|| CopilotError::Protocol("Missing sessionId".to_string()))?;
993
994 let session = {
995 let sessions = sessions.lock().await;
996 sessions.get(session_id).cloned()
997 };
998
999 let session = session.ok_or_else(|| {
1000 CopilotError::SessionError(format!("Session not found: {}", session_id))
1001 })?;
1002
1003 let result = session.handle_user_input_request(params).await?;
1004 let response = serde_json::to_value(result)
1005 .map_err(|e| CopilotError::Serialization(e.to_string()))?;
1006 Ok(response)
1007 })
1008 });
1009 rpc_client
1010 .set_request_handler("userInput.request", user_input_handler)
1011 .await;
1012
1013 let sessions_for_hooks = Arc::clone(&self.sessions);
1015 let hooks_handler: RequestHandler = Arc::new(move |params: Value| {
1016 let sessions = Arc::clone(&sessions_for_hooks);
1017 Box::pin(async move {
1018 let session_id = params
1019 .get("sessionId")
1020 .and_then(|v| v.as_str())
1021 .ok_or_else(|| CopilotError::Protocol("Missing sessionId".to_string()))?;
1022 let hook_type = params
1023 .get("hookType")
1024 .and_then(|v| v.as_str())
1025 .ok_or_else(|| CopilotError::Protocol("Missing hookType".to_string()))?;
1026 let input = params.get("input").cloned().unwrap_or(Value::Null);
1027
1028 let session = {
1029 let sessions = sessions.lock().await;
1030 sessions.get(session_id).cloned()
1031 };
1032
1033 let session = session.ok_or_else(|| {
1034 CopilotError::SessionError(format!("Session not found: {}", session_id))
1035 })?;
1036
1037 let output = session.handle_hooks_invoke(hook_type, input).await?;
1038 let response = serde_json::json!({ "output": output });
1039 Ok(response)
1040 })
1041 });
1042 rpc_client
1043 .set_request_handler("hooks.invoke", hooks_handler)
1044 .await;
1045 }
1046}
1047
1048fn normalize_tool_result(value: Value) -> ToolResultObject {
1054 if value.is_null() {
1055 return ToolResultObject {
1056 text_result_for_llm: "Tool returned no result".to_string(),
1057 binary_results_for_llm: None,
1058 result_type: ToolResultType::Failure,
1059 error: Some("tool returned no result".to_string()),
1060 session_log: None,
1061 tool_telemetry: Some(HashMap::new()),
1062 };
1063 }
1064
1065 if value.get("textResultForLlm").is_some() && value.get("resultType").is_some() {
1067 if let Ok(result) = serde_json::from_value::<ToolResultObject>(value.clone()) {
1068 return result;
1069 }
1070 }
1071
1072 let text_result = if let Some(s) = value.as_str() {
1074 s.to_string()
1075 } else {
1076 serde_json::to_string(&value).unwrap_or_else(|_| "".to_string())
1077 };
1078
1079 ToolResultObject {
1080 text_result_for_llm: text_result,
1081 binary_results_for_llm: None,
1082 result_type: ToolResultType::Success,
1083 error: None,
1084 session_log: None,
1085 tool_telemetry: Some(HashMap::new()),
1086 }
1087}