1use std::io;
14use std::sync::atomic::AtomicBool;
15use std::sync::Arc;
16
17use crate::server::ipc::{ClientConnection, SocketPaths};
18use crate::server::protocol::{
19 ClientControl, ClientHello, ServerControl, TermSize, PROTOCOL_VERSION,
20};
21
22#[cfg(unix)]
23mod relay_unix;
24#[cfg(windows)]
25mod relay_windows;
26
27pub struct ClientConfig {
29 pub socket_paths: SocketPaths,
31 pub term_size: TermSize,
33}
34
35#[derive(Debug)]
37pub enum ClientExitReason {
38 ServerQuit,
40 Detached,
42 VersionMismatch { server_version: String },
44 Error(io::Error),
46}
47
48pub fn run_client(config: ClientConfig) -> io::Result<ClientExitReason> {
58 let conn = ClientConnection::connect(&config.socket_paths)?;
59 run_client_with_connection(config, conn)
60}
61
62pub fn run_client_with_connection(
67 config: ClientConfig,
68 conn: ClientConnection,
69) -> io::Result<ClientExitReason> {
70 let hello = ClientHello::new(config.term_size);
72 let hello_json = serde_json::to_string(&ClientControl::Hello(hello))
73 .map_err(|e| io::Error::other(e.to_string()))?;
74 conn.write_control(&hello_json)?;
75
76 let response = conn
78 .read_control()?
79 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "Server closed connection"))?;
80
81 let server_msg: ServerControl =
82 serde_json::from_str(&response).map_err(|e| io::Error::other(e.to_string()))?;
83
84 match server_msg {
85 ServerControl::Hello(server_hello) => {
86 if server_hello.protocol_version != PROTOCOL_VERSION {
87 return Ok(ClientExitReason::VersionMismatch {
88 server_version: server_hello.server_version,
89 });
90 }
91 tracing::info!(
92 "Connected to session '{}' (server {})",
93 server_hello.session_id,
94 server_hello.server_version
95 );
96 }
97 ServerControl::VersionMismatch(mismatch) => {
98 return Ok(ClientExitReason::VersionMismatch {
99 server_version: mismatch.server_version,
100 });
101 }
102 ServerControl::Error { message } => {
103 return Err(io::Error::other(format!("Server error: {}", message)));
104 }
105 _ => {
106 return Err(io::Error::other("Unexpected server response"));
107 }
108 }
109
110 run_client_relay(conn)
111}
112
113pub fn run_client_relay(
118 #[allow(unused_mut)] mut conn: ClientConnection,
119) -> io::Result<ClientExitReason> {
120 #[cfg(not(windows))]
124 conn.set_data_nonblocking(true)?;
125
126 let resize_flag = Arc::new(AtomicBool::new(false));
128 #[cfg(unix)]
129 relay_unix::setup_resize_handler(resize_flag.clone())?;
130
131 #[cfg(unix)]
133 return relay_unix::relay_loop(&mut conn, resize_flag);
134
135 #[cfg(windows)]
136 return relay_windows::relay_loop(&mut conn);
137}
138
139pub fn get_terminal_size() -> io::Result<TermSize> {
141 #[cfg(unix)]
142 {
143 let mut size: libc::winsize = unsafe { std::mem::zeroed() };
144 let result = unsafe { libc::ioctl(libc::STDOUT_FILENO, libc::TIOCGWINSZ, &mut size) };
145 if result == -1 {
146 return Err(io::Error::last_os_error());
147 }
148 Ok(TermSize::new(size.ws_col, size.ws_row))
149 }
150
151 #[cfg(windows)]
152 {
153 use windows_sys::Win32::System::Console::{
154 GetConsoleScreenBufferInfo, GetStdHandle, CONSOLE_SCREEN_BUFFER_INFO, STD_OUTPUT_HANDLE,
155 };
156
157 unsafe {
158 let handle = GetStdHandle(STD_OUTPUT_HANDLE);
159 let mut info: CONSOLE_SCREEN_BUFFER_INFO = std::mem::zeroed();
160 if GetConsoleScreenBufferInfo(handle, &mut info) == 0 {
161 return Err(io::Error::last_os_error());
162 }
163 let cols = (info.srWindow.Right - info.srWindow.Left + 1) as u16;
164 let rows = (info.srWindow.Bottom - info.srWindow.Top + 1) as u16;
165 Ok(TermSize::new(cols, rows))
166 }
167 }
168}