Skip to main content

netconf_rust/
hello.rs

1use std::io::Cursor;
2
3use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event};
4use quick_xml::{Reader, Writer};
5use tokio::io::{AsyncRead, AsyncWrite};
6
7use crate::capabilities;
8use crate::codec::{self, FramingMode};
9
10#[derive(Debug)]
11pub struct ServerHello {
12    pub session_id: u32,
13    pub capabilities: Vec<String>,
14}
15
16pub async fn exchange<S: AsyncRead + AsyncWrite + Unpin>(
17    stream: &mut S,
18    max_message_size: Option<usize>,
19) -> crate::Result<(ServerHello, FramingMode)> {
20    let server_hello_xml = codec::read_eom_message(stream, max_message_size).await?;
21    let server_hello = parse_server_hello(&server_hello_xml)?;
22
23    let client_hello_xml = build_client_hello();
24    codec::write_eom_message(stream, &client_hello_xml).await?;
25
26    // RFC 6241 ยง8.1: chunked framing is used only when BOTH sides
27    // advertise base:1.1. When the client restricts itself to base:1.0
28    // (e.g. via NETCONF_BASE_VERSION=1.0), we must fall back to EOM.
29    let client_caps = capabilities::client_capabilities();
30    let framing_mode = if server_hello
31        .capabilities
32        .iter()
33        .any(|c| c == capabilities::BASE_1_1)
34        && client_caps.contains(&capabilities::BASE_1_1)
35    {
36        FramingMode::Chunked
37    } else {
38        FramingMode::EndOfMessage
39    };
40    Ok((server_hello, framing_mode))
41}
42
43// Definitely would have been way easier with basic string stuff but i wanted to try out the API
44// and this basically stays the same anyways.
45fn build_client_hello() -> String {
46    let mut writer = Writer::new_with_indent(Cursor::new(Vec::new()), b' ', 2);
47
48    writer
49        .write_event(Event::Decl(BytesDecl::new("1.0", Some("UTF-8"), None)))
50        .unwrap();
51
52    let mut hello = BytesStart::new("hello");
53    hello.push_attribute(("xmlns", "urn:ietf:params:xml:ns:netconf:base:1.0"));
54    writer.write_event(Event::Start(hello)).unwrap();
55
56    writer
57        .write_event(Event::Start(BytesStart::new("capabilities")))
58        .unwrap();
59
60    for cap in capabilities::client_capabilities() {
61        writer
62            .create_element("capability")
63            .write_text_content(BytesText::new(cap))
64            .unwrap();
65    }
66
67    writer
68        .write_event(Event::End(BytesEnd::new("capabilities")))
69        .unwrap();
70    writer
71        .write_event(Event::End(BytesEnd::new("hello")))
72        .unwrap();
73    String::from_utf8(writer.into_inner().into_inner()).unwrap()
74}
75
76pub(crate) fn parse_server_hello(xml: &str) -> crate::Result<ServerHello> {
77    let mut reader = Reader::from_str(xml);
78    reader.config_mut().trim_text(true);
79
80    let mut capabilities = Vec::new();
81    let mut session_id: Option<u32> = None;
82
83    // use these to track with section of the response we are in
84    let mut in_capability = false;
85    let mut in_session_id = false;
86
87    loop {
88        match reader.read_event() {
89            Ok(Event::Start(e)) | Ok(Event::Empty(e)) => {
90                let local_name = e.local_name();
91                match local_name.as_ref() {
92                    b"capability" => in_capability = true,
93                    b"session-id" => in_session_id = true,
94                    _ => {}
95                }
96            }
97            Ok(Event::Text(e)) => {
98                let text = e
99                    .xml_content()
100                    .map_err(|err| {
101                        crate::error::Error::HelloFailed(format!("XML unescape error: {err}"))
102                    })?
103                    .to_string();
104
105                if in_capability {
106                    if !text.is_empty() {
107                        capabilities.push(text);
108                    }
109                } else if in_session_id {
110                    session_id = Some(text.parse::<u32>().map_err(|e| {
111                        crate::Error::HelloFailed(format!("invalid session-id '{text}': {e}"))
112                    })?);
113                }
114            }
115            Ok(Event::End(e)) => {
116                let local_name = e.local_name();
117                match local_name.as_ref() {
118                    b"capability" => in_capability = false,
119                    b"session-id" => in_session_id = false,
120                    _ => {}
121                }
122            }
123            Ok(Event::Eof) => break,
124            Err(e) => return Err(crate::Error::HelloFailed(format!("XML parse error: {e}"))),
125            _ => {}
126        }
127    }
128
129    let session_id =
130        session_id.ok_or_else(|| crate::Error::HelloFailed("missing session-id".into()))?;
131
132    if capabilities.is_empty() {
133        return Err(crate::Error::HelloFailed("no capabilities found".into()));
134    }
135    Ok(ServerHello {
136        session_id,
137        capabilities,
138    })
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_parse_server_hello() {
147        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
148<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
149  <capabilities>
150    <capability>urn:ietf:params:netconf:base:1.0</capability>
151    <capability>urn:ietf:params:netconf:base:1.1</capability>
152    <capability>urn:ietf:params:netconf:capability:candidate:1.0</capability>
153  </capabilities>
154  <session-id>42</session-id>
155</hello>"#;
156        let hello = parse_server_hello(xml).unwrap();
157        assert_eq!(hello.session_id, 42);
158        assert_eq!(hello.capabilities.len(), 3);
159        assert!(
160            hello
161                .capabilities
162                .contains(&"urn:ietf:params:netconf:base:1.0".to_string())
163        );
164        assert!(
165            hello
166                .capabilities
167                .contains(&"urn:ietf:params:netconf:base:1.1".to_string())
168        );
169    }
170
171    #[test]
172    fn test_parse_hello_1_0_only() {
173        let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
174  <capabilities>
175    <capability>urn:ietf:params:netconf:base:1.0</capability>
176  </capabilities>
177  <session-id>1</session-id>
178</hello>"#;
179        let hello = parse_server_hello(xml).unwrap();
180        assert_eq!(hello.session_id, 1);
181        assert_eq!(hello.capabilities.len(), 1);
182    }
183
184    #[test]
185    fn test_parse_hello_missing_session_id() {
186        let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
187  <capabilities>
188    <capability>urn:ietf:params:netconf:base:1.0</capability>
189  </capabilities>
190</hello>"#;
191        let result = parse_server_hello(xml);
192        assert!(result.is_err());
193    }
194
195    #[test]
196    fn test_parse_hello_no_capabilities() {
197        let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
198  <capabilities/>
199  <session-id>1</session-id>
200</hello>"#;
201        let result = parse_server_hello(xml);
202        assert!(result.is_err());
203    }
204
205    #[test]
206    fn test_build_client_hello() {
207        let hello = build_client_hello();
208        assert!(hello.contains("urn:ietf:params:netconf:base:1.0"));
209        // base:1.1 is included unless NETCONF_BASE_VERSION=1.0
210        let expected_caps = capabilities::client_capabilities();
211        if expected_caps.contains(&capabilities::BASE_1_1) {
212            assert!(hello.contains("urn:ietf:params:netconf:base:1.1"));
213        }
214        assert!(hello.contains("<hello"));
215        assert!(hello.contains("</hello>"));
216        assert!(hello.contains("<capability>"));
217        assert!(hello.contains("</capability>"));
218    }
219
220    #[test]
221    fn test_build_client_hello_is_valid_xml() {
222        let hello = build_client_hello();
223        // Verify it parses without error
224        let mut reader = Reader::from_str(&hello);
225        reader.config_mut().trim_text(true);
226        let mut caps = Vec::new();
227        let mut in_cap = false;
228        loop {
229            match reader.read_event() {
230                Ok(Event::Start(e)) if e.local_name().as_ref() == b"capability" => in_cap = true,
231                Ok(Event::Text(e)) if in_cap => {
232                    caps.push(
233                        e.xml_content()
234                            .map_err(|e| e.to_string())
235                            .unwrap()
236                            .to_string(),
237                    );
238                }
239                Ok(Event::End(e)) if e.local_name().as_ref() == b"capability" => in_cap = false,
240                Ok(Event::Eof) => break,
241                Err(e) => panic!("invalid XML: {e}"),
242                _ => {}
243            }
244        }
245        assert!(!caps.is_empty());
246        assert!(caps.contains(&"urn:ietf:params:netconf:base:1.0".to_string()));
247        let expected_caps = capabilities::client_capabilities();
248        if expected_caps.contains(&capabilities::BASE_1_1) {
249            assert!(caps.contains(&"urn:ietf:params:netconf:base:1.1".to_string()));
250        }
251    }
252
253    #[tokio::test]
254    async fn test_exchange_negotiates_chunked() {
255        use tokio::io::AsyncWriteExt;
256        use tokio::io::duplex;
257
258        let (mut client, mut server) = duplex(4096);
259
260        let server_hello = r#"<?xml version="1.0" encoding="UTF-8"?>
261<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
262  <capabilities>
263    <capability>urn:ietf:params:netconf:base:1.0</capability>
264    <capability>urn:ietf:params:netconf:base:1.1</capability>
265  </capabilities>
266  <session-id>99</session-id>
267</hello>"#;
268
269        tokio::spawn(async move {
270            use tokio::io::AsyncReadExt;
271            server
272                .write_all(format!("{server_hello}]]>]]>").as_bytes())
273                .await
274                .unwrap();
275            let mut buf = vec![0u8; 4096];
276            let _ = server.read(&mut buf).await;
277        });
278
279        let (hello, mode) = exchange(&mut client, None).await.unwrap();
280        assert_eq!(hello.session_id, 99);
281        // Chunked only if the client also advertises base:1.1
282        let client_caps = capabilities::client_capabilities();
283        if client_caps.contains(&capabilities::BASE_1_1) {
284            assert_eq!(mode, FramingMode::Chunked);
285        } else {
286            assert_eq!(mode, FramingMode::EndOfMessage);
287        }
288    }
289
290    #[tokio::test]
291    async fn test_exchange_negotiates_eom() {
292        use tokio::io::AsyncWriteExt;
293        use tokio::io::duplex;
294
295        let (mut client, mut server) = duplex(4096);
296
297        let server_hello = r#"<?xml version="1.0" encoding="UTF-8"?>
298<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
299  <capabilities>
300    <capability>urn:ietf:params:netconf:base:1.0</capability>
301  </capabilities>
302  <session-id>10</session-id>
303</hello>"#;
304
305        tokio::spawn(async move {
306            use tokio::io::AsyncReadExt;
307            server
308                .write_all(format!("{server_hello}]]>]]>").as_bytes())
309                .await
310                .unwrap();
311            let mut buf = vec![0u8; 4096];
312            let _ = server.read(&mut buf).await;
313        });
314
315        let (hello, mode) = exchange(&mut client, None).await.unwrap();
316        assert_eq!(hello.session_id, 10);
317        assert_eq!(mode, FramingMode::EndOfMessage);
318    }
319}