kanata_state_machine/
tcp_server.rs

1use crate::Kanata;
2use crate::oskbd::*;
3
4#[cfg(feature = "tcp_server")]
5use kanata_tcp_protocol::*;
6use parking_lot::Mutex;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::sync::mpsc::SyncSender as Sender;
10
11#[cfg(feature = "tcp_server")]
12type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;
13#[cfg(feature = "tcp_server")]
14use kanata_parser::cfg::SimpleSExpr;
15#[cfg(feature = "tcp_server")]
16use std::io::Write;
17#[cfg(feature = "tcp_server")]
18use std::net::{TcpListener, TcpStream};
19
20#[cfg(feature = "tcp_server")]
21pub type Connections = Arc<Mutex<HashMap<String, TcpStream>>>;
22
23#[cfg(not(feature = "tcp_server"))]
24pub type Connections = ();
25
26#[cfg(feature = "tcp_server")]
27use kanata_parser::custom_action::FakeKeyAction;
28
29#[cfg(feature = "tcp_server")]
30fn send_response(
31    stream: &mut TcpStream,
32    response: ServerResponse,
33    connections: &Connections,
34    addr: &str,
35) -> bool {
36    if let Err(write_err) = stream.write_all(&response.as_bytes()) {
37        log::error!("stream write error: {write_err}");
38        connections.lock().remove(addr);
39        return false;
40    }
41    true
42}
43
44#[cfg(feature = "tcp_server")]
45fn to_action(val: FakeKeyActionMessage) -> FakeKeyAction {
46    match val {
47        FakeKeyActionMessage::Press => FakeKeyAction::Press,
48        FakeKeyActionMessage::Release => FakeKeyAction::Release,
49        FakeKeyActionMessage::Tap => FakeKeyAction::Tap,
50        FakeKeyActionMessage::Toggle => FakeKeyAction::Toggle,
51    }
52}
53
54#[cfg(feature = "tcp_server")]
55pub struct TcpServer {
56    pub address: SocketAddr,
57    pub connections: Connections,
58    pub wakeup_channel: Sender<KeyEvent>,
59}
60
61#[cfg(not(feature = "tcp_server"))]
62pub struct TcpServer {
63    pub connections: Connections,
64}
65
66impl TcpServer {
67    #[cfg(feature = "tcp_server")]
68    pub fn new(address: SocketAddr, wakeup_channel: Sender<KeyEvent>) -> Self {
69        Self {
70            address,
71            connections: Arc::new(Mutex::new(HashMap::default())),
72            wakeup_channel,
73        }
74    }
75
76    #[cfg(not(feature = "tcp_server"))]
77    pub fn new(_address: SocketAddr, _wakeup_channel: Sender<KeyEvent>) -> Self {
78        Self { connections: () }
79    }
80
81    #[cfg(feature = "tcp_server")]
82    pub fn start(&mut self, kanata: Arc<Mutex<Kanata>>) {
83        use kanata_parser::cfg::FAKE_KEY_ROW;
84
85        use crate::kanata::handle_fakekey_action;
86
87        let listener = TcpListener::bind(self.address).expect("TCP server starts");
88
89        let connections = self.connections.clone();
90        let wakeup_channel = self.wakeup_channel.clone();
91
92        std::thread::spawn(move || {
93            for stream in listener.incoming() {
94                match stream {
95                    Ok(mut stream) => {
96                        {
97                            let k = kanata.lock();
98                            log::info!(
99                                "new client connection, sending initial LayerChange event to inform them of current layer"
100                            );
101                            if let Err(e) = stream.write(
102                                &ServerMessage::LayerChange {
103                                    new: k.layer_info[k.layout.b().current_layer()].name.clone(),
104                                }
105                                .as_bytes(),
106                            ) {
107                                log::warn!("failed to write to stream, dropping it: {e:?}");
108                                continue;
109                            }
110                        }
111
112                        let addr = match stream.peer_addr() {
113                            Ok(addr) => addr.to_string(),
114                            Err(e) => {
115                                log::warn!("failed to get peer address, using fallback: {e:?}");
116                                format!("unknown_{}", std::ptr::addr_of!(stream) as usize)
117                            }
118                        };
119
120                        connections.lock().insert(
121                            addr.clone(),
122                            stream.try_clone().expect("stream is clonable"),
123                        );
124                        let reader = serde_json::Deserializer::from_reader(
125                            stream.try_clone().expect("stream is clonable"),
126                        )
127                        .into_iter::<ClientMessage>();
128
129                        log::info!("listening for incoming messages {addr}");
130
131                        let connections = connections.clone();
132                        let kanata = kanata.clone();
133                        let wakeup_channel = wakeup_channel.clone();
134                        std::thread::spawn(move || {
135                            for v in reader {
136                                match v {
137                                    Ok(event) => {
138                                        log::debug!("tcp server received command: {:?}", event);
139                                        match event {
140                                            ClientMessage::ChangeLayer { new } => {
141                                                kanata.lock().change_layer(new);
142                                            }
143                                            ClientMessage::RequestLayerNames {} => {
144                                                let msg = ServerMessage::LayerNames {
145                                                    names: kanata
146                                                        .lock()
147                                                        .layer_info
148                                                        .iter()
149                                                        .map(|info| info.name.clone())
150                                                        .collect::<Vec<_>>(),
151                                                };
152                                                match stream.write_all(&msg.as_bytes()) {
153                                                    Ok(_) => {}
154                                                    Err(err) => log::error!(
155                                                        "server could not send response: {err}"
156                                                    ),
157                                                }
158                                            }
159                                            ClientMessage::ActOnFakeKey { name, action } => {
160                                                let mut k = kanata.lock();
161                                                let index = match k.virtual_keys.get(&name) {
162                                                    Some(index) => Some(*index as u16),
163                                                    None => {
164                                                        if let Err(e) = stream.write_all(
165                                                            &ServerMessage::Error {
166                                                                msg: format!(
167                                                                "unknown virtual/fake key: {name}"
168                                                            ),
169                                                            }
170                                                            .as_bytes(),
171                                                        ) {
172                                                            log::error!("stream write error: {e}");
173                                                            connections.lock().remove(&addr);
174                                                            break;
175                                                        }
176                                                        continue;
177                                                    }
178                                                };
179                                                if let Some(index) = index {
180                                                    log::info!(
181                                                        "tcp server fake-key action: {name},{action:?}"
182                                                    );
183                                                    handle_fakekey_action(
184                                                        to_action(action),
185                                                        k.layout.bm(),
186                                                        FAKE_KEY_ROW,
187                                                        index,
188                                                    );
189                                                }
190                                                drop(k);
191                                            }
192                                            ClientMessage::SetMouse { x, y } => {
193                                                log::info!(
194                                                    "tcp server SetMouse action: x {x} y {y}"
195                                                );
196                                                match kanata.lock().kbd_out.set_mouse(x, y) {
197                                                    Ok(_) => {
198                                                        log::info!(
199                                                            "sucessfully did set mouse position to: x {x} y {y}"
200                                                        );
201                                                        // Optionally send a success message to the
202                                                        // client
203                                                    }
204                                                    Err(e) => {
205                                                        log::error!(
206                                                            "Failed to set mouse position: {}",
207                                                            e
208                                                        );
209                                                        // Implement any error handling logic here,
210                                                        // such as sending an error response to
211                                                        // the client
212                                                    }
213                                                }
214                                            }
215                                            ClientMessage::RequestCurrentLayerInfo {} => {
216                                                let mut k = kanata.lock();
217                                                let cur_layer = k.layout.bm().current_layer();
218                                                let msg = ServerMessage::CurrentLayerInfo {
219                                                    name: k.layer_info[cur_layer].name.clone(),
220                                                    cfg_text: k.layer_info[cur_layer]
221                                                        .cfg_text
222                                                        .clone(),
223                                                };
224                                                drop(k);
225                                                match stream.write_all(&msg.as_bytes()) {
226                                                    Ok(_) => {}
227                                                    Err(err) => log::error!(
228                                                        "Error writing response to RequestCurrentLayerInfo: {err}"
229                                                    ),
230                                                }
231                                            }
232                                            ClientMessage::RequestCurrentLayerName {} => {
233                                                let mut k = kanata.lock();
234                                                let cur_layer = k.layout.bm().current_layer();
235                                                let msg = ServerMessage::CurrentLayerName {
236                                                    name: k.layer_info[cur_layer].name.clone(),
237                                                };
238                                                drop(k);
239                                                match stream.write_all(&msg.as_bytes()) {
240                                                    Ok(_) => {}
241                                                    Err(err) => log::error!(
242                                                        "Error writing response to RequestCurrentLayerName: {err}"
243                                                    ),
244                                                }
245                                            }
246                                            // Handle reload commands with unified response protocol
247                                            reload_cmd @ (ClientMessage::Reload {}
248                                            | ClientMessage::ReloadNext {}
249                                            | ClientMessage::ReloadPrev {}
250                                            | ClientMessage::ReloadNum { .. }
251                                            | ClientMessage::ReloadFile { .. }) => {
252                                                // Log specific action type
253                                                match &reload_cmd {
254                                                    ClientMessage::Reload {} => {
255                                                        log::info!("tcp server Reload action")
256                                                    }
257                                                    ClientMessage::ReloadNext {} => {
258                                                        log::info!("tcp server ReloadNext action")
259                                                    }
260                                                    ClientMessage::ReloadPrev {} => {
261                                                        log::info!("tcp server ReloadPrev action")
262                                                    }
263                                                    ClientMessage::ReloadNum { index } => {
264                                                        log::info!(
265                                                            "tcp server ReloadNum action: index {index}"
266                                                        )
267                                                    }
268                                                    ClientMessage::ReloadFile { path } => {
269                                                        log::info!(
270                                                            "tcp server ReloadFile action: path {path}"
271                                                        )
272                                                    }
273                                                    _ => unreachable!(),
274                                                }
275
276                                                let response = match kanata
277                                                    .lock()
278                                                    .handle_client_command(reload_cmd)
279                                                {
280                                                    Ok(_) => ServerResponse::Ok,
281                                                    Err(e) => ServerResponse::Error {
282                                                        msg: format!("{e}"),
283                                                    },
284                                                };
285                                                if !send_response(
286                                                    &mut stream,
287                                                    response,
288                                                    &connections,
289                                                    &addr,
290                                                ) {
291                                                    break;
292                                                }
293                                            }
294                                        }
295                                        use kanata_parser::keys::*;
296                                        wakeup_channel
297                                            .send(KeyEvent {
298                                                code: OsCode::KEY_RESERVED,
299                                                value: KeyValue::WakeUp,
300                                            })
301                                            .expect("write key event");
302                                    }
303                                    Err(e) => {
304                                        log::warn!(
305                                            "client sent an invalid message, disconnecting them. Err: {e:?}"
306                                        );
307                                        // Send proper error response for malformed JSON
308                                        let response = ServerResponse::Error {
309                                            msg: format!("Failed to deserialize command: {e}"),
310                                        };
311                                        let _ = stream.write_all(&response.as_bytes());
312                                        connections.lock().remove(&addr);
313                                        break;
314                                    }
315                                }
316                            }
317                        });
318                    }
319                    Err(_) => log::error!("not able to accept client connection"),
320                }
321            }
322        });
323    }
324
325    #[cfg(not(feature = "tcp_server"))]
326    pub fn start(&mut self, _kanata: Arc<Mutex<Kanata>>) {}
327}
328
329#[cfg(feature = "tcp_server")]
330pub fn simple_sexpr_to_json_array(exprs: &[SimpleSExpr]) -> serde_json::Value {
331    let mut result = Vec::new();
332
333    for expr in exprs.iter() {
334        match expr {
335            SimpleSExpr::Atom(s) => result.push(serde_json::Value::String(s.clone())),
336            SimpleSExpr::List(list) => result.push(simple_sexpr_to_json_array(list)),
337        }
338    }
339
340    serde_json::Value::Array(result)
341}