1use std::collections::BTreeMap;
4use std::io;
5
6use bytes::{Buf, BytesMut};
7use tokio_util::codec::Decoder;
8use tracing::warn;
9
10use crate::encoder::AccumulationLimit;
11use crate::frame::Frame;
12
13const PW_PROMPT: &[u8] = b"ENTER PASSWORD:";
16
17#[derive(Debug)]
19struct ClientEnvAccumulator {
20 event: String,
21 args: String,
22 env: BTreeMap<String, String>,
23}
24
25#[derive(Debug)]
51pub struct FrameDecoder {
52 client_notification: Option<ClientEnvAccumulator>,
54
55 max_client_env_entries: AccumulationLimit,
57
58 seen_info: bool,
60}
61
62impl Default for FrameDecoder {
63 fn default() -> Self {
64 Self {
65 client_notification: None,
66 max_client_env_entries: AccumulationLimit::Unlimited,
67 seen_info: false,
68 }
69 }
70}
71
72impl FrameDecoder {
73 pub fn new() -> Self {
75 Self::default()
76 }
77
78 pub fn with_max_client_env_entries(mut self, limit: AccumulationLimit) -> Self {
80 self.max_client_env_entries = limit;
81 self
82 }
83}
84
85fn check_accumulation_limit(
86 current_len: usize,
87 limit: AccumulationLimit,
88 what: &'static str,
89) -> Result<(), io::Error> {
90 if let AccumulationLimit::Max(max) = limit
91 && current_len >= max
92 {
93 return Err(io::Error::other(AccumulationLimitExceeded { what, max }));
94 }
95 Ok(())
96}
97
98#[derive(Debug, thiserror::Error)]
99#[error("{what} accumulation limit exceeded ({max})")]
100struct AccumulationLimitExceeded {
101 what: &'static str,
102 max: usize,
103}
104
105impl Decoder for FrameDecoder {
106 type Item = Frame;
107 type Error = io::Error;
108
109 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
110 loop {
111 let Some(newline_pos) = src.iter().position(|&b| b == b'\n') else {
113 if src.starts_with(PW_PROMPT) {
116 let mut consume = PW_PROMPT.len();
117 if src.get(consume) == Some(&b'\r') {
118 consume += 1;
119 }
120 src.advance(consume);
121 return Ok(Some(Frame::PasswordPrompt));
122 }
123 if src.capacity() - src.len() < 256 {
124 src.reserve(256);
125 }
126 return Ok(None);
127 };
128
129 let line_bytes = src.split_to(newline_pos + 1);
131 let line = match std::str::from_utf8(&line_bytes) {
132 Ok(text) => text,
133 Err(error) => {
134 self.client_notification = None;
135 return Err(io::Error::new(io::ErrorKind::InvalidData, error));
136 }
137 }
138 .trim_end_matches(['\r', '\n'])
139 .to_string();
140
141 if line.is_empty() && self.client_notification.is_none() {
147 return Ok(Some(Frame::Line(line)));
148 }
149
150 if let Some(ref mut accum) = self.client_notification
152 && let Some(rest) = line.strip_prefix(">CLIENT:ENV,")
153 {
154 if rest == "END" {
155 let finished = self.client_notification.take().expect("guarded by if-let");
156 return Ok(Some(Frame::ClientEnv {
157 event: finished.event,
158 args: finished.args,
159 env: finished.env,
160 }));
161 }
162 let (k, v) = rest
163 .split_once('=')
164 .map(|(k, v)| (k.to_string(), v.to_string()))
165 .unwrap_or_else(|| (rest.to_string(), String::new()));
166 check_accumulation_limit(
167 accum.env.len(),
168 self.max_client_env_entries,
169 "client ENV",
170 )?;
171 accum.env.insert(k, v);
172 continue;
173 }
174
175 if let Some(rest) = line.strip_prefix("SUCCESS:") {
178 return Ok(Some(Frame::Success(
179 rest.strip_prefix(' ').unwrap_or(rest).to_string(),
180 )));
181 }
182
183 if let Some(rest) = line.strip_prefix("ERROR:") {
184 return Ok(Some(Frame::Error(
185 rest.strip_prefix(' ').unwrap_or(rest).to_string(),
186 )));
187 }
188
189 if line == "ENTER PASSWORD:" {
190 return Ok(Some(Frame::PasswordPrompt));
191 }
192
193 if line == "END" {
194 return Ok(Some(Frame::End));
195 }
196
197 if let Some(inner) = line.strip_prefix('>') {
199 let Some((kind, payload)) = inner.split_once(':') else {
200 warn!(line = %line, "malformed notification (no colon)");
201 return Ok(Some(Frame::Line(line)));
204 };
205
206 if kind == "INFO" {
208 if !self.seen_info {
209 self.seen_info = true;
210 return Ok(Some(Frame::Info(payload.to_string())));
211 }
212 return Ok(Some(Frame::Notification {
213 kind: kind.to_string(),
214 payload: payload.to_string(),
215 }));
216 }
217
218 if kind == "CLIENT" {
221 let (event, args) = payload
222 .split_once(',')
223 .map(|(e, a)| (e.to_string(), a.to_string()))
224 .unwrap_or_else(|| (payload.to_string(), String::new()));
225
226 if event == "ADDRESS" {
227 return Ok(Some(Frame::Notification {
229 kind: "CLIENT".to_string(),
230 payload: payload.to_string(),
231 }));
232 }
233
234 self.client_notification = Some(ClientEnvAccumulator {
236 event,
237 args,
238 env: BTreeMap::new(),
239 });
240 continue; }
242
243 return Ok(Some(Frame::Notification {
244 kind: kind.to_string(),
245 payload: payload.to_string(),
246 }));
247 }
248
249 return Ok(Some(Frame::Line(line)));
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 fn decode_one(input: &str) -> Frame {
260 let mut decoder = FrameDecoder::new();
261 let mut buf = BytesMut::from(input);
262 decoder.decode(&mut buf).unwrap().unwrap()
263 }
264
265 #[test]
266 fn success_line() {
267 assert_eq!(
268 decode_one("SUCCESS: pid=42\n"),
269 Frame::Success("pid=42".to_string())
270 );
271 }
272
273 #[test]
274 fn error_line() {
275 assert_eq!(
276 decode_one("ERROR: unknown command\n"),
277 Frame::Error("unknown command".to_string())
278 );
279 }
280
281 #[test]
282 fn end_line() {
283 assert_eq!(decode_one("END\n"), Frame::End);
284 }
285
286 #[test]
287 fn plain_line() {
288 assert_eq!(
289 decode_one("TITLE\tOpenVPN 2.6.8\n"),
290 Frame::Line("TITLE\tOpenVPN 2.6.8".to_string())
291 );
292 }
293
294 #[test]
295 fn notification() {
296 assert_eq!(
297 decode_one(">HOLD:Waiting for hold release:0\n"),
298 Frame::Notification {
299 kind: "HOLD".to_string(),
300 payload: "Waiting for hold release:0".to_string(),
301 }
302 );
303 }
304
305 #[test]
306 fn info_banner() {
307 let mut decoder = FrameDecoder::new();
308 let mut buf = BytesMut::from(">INFO:OpenVPN Management Interface\n>INFO:second\n");
309
310 let first = decoder.decode(&mut buf).unwrap().unwrap();
311 assert_eq!(
312 first,
313 Frame::Info("OpenVPN Management Interface".to_string())
314 );
315
316 let second = decoder.decode(&mut buf).unwrap().unwrap();
317 assert_eq!(
318 second,
319 Frame::Notification {
320 kind: "INFO".to_string(),
321 payload: "second".to_string(),
322 }
323 );
324 }
325
326 #[test]
327 fn state_notification() {
328 let frame = decode_one(">STATE:1711000000,CONNECTED,SUCCESS,10.8.0.6,1.2.3.4,,,,\n");
329 assert!(matches!(frame, Frame::Notification { kind, .. } if kind == "STATE"));
330 }
331
332 #[test]
333 fn client_env_accumulation() {
334 let mut decoder = FrameDecoder::new();
335 let input = "\
336 >CLIENT:CONNECT,1,2\n\
337 >CLIENT:ENV,common_name=alice\n\
338 >CLIENT:ENV,password=secret\n\
339 >CLIENT:ENV,END\n";
340 let mut buf = BytesMut::from(input);
341
342 let frame = decoder.decode(&mut buf).unwrap().unwrap();
343 match frame {
344 Frame::ClientEnv { event, args, env } => {
345 assert_eq!(event, "CONNECT");
346 assert_eq!(args, "1,2");
347 assert_eq!(env.get("common_name").unwrap(), "alice");
348 assert_eq!(env.get("password").unwrap(), "secret");
349 }
350 other => panic!("expected ClientEnv, got {other:?}"),
351 }
352 }
353
354 #[test]
355 fn client_address_is_single_line() {
356 let frame = decode_one(">CLIENT:ADDRESS,1,10.8.0.6,1\n");
357 assert!(matches!(frame, Frame::Notification { kind, .. } if kind == "CLIENT"));
358 }
359
360 #[test]
361 fn password_prompt_with_newline() {
362 assert_eq!(decode_one("ENTER PASSWORD:\n"), Frame::PasswordPrompt,);
363 }
364
365 #[test]
366 fn password_prompt_without_newline() {
367 let mut decoder = FrameDecoder::new();
368 let mut buf = BytesMut::from("ENTER PASSWORD:");
369 let frame = decoder.decode(&mut buf).unwrap().unwrap();
370 assert_eq!(frame, Frame::PasswordPrompt);
371 }
372
373 #[test]
374 fn empty_lines_emitted_as_line() {
375 let mut decoder = FrameDecoder::new();
376 let mut buf = BytesMut::from("\n\n\nSUCCESS: ok\n");
377 assert_eq!(
380 decoder.decode(&mut buf).unwrap().unwrap(),
381 Frame::Line(String::new())
382 );
383 assert_eq!(
384 decoder.decode(&mut buf).unwrap().unwrap(),
385 Frame::Line(String::new())
386 );
387 assert_eq!(
388 decoder.decode(&mut buf).unwrap().unwrap(),
389 Frame::Line(String::new())
390 );
391 assert_eq!(
392 decoder.decode(&mut buf).unwrap().unwrap(),
393 Frame::Success("ok".to_string())
394 );
395 }
396
397 #[test]
400 fn multi_frame_sequence() {
401 let mut decoder = FrameDecoder::new();
402 let mut buf =
403 BytesMut::from("SUCCESS: pid=42\n>STATE:0,CONNECTING,,,,,,,\nERROR: unknown\nEND\n");
404
405 assert_eq!(
406 decoder.decode(&mut buf).unwrap().unwrap(),
407 Frame::Success("pid=42".to_string())
408 );
409 assert!(matches!(
410 decoder.decode(&mut buf).unwrap().unwrap(),
411 Frame::Notification { ref kind, .. } if kind == "STATE"
412 ));
413 assert_eq!(
414 decoder.decode(&mut buf).unwrap().unwrap(),
415 Frame::Error("unknown".to_string())
416 );
417 assert_eq!(decoder.decode(&mut buf).unwrap().unwrap(), Frame::End);
418 assert_eq!(decoder.decode(&mut buf).unwrap(), None);
419 }
420
421 #[test]
422 fn line_then_end_sequence() {
423 let mut decoder = FrameDecoder::new();
424 let mut buf = BytesMut::from("TITLE\tOpenVPN 2.6\nManagement Version: 5\nEND\n");
425
426 assert_eq!(
427 decoder.decode(&mut buf).unwrap().unwrap(),
428 Frame::Line("TITLE\tOpenVPN 2.6".to_string())
429 );
430 assert_eq!(
431 decoder.decode(&mut buf).unwrap().unwrap(),
432 Frame::Line("Management Version: 5".to_string())
433 );
434 assert_eq!(decoder.decode(&mut buf).unwrap().unwrap(), Frame::End);
435 }
436
437 #[test]
440 fn partial_line_returns_none() {
441 let mut decoder = FrameDecoder::new();
442 let mut buf = BytesMut::from("SUCCESS: pi");
443 assert_eq!(decoder.decode(&mut buf).unwrap(), None);
444
445 buf.extend_from_slice(b"d=42\n");
447 assert_eq!(
448 decoder.decode(&mut buf).unwrap().unwrap(),
449 Frame::Success("pid=42".to_string())
450 );
451 }
452
453 #[test]
454 fn partial_client_env_accumulates_across_calls() {
455 let mut decoder = FrameDecoder::new();
456
457 let mut buf = BytesMut::from(">CLIENT:CONNECT,5,3\n");
458 assert_eq!(decoder.decode(&mut buf).unwrap(), None); buf.extend_from_slice(b">CLIENT:ENV,user=alice\n");
461 assert_eq!(decoder.decode(&mut buf).unwrap(), None); buf.extend_from_slice(b">CLIENT:ENV,END\n");
464 let frame = decoder.decode(&mut buf).unwrap().unwrap();
465 match frame {
466 Frame::ClientEnv { event, args, env } => {
467 assert_eq!(event, "CONNECT");
468 assert_eq!(args, "5,3");
469 assert_eq!(env.len(), 1);
470 assert_eq!(env["user"], "alice");
471 }
472 other => panic!("expected ClientEnv, got {other:?}"),
473 }
474 }
475
476 #[test]
479 fn client_cr_response_starts_accumulation() {
480 let mut decoder = FrameDecoder::new();
481 let mut buf = BytesMut::from(">CLIENT:CR_RESPONSE,10,2,dGVzdA==\n>CLIENT:ENV,END\n");
482 let frame = decoder.decode(&mut buf).unwrap().unwrap();
483 match frame {
484 Frame::ClientEnv { event, args, .. } => {
485 assert_eq!(event, "CR_RESPONSE");
486 assert!(args.contains("10,2,dGVzdA=="));
487 }
488 other => panic!("expected ClientEnv, got {other:?}"),
489 }
490 }
491
492 #[test]
495 fn invalid_utf8_returns_error() {
496 let mut decoder = FrameDecoder::new();
497 let mut buf = BytesMut::from(&b"SUCCESS: \xff\xfe\n"[..]);
498 let err = decoder.decode(&mut buf).unwrap_err();
499 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
500 }
501
502 #[test]
505 fn crlf_line_endings_stripped() {
506 assert_eq!(
507 decode_one("SUCCESS: ok\r\n"),
508 Frame::Success("ok".to_string())
509 );
510 }
511
512 #[test]
515 fn success_bare_no_payload() {
516 assert_eq!(decode_one("SUCCESS:\n"), Frame::Success(String::new()));
517 }
518
519 #[test]
520 fn error_bare_no_payload() {
521 assert_eq!(decode_one("ERROR:\n"), Frame::Error(String::new()));
522 }
523
524 #[test]
527 fn notification_without_colon_emitted_as_line() {
528 let frame = decode_one(">GARBAGE\n");
530 assert_eq!(frame, Frame::Line(">GARBAGE".to_string()));
531 }
532
533 #[test]
536 fn client_env_limit_exceeded() {
537 let mut decoder =
538 FrameDecoder::new().with_max_client_env_entries(crate::AccumulationLimit::Max(2));
539 let mut buf = BytesMut::from(
540 ">CLIENT:CONNECT,1,0\n\
541 >CLIENT:ENV,a=1\n\
542 >CLIENT:ENV,b=2\n\
543 >CLIENT:ENV,c=3\n",
544 );
545
546 let err = loop {
548 match decoder.decode(&mut buf) {
549 Ok(Some(_)) => continue,
550 Ok(None) => continue,
551 Err(e) => break e,
552 }
553 };
554 assert!(err.to_string().contains("limit exceeded"));
555 }
556
557 #[test]
560 fn non_env_line_during_client_accumulation_falls_through() {
561 let mut decoder = FrameDecoder::new();
562 let mut buf =
563 BytesMut::from(">CLIENT:CONNECT,1,0\n>STATE:0,CONNECTING,,,,,,,\n>CLIENT:ENV,END\n");
564
565 let first = decoder.decode(&mut buf).unwrap().unwrap();
568 assert!(matches!(
569 first,
570 Frame::Notification { ref kind, .. } if kind == "STATE"
571 ));
572
573 let second = decoder.decode(&mut buf).unwrap().unwrap();
575 assert!(matches!(second, Frame::ClientEnv { .. }));
576 }
577
578 #[test]
581 fn password_prompt_with_carriage_return() {
582 let mut decoder = FrameDecoder::new();
583 let mut buf = BytesMut::from("ENTER PASSWORD:\r");
584 let frame = decoder.decode(&mut buf).unwrap().unwrap();
585 assert_eq!(frame, Frame::PasswordPrompt);
586 assert!(buf.is_empty()); }
588}