1use std::io::ErrorKind;
17
18use serde::{Deserialize, Serialize};
19
20use super::socket::Socket;
21use crate::{bencode, error::Error};
22
23#[derive(Debug)]
24pub struct Session {
25 connection: Connection,
28 session_id: Box<str>,
29 request_count: usize,
30}
31
32impl Session {
33 pub fn close(mut self) -> Result<Connection, Error> {
34 let id = format!("{}:close", self.session_id);
35 self.connection.send(WireRequest {
36 op: Op::Close,
37 id: &id,
38 session: Some(&self.session_id),
39 ns: None,
40 code: None,
41 line: None,
42 column: None,
43 file: None,
44 })?;
45 #[allow(clippy::blocks_in_if_conditions)]
46 while self
47 .connection
48 .recv(|r| Ok(!(r.matches(&id) && r.has_status("session-closed"))))?
49 {}
50 Ok(self.connection)
51 }
52
53 pub fn eval<F>(
54 &mut self,
55 code: &str,
56 file_name: Option<&str>,
57 line: Option<usize>,
58 column: Option<usize>,
59 mut handler: F,
60 ) -> Result<(), Error>
61 where
62 F: FnMut(Response) -> Result<(), Error>,
63 {
64 self.request_count += 1;
65 let id = format!("{}:{}", self.session_id, self.request_count);
66 self.connection.send(WireRequest {
67 id: &id,
68 op: Op::Eval,
69 session: Some(&self.session_id),
70 ns: None,
71 code: Some(code),
72 line: line.map(|n| n.try_into().unwrap_or_default()),
73 column: column.map(|n| n.try_into().unwrap_or_default()),
74 file: file_name,
75 })?;
76 #[allow(clippy::blocks_in_if_conditions)]
77 while self.connection.recv(|r| {
78 if !r.matches(&id) {
79 return Ok(true);
80 }
81 handler(Response {
82 value: r.value.as_deref(),
83 out: r.out.as_deref(),
84 err: r.err.as_deref(),
85 ex: r.ex.as_deref(),
86 root_ex: r.root_ex.as_deref(),
87 })?;
88 Ok(!r.has_status("done"))
89 })? {}
90 Ok(())
91 }
92}
93
94#[derive(Debug)]
95pub struct Response<'a> {
96 pub value: Option<&'a str>,
97 pub ex: Option<&'a str>,
98 pub root_ex: Option<&'a str>,
99 pub out: Option<&'a str>,
100 pub err: Option<&'a str>,
101}
102
103#[derive(Debug)]
104pub enum Op {
105 Clone,
106 Close,
107 Eval,
108}
109
110impl Op {
111 pub fn as_str(&self) -> &'static str {
112 match *self {
113 Op::Clone => "clone",
114 Op::Close => "close",
115 Op::Eval => "eval",
116 }
117 }
118}
119
120fn serialize_op<S: serde::Serializer>(
121 op: &Op,
122 serializer: S,
123) -> Result<S::Ok, S::Error> {
124 serializer.serialize_str(op.as_str())
125}
126
127#[derive(Debug, Serialize)]
128#[serde(rename_all = "kebab-case")]
129pub struct WireRequest<'a> {
130 #[serde(serialize_with = "serialize_op")]
131 pub op: Op,
132 pub id: &'a str,
133 pub session: Option<&'a str>,
134 pub ns: Option<&'a str>,
135 pub code: Option<&'a str>,
136 pub line: Option<i32>,
137 pub column: Option<i32>,
138 pub file: Option<&'a str>,
139}
140
141#[derive(Debug, Deserialize)]
142#[serde(rename_all = "kebab-case")]
143pub struct WireResponse {
144 pub session: String,
145 pub id: Option<String>,
146 pub status: Option<Vec<String>>,
147 pub new_session: Option<String>,
148 pub value: Option<String>,
149 pub ex: Option<String>,
150 pub root_ex: Option<String>,
151 pub out: Option<String>,
152 pub err: Option<String>,
153}
154
155impl WireResponse {
156 pub fn matches(&self, id: &str) -> bool {
157 self.id.as_ref().map(|our| our == id).unwrap_or(false)
158 }
159
160 pub fn has_status(&self, label: &str) -> bool {
161 self
162 .status
163 .as_ref()
164 .map(|labels| labels.iter().any(|our| our == label))
165 .unwrap_or(false)
166 }
167}
168
169#[derive(Debug)]
170pub struct Connection {
171 socket: Socket,
172 buffer: Vec<u8>,
173}
174
175impl Connection {
176 pub fn new(socket: Socket) -> Self {
177 Self {
178 socket,
179 buffer: Default::default(),
180 }
181 }
182
183 pub fn session(mut self) -> Result<Session, Error> {
184 self.send(WireRequest {
185 op: Op::Clone,
186 id: "",
187 session: None,
188 ns: None,
189 code: None,
190 line: None,
191 column: None,
192 file: None,
193 })?;
194 let session_id = self.recv(|response| {
195 if let Some(session) = response.new_session.as_deref() {
196 Ok(session.to_owned().into_boxed_str())
197 } else {
198 Err(Error::UnexptectedResponse)
199 }
200 })?;
201 Ok(Session {
202 connection: self,
203 session_id,
204 request_count: 0,
205 })
206 }
207
208 fn send(&mut self, request: WireRequest) -> Result<(), Error> {
209 let payload = serde_bencode::to_bytes(&request).unwrap();
210 let w = self.socket.borrow_mut_write();
211 w.write_all(&payload).map_err(Error::CannotSendToHost)?;
212 w.flush().map_err(Error::CannotSendToHost)
213 }
214
215 fn recv<F, V>(&mut self, mut handler: F) -> Result<V, Error>
216 where
217 F: FnMut(&WireResponse) -> Result<V, Error>,
218 {
219 let mut buffer = [0_u8; 4096];
220 loop {
221 match bencode::scan_next(&self.buffer) {
222 Ok((_, len)) => {
223 let result = serde_bencode::from_bytes(&self.buffer[0..len])
224 .map_err(|_| Error::CorruptedResponse);
225 self.buffer.copy_within(len.., 0);
226 self.buffer.truncate(self.buffer.len() - len);
227 return handler(&result?);
228 }
229 Err(bencode::Error::BadInput) => {
230 return Err(Error::CorruptedResponse);
231 }
232 Err(bencode::Error::UnexpectedEnd) => {}
233 }
234 match self.socket.borrow_mut_read().read(&mut buffer) {
235 Ok(0) => return Err(Error::HostDisconnected),
236 Ok(len) => self.buffer.extend_from_slice(&buffer[0..len]),
237 Err(e) if e.kind() == ErrorKind::Interrupted => {}
238 Err(e) => return Err(Error::CannotReceiveFromHost(e)),
239 }
240 }
241 }
242}