1use crate::error::{IFlowError, Result};
8use crate::types::{IFlowOptions, Message, PermissionMode};
9use crate::websocket_transport::WebSocketTransport;
10use serde_json::{Value, json};
11use std::collections::HashMap;
12use std::time::Duration;
13use tokio::sync::mpsc::UnboundedSender;
14use tokio::time::timeout;
15use tracing::debug;
16
17pub struct ACPProtocol {
22 transport: WebSocketTransport,
24 initialized: bool,
26 authenticated: bool,
28 request_id: u32,
30 message_sender: UnboundedSender<Message>,
32 protocol_version: u32,
34 permission_mode: PermissionMode,
36 timeout_secs: f64,
38}
39
40impl ACPProtocol {
41 pub fn new(
48 transport: WebSocketTransport,
49 message_sender: UnboundedSender<Message>,
50 timeout_secs: f64,
51 ) -> Self {
52 Self {
53 transport,
54 initialized: false,
55 authenticated: false,
56 request_id: 0,
57 message_sender,
58 protocol_version: 1,
59 permission_mode: PermissionMode::Auto,
60 timeout_secs,
61 }
62 }
63
64 pub fn is_initialized(&self) -> bool {
69 self.initialized
70 }
71
72 pub fn is_authenticated(&self) -> bool {
77 self.authenticated
78 }
79
80 pub fn set_permission_mode(&mut self, mode: PermissionMode) {
85 self.permission_mode = mode;
86 }
87
88 fn next_request_id(&mut self) -> u32 {
93 self.request_id += 1;
94 self.request_id
95 }
96
97 pub async fn initialize(&mut self, options: &IFlowOptions) -> Result<()> {
111 if self.initialized {
112 tracing::warn!("Protocol already initialized");
113 return Ok(());
114 }
115
116 debug!("Initializing ACP protocol");
117
118 debug!("Waiting for //ready signal...");
120 let ready_timeout = Duration::from_secs_f64(self.timeout_secs);
121 let start_time = std::time::Instant::now();
122
123 loop {
124 if start_time.elapsed() > ready_timeout {
125 return Err(IFlowError::Timeout(
126 "Timeout waiting for //ready signal".to_string(),
127 ));
128 }
129
130 let msg = match timeout(
131 Duration::from_secs_f64(self.timeout_secs.min(10.0)),
132 self.transport.receive(),
133 )
134 .await
135 {
136 Ok(Ok(msg)) => msg,
137 Ok(Err(e)) => {
138 tracing::error!("Transport error while waiting for //ready: {}", e);
139 tokio::time::sleep(Duration::from_millis(500)).await;
141 continue;
142 }
143 Err(_) => {
144 tracing::debug!("No message received, continuing to wait for //ready...");
145 continue;
146 }
147 };
148
149 let trimmed_msg = msg.trim();
150 if trimmed_msg == "//ready" {
151 debug!("Received //ready signal");
152 break;
153 } else if trimmed_msg.starts_with("//") {
154 tracing::debug!("Control message: {}", trimmed_msg);
156 continue;
157 } else if !trimmed_msg.is_empty() {
158 tracing::debug!(
160 "Non-control message received while waiting for //ready: {}",
161 trimmed_msg
162 );
163 continue;
164 }
165 }
166
167 tokio::time::sleep(Duration::from_millis(100)).await;
169
170 let request_id = self.next_request_id();
172 let mut params = json!({
173 "protocolVersion": self.protocol_version,
174 "clientCapabilities": {
175 "fs": {
176 "readTextFile": true,
177 "writeTextFile": true
178 }
179 }
180 });
181
182 if !options.mcp_servers.is_empty() {
184 let mcp_servers: Vec<serde_json::Value> = options
186 .mcp_servers
187 .iter()
188 .map(|server| {
189 json!(server)
192 })
193 .collect();
194 params["mcpServers"] = json!(mcp_servers);
195 }
196
197 let request = json!({
198 "jsonrpc": "2.0",
199 "id": request_id,
200 "method": "initialize",
201 "params": params,
202 });
203
204 let mut send_attempts = 0;
206 let max_send_attempts = 3;
207
208 while send_attempts < max_send_attempts {
209 match self.transport.send(&request).await {
210 Ok(_) => {
211 debug!("Sent initialize request (attempt {})", send_attempts + 1);
212 break;
213 }
214 Err(e) => {
215 send_attempts += 1;
216 tracing::warn!(
217 "Failed to send initialize request (attempt {}): {}",
218 send_attempts,
219 e
220 );
221 if send_attempts >= max_send_attempts {
222 return Err(IFlowError::Protocol(format!(
223 "Failed to send initialize request after {} attempts: {}",
224 max_send_attempts, e
225 )));
226 }
227 tokio::time::sleep(Duration::from_millis(500)).await;
228 }
229 }
230 }
231
232 let response_timeout = Duration::from_secs_f64(self.timeout_secs);
234 let response = timeout(response_timeout, self.wait_for_response(request_id))
235 .await
236 .map_err(|_| {
237 IFlowError::Timeout("Timeout waiting for initialize response".to_string())
238 })?
239 .map_err(|e| IFlowError::Protocol(format!("Failed to initialize: {}", e)))?;
240
241 if let Some(result) = response.get("result") {
242 self.authenticated = result
243 .get("isAuthenticated")
244 .and_then(|v| v.as_bool())
245 .unwrap_or(false);
246 self.initialized = true;
247 debug!(
248 "Initialized with protocol version: {:?}, authenticated: {}",
249 result.get("protocolVersion"),
250 self.authenticated
251 );
252 } else if let Some(error) = response.get("error") {
253 return Err(IFlowError::Protocol(format!(
254 "Initialize failed: {:?}",
255 error
256 )));
257 } else {
258 return Err(IFlowError::Protocol(
259 "Invalid initialize response".to_string(),
260 ));
261 }
262
263 Ok(())
264 }
265
266 pub async fn authenticate(
279 &mut self,
280 method_id: &str,
281 method_info: Option<HashMap<String, String>>,
282 ) -> Result<()> {
283 if self.authenticated {
284 debug!("Already authenticated");
285 return Ok(());
286 }
287
288 let request_id = self.next_request_id();
289 let mut params = json!({
290 "methodId": method_id,
291 });
292
293 if let Some(info) = method_info {
294 params["methodInfo"] = json!(info);
295 }
296
297 let request = json!({
298 "jsonrpc": "2.0",
299 "id": request_id,
300 "method": "authenticate",
301 "params": params,
302 });
303
304 self.transport.send(&request).await?;
305 debug!("Sent authenticate request with method: {}", method_id);
306
307 let response_timeout = Duration::from_secs_f64(self.timeout_secs);
309 let response = timeout(response_timeout, self.wait_for_response(request_id))
310 .await
311 .map_err(|_| {
312 IFlowError::Timeout("Timeout waiting for authentication response".to_string())
313 })?
314 .map_err(|e| IFlowError::Protocol(format!("Failed to authenticate: {}", e)))?;
315
316 if let Some(result) = response.get("result") {
317 if let Some(response_method) = result.get("methodId").and_then(|v| v.as_str()) {
318 if response_method == method_id {
319 self.authenticated = true;
320 debug!("Authentication successful with method: {}", response_method);
321 } else {
322 tracing::warn!(
323 "Unexpected methodId in response: {} (expected {})",
324 response_method,
325 method_id
326 );
327 self.authenticated = true;
329 }
330 } else {
331 self.authenticated = true;
332 }
333 } else if let Some(error) = response.get("error") {
334 return Err(IFlowError::Authentication(format!(
335 "Authentication failed: {:?}",
336 error
337 )));
338 } else {
339 return Err(IFlowError::Protocol(
340 "Invalid authenticate response".to_string(),
341 ));
342 }
343
344 Ok(())
345 }
346
347 pub async fn create_session(
357 &mut self,
358 cwd: &str,
359 mcp_servers: Vec<serde_json::Value>,
360 ) -> Result<String> {
361 if !self.initialized {
362 return Err(IFlowError::Protocol(
363 "Protocol not initialized. Call initialize() first.".to_string(),
364 ));
365 }
366
367 if !self.authenticated {
368 return Err(IFlowError::Protocol(
369 "Not authenticated. Call authenticate() first.".to_string(),
370 ));
371 }
372
373 let request_id = self.next_request_id();
374 let params = json!({
375 "cwd": cwd,
376 "mcpServers": mcp_servers,
377 });
378
379 let request = json!({
380 "jsonrpc": "2.0",
381 "id": request_id,
382 "method": "session/new",
383 "params": params,
384 });
385
386 self.transport.send(&request).await?;
387 debug!(
388 "Sent session/new request with cwd: {} and mcpServers: {:?}",
389 cwd, mcp_servers
390 );
391
392 let response_timeout = Duration::from_secs_f64(self.timeout_secs);
394 let response = timeout(response_timeout, self.wait_for_response(request_id))
395 .await
396 .map_err(|_| {
397 IFlowError::Timeout("Timeout waiting for session creation response".to_string())
398 })?
399 .map_err(|e| IFlowError::Protocol(format!("Failed to create session: {}", e)))?;
400
401 if let Some(result) = response.get("result") {
402 if let Some(session_id) = result.get("sessionId").and_then(|v| v.as_str()) {
403 debug!("Created session: {}", session_id);
404 Ok(session_id.to_string())
405 } else {
406 debug!(
407 "Invalid session/new response, using fallback ID: session_{}",
408 request_id
409 );
410 Ok(format!("session_{}", request_id))
411 }
412 } else if let Some(error) = response.get("error") {
413 Err(IFlowError::Protocol(format!(
414 "session/new failed: {:?}",
415 error
416 )))
417 } else {
418 Err(IFlowError::Protocol(
419 "Invalid session/new response".to_string(),
420 ))
421 }
422 }
423
424 pub async fn send_prompt(&mut self, session_id: &str, prompt: &str) -> Result<u32> {
434 if !self.initialized {
435 return Err(IFlowError::Protocol(
436 "Protocol not initialized. Call initialize() first.".to_string(),
437 ));
438 }
439
440 if !self.authenticated {
441 return Err(IFlowError::Protocol(
442 "Not authenticated. Call authenticate() first.".to_string(),
443 ));
444 }
445
446 let request_id = self.next_request_id();
447 let prompt_blocks = vec![json!({
449 "type": "text",
450 "text": prompt
451 })];
452
453 let params = json!({
454 "sessionId": session_id,
455 "prompt": prompt_blocks,
456 });
457
458 let request = json!({
459 "jsonrpc": "2.0",
460 "id": request_id,
461 "method": "session/prompt",
462 "params": params,
463 });
464
465 self.transport.send(&request).await?;
466 debug!("Sent session/prompt");
467
468 let response_timeout = Duration::from_secs_f64(self.timeout_secs);
470 let response = timeout(response_timeout, self.wait_for_response_with_notifications(request_id))
471 .await
472 .map_err(|_| IFlowError::Timeout("Timeout waiting for prompt response".to_string()))?
473 .map_err(|e| IFlowError::Protocol(format!("Failed to send prompt: {}", e)))?;
474
475 if let Some(error) = response.get("error") {
477 return Err(IFlowError::Protocol(format!("Prompt failed: {:?}", error)));
478 }
479
480 let msg = Message::TaskFinish {
482 reason: Some("completed".to_string()),
483 };
484 let _ = self.message_sender.send(msg);
485
486 Ok(request_id)
487 }
488
489 async fn wait_for_response(&mut self, request_id: u32) -> Result<Value> {
498 let timeout_duration = Duration::from_secs_f64(self.timeout_secs);
499 let start_time = std::time::Instant::now();
500
501 loop {
502 if start_time.elapsed() > timeout_duration {
503 return Err(IFlowError::Timeout(format!(
504 "Timeout waiting for response to request {}",
505 request_id
506 )));
507 }
508
509 let receive_timeout = Duration::from_secs_f64(self.timeout_secs.min(1.0));
511 let msg = match timeout(receive_timeout, self.transport.receive()).await {
512 Ok(Ok(msg)) => msg,
513 Ok(Err(e)) => {
514 tracing::error!("Transport error while waiting for response: {}", e);
515 return Err(e);
516 }
517 Err(_) => {
518 tracing::debug!(
520 "No message received, continuing to wait for response to request {}...",
521 request_id
522 );
523 continue;
524 }
525 };
526
527 if msg.starts_with("//") {
529 tracing::debug!("Control message: {}", msg);
530 continue;
531 }
532
533 let data: Value = match serde_json::from_str(&msg) {
535 Ok(data) => data,
536 Err(e) => {
537 tracing::debug!("Failed to parse message as JSON: {}, message: {}", e, msg);
538 continue;
539 }
540 };
541
542 if let Some(id) = data.get("id").and_then(|v| v.as_u64()) {
544 if id == request_id as u64 {
545 return Ok(data);
546 }
547 }
548
549 if let Err(e) = self.handle_notification(data).await {
551 tracing::warn!("Failed to handle notification: {}", e);
552 }
554 }
555 }
556
557 async fn wait_for_response_with_notifications(&mut self, request_id: u32) -> Result<Value> {
566 let timeout_duration = Duration::from_secs_f64(self.timeout_secs);
567 let start_time = std::time::Instant::now();
568
569 loop {
570 if start_time.elapsed() > timeout_duration {
571 return Err(IFlowError::Timeout(format!(
572 "Timeout waiting for response to request {}",
573 request_id
574 )));
575 }
576
577 let receive_timeout = Duration::from_secs_f64(self.timeout_secs.min(1.0));
579 let msg = match timeout(receive_timeout, self.transport.receive()).await {
580 Ok(Ok(msg)) => msg,
581 Ok(Err(e)) => {
582 tracing::error!("Transport error while waiting for response: {}", e);
583 return Err(e);
584 }
585 Err(_) => {
586 tracing::debug!(
588 "No message received, continuing to wait for response to request {}...",
589 request_id
590 );
591 continue;
592 }
593 };
594
595 if msg.starts_with("//") {
597 tracing::debug!("Control message: {}", msg);
598 continue;
599 }
600
601 let data: Value = match serde_json::from_str(&msg) {
603 Ok(data) => data,
604 Err(e) => {
605 tracing::debug!("Failed to parse message as JSON: {}, message: {}", e, msg);
606 continue;
607 }
608 };
609
610 if let Some(id) = data.get("id").and_then(|v| v.as_u64()) {
612 if let Some(method) = data.get("method").and_then(|v| v.as_str()) {
614 if method == "session/request_permission" {
615 tracing::debug!("Handling session/request_permission with ID: {}", id);
616 if let Err(e) = self.handle_client_method(method, data.clone()).await {
618 tracing::warn!("Failed to handle permission request: {}", e);
619 }
620 continue;
622 }
623 }
624
625 if id == request_id as u64 {
627 return Ok(data);
628 }
629 }
630
631 if let Err(e) = self.handle_notification(data).await {
633 tracing::warn!("Failed to handle notification: {}", e);
634 }
636 }
637 }
638
639 async fn handle_notification(&mut self, data: Value) -> Result<()> {
648 if let Some(method) = data.get("method").and_then(|v| v.as_str()) {
650 if data.get("result").is_none() && data.get("error").is_none() {
651 self.handle_client_method(method, data.clone()).await?;
652 }
653 }
654
655 Ok(())
656 }
657
658 async fn handle_client_method(&mut self, method: &str, data: Value) -> Result<()> {
668 let params = data.get("params").cloned().unwrap_or(Value::Null);
669 let request_id = data.get("id").and_then(|v| v.as_u64());
670
671 match method {
672 "session/update" => {
673 if let Some(update_obj) = params.get("update").and_then(|v| v.as_object()) {
674 if let Some(session_update) =
675 update_obj.get("sessionUpdate").and_then(|v| v.as_str())
676 {
677 self.handle_session_update(session_update, update_obj, request_id)
678 .await?;
679 }
680 }
681 }
682 "session/request_permission" => {
683 tracing::debug!("Handling session/request_permission");
685 self.handle_permission_request(params, request_id).await?;
686 }
687 _ => {
688 tracing::warn!("Unknown method: {}", method);
689 if let Some(id) = request_id {
691 let error_response = json!({
692 "jsonrpc": "2.0",
693 "id": id,
694 "error": {
695 "code": -32601,
696 "message": "Method not found"
697 }
698 });
699 self.transport.send(&error_response).await?;
700 }
701 }
702 }
703
704 Ok(())
705 }
706
707 async fn handle_permission_request(
717 &mut self,
718 params: Value,
719 request_id: Option<u64>,
720 ) -> Result<()> {
721 let tool_call = params.get("toolCall").unwrap_or(&Value::Null);
723 let options = params.get("options").unwrap_or(&Value::Null);
724 let _session_id = params.get("sessionId").and_then(|v| v.as_str());
725
726 let tool_title = tool_call
728 .get("title")
729 .and_then(|v| v.as_str())
730 .unwrap_or("unknown");
731 let tool_type = tool_call
732 .get("type")
733 .and_then(|v| v.as_str())
734 .unwrap_or("unknown");
735
736 tracing::debug!(
737 "Permission request for tool '{}' (type: {})",
738 tool_title,
739 tool_type
740 );
741
742 let auto_approve = match self.permission_mode {
744 PermissionMode::Auto => {
745 true
747 }
748 PermissionMode::Manual => {
749 false
751 }
752 PermissionMode::Selective => {
753 tool_type == "read" || tool_type == "fetch" || tool_type == "list"
756 }
757 };
758
759 use agent_client_protocol::{RequestPermissionOutcome, RequestPermissionResponse};
760 let permission_response = if auto_approve {
761 let mut selected_option = "proceed_once".to_string();
763 if let Some(options_array) = options.as_array() {
764 for option in options_array {
765 if let Some(option_id) = option.get("optionId").and_then(|v| v.as_str()) {
766 if option_id == "proceed_once" {
767 selected_option = option_id.to_string();
768 break;
769 } else if option_id == "proceed_always" {
770 selected_option = option_id.to_string();
771 }
772 }
773 }
774 if selected_option == "proceed_once" && !options_array.is_empty() {
776 if let Some(first_option_id) = options_array[0].get("optionId").and_then(|v| v.as_str()) {
777 selected_option = first_option_id.to_string();
778 }
779 }
780 }
781 RequestPermissionResponse {
782 outcome: RequestPermissionOutcome::Selected {
783 option_id: agent_client_protocol::PermissionOptionId(std::sync::Arc::from(selected_option)),
784 },
785 meta: None,
786 }
787 } else {
788 RequestPermissionResponse {
789 outcome: RequestPermissionOutcome::Cancelled,
790 meta: None,
791 }
792 };
793
794 if let Some(id) = request_id {
796 let response_message = serde_json::json!({
797 "jsonrpc": "2.0",
798 "id": id,
799 "result": permission_response
800 });
801 self.transport.send(&response_message).await?;
802 }
803
804 let outcome_str = match &permission_response.outcome {
805 RequestPermissionOutcome::Cancelled => "cancelled",
806 RequestPermissionOutcome::Selected { option_id } => &*option_id.0,
807 };
808 tracing::debug!("Permission request for tool '{}': {}", tool_title, outcome_str);
809
810 Ok(())
811 }
812
813 async fn handle_session_update(
824 &mut self,
825 update_type: &str,
826 update: &serde_json::Map<String, Value>,
827 request_id: Option<u64>,
828 ) -> Result<()> {
829 match update_type {
830 "agent_message_chunk" => {
831 if let Some(content) = update.get("content") {
832 let text = match content {
833 Value::Object(obj) => {
834 if let Some(text_content) = obj.get("text").and_then(|v| v.as_str()) {
835 text_content.to_string()
836 } else {
837 "<unknown>".to_string()
838 }
839 }
840 _ => "<unknown>".to_string(),
841 };
842
843 let msg = Message::Assistant { content: text };
844 let _ = self.message_sender.send(msg);
845 }
846 }
847 "user_message_chunk" => {
848 if let Some(content) = update.get("content") {
849 let text = match content {
850 Value::Object(obj) => {
851 if let Some(text_content) = obj.get("text").and_then(|v| v.as_str()) {
852 text_content.to_string()
853 } else {
854 "<unknown>".to_string()
855 }
856 }
857 _ => "<unknown>".to_string(),
858 };
859
860 let msg = Message::User { content: text };
861 let _ = self.message_sender.send(msg);
862 }
863 }
864 "tool_call" => {
865 if let Some(tool_call) = update.get("toolCall") {
866 let id = tool_call
867 .get("id")
868 .and_then(|v| v.as_str())
869 .unwrap_or("")
870 .to_string();
871 let name = tool_call
872 .get("title")
873 .and_then(|v| v.as_str())
874 .unwrap_or("Unknown")
875 .to_string();
876 let status = tool_call
877 .get("status")
878 .and_then(|v| v.as_str())
879 .unwrap_or("unknown")
880 .to_string();
881
882 let msg = Message::ToolCall { id, name, status };
883 let _ = self.message_sender.send(msg);
884 }
885 }
886 "plan" => {
887 if let Some(entries) = update.get("entries").and_then(|v| v.as_array()) {
888 let entries: Vec<super::types::PlanEntry> = entries
889 .iter()
890 .filter_map(|entry| {
891 let content =
892 entry.get("content").and_then(|v| v.as_str())?.to_string();
893 let priority_str = entry
894 .get("priority")
895 .and_then(|v| v.as_str())
896 .unwrap_or("medium");
897 let status_str = entry
898 .get("status")
899 .and_then(|v| v.as_str())
900 .unwrap_or("pending");
901
902 let priority = match priority_str {
903 "high" => super::types::PlanPriority::High,
904 "medium" => super::types::PlanPriority::Medium,
905 "low" => super::types::PlanPriority::Low,
906 _ => super::types::PlanPriority::Medium,
907 };
908
909 let status = match status_str {
910 "pending" => super::types::PlanStatus::Pending,
911 "in_progress" => super::types::PlanStatus::InProgress,
912 "completed" => super::types::PlanStatus::Completed,
913 _ => super::types::PlanStatus::Pending,
914 };
915
916 Some(super::types::PlanEntry {
917 content,
918 priority,
919 status,
920 })
921 })
922 .collect();
923
924 let msg = Message::Plan { entries };
925 let _ = self.message_sender.send(msg);
926 }
927 }
928 "tool_call_update" => {
929 if let Some(id) = request_id {
931 let response = json!({
932 "jsonrpc": "2.0",
933 "id": id,
934 "result": null
935 });
936 self.transport.send(&response).await?;
937 }
938 }
939 "agent_thought_chunk" | "current_mode_update" | "available_commands_update" => {
940 }
942 _ => {
943 tracing::debug!("Unhandled session update type: {}", update_type);
944 }
945 }
946
947 if let Some(id) = request_id {
949 match update_type {
950 "tool_call_update" | "notifyTaskFinish" => {
951 let response = json!({
952 "jsonrpc": "2.0",
953 "id": id,
954 "result": null
955 });
956 self.transport.send(&response).await?;
957 }
958 _ => {}
959 }
960 }
961
962 Ok(())
963 }
964
965 pub async fn close(&mut self) -> Result<()> {
967 self.transport.close().await?;
968 Ok(())
969 }
970
971 pub fn is_connected(&self) -> bool {
976 self.transport.is_connected()
977 }
978}