1use anyhow::{Context, Result};
2use std::sync::Arc;
3use tokio::io::AsyncWriteExt;
4use tokio::sync::Mutex;
5
6use crate::model::AuthInfo;
7use crate::network::{Connection, NetType, TcpConnection, WebSocketConnection};
8use crate::util::CLEAN_LINE;
9use tokio::sync::mpsc;
10
11const IS_START: &str = "INTERACTIVE_SIGNAL_START";
12const IS_STOP: &str = "INTERACTIVE_SIGNAL_STOP";
13
14pub enum ShellConnection {
15 Tcp(TcpConnection),
16 WebSocket(WebSocketConnection),
17}
18
19impl ShellConnection {
20 async fn send(&mut self, data: &[u8]) -> Result<()> {
21 match self {
22 ShellConnection::Tcp(conn) => conn.send(data).await,
23 ShellConnection::WebSocket(conn) => conn.send(data).await,
24 }
25 }
26
27 async fn receive(&mut self) -> Result<Option<Vec<u8>>> {
28 match self {
29 ShellConnection::Tcp(conn) => conn.receive().await,
30 ShellConnection::WebSocket(conn) => conn.receive().await,
31 }
32 }
33}
34
35pub struct Shell {
36 addr: String,
37 name: String,
38 id: i64,
39 score: i64,
40 conn: Option<ShellConnection>,
41 is_active: Arc<Mutex<bool>>,
42}
43
44impl Shell {
45 pub fn new(addr: String, name: String) -> Self {
46 Self {
47 addr,
48 name,
49 id: std::time::SystemTime::now()
50 .duration_since(std::time::UNIX_EPOCH)
51 .unwrap()
52 .as_nanos() as i64,
53 score: 100,
54 conn: None,
55 is_active: Arc::new(Mutex::new(false)),
56 }
57 }
58
59 fn detect_net_type(&self) -> NetType {
60 if self.addr.ends_with("9998") {
61 NetType::WebSocket
62 } else {
63 NetType::Tcp
64 }
65 }
66
67 async fn connect(&mut self) -> Result<()> {
68 let net_type = self.detect_net_type();
69
70 self.conn = Some(match net_type {
71 NetType::Tcp => {
72 let conn = TcpConnection::connect(&self.addr).await?;
73 ShellConnection::Tcp(conn)
74 }
75 NetType::WebSocket => {
76 let conn = WebSocketConnection::connect(&self.addr).await?;
77 ShellConnection::WebSocket(conn)
78 }
79 });
80
81 Ok(())
82 }
83
84 async fn auth(&mut self) -> Result<()> {
85 let auth_info = AuthInfo {
86 id: self.id,
87 name: self.name.clone(),
88 score: self.score,
89 };
90
91 let json_data = serde_json::to_vec(&auth_info)?;
93 let len_bytes = (json_data.len() as u32).to_be_bytes();
94
95 let mut data = Vec::new();
96 data.extend_from_slice(&len_bytes);
97 data.extend_from_slice(&json_data);
98
99 if let Some(conn) = self.conn.as_mut() {
100 conn.send(&data).await?;
101 }
102
103 Ok(())
104 }
105
106 fn get_prompt(&self) -> String {
107 format!(
108 "{}[{}@ratel {}]# ",
109 CLEAN_LINE,
110 self.name.to_lowercase(),
111 "~"
112 )
113 }
114
115 async fn print(text: &str) {
116 print!("{}", text);
117 use std::io::{self, Write};
118 io::stdout().flush().ok();
119 }
120
121 pub async fn start(&mut self) -> Result<()> {
122 self.connect().await.context("Failed to connect")?;
123 self.auth().await.context("Failed to authenticate")?;
124
125 let is_active = self.is_active.clone();
126 let conn = Arc::new(Mutex::new(self.conn.take()));
127 let prompt = self.get_prompt();
128
129 let (tx, mut rx) = mpsc::unbounded_channel::<Vec<u8>>();
131
132 let is_active_stdin = is_active.clone();
133 let prompt_stdin = prompt.clone();
134
135 tokio::spawn(async move {
137 use tokio::io::{AsyncBufReadExt, BufReader};
138
139 let stdin = tokio::io::stdin();
140 let mut reader = BufReader::new(stdin);
141
142 loop {
143 let mut line = String::new();
144 match reader.read_line(&mut line).await {
145 Ok(0) => break,
146 Ok(_) => {
147 let line = line.trim();
148 let active = *is_active_stdin.lock().await;
149
150 if active {
151 print!("{}{}", CLEAN_LINE, prompt_stdin);
152 tokio::io::stdout().flush().await.ok();
153
154 let data = line.as_bytes();
156 let len_bytes = (data.len() as u32).to_be_bytes();
157
158 let mut packet = Vec::new();
159 packet.extend_from_slice(&len_bytes);
160 packet.extend_from_slice(data);
161
162 let _ = tx.send(packet);
164 }
165 }
166 Err(_) => break,
167 }
168 }
169 });
170
171 if let Some(c) = conn.lock().await.as_mut() {
173 loop {
174 tokio::select! {
175 Some(packet) = rx.recv() => {
177 if let Err(e) = c.send(&packet).await {
178 eprintln!("Send error: {:?}", e);
179 break;
180 }
181 }
182 result = c.receive() => {
184 match result {
185 Ok(Some(data)) => {
186 let text = String::from_utf8_lossy(&data);
187
188 let mut active = is_active.lock().await;
189
190 if text == IS_START {
191 if !*active {
192 Self::print(&prompt).await;
193 }
194 *active = true;
195 } else if text == IS_STOP {
196 if *active {
197 Self::print(CLEAN_LINE).await;
198 }
199 *active = false;
200 } else if *active {
201 let output = format!("{}{}{}", CLEAN_LINE, text, prompt);
202 Self::print(&output).await;
203 } else {
204 Self::print(&text).await;
205 }
206 }
207 Ok(None) => {
208 break;
209 }
210 Err(e) => {
211 eprintln!("Error receiving: {:?}", e);
212 break;
213 }
214 }
215 }
216 }
217 }
218 }
219
220 Ok(())
221 }
222}