Skip to main content

netconf_rust/
hello.rs

1use std::io::Cursor;
2
3use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event};
4use quick_xml::{Reader, Writer};
5use tokio::io::{AsyncRead, AsyncWrite};
6
7use crate::capabilities;
8use crate::codec::{self, FramingMode};
9
10#[derive(Debug)]
11pub struct ServerHello {
12    pub session_id: u32,
13    pub capabilities: Vec<String>,
14}
15
16pub async fn exchange<S: AsyncRead + AsyncWrite + Unpin>(
17    stream: &mut S,
18    max_message_size: Option<usize>,
19) -> crate::Result<(ServerHello, FramingMode)> {
20    let server_hello_xml = codec::read_eom_message(stream, max_message_size).await?;
21    let server_hello = parse_server_hello(&server_hello_xml)?;
22
23    let client_hello_xml = build_client_hello();
24    codec::write_eom_message(stream, &client_hello_xml).await?;
25
26    let framing_mode = if server_hello
27        .capabilities
28        .iter()
29        .any(|c| c == capabilities::BASE_1_1)
30    {
31        FramingMode::Chunked
32    } else {
33        FramingMode::EndOfMessage
34    };
35    Ok((server_hello, framing_mode))
36}
37
38// Definitely would have been way easier with basic string stuff but i wanted to try out the API
39// and this basically stays the same anyways.
40fn build_client_hello() -> String {
41    let mut writer = Writer::new_with_indent(Cursor::new(Vec::new()), b' ', 2);
42
43    writer
44        .write_event(Event::Decl(BytesDecl::new("1.0", Some("UTF-8"), None)))
45        .unwrap();
46
47    let mut hello = BytesStart::new("hello");
48    hello.push_attribute(("xmlns", "urn:ietf:params:xml:ns:netconf:base:1.0"));
49    writer.write_event(Event::Start(hello)).unwrap();
50
51    writer
52        .write_event(Event::Start(BytesStart::new("capabilities")))
53        .unwrap();
54
55    for cap in capabilities::client_capabilities() {
56        writer
57            .create_element("capability")
58            .write_text_content(BytesText::new(cap))
59            .unwrap();
60    }
61
62    writer
63        .write_event(Event::End(BytesEnd::new("capabilities")))
64        .unwrap();
65    writer
66        .write_event(Event::End(BytesEnd::new("hello")))
67        .unwrap();
68    String::from_utf8(writer.into_inner().into_inner()).unwrap()
69}
70
71pub(crate) fn parse_server_hello(xml: &str) -> crate::Result<ServerHello> {
72    let mut reader = Reader::from_str(xml);
73    reader.config_mut().trim_text(true);
74
75    let mut capabilities = Vec::new();
76    let mut session_id: Option<u32> = None;
77
78    // use these to track with section of the response we are in
79    let mut in_capability = false;
80    let mut in_session_id = false;
81
82    loop {
83        match reader.read_event() {
84            Ok(Event::Start(e)) | Ok(Event::Empty(e)) => {
85                let local_name = e.local_name();
86                match local_name.as_ref() {
87                    b"capability" => in_capability = true,
88                    b"session-id" => in_session_id = true,
89                    _ => {}
90                }
91            }
92            Ok(Event::Text(e)) => {
93                let text = e
94                    .xml_content()
95                    .map_err(|err| {
96                        crate::error::Error::HelloFailed(format!("XML unescape error: {err}"))
97                    })?
98                    .to_string();
99
100                if in_capability {
101                    if !text.is_empty() {
102                        capabilities.push(text);
103                    }
104                } else if in_session_id {
105                    session_id = Some(text.parse::<u32>().map_err(|e| {
106                        crate::Error::HelloFailed(format!("invalid session-id '{text}': {e}"))
107                    })?);
108                }
109            }
110            Ok(Event::End(e)) => {
111                let local_name = e.local_name();
112                match local_name.as_ref() {
113                    b"capability" => in_capability = false,
114                    b"session-id" => in_session_id = false,
115                    _ => {}
116                }
117            }
118            Ok(Event::Eof) => break,
119            Err(e) => return Err(crate::Error::HelloFailed(format!("XML parse error: {e}"))),
120            _ => {}
121        }
122    }
123
124    let session_id =
125        session_id.ok_or_else(|| crate::Error::HelloFailed("missing session-id".into()))?;
126
127    if capabilities.is_empty() {
128        return Err(crate::Error::HelloFailed("no capabilities found".into()));
129    }
130    Ok(ServerHello {
131        session_id,
132        capabilities,
133    })
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn test_parse_server_hello() {
142        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
143<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
144  <capabilities>
145    <capability>urn:ietf:params:netconf:base:1.0</capability>
146    <capability>urn:ietf:params:netconf:base:1.1</capability>
147    <capability>urn:ietf:params:netconf:capability:candidate:1.0</capability>
148  </capabilities>
149  <session-id>42</session-id>
150</hello>"#;
151        let hello = parse_server_hello(xml).unwrap();
152        assert_eq!(hello.session_id, 42);
153        assert_eq!(hello.capabilities.len(), 3);
154        assert!(
155            hello
156                .capabilities
157                .contains(&"urn:ietf:params:netconf:base:1.0".to_string())
158        );
159        assert!(
160            hello
161                .capabilities
162                .contains(&"urn:ietf:params:netconf:base:1.1".to_string())
163        );
164    }
165
166    #[test]
167    fn test_parse_hello_1_0_only() {
168        let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
169  <capabilities>
170    <capability>urn:ietf:params:netconf:base:1.0</capability>
171  </capabilities>
172  <session-id>1</session-id>
173</hello>"#;
174        let hello = parse_server_hello(xml).unwrap();
175        assert_eq!(hello.session_id, 1);
176        assert_eq!(hello.capabilities.len(), 1);
177    }
178
179    #[test]
180    fn test_parse_hello_missing_session_id() {
181        let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
182  <capabilities>
183    <capability>urn:ietf:params:netconf:base:1.0</capability>
184  </capabilities>
185</hello>"#;
186        let result = parse_server_hello(xml);
187        assert!(result.is_err());
188    }
189
190    #[test]
191    fn test_parse_hello_no_capabilities() {
192        let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
193  <capabilities/>
194  <session-id>1</session-id>
195</hello>"#;
196        let result = parse_server_hello(xml);
197        assert!(result.is_err());
198    }
199
200    #[test]
201    fn test_build_client_hello() {
202        let hello = build_client_hello();
203        assert!(hello.contains("urn:ietf:params:netconf:base:1.0"));
204        assert!(hello.contains("urn:ietf:params:netconf:base:1.1"));
205        assert!(hello.contains("<hello"));
206        assert!(hello.contains("</hello>"));
207        assert!(hello.contains("<capability>"));
208        assert!(hello.contains("</capability>"));
209    }
210
211    #[test]
212    fn test_build_client_hello_is_valid_xml() {
213        let hello = build_client_hello();
214        // Verify it parses without error
215        let mut reader = Reader::from_str(&hello);
216        reader.config_mut().trim_text(true);
217        let mut caps = Vec::new();
218        let mut in_cap = false;
219        loop {
220            match reader.read_event() {
221                Ok(Event::Start(e)) if e.local_name().as_ref() == b"capability" => in_cap = true,
222                Ok(Event::Text(e)) if in_cap => {
223                    caps.push(
224                        e.xml_content()
225                            .map_err(|e| e.to_string())
226                            .unwrap()
227                            .to_string(),
228                    );
229                }
230                Ok(Event::End(e)) if e.local_name().as_ref() == b"capability" => in_cap = false,
231                Ok(Event::Eof) => break,
232                Err(e) => panic!("invalid XML: {e}"),
233                _ => {}
234            }
235        }
236        assert!(!caps.is_empty());
237        assert!(caps.contains(&"urn:ietf:params:netconf:base:1.0".to_string()));
238        assert!(caps.contains(&"urn:ietf:params:netconf:base:1.1".to_string()));
239    }
240
241    #[tokio::test]
242    async fn test_exchange_negotiates_chunked() {
243        use tokio::io::AsyncWriteExt;
244        use tokio::io::duplex;
245
246        let (mut client, mut server) = duplex(4096);
247
248        let server_hello = r#"<?xml version="1.0" encoding="UTF-8"?>
249<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
250  <capabilities>
251    <capability>urn:ietf:params:netconf:base:1.0</capability>
252    <capability>urn:ietf:params:netconf:base:1.1</capability>
253  </capabilities>
254  <session-id>99</session-id>
255</hello>"#;
256
257        tokio::spawn(async move {
258            use tokio::io::AsyncReadExt;
259            server
260                .write_all(format!("{server_hello}]]>]]>").as_bytes())
261                .await
262                .unwrap();
263            let mut buf = vec![0u8; 4096];
264            let _ = server.read(&mut buf).await;
265        });
266
267        let (hello, mode) = exchange(&mut client, None).await.unwrap();
268        assert_eq!(hello.session_id, 99);
269        assert_eq!(mode, FramingMode::Chunked);
270    }
271
272    #[tokio::test]
273    async fn test_exchange_negotiates_eom() {
274        use tokio::io::AsyncWriteExt;
275        use tokio::io::duplex;
276
277        let (mut client, mut server) = duplex(4096);
278
279        let server_hello = r#"<?xml version="1.0" encoding="UTF-8"?>
280<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
281  <capabilities>
282    <capability>urn:ietf:params:netconf:base:1.0</capability>
283  </capabilities>
284  <session-id>10</session-id>
285</hello>"#;
286
287        tokio::spawn(async move {
288            use tokio::io::AsyncReadExt;
289            server
290                .write_all(format!("{server_hello}]]>]]>").as_bytes())
291                .await
292                .unwrap();
293            let mut buf = vec![0u8; 4096];
294            let _ = server.read(&mut buf).await;
295        });
296
297        let (hello, mode) = exchange(&mut client, None).await.unwrap();
298        assert_eq!(hello.session_id, 10);
299        assert_eq!(mode, FramingMode::EndOfMessage);
300    }
301}