1use std::fmt::Debug;
2use std::io::{BufRead, BufReader, BufWriter, Read, Write};
3use std::sync::{Arc, Mutex};
4
5use serde_json;
6
7use crate::{
8 base_message::{BaseMessage, Sendable},
9 errors::{DeserializationError, ServerError},
10 events::Event,
11 requests::Request,
12 responses::Response,
13 reverse_requests::ReverseRequest,
14};
15
16#[derive(Debug)]
17enum ServerState {
18 Header,
20 Content,
22}
23
24pub struct Server<R: Read, W: Write> {
29 input_buffer: BufReader<R>,
30
31 pub output: Arc<Mutex<ServerOutput<W>>>,
34}
35
36pub struct ServerOutput<W: Write> {
43 output_buffer: BufWriter<W>,
44 sequence_number: i64,
45}
46
47impl<R: Read, W: Write> Server<R, W> {
48 pub fn new(input: BufReader<R>, output: BufWriter<W>) -> Self {
50 let server_output = Arc::new(Mutex::new(ServerOutput {
51 output_buffer: output,
52 sequence_number: 0,
53 }));
54
55 Self {
56 input_buffer: input,
57 output: server_output,
58 }
59 }
60
61 pub fn poll_request(&mut self) -> Result<Option<Request>, ServerError> {
66 let mut state = ServerState::Header;
67 let mut buffer = String::new();
68 let mut content_length: usize = 0;
69
70 loop {
71 match self.input_buffer.read_line(&mut buffer) {
72 Ok(read_size) => {
73 if read_size == 0 {
74 break Ok(None);
75 }
76 match state {
77 ServerState::Header => {
78 let parts: Vec<&str> = buffer.trim_end().split(':').collect();
79 if parts.len() == 2 {
80 match parts[0] {
81 "Content-Length" => {
82 content_length = match parts[1].trim().parse() {
83 Ok(val) => val,
84 Err(_) => return Err(ServerError::HeaderParseError { line: buffer }),
85 };
86 buffer.clear();
87 buffer.reserve(content_length);
88 state = ServerState::Content;
89 }
90 other => {
91 return Err(ServerError::UnknownHeader {
92 header: other.to_string(),
93 })
94 }
95 }
96 } else {
97 return Err(ServerError::HeaderParseError { line: buffer });
98 }
99 }
100 ServerState::Content => {
101 buffer.clear();
102 let mut content = vec![0; content_length];
103 self
104 .input_buffer
105 .read_exact(content.as_mut_slice())
106 .map_err(ServerError::IoError)?;
107
108 let content = std::str::from_utf8(content.as_slice())
109 .map_err(|e| ServerError::ParseError(DeserializationError::DecodingError(e)))?;
110 let request: Request = serde_json::from_str(content)
111 .map_err(|e| ServerError::ParseError(DeserializationError::SerdeError(e)))?;
112 return Ok(Some(request));
113 }
114 }
115 }
116 Err(e) => return Err(ServerError::IoError(e)),
117 }
118 }
119 }
120
121 pub fn send(&mut self, body: Sendable) -> Result<(), ServerError> {
122 let mut output = self
123 .output
124 .lock()
125 .map_err(|_| ServerError::OutputLockError)?;
126 output.send(body)
127 }
128
129 pub fn respond(&mut self, response: Response) -> Result<(), ServerError> {
130 self.send(Sendable::Response(response))
131 }
132
133 pub fn send_event(&mut self, event: Event) -> Result<(), ServerError> {
134 self.send(Sendable::Event(event))
135 }
136
137 pub fn send_reverse_request(&mut self, request: ReverseRequest) -> Result<(), ServerError> {
138 self.send(Sendable::ReverseRequest(request))
139 }
140}
141
142impl<W: Write> ServerOutput<W> {
143 pub fn send(&mut self, body: Sendable) -> Result<(), ServerError> {
144 self.sequence_number += 1;
145
146 let message = BaseMessage {
147 seq: self.sequence_number,
148 message: body,
149 };
150
151 let resp_json = serde_json::to_string(&message).map_err(ServerError::SerializationError)?;
152 write!(
153 self.output_buffer,
154 "Content-Length: {}\r\n\r\n",
155 resp_json.len()
156 )
157 .map_err(ServerError::IoError)?;
158
159 write!(self.output_buffer, "{}\r\n", resp_json).map_err(ServerError::IoError)?;
160 self.output_buffer.flush().map_err(ServerError::IoError)?;
161 Ok(())
162 }
163
164 pub fn respond(&mut self, response: Response) -> Result<(), ServerError> {
165 self.send(Sendable::Response(response))
166 }
167
168 pub fn send_event(&mut self, event: Event) -> Result<(), ServerError> {
169 self.send(Sendable::Event(event))
170 }
171
172 pub fn send_reverse_request(&mut self, request: ReverseRequest) -> Result<(), ServerError> {
173 self.send(Sendable::ReverseRequest(request))
174 }
175}
176
177#[cfg(test)]
178mod tests {
179
180 use std::io::Cursor;
181
182 use serde_json::Value;
183
184 use super::*;
185 use crate::requests::{AttachOrLaunchArguments, Command, RestartArguments};
186
187 fn simulate_poll_request(input: &str) -> Request {
188 let mut server_in = Cursor::new(input.as_bytes().to_vec());
189 let server_out = Vec::new();
190 let mut server = Server::new(BufReader::new(&mut server_in), BufWriter::new(server_out));
191
192 server.poll_request().unwrap().unwrap()
193 }
194
195 #[test]
196 fn test_server_init_request() {
197 let req = simulate_poll_request("Content-Length: 155\r\n\r\n{\"seq\": 152,\"type\": \"request\",\"command\": \"initialize\",\"arguments\": {\"adapterID\": \"0001e357-72c7-4f03-ae8f-c5b54bd8dabf\", \"clientName\": \"Some Cool Editor\"}}");
198
199 assert_eq!(req.seq, 152);
200 assert!(matches!(req.command, Command::Initialize { .. }));
201 }
202
203 #[test]
204 fn test_server_restart_request() {
205 let req = simulate_poll_request("Content-Length: 67\r\n\r\n{\"seq\": 152,\"type\": \"request\",\"command\": \"restart\",\"arguments\": {}}");
206
207 assert!(matches!(
208 req.command,
209 Command::Restart {
210 0: RestartArguments { arguments: None }
211 }
212 ));
213
214 let req = simulate_poll_request("Content-Length: 96\r\n\r\n{\"seq\": 152,\"type\": \"request\",\"command\": \"restart\",\"arguments\": {\"arguments\": {\"noDebug\":true}}}");
216 assert!(matches!(
217 req.command,
218 Command::Restart {
219 0: RestartArguments {
220 arguments: Some(AttachOrLaunchArguments {
221 no_debug: Some(_),
222 ..
223 })
224 }
225 }
226 ));
227
228 let req = simulate_poll_request("Content-Length: 98\r\n\r\n{\"seq\": 152,\"type\": \"request\",\"command\": \"restart\",\"arguments\": {\"arguments\": {\"__restart\":true}}}");
230 assert!(matches!(
231 req.command,
232 Command::Restart {
233 0: RestartArguments {
234 arguments: Some(AttachOrLaunchArguments {
235 restart_data: Some(Value::Bool(true)),
236 ..
237 })
238 }
239 }
240 ));
241 }
242}