1use std::collections::HashMap;
2
3use crate::{
4 Error,
5 message::{Message, MessageType},
6 message_stream::MessageStream,
7};
8use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader};
9
10pub struct AptStream {
12 inner: Box<dyn AsyncBufRead + Unpin + Send>,
13 message_stream: MessageStream,
14 output_stream: Box<dyn AsyncWrite + Unpin + Send>,
15 has_initialized: bool,
16}
17
18impl AptStream {
19 pub fn new() -> Self {
21 Self {
22 inner: Box::new(BufReader::new(tokio::io::stdin())),
23 message_stream: MessageStream::default(),
24 output_stream: Box::new(tokio::io::stdout()),
25 has_initialized: false,
26 }
27 }
28
29 pub fn with_input_stream(
36 mut self,
37 input_stream: Box<dyn AsyncBufRead + Unpin + Send>,
38 ) -> Result<Self, Error> {
39 if self.has_initialized {
40 return Err(Error::StreamAlreadyInitialized);
41 }
42 self.inner = Box::new(input_stream);
43 Ok(self)
44 }
45
46 pub fn with_output_stream(
53 mut self,
54 output_stream: Box<dyn AsyncWrite + Unpin + Send>,
55 ) -> Result<Self, Error> {
56 if self.has_initialized {
57 return Err(Error::StreamAlreadyInitialized);
58 }
59 self.output_stream = output_stream;
60 Ok(self)
61 }
62
63 pub async fn next<'a>(&'a mut self) -> Result<Option<AptRequest<'a>>, Error> {
65 if !self.has_initialized {
66 self.has_initialized = true;
67
68 log::debug!("synthesizing capabilities request");
70 let capabilities_req = AptRequest::Capabilities(CapabilitiesRequest { this: self });
71
72 return Ok(Some(capabilities_req));
73 }
74
75 let mut line = String::new();
77 loop {
78 log::trace!("reading line from input stream");
79 line.clear();
80 let nread = self.inner.as_mut().read_line(&mut line).await?;
81 log::trace!(
82 "read {} bytes from input stream: {:?}",
83 nread,
84 line.as_bytes()
85 );
86
87 if let Some(message_result) = self.message_stream.push_line(line.as_bytes()) {
88 log::debug!("complete message received");
89 log::trace!("complete message received: {:?}", message_result);
90
91 match message_result {
92 Ok(message) => {
93 let apt_request = AptRequest::<'a>::try_from_message((self, message)).await;
94
95 return match apt_request {
96 Ok(req) => {
97 log::debug!("APT request parse OK");
98 Ok(Some(req))
99 }
100 Err(e) => {
101 log::debug!("APT request parse error: {e:#?}");
102 Err(e)
103 }
104 };
105 }
106 Err(e) => {
107 log::debug!("APT message parse error: {e:?}");
108 return Err(e);
109 }
110 }
111 } else if nread == 0 {
112 log::debug!("EOF reached on input stream");
114 return Ok(None);
115 }
116 }
117 }
118}
119
120pub enum AptRequest<'a> {
122 Capabilities(CapabilitiesRequest<'a>),
124 Configuration(ConfigRequest),
126 Uri(UriRequest<'a>),
128}
129
130pub struct CapabilitiesRequest<'a> {
132 this: &'a mut AptStream,
133}
134
135impl<'a> CapabilitiesRequest<'a> {
136 #[must_use = "responses must be sent to the APT client with `.send().await`"]
138 pub fn respond(self) -> CapabilitiesResponse<'a> {
139 CapabilitiesResponse::new(self.this)
140 }
141}
142
143pub struct CapabilitiesResponse<'a> {
145 this: &'a mut AptStream,
146 single_instance: bool,
147 send_config: bool,
148 pipeline: bool,
149 local_only: bool,
150 removable: bool,
151 needs_cleanup: bool,
152 version: String,
153}
154
155impl CapabilitiesResponse<'_> {
156 #[must_use]
158 fn new(this: &mut AptStream) -> CapabilitiesResponse<'_> {
159 CapabilitiesResponse {
160 this,
161 single_instance: false,
162 send_config: false,
163 pipeline: false,
164 local_only: false,
165 removable: false,
166 needs_cleanup: false,
167 version: env!("CARGO_PKG_VERSION").to_string(),
168 }
169 }
170
171 #[must_use = "responses must be sent with `.send().await`"]
173 pub fn single_instance(mut self, enabled: bool) -> Self {
174 self.single_instance = enabled;
175 self
176 }
177
178 #[must_use = "responses must be sent with `.send().await`"]
180 pub fn send_config(mut self, enabled: bool) -> Self {
181 self.send_config = enabled;
182 self
183 }
184
185 #[must_use = "responses must be sent with `.send().await`"]
194 pub fn version<S: Into<String>>(mut self, version: S) -> Self {
195 self.version = version.into();
196 self
197 }
198
199 #[must_use = "responses must be sent with `.send().await`"]
201 pub fn pipeline(mut self, enabled: bool) -> Self {
202 self.pipeline = enabled;
203 self
204 }
205
206 #[must_use = "responses must be sent with `.send().await`"]
208 pub fn local_only(mut self, enabled: bool) -> Self {
209 self.local_only = enabled;
210 self
211 }
212
213 #[must_use = "responses must be sent with `.send().await`"]
215 pub fn removable(mut self, enabled: bool) -> Self {
216 self.removable = enabled;
217 self
218 }
219
220 #[must_use = "responses must be sent with `.send().await`"]
222 pub fn needs_cleanup(mut self, enabled: bool) -> Self {
223 self.needs_cleanup = enabled;
224 self
225 }
226
227 pub async fn send(self) -> Result<(), Error> {
229 let msg = Message::new(
230 MessageType::Capabilities,
231 vec![
232 ("Version", &self.version),
233 (
234 "Send-Config",
235 if self.send_config { "true" } else { "false" },
236 ),
237 (
238 "Single-Instance",
239 if self.single_instance {
240 "true"
241 } else {
242 "false"
243 },
244 ),
245 ("Pipeline", if self.pipeline { "true" } else { "false" }),
246 ("Local-Only", if self.local_only { "true" } else { "false" }),
247 ("Removable", if self.removable { "true" } else { "false" }),
248 (
249 "Needs-Cleanup",
250 if self.needs_cleanup { "true" } else { "false" },
251 ),
252 ],
253 )
254 .to_string();
255
256 log::debug!("sending capabilities response");
257 log::trace!("sending capabilities response: {}", msg.trim());
258 self.this
259 .output_stream
260 .as_mut()
261 .write_all(msg.as_bytes())
262 .await?;
263
264 log::debug!("flushing output stream");
265 self.this.output_stream.as_mut().flush().await?;
266
267 log::debug!("capabilities response sent successfully");
268 Ok(())
269 }
270}
271
272pub struct ConfigRequest {
276 options: HashMap<String, String>,
277}
278
279impl ConfigRequest {
280 fn from(message: Message) -> Self {
281 let options = message
282 .headers
283 .into_iter()
284 .filter_map(|(key, value)| {
285 if key == "Config-Item" {
286 let (key, value) = value.split_once('=')?;
287 Some((key.to_string(), value.to_string()))
288 } else {
289 None
290 }
291 })
292 .collect::<HashMap<String, String>>();
293
294 ConfigRequest { options }
295 }
296
297 pub fn options(&self) -> &HashMap<String, String> {
299 &self.options
300 }
301}
302
303pub struct UriRequest<'a> {
305 this: &'a mut AptStream,
306 uri: String,
307 repo_uri: String,
308 filename: String,
309}
310
311impl<'a> UriRequest<'a> {
312 async fn from(this: &'a mut AptStream, message: Message) -> Result<UriRequest<'a>, Error> {
314 let Ok(uri) = message.header("URI") else {
315 let msg = Message::uri_failure("", "Missing URI header").to_string();
316 this.output_stream.write_all(msg.as_bytes()).await?;
317 this.output_stream.as_mut().flush().await?;
318 return Err(Error::HeaderNotFound("URI".to_string()));
319 };
320
321 let Ok(filename) = message.header("Filename") else {
322 let msg = Message::uri_failure(uri, "Missing Filename header").to_string();
323 this.output_stream.write_all(msg.as_bytes()).await?;
324 this.output_stream.as_mut().flush().await?;
325 return Err(Error::HeaderNotFound("Filename".to_string()));
326 };
327
328 let Ok(target_uri) = message.header("Target-Site") else {
329 let msg = Message::uri_failure(uri, "Missing Target-Site header").to_string();
330 this.output_stream.write_all(msg.as_bytes()).await?;
331 this.output_stream.as_mut().flush().await?;
332 return Err(Error::HeaderNotFound("Target-Site".to_string()));
333 };
334
335 Ok(UriRequest {
336 this,
337 uri: uri.to_string(),
338 repo_uri: target_uri.to_string(),
339 filename: filename.to_string(),
340 })
341 }
342
343 pub fn uri(&self) -> &str {
345 &self.uri
346 }
347
348 pub fn filename(&self) -> &str {
350 &self.filename
351 }
352
353 pub fn repo_uri(&self) -> &str {
359 &self.repo_uri
360 }
361
362 pub async fn send_status(&mut self, status: &str) -> Result<(), Error> {
368 let msg = Message::status(status).to_string();
369 self.this
370 .output_stream
371 .as_mut()
372 .write_all(msg.as_bytes())
373 .await?;
374 self.this.output_stream.as_mut().flush().await?;
375 Ok(())
376 }
377
378 pub async fn fail(self, reason: &str) -> Result<(), Error> {
382 let msg = Message::uri_failure(&self.uri, reason).to_string();
383 self.this
384 .output_stream
385 .as_mut()
386 .write_all(msg.as_bytes())
387 .await?;
388 self.this.output_stream.as_mut().flush().await?;
389 Ok(())
390 }
391
392 pub async fn start(
397 self,
398 size_in_bytes: u64,
399 last_modified: &str,
400 ) -> Result<UriResponse<'a>, Error> {
401 let msg = Message::uri_start(&self.uri, size_in_bytes, last_modified).to_string();
402 self.this
403 .output_stream
404 .as_mut()
405 .write_all(msg.as_bytes())
406 .await?;
407 self.this.output_stream.as_mut().flush().await?;
408 Ok(UriResponse {
409 this: self.this,
410 uri: self.uri,
411 filename: self.filename,
412 })
413 }
414}
415
416pub struct UriResponse<'a> {
418 this: &'a mut AptStream,
419 uri: String,
420 filename: String,
421}
422
423impl UriResponse<'_> {
424 pub async fn complete(self) -> Result<(), Error> {
426 let msg = Message::uri_success(&self.uri, &self.filename).to_string();
427 self.this
428 .output_stream
429 .as_mut()
430 .write_all(msg.as_bytes())
431 .await?;
432 self.this.output_stream.as_mut().flush().await?;
433 Ok(())
434 }
435
436 pub async fn fail(self, reason: &str) -> Result<(), Error> {
438 let msg = Message::uri_failure(&self.uri, reason).to_string();
439 self.this
440 .output_stream
441 .as_mut()
442 .write_all(msg.as_bytes())
443 .await?;
444 self.this.output_stream.as_mut().flush().await?;
445 Ok(())
446 }
447
448 pub async fn writer(&self) -> Result<impl tokio::io::AsyncWrite + 'static, Error> {
452 Ok(tokio::fs::OpenOptions::new()
453 .write(true)
454 .create(true)
455 .truncate(true)
456 .mode(0o600)
457 .open(self.filename.clone())
458 .await?)
459 }
460}
461
462trait TryFromMessage<'a>: Sized {
463 type Error;
464
465 async fn try_from_message(msg: (&'a mut AptStream, Message)) -> Result<Self, Self::Error>;
466}
467
468impl<'a> TryFromMessage<'a> for AptRequest<'a> {
469 type Error = Error;
470
471 async fn try_from_message(
472 (this, message): (&'a mut AptStream, Message),
473 ) -> Result<Self, Self::Error> {
474 match message.message_type {
475 MessageType::Configuration => {
476 let config_request = ConfigRequest::from(message);
477 Ok(AptRequest::Configuration(config_request))
478 }
479 MessageType::URIAcquire => {
480 let uri_request = UriRequest::from(this, message).await?;
481 Ok(AptRequest::Uri(uri_request))
482 }
483 other => Err(Error::UnexpectedMessageType(other, message)),
484 }
485 }
486}