Skip to main content

openvpn_mgmt_frame/
encoder.rs

1//! Encoder primitives for the OpenVPN management wire format.
2//!
3//! These are pure functions with no state — they serialize strings and
4//! blocks into a `BytesMut` buffer.
5
6use std::borrow::Cow;
7use std::io;
8
9use bytes::{BufMut, BytesMut};
10
11/// Characters that are unsafe in the line-oriented management protocol:
12/// `\n` and `\r` split commands; `\0` truncates at the C layer.
13pub const WIRE_UNSAFE: &[char] = &['\n', '\r', '\0'];
14
15/// Controls how the encoder handles characters that are unsafe for the
16/// line-oriented management protocol (`\n`, `\r`, `\0`).
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
19pub enum EncoderMode {
20    /// Silently strip unsafe characters (default, defensive).
21    ///
22    /// `\n`, `\r`, and `\0` are removed from all user-supplied strings.
23    /// Block body lines equaling `"END"` are escaped to `" END"`.
24    #[default]
25    Sanitize,
26
27    /// Reject inputs containing unsafe characters with an error.
28    ///
29    /// Returns `Err(io::Error)` if any field contains `\n`, `\r`, or
30    /// `\0`, or if a block body line equals `"END"`. The inner error
31    /// can be downcast to [`EncodeError`].
32    Strict,
33}
34
35/// Structured error for encoder-side validation failures.
36#[derive(Debug, thiserror::Error)]
37pub enum EncodeError {
38    /// A field contains `\n`, `\r`, or `\0`.
39    #[error("{0} contains characters unsafe for the management protocol (\\n, \\r, or \\0)")]
40    UnsafeCharacters(&'static str),
41
42    /// A multi-line block body line equals `"END"`.
43    #[error("block body line equals \"END\", which would terminate the block early")]
44    EndInBlockBody,
45}
46
47/// Ensure a string is safe for the wire protocol.
48///
49/// In [`EncoderMode::Sanitize`]: strips `\n`, `\r`, and `\0`, returning
50/// the cleaned string (or borrowing the original if already clean).
51///
52/// In [`EncoderMode::Strict`]: returns `Err` if any unsafe characters
53/// are present.
54///
55/// ```
56/// use std::borrow::Cow;
57/// use openvpn_mgmt_frame::{wire_safe, EncoderMode};
58///
59/// // Clean input borrows the original.
60/// let clean = wire_safe("hello", "field", EncoderMode::Sanitize).unwrap();
61/// assert!(matches!(clean, Cow::Borrowed("hello")));
62///
63/// // Sanitize mode strips unsafe characters.
64/// let sanitized = wire_safe("line\none", "field", EncoderMode::Sanitize).unwrap();
65/// assert_eq!(&*sanitized, "lineone");
66///
67/// // Strict mode rejects unsafe characters.
68/// assert!(wire_safe("line\none", "field", EncoderMode::Strict).is_err());
69/// ```
70pub fn wire_safe<'a>(
71    s: &'a str,
72    field: &'static str,
73    mode: EncoderMode,
74) -> Result<Cow<'a, str>, io::Error> {
75    if !s.contains(WIRE_UNSAFE) {
76        return Ok(Cow::Borrowed(s));
77    }
78    match mode {
79        EncoderMode::Sanitize => Ok(Cow::Owned(
80            s.chars().filter(|chr| !WIRE_UNSAFE.contains(chr)).collect(),
81        )),
82        EncoderMode::Strict => Err(io::Error::other(EncodeError::UnsafeCharacters(field))),
83    }
84}
85
86/// Backslash-escape `\` and `"` per the OpenVPN config-file lexer rules.
87///
88/// ```
89/// use openvpn_mgmt_frame::escape;
90///
91/// assert_eq!(escape("hello"), "hello");
92/// assert_eq!(escape(r#"pass"word"#), r#"pass\"word"#);
93/// assert_eq!(escape(r"back\slash"), r"back\\slash");
94/// ```
95pub fn escape(s: &str) -> String {
96    let mut out = String::with_capacity(s.len());
97    for c in s.chars() {
98        match c {
99            '\\' => out.push_str("\\\\"),
100            '"' => out.push_str("\\\""),
101            _ => out.push(c),
102        }
103    }
104    out
105}
106
107/// Wrap an already-escaped string in double quotes.
108///
109/// ```
110/// use openvpn_mgmt_frame::{escape, quote};
111///
112/// let escaped = escape(r#"my "password""#);
113/// assert_eq!(quote(&escaped), r#""my \"password\"""#);
114/// ```
115pub fn quote(s: &str) -> String {
116    format!("\"{s}\"")
117}
118
119/// Write a single line followed by `\n`.
120///
121/// ```
122/// use bytes::BytesMut;
123/// use openvpn_mgmt_frame::write_line;
124///
125/// let mut buf = BytesMut::new();
126/// write_line(&mut buf, "hold release");
127/// assert_eq!(&buf[..], b"hold release\n");
128/// ```
129pub fn write_line(dst: &mut BytesMut, s: &str) {
130    dst.reserve(s.len() + 1);
131    dst.put_slice(s.as_bytes());
132    dst.put_u8(b'\n');
133}
134
135/// Write a multi-line block: header line, body lines, and a terminating
136/// `END`.
137///
138/// In [`EncoderMode::Sanitize`] mode, body lines have `\n`, `\r`, and
139/// `\0` stripped, and any line that would be exactly `"END"` is escaped
140/// to `" END"`.
141///
142/// In [`EncoderMode::Strict`] mode, body lines containing unsafe
143/// characters or equaling `"END"` cause an error.
144///
145/// ```
146/// use bytes::BytesMut;
147/// use openvpn_mgmt_frame::{write_block, EncoderMode};
148///
149/// let mut buf = BytesMut::new();
150/// let body = vec!["push \"route 10.0.0.0 255.0.0.0\"".to_string()];
151/// write_block(&mut buf, "client-auth 5 7", &body, EncoderMode::Sanitize).unwrap();
152/// assert_eq!(
153///     &buf[..],
154///     b"client-auth 5 7\npush \"route 10.0.0.0 255.0.0.0\"\nEND\n",
155/// );
156/// ```
157pub fn write_block(
158    dst: &mut BytesMut,
159    header: &str,
160    lines: &[String],
161    mode: EncoderMode,
162) -> Result<(), io::Error> {
163    let total: usize =
164        header.len() + 1 + lines.iter().map(|line| line.len() + 2).sum::<usize>() + 4;
165    dst.reserve(total);
166    dst.put_slice(header.as_bytes());
167    dst.put_u8(b'\n');
168    for line in lines {
169        let clean = wire_safe(line, "block body line", mode)?;
170        if *clean == *"END" {
171            match mode {
172                EncoderMode::Sanitize => {
173                    dst.put_slice(b" END");
174                    dst.put_u8(b'\n');
175                    continue;
176                }
177                EncoderMode::Strict => {
178                    return Err(io::Error::other(EncodeError::EndInBlockBody));
179                }
180            }
181        }
182        dst.put_slice(clean.as_bytes());
183        dst.put_u8(b'\n');
184    }
185    dst.put_slice(b"END\n");
186    Ok(())
187}
188
189/// Controls how many items the decoder will accumulate before returning
190/// an error.
191#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
192#[derive(Debug, Clone, Copy, PartialEq, Eq)]
193pub enum AccumulationLimit {
194    /// No limit.
195    Unlimited,
196
197    /// At most this many items.
198    Max(usize),
199}