1use crate::error::{CopilotError, Result};
9use crate::events::SessionEvent;
10use crate::jsonrpc::{StdioJsonRpcClient, TcpJsonRpcClient};
11use crate::process::{CopilotProcess, ProcessOptions};
12use crate::session::Session;
13use crate::types::{
14 ClientOptions, ConnectionState, GetAuthStatusResponse, GetForegroundSessionResponse,
15 GetStatusResponse, LogLevel, ModelInfo, PingResponse, ProviderConfig, ResumeSessionConfig,
16 SessionConfig, SessionLifecycleEvent, SessionMetadata, SetForegroundSessionResponse, StopError,
17 SDK_PROTOCOL_VERSION,
18};
19use serde_json::{json, Value};
20use std::collections::HashMap;
21use std::path::{Path, PathBuf};
22use std::sync::atomic::{AtomicU64, Ordering};
23use std::sync::Arc;
24use std::time::Duration;
25use tokio::io::{AsyncBufReadExt, BufReader};
26use tokio::sync::{Mutex, RwLock};
27
28fn resolve_cli_command(cli_path: &Path, args: &[String]) -> (PathBuf, Vec<String>) {
37 let path = cli_path.to_path_buf();
38 let args_owned = args.to_vec();
39
40 if crate::process::is_node_script(&path) {
42 if let Some(node_path) = crate::process::find_node() {
43 let mut full_args = vec![path.to_string_lossy().to_string()];
44 full_args.extend(args_owned);
45 return (node_path, full_args);
46 }
47 }
48
49 #[cfg(windows)]
50 {
51 if let Some(ext) = path.extension() {
55 let ext_lower = ext.to_string_lossy().to_lowercase();
56 if ext_lower == "cmd" {
57 if let Some(parent) = path.parent() {
60 if let Some(stem) = path.file_stem() {
62 let stem_str = stem.to_string_lossy();
63
64 let possible_paths = vec![
67 parent
68 .join("node_modules/@github")
69 .join(&*stem_str)
70 .join("npm-loader.js"),
71 parent
72 .join("node_modules")
73 .join(&*stem_str)
74 .join("npm-loader.js"),
75 parent
76 .join("node_modules/@github")
77 .join(&*stem_str)
78 .join("index.js"),
79 parent
80 .join("node_modules")
81 .join(&*stem_str)
82 .join("index.js"),
83 ];
84
85 for loader_path in possible_paths {
86 if loader_path.exists() {
87 if let Some(node_path) = crate::process::find_node() {
88 let mut full_args =
89 vec![loader_path.to_string_lossy().to_string()];
90 full_args.extend(args_owned);
91 return (node_path, full_args);
92 }
93 }
94 }
95 }
96 }
97
98 let mut full_args = vec!["/c".to_string(), path.to_string_lossy().to_string()];
100 full_args.extend(args_owned);
101 return (PathBuf::from("cmd"), full_args);
102 }
103
104 if ext_lower == "bat" {
106 let mut full_args = vec!["/c".to_string(), path.to_string_lossy().to_string()];
107 full_args.extend(args_owned);
108 return (PathBuf::from("cmd"), full_args);
109 }
110 }
111
112 if !path.is_absolute() {
114 let mut full_args = vec!["/c".to_string(), path.to_string_lossy().to_string()];
115 full_args.extend(args_owned);
116 return (PathBuf::from("cmd"), full_args);
117 }
118 }
119
120 (path, args_owned)
121}
122
123fn spawn_cli_stderr_logger(stderr: tokio::process::ChildStderr) {
124 tokio::spawn(async move {
125 let mut lines = BufReader::new(stderr).lines();
126 while let Ok(Some(line)) = lines.next_line().await {
127 tracing::debug!(target: "copilot_sdk::cli_stderr", "{line}");
128 }
129 });
130}
131
132pub type LifecycleHandler = Arc<dyn Fn(&SessionLifecycleEvent) + Send + Sync>;
134
135async fn handle_tool_call(
137 sessions: &RwLock<HashMap<String, Arc<Session>>>,
138 params: &Value,
139) -> Result<Value> {
140 let session_id = params
141 .get("sessionId")
142 .and_then(|v| v.as_str())
143 .ok_or_else(|| CopilotError::InvalidConfig("Missing sessionId".into()))?;
144
145 let tool_name = params
146 .get("toolName")
147 .and_then(|v| v.as_str())
148 .ok_or_else(|| CopilotError::InvalidConfig("Missing toolName".into()))?;
149
150 let arguments = normalize_tool_arguments(params);
151
152 let session = sessions.read().await.get(session_id).cloned();
153
154 let session = match session {
155 Some(s) => s,
156 None => {
157 return Ok(json!({
158 "result": {
159 "textResultForLlm": "Session not found",
160 "resultType": "failure",
161 "error": format!("Unknown session {}", session_id)
162 }
163 }));
164 }
165 };
166
167 if session.get_tool(tool_name).await.is_none() {
169 return Ok(json!({
170 "result": {
171 "textResultForLlm": format!("Tool '{}' is not supported.", tool_name),
172 "resultType": "failure",
173 "error": format!("tool '{}' not supported", tool_name)
174 }
175 }));
176 }
177
178 match session.invoke_tool(tool_name, &arguments).await {
180 Ok(result) => Ok(json!({ "result": result })),
181 Err(e) => Ok(json!({
182 "result": {
183 "textResultForLlm": "Tool execution failed",
184 "resultType": "failure",
185 "error": e.to_string()
186 }
187 })),
188 }
189}
190
191fn normalize_tool_arguments(params: &Value) -> Value {
192 let raw = params
193 .get("arguments")
194 .or_else(|| params.get("argumentsJson"))
195 .cloned()
196 .unwrap_or(json!({}));
197
198 match raw {
199 Value::String(s) => serde_json::from_str(&s).unwrap_or(json!({})),
200 Value::Null => json!({}),
201 other => other,
202 }
203}
204
205async fn handle_permission_request(
207 sessions: &RwLock<HashMap<String, Arc<Session>>>,
208 params: &Value,
209) -> Result<Value> {
210 let session_id = params
211 .get("sessionId")
212 .and_then(|v| v.as_str())
213 .ok_or_else(|| CopilotError::InvalidConfig("Missing sessionId".into()))?;
214
215 let perm_data = params.get("permissionRequest").unwrap_or(params);
217
218 let session = sessions.read().await.get(session_id).cloned();
219
220 let session = match session {
221 Some(s) => s,
222 None => {
223 return Ok(json!({
225 "result": {
226 "kind": "denied-no-approval-rule-and-could-not-request-from-user"
227 }
228 }));
229 }
230 };
231
232 use crate::types::PermissionRequest;
234 let kind = perm_data
235 .get("kind")
236 .and_then(|v| v.as_str())
237 .unwrap_or("unknown")
238 .to_string();
239
240 let tool_call_id = perm_data
241 .get("toolCallId")
242 .and_then(|v| v.as_str())
243 .map(|s| s.to_string());
244
245 let mut extension_data = HashMap::new();
247 if let Some(obj) = perm_data.as_object() {
248 for (key, value) in obj {
249 if key != "kind" && key != "toolCallId" {
250 extension_data.insert(key.clone(), value.clone());
251 }
252 }
253 }
254
255 let request = PermissionRequest {
256 kind,
257 tool_call_id,
258 extension_data,
259 };
260
261 let result = session.handle_permission_request(&request).await;
262
263 let mut response = json!({
265 "result": {
266 "kind": result.kind
267 }
268 });
269
270 if let Some(rules) = result.rules {
271 response["result"]["rules"] = Value::Array(rules);
272 }
273
274 Ok(response)
275}
276
277async fn handle_user_input_request(
279 sessions: &RwLock<HashMap<String, Arc<Session>>>,
280 params: &Value,
281) -> Result<Value> {
282 let session_id = params
283 .get("sessionId")
284 .and_then(|v| v.as_str())
285 .ok_or_else(|| CopilotError::InvalidConfig("Missing sessionId".into()))?;
286
287 let session = sessions.read().await.get(session_id).cloned();
288
289 let session = match session {
290 Some(s) => s,
291 None => {
292 return Err(CopilotError::Protocol(format!(
293 "Session not found for user input request: {session_id}"
294 )));
295 }
296 };
297
298 use crate::types::UserInputRequest;
299 let request = UserInputRequest {
300 question: params
301 .get("question")
302 .and_then(|v| v.as_str())
303 .unwrap_or("")
304 .to_string(),
305 choices: params.get("choices").and_then(|v| {
306 v.as_array().map(|arr| {
307 arr.iter()
308 .filter_map(|v| v.as_str().map(String::from))
309 .collect()
310 })
311 }),
312 allow_freeform: params.get("allowFreeform").and_then(|v| v.as_bool()),
313 };
314
315 let response = session.handle_user_input_request(&request).await?;
316 Ok(serde_json::to_value(response).unwrap_or(json!({})))
317}
318
319async fn handle_hooks_invoke(
320 sessions: &RwLock<HashMap<String, Arc<Session>>>,
321 params: &Value,
322) -> Result<Value> {
323 let session_id = params
324 .get("sessionId")
325 .and_then(|v| v.as_str())
326 .ok_or_else(|| CopilotError::InvalidConfig("Missing sessionId".into()))?;
327
328 let session = sessions.read().await.get(session_id).cloned();
329
330 let session = match session {
331 Some(s) => s,
332 None => {
333 return Err(CopilotError::Protocol(format!(
334 "Session not found for hooks invoke: {session_id}"
335 )));
336 }
337 };
338
339 let hook_type = params
340 .get("hookType")
341 .and_then(|v| v.as_str())
342 .unwrap_or("");
343
344 let input = params.get("input").cloned().unwrap_or(Value::Null);
345
346 session.handle_hooks_invoke(hook_type, &input).await
347}
348
349fn parse_cli_url(url: &str) -> Result<(String, u16)> {
350 let mut s = url.trim();
351 if let Some((_, rest)) = s.split_once("://") {
352 s = rest;
353 }
354 if let Some((host_port, _)) = s.split_once('/') {
355 s = host_port;
356 }
357
358 if s.chars().all(|c| c.is_ascii_digit()) {
359 let port: u16 = s.parse().map_err(|_| {
360 CopilotError::InvalidConfig(format!("Invalid port in cli_url: {}", url))
361 })?;
362 return Ok(("localhost".to_string(), port));
363 }
364
365 if let Some((host, port_str)) = s.rsplit_once(':') {
366 let host = host.trim();
367 let port: u16 = port_str.trim().parse().map_err(|_| {
368 CopilotError::InvalidConfig(format!("Invalid port in cli_url: {}", url))
369 })?;
370 if host.is_empty() {
371 return Ok(("localhost".to_string(), port));
372 }
373 return Ok((host.to_string(), port));
374 }
375
376 Err(CopilotError::InvalidConfig(format!(
377 "Invalid cli_url format (expected host:port or port): {}",
378 url
379 )))
380}
381
382fn parse_listening_port(line: &str) -> Option<u16> {
383 let lower = line.to_lowercase();
384 let idx = lower.find("listening on port")?;
385 let after = &line[idx..];
386
387 let mut digits = String::new();
388 let mut in_digits = false;
389 for ch in after.chars() {
390 if ch.is_ascii_digit() {
391 digits.push(ch);
392 in_digits = true;
393 } else if in_digits {
394 break;
395 }
396 }
397 digits.parse::<u16>().ok()
398}
399
400async fn detect_tcp_port_from_stdout(stdout: tokio::process::ChildStdout) -> Result<u16> {
401 let mut lines = BufReader::new(stdout).lines();
402 let port = tokio::time::timeout(Duration::from_secs(15), async {
403 while let Ok(Some(line)) = lines.next_line().await {
404 if let Some(port) = parse_listening_port(&line) {
405 return Ok(port);
406 }
407 }
408 Err(CopilotError::PortDetectionFailed)
409 })
410 .await
411 .map_err(|_| CopilotError::Timeout(Duration::from_secs(15)))??;
412
413 Ok(port)
414}
415
416enum RpcClient {
417 Stdio(StdioJsonRpcClient),
418 Tcp(TcpJsonRpcClient),
419}
420
421impl RpcClient {
422 async fn stop(&self) {
423 match self {
424 RpcClient::Stdio(rpc) => rpc.stop().await,
425 RpcClient::Tcp(rpc) => rpc.stop().await,
426 }
427 }
428
429 async fn set_notification_handler<F>(&self, handler: F)
430 where
431 F: Fn(&str, &Value) + Send + Sync + 'static,
432 {
433 let handler = Arc::new(handler);
434 match self {
435 RpcClient::Stdio(rpc) => {
436 let handler = Arc::clone(&handler);
437 rpc.set_notification_handler(move |method, params| {
438 (handler)(method, params);
439 })
440 .await;
441 }
442 RpcClient::Tcp(rpc) => {
443 let handler = Arc::clone(&handler);
444 rpc.set_notification_handler(move |method, params| {
445 (handler)(method, params);
446 })
447 .await;
448 }
449 }
450 }
451
452 async fn set_request_handler<F>(&self, handler: F)
453 where
454 F: Fn(&str, &Value) -> crate::jsonrpc::RequestHandlerFuture + Send + Sync + 'static,
455 {
456 let handler = Arc::new(handler);
457 match self {
458 RpcClient::Stdio(rpc) => {
459 let handler = Arc::clone(&handler);
460 rpc.set_request_handler(move |method, params| (handler)(method, params))
461 .await;
462 }
463 RpcClient::Tcp(rpc) => {
464 let handler = Arc::clone(&handler);
465 rpc.set_request_handler(move |method, params| (handler)(method, params))
466 .await;
467 }
468 }
469 }
470
471 async fn invoke(&self, method: &str, params: Option<Value>) -> Result<Value> {
472 match self {
473 RpcClient::Stdio(rpc) => rpc.invoke(method, params).await,
474 RpcClient::Tcp(rpc) => rpc.invoke(method, params).await,
475 }
476 }
477}
478
479pub struct Client {
514 options: ClientOptions,
515 state: Arc<RwLock<ConnectionState>>,
516 lifecycle: Mutex<()>,
517 process: Mutex<Option<CopilotProcess>>,
518 rpc: Arc<Mutex<Option<RpcClient>>>,
519 sessions: Arc<RwLock<HashMap<String, Arc<Session>>>>,
520 lifecycle_handlers: Arc<RwLock<HashMap<u64, LifecycleHandler>>>,
521 next_lifecycle_handler_id: AtomicU64,
522 models_cache: Arc<Mutex<Option<Vec<ModelInfo>>>>,
523}
524
525impl Client {
526 pub fn new(options: ClientOptions) -> Result<Self> {
528 let mut options = options;
529
530 if options.cli_url.is_some() {
531 options.use_stdio = false;
532 }
533
534 if options.cli_url.is_some() {
536 if options.cli_path.is_some() {
537 return Err(CopilotError::InvalidConfig(
538 "cli_url is mutually exclusive with cli_path".into(),
539 ));
540 }
541 if options.port != 0 {
542 return Err(CopilotError::InvalidConfig(
543 "cli_url is mutually exclusive with port".into(),
544 ));
545 }
546 }
547 if options.use_stdio && options.port != 0 {
548 return Err(CopilotError::InvalidConfig(
549 "port is only valid when use_stdio=false".into(),
550 ));
551 }
552 if options.cli_url.is_some() && options.github_token.is_some() {
553 return Err(CopilotError::InvalidConfig(
554 "github_token cannot be used with cli_url (external server doesn't accept token)"
555 .into(),
556 ));
557 }
558 if options.cli_url.is_some() && options.use_logged_in_user.is_some() {
559 return Err(CopilotError::InvalidConfig(
560 "use_logged_in_user cannot be used with cli_url (external server doesn't accept this option)".into(),
561 ));
562 }
563
564 Ok(Self {
565 options,
566 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
567 lifecycle: Mutex::new(()),
568 process: Mutex::new(None),
569 rpc: Arc::new(Mutex::new(None)),
570 sessions: Arc::new(RwLock::new(HashMap::new())),
571 lifecycle_handlers: Arc::new(RwLock::new(HashMap::new())),
572 next_lifecycle_handler_id: AtomicU64::new(1),
573 models_cache: Arc::new(Mutex::new(None)),
574 })
575 }
576
577 pub fn builder() -> ClientBuilder {
579 ClientBuilder::new()
580 }
581
582 pub async fn start(&self) -> Result<()> {
588 let _guard = self.lifecycle.lock().await;
589
590 let mut state = self.state.write().await;
591 if *state == ConnectionState::Connected {
592 return Ok(());
593 }
594 if *state != ConnectionState::Disconnected {
595 return Err(CopilotError::InvalidConfig(
596 "Client is already started".into(),
597 ));
598 }
599 *state = ConnectionState::Connecting;
600 drop(state);
601
602 let result = self.start_cli_server().await;
604 if let Err(e) = result {
605 *self.state.write().await = ConnectionState::Error;
606 return Err(e);
607 }
608
609 if let Err(e) = self.verify_protocol_version().await {
611 *self.state.write().await = ConnectionState::Error;
612 return Err(e);
613 }
614
615 self.setup_handlers().await?;
617
618 *self.state.write().await = ConnectionState::Connected;
619 Ok(())
620 }
621
622 pub async fn stop(&self) -> Vec<StopError> {
624 let _guard = self.lifecycle.lock().await;
625 let mut errors = Vec::new();
626
627 let state = *self.state.read().await;
628 if state == ConnectionState::Disconnected {
629 self.sessions.write().await.clear();
630 *self.rpc.lock().await = None;
631 *self.process.lock().await = None;
632 return errors;
633 }
634
635 let sessions: Vec<Arc<Session>> = self.sessions.read().await.values().cloned().collect();
637 for session in sessions {
638 if let Err(e) = session.destroy().await {
639 errors.push(StopError {
640 message: format!("Failed to destroy session {}: {}", session.session_id(), e),
641 source: Some("session.destroy".into()),
642 });
643 }
644 }
645 self.sessions.write().await.clear();
646
647 if let Some(rpc) = self.rpc.lock().await.take() {
649 rpc.stop().await;
650 }
651
652 if let Some(mut process) = self.process.lock().await.take() {
654 let _ = process.terminate();
655 let _ = process.wait().await;
656 }
657
658 *self.state.write().await = ConnectionState::Disconnected;
659 *self.models_cache.lock().await = None;
660 errors
661 }
662
663 pub async fn force_stop(&self) {
665 let _guard = self.lifecycle.lock().await;
666
667 self.sessions.write().await.clear();
668
669 if let Some(mut process) = self.process.lock().await.take() {
671 let _ = process.kill();
672 }
673
674 if let Some(rpc) = self.rpc.lock().await.take() {
676 rpc.stop().await;
677 }
678
679 *self.state.write().await = ConnectionState::Disconnected;
680 *self.models_cache.lock().await = None;
681 }
682
683 pub async fn state(&self) -> ConnectionState {
685 *self.state.read().await
686 }
687
688 pub async fn create_session(&self, mut config: SessionConfig) -> Result<Arc<Session>> {
694 self.ensure_connected().await?;
695
696 if config.auto_byok_from_env && config.model.is_none() {
698 config.model = ProviderConfig::model_from_env();
699 }
700 if config.auto_byok_from_env && config.provider.is_none() {
701 config.provider = ProviderConfig::from_env();
702 }
703
704 let params = serde_json::to_value(&config)?;
706
707 let result = self.invoke("session.create", Some(params)).await?;
709
710 let session_id = result
712 .get("sessionId")
713 .and_then(|v| v.as_str())
714 .ok_or_else(|| CopilotError::Protocol("Missing sessionId in response".into()))?
715 .to_string();
716
717 let workspace_path = result
719 .get("workspacePath")
720 .and_then(|v| v.as_str())
721 .map(|s| s.to_string());
722
723 let session = self
725 .create_session_object(session_id.clone(), workspace_path)
726 .await;
727
728 if let Some(hooks) = config.hooks.take() {
730 if hooks.has_any() {
731 session.register_hooks(hooks).await;
732 }
733 }
734
735 self.sessions
737 .write()
738 .await
739 .insert(session_id, Arc::clone(&session));
740
741 Ok(session)
742 }
743
744 pub async fn resume_session(
746 &self,
747 session_id: &str,
748 mut config: ResumeSessionConfig,
749 ) -> Result<Arc<Session>> {
750 self.ensure_connected().await?;
751
752 if config.auto_byok_from_env && config.provider.is_none() {
754 config.provider = ProviderConfig::from_env();
755 }
756
757 let mut params = serde_json::to_value(&config)?;
759 params["sessionId"] = json!(session_id);
760
761 let result = self.invoke("session.resume", Some(params)).await?;
763
764 let resumed_id = result
766 .get("sessionId")
767 .and_then(|v| v.as_str())
768 .unwrap_or(session_id)
769 .to_string();
770
771 let workspace_path = result
773 .get("workspacePath")
774 .and_then(|v| v.as_str())
775 .map(|s| s.to_string());
776
777 let session = self
779 .create_session_object(resumed_id.clone(), workspace_path)
780 .await;
781
782 if let Some(hooks) = config.hooks.take() {
784 if hooks.has_any() {
785 session.register_hooks(hooks).await;
786 }
787 }
788
789 self.sessions
791 .write()
792 .await
793 .insert(resumed_id, Arc::clone(&session));
794
795 Ok(session)
796 }
797
798 pub async fn list_sessions(&self) -> Result<Vec<SessionMetadata>> {
800 self.ensure_connected().await?;
801
802 let result = self.invoke("session.list", None).await?;
803
804 let sessions: Vec<SessionMetadata> = result
805 .get("sessions")
806 .and_then(|v| serde_json::from_value(v.clone()).ok())
807 .unwrap_or_default();
808
809 Ok(sessions)
810 }
811
812 pub async fn delete_session(&self, session_id: &str) -> Result<()> {
814 self.ensure_connected().await?;
815
816 let params = json!({ "sessionId": session_id });
817 let result = self.invoke("session.delete", Some(params)).await?;
818
819 if let Some(success) = result.get("success").and_then(|v| v.as_bool()) {
820 if !success {
821 let msg = result
822 .get("error")
823 .and_then(|v| v.as_str())
824 .unwrap_or("Unknown error")
825 .to_string();
826 return Err(CopilotError::Protocol(format!(
827 "Failed to delete session: {}",
828 msg
829 )));
830 }
831 }
832
833 self.sessions.write().await.remove(session_id);
835
836 Ok(())
837 }
838
839 pub async fn get_last_session_id(&self) -> Result<Option<String>> {
841 self.ensure_connected().await?;
842
843 let result = self.invoke("session.getLastId", None).await?;
844
845 Ok(result
846 .get("sessionId")
847 .and_then(|v| v.as_str())
848 .map(|s| s.to_string()))
849 }
850
851 pub async fn ping(&self, message: Option<String>) -> Result<PingResponse> {
857 self.ensure_connected().await?;
858
859 let params = message.map(|m| json!({ "message": m }));
860 let result = self.invoke("ping", params).await?;
861
862 Ok(PingResponse {
863 message: result
864 .get("message")
865 .and_then(|v| v.as_str())
866 .unwrap_or("")
867 .to_string(),
868 timestamp: result
869 .get("timestamp")
870 .and_then(|v| v.as_i64())
871 .unwrap_or(0),
872 protocol_version: result
873 .get("protocolVersion")
874 .and_then(|v| v.as_u64())
875 .map(|v| v as u32),
876 })
877 }
878
879 pub async fn get_status(&self) -> Result<GetStatusResponse> {
881 self.ensure_connected().await?;
882
883 let result = self.invoke("status.get", None).await?;
884 serde_json::from_value(result)
885 .map_err(|e| CopilotError::Protocol(format!("Failed to parse status response: {}", e)))
886 }
887
888 pub async fn get_auth_status(&self) -> Result<GetAuthStatusResponse> {
890 self.ensure_connected().await?;
891
892 let result = self.invoke("auth.getStatus", None).await?;
893 serde_json::from_value(result).map_err(|e| {
894 CopilotError::Protocol(format!("Failed to parse auth status response: {}", e))
895 })
896 }
897
898 pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
905 {
907 let cache = self.models_cache.lock().await;
908 if let Some(cached) = &*cache {
909 return Ok(cached.clone());
910 }
911 }
912
913 self.ensure_connected().await?;
914
915 let result = self.invoke("models.list", None).await?;
916 let models = result
917 .get("models")
918 .cloned()
919 .unwrap_or_else(|| serde_json::json!([]));
920 let models: Vec<ModelInfo> = serde_json::from_value(models).map_err(|e| {
921 CopilotError::Protocol(format!("Failed to parse models response: {}", e))
922 })?;
923
924 *self.models_cache.lock().await = Some(models.clone());
926
927 Ok(models)
928 }
929
930 pub async fn clear_models_cache(&self) {
932 *self.models_cache.lock().await = None;
933 }
934
935 pub async fn get_foreground_session_id(&self) -> Result<GetForegroundSessionResponse> {
937 self.ensure_connected().await?;
938
939 let result = self.invoke("session.getForeground", None).await?;
940 serde_json::from_value(result).map_err(|e| {
941 CopilotError::Protocol(format!("Failed to parse foreground response: {}", e))
942 })
943 }
944
945 pub async fn set_foreground_session_id(
947 &self,
948 session_id: &str,
949 ) -> Result<SetForegroundSessionResponse> {
950 self.ensure_connected().await?;
951
952 let params = json!({ "sessionId": session_id });
953 let result = self.invoke("session.setForeground", Some(params)).await?;
954 serde_json::from_value(result).map_err(|e| {
955 CopilotError::Protocol(format!("Failed to parse set foreground response: {}", e))
956 })
957 }
958
959 pub async fn on<F>(&self, handler: F) -> impl FnOnce()
968 where
969 F: Fn(&SessionLifecycleEvent) + Send + Sync + 'static,
970 {
971 let id = self
972 .next_lifecycle_handler_id
973 .fetch_add(1, Ordering::SeqCst);
974 self.lifecycle_handlers
975 .write()
976 .await
977 .insert(id, Arc::new(handler));
978
979 let handlers = Arc::clone(&self.lifecycle_handlers);
980 move || {
981 tokio::spawn(async move {
982 handlers.write().await.remove(&id);
983 });
984 }
985 }
986
987 pub(crate) async fn invoke(&self, method: &str, params: Option<Value>) -> Result<Value> {
993 let mut attempt = 0;
994
995 loop {
996 let result = {
997 let rpc = self.rpc.lock().await;
998 let rpc = rpc.as_ref().ok_or(CopilotError::NotConnected)?;
999 rpc.invoke(method, params.clone()).await
1000 };
1001
1002 match result {
1003 Ok(v) => return Ok(v),
1004 Err(e) => {
1005 if attempt == 0
1006 && *self.state.read().await == ConnectionState::Connected
1007 && self.options.auto_restart
1008 && self.should_restart_on_error(&e)
1009 {
1010 attempt += 1;
1011 self.restart().await?;
1012 continue;
1013 }
1014 return Err(e);
1015 }
1016 }
1017 }
1018 }
1019
1020 pub async fn get_session(&self, session_id: &str) -> Option<Arc<Session>> {
1022 self.sessions.read().await.get(session_id).cloned()
1023 }
1024
1025 async fn ensure_connected(&self) -> Result<()> {
1027 match *self.state.read().await {
1028 ConnectionState::Connected => Ok(()),
1029 ConnectionState::Disconnected => {
1030 if self.options.auto_start {
1031 self.start().await
1032 } else {
1033 Err(CopilotError::NotConnected)
1034 }
1035 }
1036 ConnectionState::Error => {
1037 if self.options.auto_restart {
1038 self.restart().await
1039 } else {
1040 Err(CopilotError::NotConnected)
1041 }
1042 }
1043 ConnectionState::Connecting => Err(CopilotError::NotConnected),
1044 }
1045 }
1046
1047 fn should_restart_on_error(&self, err: &CopilotError) -> bool {
1048 match err {
1049 CopilotError::ConnectionClosed | CopilotError::NotConnected => true,
1050 CopilotError::Transport(_) => true,
1051 CopilotError::ProcessExit(_) => true,
1052 CopilotError::JsonRpc { code, .. } => *code == -32801,
1053 _ => false,
1054 }
1055 }
1056
1057 async fn restart(&self) -> Result<()> {
1058 self.force_stop().await;
1059 self.start().await
1060 }
1061
1062 async fn start_cli_server(&self) -> Result<()> {
1064 if let Some(cli_url) = &self.options.cli_url {
1065 let (host, port) = parse_cli_url(cli_url)?;
1066 let addr = format!("{}:{}", host, port);
1067
1068 let rpc = TcpJsonRpcClient::connect(addr).await?;
1069 rpc.start().await?;
1070
1071 *self.rpc.lock().await = Some(RpcClient::Tcp(rpc));
1072 return Ok(());
1073 }
1074
1075 let cli_path = self
1076 .options
1077 .cli_path
1078 .clone()
1079 .or_else(crate::process::find_copilot_cli)
1080 .ok_or_else(|| {
1081 CopilotError::InvalidConfig("Could not find Copilot CLI executable".into())
1082 })?;
1083
1084 let log_level = self.options.log_level.to_string();
1085
1086 let mut args: Vec<String> = Vec::new();
1087 if let Some(extra_args) = &self.options.cli_args {
1088 args.extend(extra_args.iter().cloned());
1089 }
1090
1091 if let Some(deny_tools) = &self.options.deny_tools {
1093 for tool_spec in deny_tools {
1094 args.push("--deny-tool".to_string());
1095 args.push(tool_spec.clone());
1096 }
1097 }
1098
1099 if let Some(allow_tools) = &self.options.allow_tools {
1101 for tool_spec in allow_tools {
1102 args.push("--allow-tool".to_string());
1103 args.push(tool_spec.clone());
1104 }
1105 }
1106
1107 if self.options.allow_all_tools {
1109 args.push("--allow-all-tools".to_string());
1110 }
1111
1112 args.extend(["--server".to_string(), "--log-level".to_string(), log_level]);
1113
1114 if self.options.use_stdio {
1115 args.push("--stdio".to_string());
1116 } else if self.options.port != 0 {
1117 args.extend(["--port".to_string(), self.options.port.to_string()]);
1118 }
1119
1120 if self.options.github_token.is_some() {
1122 args.push("--auth-token-env".to_string());
1123 args.push("COPILOT_SDK_AUTH_TOKEN".to_string());
1124 }
1125
1126 if let Some(false) = self.options.use_logged_in_user {
1128 args.push("--no-auto-login".to_string());
1129 }
1130
1131 let (executable, full_args) = resolve_cli_command(&cli_path, &args);
1134
1135 let mut proc_options = ProcessOptions::new()
1137 .stdin(self.options.use_stdio)
1138 .stdout(true)
1139 .stderr(true);
1140
1141 if let Some(ref dir) = self.options.cwd {
1142 proc_options = proc_options.working_dir(dir.clone());
1143 }
1144
1145 if let Some(ref env) = self.options.environment {
1147 for (key, value) in env {
1148 proc_options = proc_options.env(key, value);
1149 }
1150 }
1151
1152 proc_options = proc_options.env("NODE_DEBUG", "");
1154
1155 if let Some(ref token) = self.options.github_token {
1157 proc_options = proc_options.env("COPILOT_SDK_AUTH_TOKEN", token);
1158 args.push("--auth-token-env".to_string());
1159 args.push("COPILOT_SDK_AUTH_TOKEN".to_string());
1160 }
1161
1162 if let Some(false) = self.options.use_logged_in_user {
1164 args.push("--no-auto-login".to_string());
1165 }
1166
1167 let args_refs: Vec<&str> = full_args.iter().map(|s| s.as_str()).collect();
1168 let mut process = CopilotProcess::spawn(&executable, &args_refs, proc_options)?;
1169
1170 if let Some(stderr) = process.take_stderr() {
1171 spawn_cli_stderr_logger(stderr);
1172 }
1173
1174 let rpc = if self.options.use_stdio {
1175 let transport = process.take_transport().ok_or_else(|| {
1176 CopilotError::InvalidConfig("Failed to get transport from process".into())
1177 })?;
1178 let rpc = StdioJsonRpcClient::new(transport);
1179 rpc.start().await?;
1180 RpcClient::Stdio(rpc)
1181 } else {
1182 let stdout = process.take_stdout().ok_or_else(|| {
1183 CopilotError::InvalidConfig("Failed to capture stdout for port detection".into())
1184 })?;
1185
1186 let detected_port = detect_tcp_port_from_stdout(stdout).await?;
1187 let addr = format!("127.0.0.1:{}", detected_port);
1188 let rpc = TcpJsonRpcClient::connect(addr).await?;
1189 rpc.start().await?;
1190 RpcClient::Tcp(rpc)
1191 };
1192
1193 *self.process.lock().await = Some(process);
1194 *self.rpc.lock().await = Some(rpc);
1195
1196 Ok(())
1197 }
1198
1199 async fn verify_protocol_version(&self) -> Result<()> {
1201 let rpc = self.rpc.lock().await;
1204 let rpc = rpc.as_ref().ok_or(CopilotError::NotConnected)?;
1205 let result = rpc
1206 .invoke("ping", Some(serde_json::json!({ "message": null })))
1207 .await?;
1208
1209 let protocol_version = result
1210 .get("protocolVersion")
1211 .and_then(|v| v.as_u64())
1212 .map(|v| v as u32);
1213
1214 if let Some(version) = protocol_version {
1215 if version != SDK_PROTOCOL_VERSION {
1216 return Err(CopilotError::ProtocolMismatch {
1217 expected: SDK_PROTOCOL_VERSION,
1218 actual: version,
1219 });
1220 }
1221 }
1222
1223 Ok(())
1224 }
1225
1226 async fn setup_handlers(&self) -> Result<()> {
1228 let rpc = self.rpc.lock().await;
1229 let rpc = rpc.as_ref().ok_or(CopilotError::NotConnected)?;
1230
1231 let sessions = Arc::clone(&self.sessions);
1233 let lifecycle_handlers = Arc::clone(&self.lifecycle_handlers);
1234
1235 rpc.set_notification_handler(move |method, params| {
1237 if method == "session.event" {
1238 let sessions = Arc::clone(&sessions);
1239 let params = params.clone();
1240
1241 tokio::spawn(async move {
1243 if let Some(session_id) = params.get("sessionId").and_then(|v| v.as_str()) {
1244 if let Some(session) = sessions.read().await.get(session_id) {
1245 if let Some(event_data) = params.get("event") {
1246 if let Ok(event) = SessionEvent::from_json(event_data) {
1247 session.dispatch_event(event).await;
1248 }
1249 }
1250 }
1251 }
1252 });
1253 } else if method == "session.lifecycle" {
1254 let lifecycle_handlers = Arc::clone(&lifecycle_handlers);
1255 let params = params.clone();
1256
1257 tokio::spawn(async move {
1258 if let Ok(event) = serde_json::from_value::<SessionLifecycleEvent>(params) {
1259 let handlers = lifecycle_handlers.read().await;
1260 for handler in handlers.values() {
1261 handler(&event);
1262 }
1263 }
1264 });
1265 }
1266 })
1267 .await;
1268
1269 let sessions_for_requests = Arc::clone(&self.sessions);
1271
1272 rpc.set_request_handler(move |method, params| {
1274 use crate::jsonrpc::JsonRpcError;
1275
1276 let sessions = Arc::clone(&sessions_for_requests);
1277 let method = method.to_string();
1278 let params = params.clone();
1279
1280 Box::pin(async move {
1281 let result = match method.as_str() {
1282 "tool.call" => handle_tool_call(&sessions, ¶ms).await,
1283 "permission.request" => handle_permission_request(&sessions, ¶ms).await,
1284 "userInput.request" => handle_user_input_request(&sessions, ¶ms).await,
1285 "hooks.invoke" => handle_hooks_invoke(&sessions, ¶ms).await,
1286 _ => {
1287 return Err(JsonRpcError::new(
1288 -32601,
1289 format!("Unknown method: {}", method),
1290 ));
1291 }
1292 };
1293
1294 result.map_err(|e| JsonRpcError::new(-32000, e.to_string()))
1295 })
1296 })
1297 .await;
1298
1299 Ok(())
1300 }
1301
1302 async fn create_session_object(
1304 &self,
1305 session_id: String,
1306 workspace_path: Option<String>,
1307 ) -> Arc<Session> {
1308 let rpc = Arc::clone(&self.rpc);
1309
1310 let invoke_fn = move |method: &str, params: Option<Value>| {
1312 let rpc = Arc::clone(&rpc);
1313 let method = method.to_string();
1314
1315 Box::pin(async move {
1316 let rpc = rpc.lock().await;
1317 let rpc = rpc.as_ref().ok_or(CopilotError::NotConnected)?;
1318 rpc.invoke(&method, params).await
1319 }) as crate::session::InvokeFuture
1320 };
1321
1322 Arc::new(Session::new(session_id, workspace_path, invoke_fn))
1323 }
1324}
1325
1326#[derive(Debug, Default)]
1332pub struct ClientBuilder {
1333 options: ClientOptions,
1334}
1335
1336impl ClientBuilder {
1337 pub fn new() -> Self {
1339 Self::default()
1340 }
1341
1342 pub fn cli_path(mut self, path: impl Into<PathBuf>) -> Self {
1344 self.options.cli_path = Some(path.into());
1345 self
1346 }
1347
1348 pub fn cli_args<I, S>(mut self, args: I) -> Self
1350 where
1351 I: IntoIterator<Item = S>,
1352 S: Into<String>,
1353 {
1354 self.options.cli_args = Some(args.into_iter().map(Into::into).collect());
1355 self
1356 }
1357
1358 pub fn cli_arg(mut self, arg: impl Into<String>) -> Self {
1360 self.options
1361 .cli_args
1362 .get_or_insert_with(Vec::new)
1363 .push(arg.into());
1364 self
1365 }
1366
1367 pub fn use_stdio(mut self, use_stdio: bool) -> Self {
1369 self.options.use_stdio = use_stdio;
1370 self
1371 }
1372
1373 pub fn cli_url(mut self, url: impl Into<String>) -> Self {
1377 self.options.cli_url = Some(url.into());
1378 self.options.use_stdio = false;
1379 self
1380 }
1381
1382 pub fn port(mut self, port: u16) -> Self {
1386 self.options.port = port;
1387 self
1388 }
1389
1390 pub fn auto_start(mut self, auto_start: bool) -> Self {
1392 self.options.auto_start = auto_start;
1393 self
1394 }
1395
1396 pub fn auto_restart(mut self, auto_restart: bool) -> Self {
1398 self.options.auto_restart = auto_restart;
1399 self
1400 }
1401
1402 pub fn log_level(mut self, level: LogLevel) -> Self {
1404 self.options.log_level = level;
1405 self
1406 }
1407
1408 pub fn cwd(mut self, dir: impl Into<PathBuf>) -> Self {
1410 self.options.cwd = Some(dir.into());
1411 self
1412 }
1413
1414 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1416 self.options
1417 .environment
1418 .get_or_insert_with(HashMap::new)
1419 .insert(key.into(), value.into());
1420 self
1421 }
1422
1423 pub fn github_token(mut self, token: impl Into<String>) -> Self {
1425 self.options.github_token = Some(token.into());
1426 self
1427 }
1428
1429 pub fn use_logged_in_user(mut self, value: bool) -> Self {
1431 self.options.use_logged_in_user = Some(value);
1432 self
1433 }
1434
1435 pub fn deny_tool(mut self, tool_spec: impl Into<String>) -> Self {
1452 self.options
1453 .deny_tools
1454 .get_or_insert_with(Vec::new)
1455 .push(tool_spec.into());
1456 self
1457 }
1458
1459 pub fn deny_tools<I, S>(mut self, tool_specs: I) -> Self
1463 where
1464 I: IntoIterator<Item = S>,
1465 S: Into<String>,
1466 {
1467 self.options.deny_tools = Some(tool_specs.into_iter().map(Into::into).collect());
1468 self
1469 }
1470
1471 pub fn allow_tool(mut self, tool_spec: impl Into<String>) -> Self {
1475 self.options
1476 .allow_tools
1477 .get_or_insert_with(Vec::new)
1478 .push(tool_spec.into());
1479 self
1480 }
1481
1482 pub fn allow_tools<I, S>(mut self, tool_specs: I) -> Self
1486 where
1487 I: IntoIterator<Item = S>,
1488 S: Into<String>,
1489 {
1490 self.options.allow_tools = Some(tool_specs.into_iter().map(Into::into).collect());
1491 self
1492 }
1493
1494 pub fn allow_all_tools(mut self, allow: bool) -> Self {
1514 self.options.allow_all_tools = allow;
1515 self
1516 }
1517
1518 pub fn build(self) -> Result<Client> {
1520 Client::new(self.options)
1521 }
1522}
1523
1524#[cfg(test)]
1525mod tests {
1526 use super::*;
1527
1528 #[test]
1529 fn test_client_builder() {
1530 let client = Client::builder()
1531 .cli_path("/usr/bin/copilot")
1532 .cli_arg("--foo")
1533 .use_stdio(true)
1534 .log_level(LogLevel::Debug)
1535 .cwd("/tmp")
1536 .env("FOO", "bar")
1537 .build();
1538
1539 assert!(client.is_ok());
1540 }
1541
1542 #[test]
1543 fn test_client_builder_deny_allow_tools() {
1544 let client = Client::builder()
1545 .allow_all_tools(true)
1546 .deny_tool("shell(git push)")
1547 .deny_tool("shell(git commit)")
1548 .deny_tool("shell(rm)")
1549 .allow_tool("shell(ls)")
1550 .build()
1551 .unwrap();
1552
1553 assert!(client.options.allow_all_tools);
1554 assert_eq!(
1555 client.options.deny_tools,
1556 Some(vec![
1557 "shell(git push)".to_string(),
1558 "shell(git commit)".to_string(),
1559 "shell(rm)".to_string(),
1560 ])
1561 );
1562 assert_eq!(
1563 client.options.allow_tools,
1564 Some(vec!["shell(ls)".to_string()])
1565 );
1566 }
1567
1568 #[test]
1569 fn test_client_builder_deny_tools_batch() {
1570 let client = Client::builder()
1571 .deny_tools(vec!["shell(git push)", "shell(git add)"])
1572 .build()
1573 .unwrap();
1574
1575 assert_eq!(
1576 client.options.deny_tools,
1577 Some(vec![
1578 "shell(git push)".to_string(),
1579 "shell(git add)".to_string(),
1580 ])
1581 );
1582 }
1583
1584 #[test]
1585 fn test_client_mutually_exclusive_options() {
1586 let options = ClientOptions {
1587 cli_path: Some("/usr/bin/copilot".into()),
1588 cli_url: Some("http://localhost:8080".into()),
1589 ..Default::default()
1590 };
1591 assert!(matches!(
1592 Client::new(options),
1593 Err(CopilotError::InvalidConfig(_))
1594 ));
1595
1596 let options = ClientOptions {
1597 cli_url: Some("localhost:8080".into()),
1598 port: 1234,
1599 ..Default::default()
1600 };
1601 assert!(matches!(
1602 Client::new(options),
1603 Err(CopilotError::InvalidConfig(_))
1604 ));
1605
1606 let options = ClientOptions {
1607 use_stdio: true,
1608 port: 1234,
1609 ..Default::default()
1610 };
1611 assert!(matches!(
1612 Client::new(options),
1613 Err(CopilotError::InvalidConfig(_))
1614 ));
1615
1616 let options = ClientOptions {
1618 cli_url: Some("localhost:8080".into()),
1619 github_token: Some("ghp_abc123".into()),
1620 ..Default::default()
1621 };
1622 assert!(matches!(
1623 Client::new(options),
1624 Err(CopilotError::InvalidConfig(_))
1625 ));
1626
1627 let options = ClientOptions {
1629 cli_url: Some("localhost:8080".into()),
1630 use_logged_in_user: Some(true),
1631 ..Default::default()
1632 };
1633 assert!(matches!(
1634 Client::new(options),
1635 Err(CopilotError::InvalidConfig(_))
1636 ));
1637 }
1638
1639 #[tokio::test]
1640 async fn test_client_state_initial() {
1641 let client = Client::new(ClientOptions::default()).unwrap();
1642 assert_eq!(client.state().await, ConnectionState::Disconnected);
1643 }
1644
1645 #[test]
1646 fn test_normalize_tool_arguments_object() {
1647 let params = json!({
1648 "arguments": { "n": 42 }
1649 });
1650 assert_eq!(normalize_tool_arguments(¶ms), json!({ "n": 42 }));
1651 }
1652
1653 #[test]
1654 fn test_normalize_tool_arguments_string() {
1655 let params = json!({
1656 "arguments": "{\"n\":42}"
1657 });
1658 assert_eq!(normalize_tool_arguments(¶ms), json!({ "n": 42 }));
1659 }
1660
1661 #[test]
1662 fn test_normalize_tool_arguments_fallback_arguments_json() {
1663 let params = json!({
1664 "argumentsJson": "{\"text\":\"hello\",\"shift\":-5}"
1665 });
1666 assert_eq!(
1667 normalize_tool_arguments(¶ms),
1668 json!({ "text": "hello", "shift": -5 })
1669 );
1670 }
1671
1672 #[test]
1673 fn test_normalize_tool_arguments_invalid_json_string() {
1674 let params = json!({
1675 "arguments": "{not valid json"
1676 });
1677 assert_eq!(normalize_tool_arguments(¶ms), json!({}));
1678 }
1679}