1use crate::{
16 conn::{KeepAlive, Mode, ShortConn},
17 meta::{BeginRequestRec, EndRequestRec, Header, ParamPairs, RequestType, Role},
18 params::Params,
19 request::Request,
20 response::ResponseStream,
21 ClientError, ClientResult, Response,
22};
23use std::marker::PhantomData;
24use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
25use tracing::debug;
26
27const REQUEST_ID: u16 = 1;
31
32pub struct Client<S, M> {
34 stream: S,
35 _mode: PhantomData<M>,
36}
37
38impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S, ShortConn> {
39 pub fn new(stream: S) -> Self {
42 Self {
43 stream,
44 _mode: PhantomData,
45 }
46 }
47
48 pub async fn execute_once<I: AsyncRead + Unpin + Send>(
51 mut self, request: Request<'_, I>,
52 ) -> ClientResult<Response> {
53 self.inner_execute(request).await
54 }
55
56 pub async fn execute_once_stream<'a, I: AsyncRead + Unpin + Send>(
84 mut self, request: Request<'_, I>,
85 ) -> ClientResult<ResponseStream<S>> {
86 Self::handle_request(&mut self.stream, REQUEST_ID, request.params, request.stdin).await?;
87 Ok(ResponseStream::new(self.stream, REQUEST_ID))
88 }
89}
90
91impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S, KeepAlive> {
92 pub fn new_keep_alive(stream: S) -> Self {
95 Self {
96 stream,
97 _mode: PhantomData,
98 }
99 }
100
101 pub async fn execute<I: AsyncRead + Unpin + Send>(
104 &mut self, request: Request<'_, I>,
105 ) -> ClientResult<Response> {
106 self.inner_execute(request).await
107 }
108
109 pub async fn execute_stream<I: AsyncRead + Unpin + Send>(
140 &mut self, request: Request<'_, I>,
141 ) -> ClientResult<ResponseStream<&mut S>> {
142 Self::handle_request(&mut self.stream, REQUEST_ID, request.params, request.stdin).await?;
143 Ok(ResponseStream::new(&mut self.stream, REQUEST_ID))
144 }
145}
146
147impl<S: AsyncRead + AsyncWrite + Unpin + Send, M: Mode> Client<S, M> {
148 async fn inner_execute<I: AsyncRead + Unpin + Send>(
149 &mut self, request: Request<'_, I>,
150 ) -> ClientResult<Response> {
151 Self::handle_request(&mut self.stream, REQUEST_ID, request.params, request.stdin).await?;
152 Self::handle_response(&mut self.stream, REQUEST_ID).await
153 }
154
155 async fn handle_request<'a, I: AsyncRead + Unpin + Send>(
156 stream: &mut S, id: u16, params: Params<'a>, mut body: I,
157 ) -> ClientResult<()> {
158 Self::handle_request_start(stream, id).await?;
159 Self::handle_request_params(stream, id, params).await?;
160 Self::handle_request_body(stream, id, &mut body).await?;
161 Self::handle_request_flush(stream).await?;
162 Ok(())
163 }
164
165 async fn handle_request_start(stream: &mut S, id: u16) -> ClientResult<()> {
166 debug!(id, "Start handle request");
167
168 let begin_request_rec =
169 BeginRequestRec::new(id, Role::Responder, <M>::is_keep_alive()).await?;
170
171 debug!(id, ?begin_request_rec, "Send to stream.");
172
173 begin_request_rec.write_to_stream(stream).await?;
174
175 Ok(())
176 }
177
178 async fn handle_request_params<'a>(
179 stream: &mut S, id: u16, params: Params<'a>,
180 ) -> ClientResult<()> {
181 let param_pairs = ParamPairs::new(params);
182 debug!(id, ?param_pairs, "Params will be sent.");
183
184 Header::write_to_stream_batches(
185 RequestType::Params,
186 id,
187 stream,
188 &mut ¶m_pairs.to_content().await?[..],
189 Some(|header| {
190 debug!(id, ?header, "Send to stream for Params.");
191 header
192 }),
193 )
194 .await?;
195
196 Header::write_to_stream_batches(
197 RequestType::Params,
198 id,
199 stream,
200 &mut tokio::io::empty(),
201 Some(|header| {
202 debug!(id, ?header, "Send to stream for Params.");
203 header
204 }),
205 )
206 .await?;
207
208 Ok(())
209 }
210
211 async fn handle_request_body<I: AsyncRead + Unpin + Send>(
212 stream: &mut S, id: u16, body: &mut I,
213 ) -> ClientResult<()> {
214 Header::write_to_stream_batches(
215 RequestType::Stdin,
216 id,
217 stream,
218 body,
219 Some(|header| {
220 debug!(id, ?header, "Send to stream for Stdin.");
221 header
222 }),
223 )
224 .await?;
225
226 Header::write_to_stream_batches(
227 RequestType::Stdin,
228 id,
229 stream,
230 &mut tokio::io::empty(),
231 Some(|header| {
232 debug!(id, ?header, "Send to stream for Stdin.");
233 header
234 }),
235 )
236 .await?;
237
238 Ok(())
239 }
240
241 async fn handle_request_flush(stream: &mut S) -> ClientResult<()> {
242 stream.flush().await?;
243
244 Ok(())
245 }
246
247 async fn handle_response(stream: &mut S, id: u16) -> ClientResult<Response> {
248 let mut response = Response::default();
249
250 let mut stderr = Vec::new();
251 let mut stdout = Vec::new();
252
253 loop {
254 let header = Header::new_from_stream(stream).await?;
255 if header.request_id != id {
256 return Err(ClientError::ResponseNotFound { id });
257 }
258 debug!(id, ?header, "Receive from stream.");
259
260 match header.r#type {
261 RequestType::Stdout => {
262 stdout.extend(header.read_content_from_stream(stream).await?);
263 }
264 RequestType::Stderr => {
265 stderr.extend(header.read_content_from_stream(stream).await?);
266 }
267 RequestType::EndRequest => {
268 let end_request_rec = EndRequestRec::from_header(&header, stream).await?;
269 debug!(id, ?end_request_rec, "Receive from stream.");
270
271 end_request_rec
272 .end_request
273 .protocol_status
274 .convert_to_client_result(end_request_rec.end_request.app_status)?;
275
276 response.stdout = if stdout.is_empty() {
277 None
278 } else {
279 Some(stdout)
280 };
281 response.stderr = if stderr.is_empty() {
282 None
283 } else {
284 Some(stderr)
285 };
286
287 return Ok(response);
288 }
289 r#type => {
290 return Err(ClientError::UnknownRequestType {
291 request_type: r#type,
292 })
293 }
294 }
295 }
296 }
297}