1use crate::acp_protocol::ACPProtocol;
7use crate::error::{IFlowError, Result};
8use crate::logger::MessageLogger;
9use crate::process_manager::IFlowProcessManager;
10use crate::types::*;
11use crate::websocket_transport::WebSocketTransport;
12use agent_client_protocol::{
13 Agent, Client, ClientSideConnection, ContentBlock, SessionId, SessionUpdate,
14};
15use futures::{FutureExt, pin_mut, stream::Stream};
16use serde_json;
17use std::path::Path;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::task::{Context, Poll};
21
22use tokio::sync::{Mutex, mpsc};
24use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
25use tracing::debug;
26
27enum Connection {
29 Stdio {
31 acp_client: ClientSideConnection,
32 process_manager: Option<IFlowProcessManager>,
33 session_id: Option<SessionId>,
34 initialized: bool,
35 },
36 WebSocket {
38 acp_protocol: ACPProtocol,
39 session_id: Option<String>,
40 process_manager: Option<IFlowProcessManager>,
41 },
42}
43
44pub struct IFlowClient {
49 options: IFlowOptions,
50 message_receiver: Arc<Mutex<mpsc::UnboundedReceiver<Message>>>,
51 message_sender: mpsc::UnboundedSender<Message>,
52 connected: Arc<Mutex<bool>>,
53 connection: Option<Connection>,
54 logger: Option<MessageLogger>,
55}
56
57pub struct MessageStream {
62 receiver: Arc<Mutex<mpsc::UnboundedReceiver<Message>>>,
63}
64
65impl Stream for MessageStream {
66 type Item = Message;
67
68 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
69 let mut receiver = match self.receiver.try_lock() {
70 Ok(guard) => guard,
71 Err(_) => {
72 cx.waker().wake_by_ref();
73 return Poll::Pending;
74 }
75 };
76
77 match receiver.try_recv() {
79 Ok(msg) => Poll::Ready(Some(msg)),
80 Err(mpsc::error::TryRecvError::Empty) => {
81 let recv_future = receiver.recv();
83 pin_mut!(recv_future);
84 match recv_future.poll_unpin(cx) {
85 Poll::Ready(msg) => Poll::Ready(msg),
86 Poll::Pending => Poll::Pending,
87 }
88 }
89 Err(mpsc::error::TryRecvError::Disconnected) => Poll::Ready(None),
90 }
91 }
92}
93
94struct IFlowClientHandler {
96 message_sender: mpsc::UnboundedSender<Message>,
97 logger: Option<MessageLogger>,
98}
99
100#[async_trait::async_trait(?Send)]
101impl Client for IFlowClientHandler {
102 async fn request_permission(
103 &self,
104 _args: agent_client_protocol::RequestPermissionRequest,
105 ) -> anyhow::Result<
106 agent_client_protocol::RequestPermissionResponse,
107 agent_client_protocol::Error,
108 > {
109 Ok(agent_client_protocol::RequestPermissionResponse {
111 outcome: agent_client_protocol::RequestPermissionOutcome::Cancelled,
112 meta: None,
113 })
114 }
115
116 async fn write_text_file(
117 &self,
118 _args: agent_client_protocol::WriteTextFileRequest,
119 ) -> anyhow::Result<agent_client_protocol::WriteTextFileResponse, agent_client_protocol::Error>
120 {
121 Err(agent_client_protocol::Error::method_not_found())
122 }
123
124 async fn read_text_file(
125 &self,
126 _args: agent_client_protocol::ReadTextFileRequest,
127 ) -> anyhow::Result<agent_client_protocol::ReadTextFileResponse, agent_client_protocol::Error>
128 {
129 Err(agent_client_protocol::Error::method_not_found())
130 }
131
132 async fn create_terminal(
133 &self,
134 _args: agent_client_protocol::CreateTerminalRequest,
135 ) -> anyhow::Result<agent_client_protocol::CreateTerminalResponse, agent_client_protocol::Error>
136 {
137 Err(agent_client_protocol::Error::method_not_found())
138 }
139
140 async fn terminal_output(
141 &self,
142 _args: agent_client_protocol::TerminalOutputRequest,
143 ) -> anyhow::Result<agent_client_protocol::TerminalOutputResponse, agent_client_protocol::Error>
144 {
145 Err(agent_client_protocol::Error::method_not_found())
146 }
147
148 async fn release_terminal(
149 &self,
150 _args: agent_client_protocol::ReleaseTerminalRequest,
151 ) -> anyhow::Result<agent_client_protocol::ReleaseTerminalResponse, agent_client_protocol::Error>
152 {
153 Err(agent_client_protocol::Error::method_not_found())
154 }
155
156 async fn wait_for_terminal_exit(
157 &self,
158 _args: agent_client_protocol::WaitForTerminalExitRequest,
159 ) -> anyhow::Result<
160 agent_client_protocol::WaitForTerminalExitResponse,
161 agent_client_protocol::Error,
162 > {
163 Err(agent_client_protocol::Error::method_not_found())
164 }
165
166 async fn kill_terminal_command(
167 &self,
168 _args: agent_client_protocol::KillTerminalCommandRequest,
169 ) -> anyhow::Result<
170 agent_client_protocol::KillTerminalCommandResponse,
171 agent_client_protocol::Error,
172 > {
173 Err(agent_client_protocol::Error::method_not_found())
174 }
175
176 async fn session_notification(
177 &self,
178 args: agent_client_protocol::SessionNotification,
179 ) -> anyhow::Result<(), agent_client_protocol::Error> {
180 match args.update {
181 SessionUpdate::AgentMessageChunk { content } => {
182 let text = match content {
183 ContentBlock::Text(text_content) => text_content.text,
184 ContentBlock::Image(_) => "<image>".into(),
185 ContentBlock::Audio(_) => "<audio>".into(),
186 ContentBlock::ResourceLink(resource_link) => resource_link.uri,
187 ContentBlock::Resource(_) => "<resource>".into(),
188 };
189 let msg = Message::Assistant { content: text };
190 let _ = self.message_sender.send(msg.clone());
191
192 if let Some(logger) = &self.logger {
194 let _ = logger.log_message(&msg).await;
195 }
196 }
197 SessionUpdate::UserMessageChunk { content } => {
198 let text = match content {
199 ContentBlock::Text(text_content) => text_content.text,
200 ContentBlock::Image(_) => "<image>".into(),
201 ContentBlock::Audio(_) => "<audio>".into(),
202 ContentBlock::ResourceLink(resource_link) => resource_link.uri,
203 ContentBlock::Resource(_) => "<resource>".into(),
204 };
205 let msg = Message::User { content: text };
206 let _ = self.message_sender.send(msg.clone());
207
208 if let Some(logger) = &self.logger {
210 let _ = logger.log_message(&msg).await;
211 }
212 }
213 SessionUpdate::ToolCall(tool_call) => {
214 let msg = Message::ToolCall {
215 id: tool_call.id.0.to_string(),
216 name: tool_call.title.clone(),
217 status: format!("{:?}", tool_call.status),
218 };
219 let _ = self.message_sender.send(msg.clone());
220
221 if let Some(logger) = &self.logger {
223 let _ = logger.log_message(&msg).await;
224 }
225 }
226 SessionUpdate::Plan(plan) => {
227 let entries = plan
228 .entries
229 .into_iter()
230 .map(|entry| {
231 super::types::PlanEntry {
233 content: entry.content,
234 priority: match entry.priority {
235 agent_client_protocol::PlanEntryPriority::High => {
236 super::types::PlanPriority::High
237 }
238 agent_client_protocol::PlanEntryPriority::Medium => {
239 super::types::PlanPriority::Medium
240 }
241 agent_client_protocol::PlanEntryPriority::Low => {
242 super::types::PlanPriority::Low
243 }
244 },
245 status: match entry.status {
246 agent_client_protocol::PlanEntryStatus::Pending => {
247 super::types::PlanStatus::Pending
248 }
249 agent_client_protocol::PlanEntryStatus::InProgress => {
250 super::types::PlanStatus::InProgress
251 }
252 agent_client_protocol::PlanEntryStatus::Completed => {
253 super::types::PlanStatus::Completed
254 }
255 },
256 }
257 })
258 .collect();
259
260 let msg = Message::Plan { entries };
261 let _ = self.message_sender.send(msg.clone());
262
263 if let Some(logger) = &self.logger {
265 let _ = logger.log_message(&msg).await;
266 }
267 }
268 SessionUpdate::AgentThoughtChunk { .. }
269 | SessionUpdate::ToolCallUpdate(_)
270 | SessionUpdate::CurrentModeUpdate { .. }
271 | SessionUpdate::AvailableCommandsUpdate { .. } => {
272 }
274 }
275 Ok(())
276 }
277
278 async fn ext_method(
279 &self,
280 _args: agent_client_protocol::ExtRequest,
281 ) -> anyhow::Result<agent_client_protocol::ExtResponse, agent_client_protocol::Error> {
282 Err(agent_client_protocol::Error::method_not_found())
283 }
284
285 async fn ext_notification(
286 &self,
287 _args: agent_client_protocol::ExtNotification,
288 ) -> anyhow::Result<(), agent_client_protocol::Error> {
289 Err(agent_client_protocol::Error::method_not_found())
290 }
291}
292
293impl IFlowClient {
294 pub fn new(options: Option<IFlowOptions>) -> Self {
302 let options = options.unwrap_or_default();
303 let (sender, receiver) = mpsc::unbounded_channel();
304
305 let logger = if options.logging.enabled {
307 MessageLogger::new(options.logging.logger_config.clone()).ok()
308 } else {
309 None
310 };
311
312 Self {
313 options,
314 message_receiver: Arc::new(Mutex::new(receiver)),
315 message_sender: sender,
316 connected: Arc::new(Mutex::new(false)),
317 connection: None,
318 logger,
319 }
320 }
321
322 pub async fn connect(&mut self) -> Result<()> {
331 if *self.connected.lock().await {
332 tracing::warn!("Already connected to iFlow");
333 return Ok(());
334 }
335
336 if self.options.websocket.is_some() {
338 self.connect_websocket().await
339 } else {
340 self.connect_stdio().await
341 }
342 }
343
344 async fn connect_stdio(&mut self) -> Result<()> {
346 debug!("Connecting to iFlow via stdio");
347
348 let mut process_manager = if self.options.process.auto_start {
350 let port = self.options.process.start_port.unwrap_or(8090);
352 let mut pm = IFlowProcessManager::new(port, self.options.process.debug);
353 let _url = pm.start(false).await?; debug!("iFlow process started");
355 Some(pm)
356 } else {
357 None
358 };
359
360 let stdin = process_manager
362 .as_mut()
363 .and_then(|pm| pm.take_stdin())
364 .ok_or_else(|| IFlowError::Connection("Failed to get stdin".to_string()))?;
365
366 let stdout = process_manager
367 .as_mut()
368 .and_then(|pm| pm.take_stdout())
369 .ok_or_else(|| IFlowError::Connection("Failed to get stdout".to_string()))?;
370
371 let handler = IFlowClientHandler {
373 message_sender: self.message_sender.clone(),
374 logger: self.logger.clone(),
375 };
376
377 let (conn, handle_io) =
378 ClientSideConnection::new(handler, stdin.compat_write(), stdout.compat(), |fut| {
379 tokio::task::spawn_local(fut);
380 });
381
382 tokio::task::spawn_local(handle_io);
384
385 self.connection = Some(Connection::Stdio {
387 acp_client: conn,
388 process_manager,
389 session_id: None,
390 initialized: false,
391 });
392
393 *self.connected.lock().await = true;
394 debug!("Connected to iFlow via stdio");
395
396 Ok(())
397 }
398
399 async fn connect_websocket(&mut self) -> Result<()> {
401 debug!("Connecting to iFlow via WebSocket");
402
403 let websocket_config = self.options.websocket.as_ref().ok_or_else(|| {
404 IFlowError::Connection("WebSocket configuration not provided".to_string())
405 })?;
406
407 let mut process_manager_to_keep: Option<IFlowProcessManager> = None;
409
410 let final_url = if self.options.process.auto_start {
413 if let Some(url) = &websocket_config.url {
414 if url.starts_with("ws://localhost:") {
416 debug!(
417 "iFlow auto-start enabled with provided URL, checking if iFlow is already running..."
418 );
419
420 let mut test_transport =
422 WebSocketTransport::new(url.clone(), self.options.timeout);
423 match test_transport.connect().await {
424 Ok(_) => {
425 let _ = test_transport.close().await;
427 debug!("Connected to existing iFlow process at {}", url);
428 url.clone()
429 }
430 Err(e) => {
431 let port = url
434 .split(':')
435 .nth(2)
436 .and_then(|port_str| port_str.split('/').next())
437 .and_then(|port_str| port_str.parse::<u16>().ok())
438 .unwrap_or(8090);
439
440 if IFlowProcessManager::is_port_listening(port) {
442 debug!(
448 "iFlow appears to be running on port {}, but connection failed: {}",
449 port, e
450 );
451 debug!(
452 "Since iFlow is running on the specified port, we won't start a new process. Please check if the existing iFlow instance is configured correctly for WebSocket connections."
453 );
454 return Err(IFlowError::Connection(format!(
455 "Failed to connect to existing iFlow process at {}: {}. iFlow appears to be running on port {}, but connection could not be established.",
456 url, e, port
457 )));
458 } else {
459 debug!("iFlow not running on port {}, starting process", port);
461 let mut pm =
462 IFlowProcessManager::new(port, self.options.process.debug);
463 let iflow_url = pm.start(true).await?.ok_or_else(|| {
464 IFlowError::Connection(
465 "Failed to start iFlow with WebSocket".to_string(),
466 )
467 })?;
468 debug!("Started iFlow process at {}", iflow_url);
469
470 process_manager_to_keep = Some(pm);
472
473 iflow_url
474 }
475 }
476 }
477 } else {
478 debug!("Using manual start mode or non-local WebSocket URL");
480 url.clone()
481 }
482 } else {
483 debug!("iFlow auto-start enabled with auto-generated URL...");
485 let port = self.options.process.start_port.unwrap_or(8090);
486 let mut pm = IFlowProcessManager::new(port, self.options.process.debug);
487 let iflow_url = pm.start(true).await?.ok_or_else(|| {
488 IFlowError::Connection("Failed to start iFlow with WebSocket".to_string())
489 })?;
490 debug!("Started iFlow process at {}", iflow_url);
491
492 process_manager_to_keep = Some(pm);
494
495 iflow_url
496 }
497 } else {
498 let url = websocket_config.url.as_ref().ok_or_else(|| {
500 IFlowError::Connection(
501 "WebSocket URL must be provided in manual start mode".to_string(),
502 )
503 })?;
504 debug!("Using manual start mode with WebSocket URL: {}", url);
505 url.clone()
506 };
507
508 let mut transport = WebSocketTransport::new(final_url.clone(), self.options.timeout);
510
511 let mut connect_attempts = 0;
513
514 while connect_attempts < websocket_config.reconnect_attempts {
515 match transport.connect().await {
516 Ok(_) => {
517 debug!("Successfully connected to WebSocket at {}", final_url);
518 break;
519 }
520 Err(e) => {
521 connect_attempts += 1;
522 tracing::warn!(
523 "Failed to connect to WebSocket (attempt {}): {}",
524 connect_attempts,
525 e
526 );
527
528 if connect_attempts >= websocket_config.reconnect_attempts {
529 return Err(IFlowError::Connection(format!(
530 "Failed to connect to WebSocket after {} attempts: {}",
531 websocket_config.reconnect_attempts, e
532 )));
533 }
534
535 tracing::debug!(
537 "Waiting {:?} before retry...",
538 websocket_config.reconnect_interval
539 );
540 tokio::time::sleep(websocket_config.reconnect_interval).await;
541 }
542 }
543 }
544
545 let mut acp_protocol =
547 ACPProtocol::new(transport, self.message_sender.clone(), self.options.timeout);
548 acp_protocol.set_permission_mode(self.options.permission_mode);
549
550 self.connection = Some(Connection::WebSocket {
552 acp_protocol,
553 session_id: None,
554 process_manager: process_manager_to_keep,
555 });
556
557 *self.connected.lock().await = true;
558 debug!("Connected to iFlow via WebSocket");
559
560 Ok(())
561 }
562
563 pub async fn send_message(&mut self, text: &str, _files: Option<Vec<&Path>>) -> Result<()> {
577 if !*self.connected.lock().await {
578 return Err(IFlowError::NotConnected);
579 }
580
581 let is_websocket = matches!(self.connection, Some(Connection::WebSocket { .. }));
582
583 if is_websocket {
584 if let Some(Connection::WebSocket {
586 mut acp_protocol,
587 mut session_id,
588 process_manager,
589 }) = self.connection.take()
590 {
591 let pm = process_manager;
592 let result = self
593 .send_message_websocket(&mut acp_protocol, &mut session_id, text)
594 .await;
595 self.connection = Some(Connection::WebSocket {
596 acp_protocol,
597 session_id,
598 process_manager: pm,
599 });
600 result
601 } else {
602 Err(IFlowError::NotConnected)
603 }
604 } else {
605 if let Some(Connection::Stdio {
607 acp_client,
608 process_manager,
609 mut session_id,
610 mut initialized,
611 }) = self.connection.take()
612 {
613 let result = self
614 .send_message_stdio(&acp_client, &mut session_id, &mut initialized, text)
615 .await;
616 self.connection = Some(Connection::Stdio {
617 acp_client,
618 process_manager,
619 session_id,
620 initialized,
621 });
622 result
623 } else {
624 Err(IFlowError::NotConnected)
625 }
626 }
627 }
628
629 async fn send_message_stdio(
631 &self,
632 client: &ClientSideConnection,
633 session_id: &mut Option<SessionId>,
634 initialized: &mut bool,
635 text: &str,
636 ) -> Result<()> {
637 tracing::debug!("send_message_stdio called with text: {}", text);
638
639 if !*initialized {
641 tracing::debug!("Initializing connection...");
642 client
643 .initialize(agent_client_protocol::InitializeRequest {
644 protocol_version: agent_client_protocol::V1,
645 client_capabilities: agent_client_protocol::ClientCapabilities::default(),
646 meta: None,
647 })
648 .await
649 .map_err(|e| IFlowError::Connection(format!("Failed to initialize: {}", e)))?;
650
651 *initialized = true;
652 debug!("Initialized stdio connection");
653 }
654
655 if session_id.is_none() {
657 tracing::debug!("Creating new session...");
658 let session_request = agent_client_protocol::NewSessionRequest {
659 mcp_servers: self.options.mcp_servers.clone(),
660 cwd: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
661 meta: None,
662 };
663 tracing::debug!("Session request: {:?}", session_request);
664
665 let session_response = client.new_session(session_request).await.map_err(|e| {
666 tracing::error!("Failed to create session: {}", e);
667 IFlowError::Connection(format!("Failed to create session: {}", e))
668 })?;
669
670 *session_id = Some(session_response.session_id);
671 debug!("Created new session: {:?}", session_id);
672 }
673
674 let current_session_id = session_id.as_ref().unwrap();
676
677 tracing::debug!("Sending prompt to session: {:?}", current_session_id);
679 let prompt_response = client
680 .prompt(agent_client_protocol::PromptRequest {
681 session_id: current_session_id.clone(),
682 prompt: vec![agent_client_protocol::ContentBlock::Text(
683 agent_client_protocol::TextContent {
684 text: text.to_string(),
685 annotations: None,
686 meta: None,
687 },
688 )],
689 meta: None,
690 })
691 .await
692 .map_err(|e| {
693 tracing::error!("Failed to send message: {}", e);
694 IFlowError::Connection(format!("Failed to send message: {}", e))
695 })?;
696
697 tracing::debug!(
698 "Prompt response received, stop reason: {:?}",
699 prompt_response.stop_reason
700 );
701
702 let message = Message::TaskFinish {
704 reason: Some(format!("{:?}", prompt_response.stop_reason)),
705 };
706
707 self.message_sender.send(message).map_err(|e| {
708 tracing::error!("Failed to send task finish message: {}", e);
709 IFlowError::Connection("Message channel closed".to_string())
710 })?;
711
712 debug!("Sent message to iFlow via stdio: {}", text);
713 Ok(())
714 }
715
716 async fn send_message_websocket(
718 &mut self,
719 protocol: &mut ACPProtocol,
720 session_id: &mut Option<String>,
721 text: &str,
722 ) -> Result<()> {
723 if !protocol.is_initialized() {
725 tracing::debug!("Initializing WebSocket protocol...");
726 protocol.initialize(&self.options).await.map_err(|e| {
727 tracing::error!("Failed to initialize protocol: {}", e);
728 e
729 })?;
730
731 if !protocol.is_authenticated() {
733 tracing::debug!("Authenticating...");
734 if let Some(method_id) = &self.options.auth_method_id {
735 protocol.authenticate(method_id, None).await.map_err(|e| {
736 tracing::error!("Authentication failed with method {}: {}", method_id, e);
737 e
738 })?;
739 } else {
740 protocol.authenticate("iflow", None).await.map_err(|e| {
742 tracing::error!("Default authentication failed: {}", e);
743 e
744 })?;
745 }
746 }
747
748 tracing::debug!("Creating new session...");
750 let current_dir = std::env::current_dir()
751 .unwrap_or_else(|_| std::path::PathBuf::from("."))
752 .to_string_lossy()
753 .to_string();
754
755 let mcp_servers: Vec<serde_json::Value> = self
757 .options
758 .mcp_servers
759 .iter()
760 .map(|server| {
761 serde_json::json!(server)
764 })
765 .collect();
766
767 let new_session_id = protocol
768 .create_session(¤t_dir, mcp_servers)
769 .await
770 .map_err(|e| {
771 tracing::error!("Failed to create session: {}", e);
772 e
773 })?;
774 *session_id = Some(new_session_id);
775 tracing::debug!("Session created successfully");
776 }
777
778 let current_session_id = session_id
780 .as_ref()
781 .ok_or_else(|| IFlowError::Connection("No session available".to_string()))?;
782
783 tracing::debug!("Sending prompt to session: {}", current_session_id);
785 let _request_id = protocol
786 .send_prompt(current_session_id, text)
787 .await
788 .map_err(|e| {
789 tracing::error!("Failed to send prompt: {}", e);
790 e
791 })?;
792
793 debug!("Sent message to iFlow: {}", text);
794 Ok(())
795 }
796
797 pub async fn interrupt(&self) -> Result<()> {
806 if !*self.connected.lock().await {
807 return Err(IFlowError::NotConnected);
808 }
809
810 let message = Message::TaskFinish {
811 reason: Some("interrupted".to_string()),
812 };
813
814 self.message_sender
815 .send(message)
816 .map_err(|_| IFlowError::Connection("Message channel closed".to_string()))?;
817 Ok(())
818 }
819
820 pub fn messages(&self) -> MessageStream {
827 MessageStream {
828 receiver: self.message_receiver.clone(),
829 }
830 }
831
832 pub async fn receive_message(&self) -> Result<Option<Message>> {
841 let mut receiver = self.message_receiver.lock().await;
842 Ok(receiver.recv().await)
843 }
844
845 pub async fn disconnect(&mut self) -> Result<()> {
854 *self.connected.lock().await = false;
855
856 if let Some(connection) = self.connection.take() {
858 match connection {
859 Connection::Stdio {
860 acp_client,
861 mut process_manager,
862 session_id: _,
863 initialized: _,
864 } => {
865 drop(acp_client);
867
868 if let Some(mut pm) = process_manager.take() {
870 pm.stop().await?;
871 }
872
873 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
875 }
876 Connection::WebSocket {
877 mut acp_protocol,
878 mut process_manager,
879 session_id: _,
880 } => {
881 let _ = acp_protocol.close().await;
882 if let Some(mut pm) = process_manager.take() {
884 pm.stop().await?;
885 }
886 }
887 }
888 }
889
890 debug!("Disconnected from iFlow");
891 Ok(())
892 }
893}
894
895impl Drop for IFlowClient {
896 fn drop(&mut self) {
897 if let Ok(mut connected) = self.connected.try_lock() {
899 *connected = false;
900 }
901 }
902}
903
904pub use serde_json::Value as JsonValue;