ratel_rust/
shell.rs

1use anyhow::{Context, Result};
2use std::sync::Arc;
3use tokio::io::AsyncWriteExt;
4use tokio::sync::Mutex;
5
6use crate::model::AuthInfo;
7use crate::network::{Connection, NetType, TcpConnection, WebSocketConnection};
8use crate::util::CLEAN_LINE;
9use tokio::sync::mpsc;
10
11const IS_START: &str = "INTERACTIVE_SIGNAL_START";
12const IS_STOP: &str = "INTERACTIVE_SIGNAL_STOP";
13
14pub enum ShellConnection {
15    Tcp(TcpConnection),
16    WebSocket(WebSocketConnection),
17}
18
19impl ShellConnection {
20    async fn send(&mut self, data: &[u8]) -> Result<()> {
21        match self {
22            ShellConnection::Tcp(conn) => conn.send(data).await,
23            ShellConnection::WebSocket(conn) => conn.send(data).await,
24        }
25    }
26
27    async fn receive(&mut self) -> Result<Option<Vec<u8>>> {
28        match self {
29            ShellConnection::Tcp(conn) => conn.receive().await,
30            ShellConnection::WebSocket(conn) => conn.receive().await,
31        }
32    }
33}
34
35pub struct Shell {
36    addr: String,
37    name: String,
38    id: i64,
39    score: i64,
40    conn: Option<ShellConnection>,
41    is_active: Arc<Mutex<bool>>,
42}
43
44impl Shell {
45    pub fn new(addr: String, name: String) -> Self {
46        Self {
47            addr,
48            name,
49            id: std::time::SystemTime::now()
50                .duration_since(std::time::UNIX_EPOCH)
51                .unwrap()
52                .as_nanos() as i64,
53            score: 100,
54            conn: None,
55            is_active: Arc::new(Mutex::new(false)),
56        }
57    }
58
59    fn detect_net_type(&self) -> NetType {
60        if self.addr.ends_with("9998") {
61            NetType::WebSocket
62        } else {
63            NetType::Tcp
64        }
65    }
66
67    async fn connect(&mut self) -> Result<()> {
68        let net_type = self.detect_net_type();
69
70        self.conn = Some(match net_type {
71            NetType::Tcp => {
72                let conn = TcpConnection::connect(&self.addr).await?;
73                ShellConnection::Tcp(conn)
74            }
75            NetType::WebSocket => {
76                let conn = WebSocketConnection::connect(&self.addr).await?;
77                ShellConnection::WebSocket(conn)
78            }
79        });
80
81        Ok(())
82    }
83
84    async fn auth(&mut self) -> Result<()> {
85        let auth_info = AuthInfo {
86            id: self.id,
87            name: self.name.clone(),
88            score: self.score,
89        };
90
91        // Send auth info directly as JSON with 4-byte big-endian length prefix
92        let json_data = serde_json::to_vec(&auth_info)?;
93        let len_bytes = (json_data.len() as u32).to_be_bytes();
94
95        let mut data = Vec::new();
96        data.extend_from_slice(&len_bytes);
97        data.extend_from_slice(&json_data);
98
99        if let Some(conn) = self.conn.as_mut() {
100            conn.send(&data).await?;
101        }
102
103        Ok(())
104    }
105
106    fn get_prompt(&self) -> String {
107        format!(
108            "{}[{}@ratel {}]# ",
109            CLEAN_LINE,
110            self.name.to_lowercase(),
111            "~"
112        )
113    }
114
115    async fn print(text: &str) {
116        print!("{}", text);
117        use std::io::{self, Write};
118        io::stdout().flush().ok();
119    }
120
121    pub async fn start(&mut self) -> Result<()> {
122        self.connect().await.context("Failed to connect")?;
123        self.auth().await.context("Failed to authenticate")?;
124
125        let is_active = self.is_active.clone();
126        let conn = Arc::new(Mutex::new(self.conn.take()));
127        let prompt = self.get_prompt();
128
129        // Create channel for stdin -> network communication
130        let (tx, mut rx) = mpsc::unbounded_channel::<Vec<u8>>();
131
132        let is_active_stdin = is_active.clone();
133        let prompt_stdin = prompt.clone();
134
135        // Spawn stdin handler
136        tokio::spawn(async move {
137            use tokio::io::{AsyncBufReadExt, BufReader};
138
139            let stdin = tokio::io::stdin();
140            let mut reader = BufReader::new(stdin);
141
142            loop {
143                let mut line = String::new();
144                match reader.read_line(&mut line).await {
145                    Ok(0) => break,
146                    Ok(_) => {
147                        let line = line.trim();
148                        let active = *is_active_stdin.lock().await;
149
150                        if active {
151                            print!("{}{}", CLEAN_LINE, prompt_stdin);
152                            tokio::io::stdout().flush().await.ok();
153
154                            // Create packet with 4-byte big-endian length prefix
155                            let data = line.as_bytes();
156                            let len_bytes = (data.len() as u32).to_be_bytes();
157
158                            let mut packet = Vec::new();
159                            packet.extend_from_slice(&len_bytes);
160                            packet.extend_from_slice(data);
161
162                            // Send to channel
163                            let _ = tx.send(packet);
164                        }
165                    }
166                    Err(_) => break,
167                }
168            }
169        });
170
171        // Main network loop: handle both receiving and sending
172        if let Some(c) = conn.lock().await.as_mut() {
173            loop {
174                tokio::select! {
175                    // Handle incoming messages from stdin
176                    Some(packet) = rx.recv() => {
177                        if let Err(e) = c.send(&packet).await {
178                            eprintln!("Send error: {:?}", e);
179                            break;
180                        }
181                    }
182                    // Handle incoming messages from server
183                    result = c.receive() => {
184                        match result {
185                            Ok(Some(data)) => {
186                                let text = String::from_utf8_lossy(&data);
187
188                                let mut active = is_active.lock().await;
189
190                                if text == IS_START {
191                                    if !*active {
192                                        Self::print(&prompt).await;
193                                    }
194                                    *active = true;
195                                } else if text == IS_STOP {
196                                    if *active {
197                                        Self::print(CLEAN_LINE).await;
198                                    }
199                                    *active = false;
200                                } else if *active {
201                                    let output = format!("{}{}{}", CLEAN_LINE, text, prompt);
202                                    Self::print(&output).await;
203                                } else {
204                                    Self::print(&text).await;
205                                }
206                            }
207                            Ok(None) => {
208                                break;
209                            }
210                            Err(e) => {
211                                eprintln!("Error receiving: {:?}", e);
212                                break;
213                            }
214                        }
215                    }
216                }
217            }
218        }
219
220        Ok(())
221    }
222}