1use crate::error::{IFlowError, Result};
8use crate::types::{IFlowOptions, Message, PermissionMode};
9use crate::websocket_transport::WebSocketTransport;
10use serde_json::{Value, json};
12use std::collections::HashMap;
13use std::time::Duration;
14use tokio::sync::mpsc::UnboundedSender;
15use tokio::time::timeout;
16use tracing::debug;
17
18pub struct ACPProtocol {
23 transport: WebSocketTransport,
25 initialized: bool,
27 authenticated: bool,
29 request_id: u32,
31 message_sender: UnboundedSender<Message>,
33 protocol_version: u32,
35 permission_mode: PermissionMode,
37 timeout_secs: f64,
39}
40
41impl ACPProtocol {
42 pub fn new(
49 transport: WebSocketTransport,
50 message_sender: UnboundedSender<Message>,
51 timeout_secs: f64,
52 ) -> Self {
53 Self {
54 transport,
55 initialized: false,
56 authenticated: false,
57 request_id: 0,
58 message_sender,
59 protocol_version: 1,
60 permission_mode: PermissionMode::Auto,
61 timeout_secs,
62 }
63 }
64
65 pub fn is_initialized(&self) -> bool {
70 self.initialized
71 }
72
73 pub fn is_authenticated(&self) -> bool {
78 self.authenticated
79 }
80
81 pub fn set_permission_mode(&mut self, mode: PermissionMode) {
86 self.permission_mode = mode;
87 }
88
89 fn next_request_id(&mut self) -> u32 {
94 self.request_id += 1;
95 self.request_id
96 }
97
98 pub async fn initialize(&mut self, options: &IFlowOptions) -> Result<()> {
112 if self.initialized {
113 tracing::warn!("Protocol already initialized");
114 return Ok(());
115 }
116
117 debug!("Initializing ACP protocol");
118
119 debug!("Waiting for //ready signal...");
121 let ready_timeout = Duration::from_secs_f64(self.timeout_secs);
122 let start_time = std::time::Instant::now();
123
124 loop {
125 if start_time.elapsed() > ready_timeout {
126 return Err(IFlowError::Timeout(
127 "Timeout waiting for //ready signal".to_string(),
128 ));
129 }
130
131 let msg = match timeout(
132 Duration::from_secs_f64(self.timeout_secs.min(10.0)),
133 self.transport.receive(),
134 )
135 .await
136 {
137 Ok(Ok(msg)) => msg,
138 Ok(Err(e)) => {
139 tracing::error!("Transport error while waiting for //ready: {}", e);
140 tokio::time::sleep(Duration::from_millis(500)).await;
142 continue;
143 }
144 Err(_) => {
145 tracing::debug!("No message received, continuing to wait for //ready...");
146 continue;
147 }
148 };
149
150 let trimmed_msg = msg.trim();
151 if trimmed_msg == "//ready" {
152 debug!("Received //ready signal");
153 break;
154 } else if trimmed_msg.starts_with("//") {
155 tracing::debug!("Control message: {}", trimmed_msg);
157 continue;
158 } else if !trimmed_msg.is_empty() {
159 tracing::debug!(
161 "Non-control message received while waiting for //ready: {}",
162 trimmed_msg
163 );
164 continue;
165 }
166 }
167
168 tokio::time::sleep(Duration::from_millis(100)).await;
170
171 let request_id = self.next_request_id();
173 let mut params = json!({
174 "protocolVersion": self.protocol_version,
175 "clientCapabilities": {
176 "fs": {
177 "readTextFile": true,
178 "writeTextFile": true
179 }
180 }
181 });
182
183 if !options.mcp_servers.is_empty() {
185 let mcp_servers: Vec<serde_json::Value> = options
187 .mcp_servers
188 .iter()
189 .map(|server| {
190 json!(server)
193 })
194 .collect();
195 params["mcpServers"] = json!(mcp_servers);
196 }
197
198 let request = json!({
199 "jsonrpc": "2.0",
200 "id": request_id,
201 "method": "initialize",
202 "params": params,
203 });
204
205 let mut send_attempts = 0;
207 let max_send_attempts = 3;
208
209 while send_attempts < max_send_attempts {
210 match self.transport.send(&request).await {
211 Ok(_) => {
212 debug!("Sent initialize request (attempt {})", send_attempts + 1);
213 break;
214 }
215 Err(e) => {
216 send_attempts += 1;
217 tracing::warn!(
218 "Failed to send initialize request (attempt {}): {}",
219 send_attempts,
220 e
221 );
222 if send_attempts >= max_send_attempts {
223 return Err(IFlowError::Protocol(format!(
224 "Failed to send initialize request after {} attempts: {}",
225 max_send_attempts, e
226 )));
227 }
228 tokio::time::sleep(Duration::from_millis(500)).await;
229 }
230 }
231 }
232
233 let response_timeout = Duration::from_secs_f64(self.timeout_secs);
235 let response = timeout(response_timeout, self.wait_for_response(request_id))
236 .await
237 .map_err(|_| {
238 IFlowError::Timeout("Timeout waiting for initialize response".to_string())
239 })?
240 .map_err(|e| IFlowError::Protocol(format!("Failed to initialize: {}", e)))?;
241
242 if let Some(result) = response.get("result") {
243 self.authenticated = result
244 .get("isAuthenticated")
245 .and_then(|v| v.as_bool())
246 .unwrap_or(false);
247 self.initialized = true;
248 debug!(
249 "Initialized with protocol version: {:?}, authenticated: {}",
250 result.get("protocolVersion"),
251 self.authenticated
252 );
253 } else if let Some(error) = response.get("error") {
254 return Err(IFlowError::Protocol(format!(
255 "Initialize failed: {:?}",
256 error
257 )));
258 } else {
259 return Err(IFlowError::Protocol(
260 "Invalid initialize response".to_string(),
261 ));
262 }
263
264 Ok(())
265 }
266
267 pub async fn authenticate(
280 &mut self,
281 method_id: &str,
282 method_info: Option<HashMap<String, String>>,
283 ) -> Result<()> {
284 if self.authenticated {
285 debug!("Already authenticated");
286 return Ok(());
287 }
288
289 let request_id = self.next_request_id();
290 let mut params = json!({
291 "methodId": method_id,
292 });
293
294 if let Some(info) = method_info {
295 params["methodInfo"] = json!(info);
296 }
297
298 let request = json!({
299 "jsonrpc": "2.0",
300 "id": request_id,
301 "method": "authenticate",
302 "params": params,
303 });
304
305 self.transport.send(&request).await?;
306 debug!("Sent authenticate request with method: {}", method_id);
307
308 let response_timeout = Duration::from_secs_f64(self.timeout_secs);
310 let response = timeout(response_timeout, self.wait_for_response(request_id))
311 .await
312 .map_err(|_| {
313 IFlowError::Timeout("Timeout waiting for authentication response".to_string())
314 })?
315 .map_err(|e| IFlowError::Protocol(format!("Failed to authenticate: {}", e)))?;
316
317 if let Some(result) = response.get("result") {
318 if let Some(response_method) = result.get("methodId").and_then(|v| v.as_str()) {
319 if response_method == method_id {
320 self.authenticated = true;
321 debug!("Authentication successful with method: {}", response_method);
322 } else {
323 tracing::warn!(
324 "Unexpected methodId in response: {} (expected {})",
325 response_method,
326 method_id
327 );
328 self.authenticated = true;
330 }
331 } else {
332 self.authenticated = true;
333 }
334 } else if let Some(error) = response.get("error") {
335 return Err(IFlowError::Authentication(format!(
336 "Authentication failed: {:?}",
337 error
338 )));
339 } else {
340 return Err(IFlowError::Protocol(
341 "Invalid authenticate response".to_string(),
342 ));
343 }
344
345 Ok(())
346 }
347
348 pub async fn create_session(
358 &mut self,
359 cwd: &str,
360 mcp_servers: Vec<serde_json::Value>,
361 ) -> Result<String> {
362 if !self.initialized {
363 return Err(IFlowError::Protocol(
364 "Protocol not initialized. Call initialize() first.".to_string(),
365 ));
366 }
367
368 if !self.authenticated {
369 return Err(IFlowError::Protocol(
370 "Not authenticated. Call authenticate() first.".to_string(),
371 ));
372 }
373
374 let request_id = self.next_request_id();
375 let params = json!({
376 "cwd": cwd,
377 "mcpServers": mcp_servers,
378 });
379
380 let request = json!({
381 "jsonrpc": "2.0",
382 "id": request_id,
383 "method": "session/new",
384 "params": params,
385 });
386
387 self.transport.send(&request).await?;
388 debug!(
389 "Sent session/new request with cwd: {} and mcpServers: {:?}",
390 cwd, mcp_servers
391 );
392
393 let response_timeout = Duration::from_secs_f64(self.timeout_secs);
395 let response = timeout(response_timeout, self.wait_for_response(request_id))
396 .await
397 .map_err(|_| {
398 IFlowError::Timeout("Timeout waiting for session creation response".to_string())
399 })?
400 .map_err(|e| IFlowError::Protocol(format!("Failed to create session: {}", e)))?;
401
402 if let Some(result) = response.get("result") {
403 if let Some(session_id) = result.get("sessionId").and_then(|v| v.as_str()) {
404 debug!("Created session: {}", session_id);
405 Ok(session_id.to_string())
406 } else {
407 debug!(
408 "Invalid session/new response, using fallback ID: session_{}",
409 request_id
410 );
411 Ok(format!("session_{}", request_id))
412 }
413 } else if let Some(error) = response.get("error") {
414 Err(IFlowError::Protocol(format!(
415 "session/new failed: {:?}",
416 error
417 )))
418 } else {
419 Err(IFlowError::Protocol(
420 "Invalid session/new response".to_string(),
421 ))
422 }
423 }
424
425 pub async fn send_prompt(&mut self, session_id: &str, prompt: &str) -> Result<u32> {
435 if !self.initialized {
436 return Err(IFlowError::Protocol(
437 "Protocol not initialized. Call initialize() first.".to_string(),
438 ));
439 }
440
441 if !self.authenticated {
442 return Err(IFlowError::Protocol(
443 "Not authenticated. Call authenticate() first.".to_string(),
444 ));
445 }
446
447 let request_id = self.next_request_id();
448 let prompt_blocks = vec![json!({
450 "type": "text",
451 "text": prompt
452 })];
453
454 let params = json!({
455 "sessionId": session_id,
456 "prompt": prompt_blocks,
457 });
458
459 let request = json!({
460 "jsonrpc": "2.0",
461 "id": request_id,
462 "method": "session/prompt",
463 "params": params,
464 });
465
466 self.transport.send(&request).await?;
467 debug!("Sent session/prompt");
468
469 let response_timeout = Duration::from_secs_f64(self.timeout_secs);
471 let response = timeout(response_timeout, self.wait_for_response(request_id))
472 .await
473 .map_err(|_| IFlowError::Timeout("Timeout waiting for prompt response".to_string()))?
474 .map_err(|e| IFlowError::Protocol(format!("Failed to send prompt: {}", e)))?;
475
476 if let Some(error) = response.get("error") {
478 return Err(IFlowError::Protocol(format!("Prompt failed: {:?}", error)));
479 }
480
481 let msg = Message::TaskFinish {
483 reason: Some("completed".to_string()),
484 };
485 let _ = self.message_sender.send(msg);
486
487 Ok(request_id)
488 }
489
490 async fn wait_for_response(&mut self, request_id: u32) -> Result<Value> {
499 let timeout_duration = Duration::from_secs_f64(self.timeout_secs);
500 let start_time = std::time::Instant::now();
501
502 loop {
503 if start_time.elapsed() > timeout_duration {
504 return Err(IFlowError::Timeout(format!(
505 "Timeout waiting for response to request {}",
506 request_id
507 )));
508 }
509
510 let msg = match timeout(
511 Duration::from_secs_f64(self.timeout_secs.min(5.0)),
512 self.transport.receive(),
513 )
514 .await
515 {
516 Ok(Ok(msg)) => msg,
517 Ok(Err(e)) => {
518 tracing::error!("Transport error while waiting for response: {}", e);
519 return Err(e);
520 }
521 Err(_) => {
522 tracing::debug!(
523 "No message received, continuing to wait for response to request {}...",
524 request_id
525 );
526 continue;
527 }
528 };
529
530 if msg.starts_with("//") {
532 tracing::debug!("Control message: {}", msg);
533 continue;
534 }
535
536 let data: Value = match serde_json::from_str(&msg) {
538 Ok(data) => data,
539 Err(e) => {
540 tracing::debug!("Failed to parse message as JSON: {}, message: {}", e, msg);
541 continue;
542 }
543 };
544
545 if let Some(id) = data.get("id").and_then(|v| v.as_u64()) {
547 if id == request_id as u64 {
548 return Ok(data);
549 }
550 }
551
552 if let Err(e) = self.handle_notification(data).await {
554 tracing::warn!("Failed to handle notification: {}", e);
555 }
557 }
558 }
559
560 async fn handle_notification(&mut self, data: Value) -> Result<()> {
569 if let Some(method) = data.get("method").and_then(|v| v.as_str()) {
571 if data.get("result").is_none() && data.get("error").is_none() {
572 self.handle_client_method(method, data.clone()).await?;
573 }
574 }
575
576 Ok(())
577 }
578
579 async fn handle_client_method(&mut self, method: &str, data: Value) -> Result<()> {
589 let params = data.get("params").cloned().unwrap_or(Value::Null);
590 let request_id = data.get("id").and_then(|v| v.as_u64());
591
592 match method {
593 "session/update" => {
594 if let Some(update_obj) = params.get("update").and_then(|v| v.as_object()) {
595 if let Some(session_update) =
596 update_obj.get("sessionUpdate").and_then(|v| v.as_str())
597 {
598 self.handle_session_update(session_update, update_obj, request_id)
599 .await?;
600 }
601 }
602 }
603 "session/request_permission" => {
604 self.handle_permission_request(params, request_id).await?;
606 }
607 _ => {
608 tracing::warn!("Unknown method: {}", method);
609 if let Some(id) = request_id {
611 let error_response = json!({
612 "jsonrpc": "2.0",
613 "id": id,
614 "error": {
615 "code": -32601,
616 "message": "Method not found"
617 }
618 });
619 self.transport.send(&error_response).await?;
620 }
621 }
622 }
623
624 Ok(())
625 }
626
627 async fn handle_permission_request(
637 &mut self,
638 params: Value,
639 request_id: Option<u64>,
640 ) -> Result<()> {
641 let tool_call = params.get("toolCall").unwrap_or(&Value::Null);
643 let options = params.get("options").unwrap_or(&Value::Null);
644 let _session_id = params.get("sessionId").and_then(|v| v.as_str());
645
646 let tool_title = tool_call
648 .get("title")
649 .and_then(|v| v.as_str())
650 .unwrap_or("unknown");
651 let tool_type = tool_call
652 .get("type")
653 .and_then(|v| v.as_str())
654 .unwrap_or("unknown");
655
656 tracing::debug!(
657 "Permission request for tool '{}' (type: {})",
658 tool_title,
659 tool_type
660 );
661
662 let auto_approve = match self.permission_mode {
664 PermissionMode::Auto => {
665 true
667 }
668 PermissionMode::Manual => {
669 false
671 }
672 PermissionMode::Selective => {
673 tool_type == "read" || tool_type == "fetch" || tool_type == "list"
676 }
677 };
678
679 let response = if auto_approve {
680 let mut selected_option = "proceed_once".to_string();
682 if let Some(options_array) = options.as_array() {
683 for option in options_array {
684 if let Some(option_id) = option.get("optionId").and_then(|v| v.as_str()) {
685 if option_id == "proceed_once" {
686 selected_option = option_id.to_string();
687 break;
688 } else if option_id == "proceed_always" {
689 selected_option = option_id.to_string();
690 }
691 }
692 }
693
694 if selected_option == "proceed_once" && !options_array.is_empty() {
696 if let Some(first_option_id) =
697 options_array[0].get("optionId").and_then(|v| v.as_str())
698 {
699 selected_option = first_option_id.to_string();
700 }
701 }
702 }
703
704 json!({
705 "outcome": {
706 "outcome": "selected",
707 "optionId": selected_option
708 }
709 })
710 } else {
711 json!({
713 "outcome": {
714 "outcome": "cancelled"
715 }
716 })
717 };
718
719 if let Some(id) = request_id {
721 let response_message = json!({
722 "jsonrpc": "2.0",
723 "id": id,
724 "result": response
725 });
726 self.transport.send(&response_message).await?;
727 }
728
729 let outcome = response
730 .get("outcome")
731 .and_then(|o| o.get("outcome"))
732 .and_then(|o| o.as_str())
733 .unwrap_or("unknown");
734 tracing::debug!("Permission request for tool '{}': {}", tool_title, outcome);
735 Ok(())
736 }
737
738 async fn handle_session_update(
749 &mut self,
750 update_type: &str,
751 update: &serde_json::Map<String, Value>,
752 request_id: Option<u64>,
753 ) -> Result<()> {
754 match update_type {
755 "agent_message_chunk" => {
756 if let Some(content) = update.get("content") {
757 let text = match content {
758 Value::Object(obj) => {
759 if let Some(text_content) = obj.get("text").and_then(|v| v.as_str()) {
760 text_content.to_string()
761 } else {
762 "<unknown>".to_string()
763 }
764 }
765 _ => "<unknown>".to_string(),
766 };
767
768 let msg = Message::Assistant { content: text };
769 let _ = self.message_sender.send(msg);
770 }
771 }
772 "user_message_chunk" => {
773 if let Some(content) = update.get("content") {
774 let text = match content {
775 Value::Object(obj) => {
776 if let Some(text_content) = obj.get("text").and_then(|v| v.as_str()) {
777 text_content.to_string()
778 } else {
779 "<unknown>".to_string()
780 }
781 }
782 _ => "<unknown>".to_string(),
783 };
784
785 let msg = Message::User { content: text };
786 let _ = self.message_sender.send(msg);
787 }
788 }
789 "tool_call" => {
790 if let Some(tool_call) = update.get("toolCall") {
791 let id = tool_call
792 .get("id")
793 .and_then(|v| v.as_str())
794 .unwrap_or("")
795 .to_string();
796 let name = tool_call
797 .get("title")
798 .and_then(|v| v.as_str())
799 .unwrap_or("Unknown")
800 .to_string();
801 let status = tool_call
802 .get("status")
803 .and_then(|v| v.as_str())
804 .unwrap_or("unknown")
805 .to_string();
806
807 let msg = Message::ToolCall { id, name, status };
808 let _ = self.message_sender.send(msg);
809 }
810 }
811 "plan" => {
812 if let Some(entries) = update.get("entries").and_then(|v| v.as_array()) {
813 let entries: Vec<super::types::PlanEntry> = entries
814 .iter()
815 .filter_map(|entry| {
816 let content =
817 entry.get("content").and_then(|v| v.as_str())?.to_string();
818 let priority_str = entry
819 .get("priority")
820 .and_then(|v| v.as_str())
821 .unwrap_or("medium");
822 let status_str = entry
823 .get("status")
824 .and_then(|v| v.as_str())
825 .unwrap_or("pending");
826
827 let priority = match priority_str {
828 "high" => super::types::PlanPriority::High,
829 "medium" => super::types::PlanPriority::Medium,
830 "low" => super::types::PlanPriority::Low,
831 _ => super::types::PlanPriority::Medium,
832 };
833
834 let status = match status_str {
835 "pending" => super::types::PlanStatus::Pending,
836 "in_progress" => super::types::PlanStatus::InProgress,
837 "completed" => super::types::PlanStatus::Completed,
838 _ => super::types::PlanStatus::Pending,
839 };
840
841 Some(super::types::PlanEntry {
842 content,
843 priority,
844 status,
845 })
846 })
847 .collect();
848
849 let msg = Message::Plan { entries };
850 let _ = self.message_sender.send(msg);
851 }
852 }
853 "tool_call_update" => {
854 if let Some(id) = request_id {
856 let response = json!({
857 "jsonrpc": "2.0",
858 "id": id,
859 "result": null
860 });
861 self.transport.send(&response).await?;
862 }
863 }
864 "agent_thought_chunk" | "current_mode_update" | "available_commands_update" => {
865 }
867 _ => {
868 tracing::debug!("Unhandled session update type: {}", update_type);
869 }
870 }
871
872 if let Some(id) = request_id {
874 match update_type {
875 "tool_call_update" | "notifyTaskFinish" => {
876 let response = json!({
877 "jsonrpc": "2.0",
878 "id": id,
879 "result": null
880 });
881 self.transport.send(&response).await?;
882 }
883 _ => {}
884 }
885 }
886
887 Ok(())
888 }
889
890 pub async fn close(&mut self) -> Result<()> {
892 self.transport.close().await?;
893 Ok(())
894 }
895
896 pub fn is_connected(&self) -> bool {
901 self.transport.is_connected()
902 }
903}