openvpn_mgmt_frame/encoder.rs
1//! Encoder primitives for the OpenVPN management wire format.
2//!
3//! These are pure functions with no state — they serialize strings and
4//! blocks into a `BytesMut` buffer.
5
6use std::borrow::Cow;
7use std::io;
8
9use bytes::{BufMut, BytesMut};
10
11/// Characters that are unsafe in the line-oriented management protocol:
12/// `\n` and `\r` split commands; `\0` truncates at the C layer.
13pub const WIRE_UNSAFE: &[char] = &['\n', '\r', '\0'];
14
15/// Controls how the encoder handles characters that are unsafe for the
16/// line-oriented management protocol (`\n`, `\r`, `\0`).
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum EncoderMode {
20 /// Silently strip unsafe characters (default, defensive).
21 ///
22 /// `\n`, `\r`, and `\0` are removed from all user-supplied strings.
23 /// Block body lines equaling `"END"` are escaped to `" END"`.
24 #[default]
25 Sanitize,
26
27 /// Reject inputs containing unsafe characters with an error.
28 ///
29 /// Returns `Err(io::Error)` if any field contains `\n`, `\r`, or
30 /// `\0`, or if a block body line equals `"END"`. The inner error
31 /// can be downcast to [`EncodeError`].
32 Strict,
33}
34
35/// Structured error for encoder-side validation failures.
36#[derive(Debug, thiserror::Error)]
37pub enum EncodeError {
38 /// A field contains `\n`, `\r`, or `\0`.
39 #[error("{0} contains characters unsafe for the management protocol (\\n, \\r, or \\0)")]
40 UnsafeCharacters(&'static str),
41
42 /// A multi-line block body line equals `"END"`.
43 #[error("block body line equals \"END\", which would terminate the block early")]
44 EndInBlockBody,
45}
46
47/// Ensure a string is safe for the wire protocol.
48///
49/// In [`EncoderMode::Sanitize`]: strips `\n`, `\r`, and `\0`, returning
50/// the cleaned string (or borrowing the original if already clean).
51///
52/// In [`EncoderMode::Strict`]: returns `Err` if any unsafe characters
53/// are present.
54///
55/// ```
56/// use std::borrow::Cow;
57/// use openvpn_mgmt_frame::{wire_safe, EncoderMode};
58///
59/// // Clean input borrows the original.
60/// let clean = wire_safe("hello", "field", EncoderMode::Sanitize).unwrap();
61/// assert!(matches!(clean, Cow::Borrowed("hello")));
62///
63/// // Sanitize mode strips unsafe characters.
64/// let sanitized = wire_safe("line\none", "field", EncoderMode::Sanitize).unwrap();
65/// assert_eq!(&*sanitized, "lineone");
66///
67/// // Strict mode rejects unsafe characters.
68/// assert!(wire_safe("line\none", "field", EncoderMode::Strict).is_err());
69/// ```
70pub fn wire_safe<'a>(
71 s: &'a str,
72 field: &'static str,
73 mode: EncoderMode,
74) -> Result<Cow<'a, str>, io::Error> {
75 if !s.contains(WIRE_UNSAFE) {
76 return Ok(Cow::Borrowed(s));
77 }
78 match mode {
79 EncoderMode::Sanitize => Ok(Cow::Owned(
80 s.chars().filter(|chr| !WIRE_UNSAFE.contains(chr)).collect(),
81 )),
82 EncoderMode::Strict => Err(io::Error::other(EncodeError::UnsafeCharacters(field))),
83 }
84}
85
86/// Backslash-escape `\` and `"` per the OpenVPN config-file lexer rules.
87///
88/// ```
89/// use openvpn_mgmt_frame::escape;
90///
91/// assert_eq!(escape("hello"), "hello");
92/// assert_eq!(escape(r#"pass"word"#), r#"pass\"word"#);
93/// assert_eq!(escape(r"back\slash"), r"back\\slash");
94/// ```
95pub fn escape(s: &str) -> String {
96 let mut out = String::with_capacity(s.len());
97 for c in s.chars() {
98 match c {
99 '\\' => out.push_str("\\\\"),
100 '"' => out.push_str("\\\""),
101 _ => out.push(c),
102 }
103 }
104 out
105}
106
107/// Wrap an already-escaped string in double quotes.
108///
109/// ```
110/// use openvpn_mgmt_frame::{escape, quote};
111///
112/// let escaped = escape(r#"my "password""#);
113/// assert_eq!(quote(&escaped), r#""my \"password\"""#);
114/// ```
115pub fn quote(s: &str) -> String {
116 format!("\"{s}\"")
117}
118
119/// Write a single line followed by `\n`.
120///
121/// ```
122/// use bytes::BytesMut;
123/// use openvpn_mgmt_frame::write_line;
124///
125/// let mut buf = BytesMut::new();
126/// write_line(&mut buf, "hold release");
127/// assert_eq!(&buf[..], b"hold release\n");
128/// ```
129pub fn write_line(dst: &mut BytesMut, s: &str) {
130 dst.reserve(s.len() + 1);
131 dst.put_slice(s.as_bytes());
132 dst.put_u8(b'\n');
133}
134
135/// Write a multi-line block: header line, body lines, and a terminating
136/// `END`.
137///
138/// In [`EncoderMode::Sanitize`] mode, body lines have `\n`, `\r`, and
139/// `\0` stripped, and any line that would be exactly `"END"` is escaped
140/// to `" END"`.
141///
142/// In [`EncoderMode::Strict`] mode, body lines containing unsafe
143/// characters or equaling `"END"` cause an error.
144///
145/// ```
146/// use bytes::BytesMut;
147/// use openvpn_mgmt_frame::{write_block, EncoderMode};
148///
149/// let mut buf = BytesMut::new();
150/// let body = vec!["push \"route 10.0.0.0 255.0.0.0\"".to_string()];
151/// write_block(&mut buf, "client-auth 5 7", &body, EncoderMode::Sanitize).unwrap();
152/// assert_eq!(
153/// &buf[..],
154/// b"client-auth 5 7\npush \"route 10.0.0.0 255.0.0.0\"\nEND\n",
155/// );
156/// ```
157pub fn write_block(
158 dst: &mut BytesMut,
159 header: &str,
160 lines: &[String],
161 mode: EncoderMode,
162) -> Result<(), io::Error> {
163 let total: usize =
164 header.len() + 1 + lines.iter().map(|line| line.len() + 2).sum::<usize>() + 4;
165 dst.reserve(total);
166 dst.put_slice(header.as_bytes());
167 dst.put_u8(b'\n');
168 for line in lines {
169 let clean = wire_safe(line, "block body line", mode)?;
170 if *clean == *"END" {
171 match mode {
172 EncoderMode::Sanitize => {
173 dst.put_slice(b" END");
174 dst.put_u8(b'\n');
175 continue;
176 }
177 EncoderMode::Strict => {
178 return Err(io::Error::other(EncodeError::EndInBlockBody));
179 }
180 }
181 }
182 dst.put_slice(clean.as_bytes());
183 dst.put_u8(b'\n');
184 }
185 dst.put_slice(b"END\n");
186 Ok(())
187}
188
189/// Controls how many items the decoder will accumulate before returning
190/// an error.
191#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
192#[derive(Debug, Clone, Copy, PartialEq, Eq)]
193pub enum AccumulationLimit {
194 /// No limit.
195 Unlimited,
196
197 /// At most this many items.
198 Max(usize),
199}