1use std::io::Cursor;
2
3use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event};
4use quick_xml::{Reader, Writer};
5use tokio::io::{AsyncRead, AsyncWrite};
6
7use crate::capabilities;
8use crate::codec::{self, FramingMode};
9
10#[derive(Debug)]
11pub struct ServerHello {
12 pub session_id: u32,
13 pub capabilities: Vec<String>,
14}
15
16pub async fn exchange<S: AsyncRead + AsyncWrite + Unpin>(
17 stream: &mut S,
18 max_message_size: Option<usize>,
19) -> crate::Result<(ServerHello, FramingMode)> {
20 let server_hello_xml = codec::read_eom_message(stream, max_message_size).await?;
21 let server_hello = parse_server_hello(&server_hello_xml)?;
22
23 let client_hello_xml = build_client_hello();
24 codec::write_eom_message(stream, &client_hello_xml).await?;
25
26 let client_caps = capabilities::client_capabilities();
30 let framing_mode = if server_hello
31 .capabilities
32 .iter()
33 .any(|c| c == capabilities::BASE_1_1)
34 && client_caps.contains(&capabilities::BASE_1_1)
35 {
36 FramingMode::Chunked
37 } else {
38 FramingMode::EndOfMessage
39 };
40 Ok((server_hello, framing_mode))
41}
42
43fn build_client_hello() -> String {
46 let mut writer = Writer::new_with_indent(Cursor::new(Vec::new()), b' ', 2);
47
48 writer
49 .write_event(Event::Decl(BytesDecl::new("1.0", Some("UTF-8"), None)))
50 .unwrap();
51
52 let mut hello = BytesStart::new("hello");
53 hello.push_attribute(("xmlns", "urn:ietf:params:xml:ns:netconf:base:1.0"));
54 writer.write_event(Event::Start(hello)).unwrap();
55
56 writer
57 .write_event(Event::Start(BytesStart::new("capabilities")))
58 .unwrap();
59
60 for cap in capabilities::client_capabilities() {
61 writer
62 .create_element("capability")
63 .write_text_content(BytesText::new(cap))
64 .unwrap();
65 }
66
67 writer
68 .write_event(Event::End(BytesEnd::new("capabilities")))
69 .unwrap();
70 writer
71 .write_event(Event::End(BytesEnd::new("hello")))
72 .unwrap();
73 String::from_utf8(writer.into_inner().into_inner()).unwrap()
74}
75
76pub(crate) fn parse_server_hello(xml: &str) -> crate::Result<ServerHello> {
77 let mut reader = Reader::from_str(xml);
78 reader.config_mut().trim_text(true);
79
80 let mut capabilities = Vec::new();
81 let mut session_id: Option<u32> = None;
82
83 let mut in_capability = false;
85 let mut in_session_id = false;
86
87 loop {
88 match reader.read_event() {
89 Ok(Event::Start(e)) | Ok(Event::Empty(e)) => {
90 let local_name = e.local_name();
91 match local_name.as_ref() {
92 b"capability" => in_capability = true,
93 b"session-id" => in_session_id = true,
94 _ => {}
95 }
96 }
97 Ok(Event::Text(e)) => {
98 let text = e
99 .xml_content()
100 .map_err(|err| {
101 crate::error::Error::HelloFailed(format!("XML unescape error: {err}"))
102 })?
103 .to_string();
104
105 if in_capability {
106 if !text.is_empty() {
107 capabilities.push(text);
108 }
109 } else if in_session_id {
110 session_id = Some(text.parse::<u32>().map_err(|e| {
111 crate::Error::HelloFailed(format!("invalid session-id '{text}': {e}"))
112 })?);
113 }
114 }
115 Ok(Event::End(e)) => {
116 let local_name = e.local_name();
117 match local_name.as_ref() {
118 b"capability" => in_capability = false,
119 b"session-id" => in_session_id = false,
120 _ => {}
121 }
122 }
123 Ok(Event::Eof) => break,
124 Err(e) => return Err(crate::Error::HelloFailed(format!("XML parse error: {e}"))),
125 _ => {}
126 }
127 }
128
129 let session_id =
130 session_id.ok_or_else(|| crate::Error::HelloFailed("missing session-id".into()))?;
131
132 if capabilities.is_empty() {
133 return Err(crate::Error::HelloFailed("no capabilities found".into()));
134 }
135 Ok(ServerHello {
136 session_id,
137 capabilities,
138 })
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn test_parse_server_hello() {
147 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
148<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
149 <capabilities>
150 <capability>urn:ietf:params:netconf:base:1.0</capability>
151 <capability>urn:ietf:params:netconf:base:1.1</capability>
152 <capability>urn:ietf:params:netconf:capability:candidate:1.0</capability>
153 </capabilities>
154 <session-id>42</session-id>
155</hello>"#;
156 let hello = parse_server_hello(xml).unwrap();
157 assert_eq!(hello.session_id, 42);
158 assert_eq!(hello.capabilities.len(), 3);
159 assert!(
160 hello
161 .capabilities
162 .contains(&"urn:ietf:params:netconf:base:1.0".to_string())
163 );
164 assert!(
165 hello
166 .capabilities
167 .contains(&"urn:ietf:params:netconf:base:1.1".to_string())
168 );
169 }
170
171 #[test]
172 fn test_parse_hello_1_0_only() {
173 let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
174 <capabilities>
175 <capability>urn:ietf:params:netconf:base:1.0</capability>
176 </capabilities>
177 <session-id>1</session-id>
178</hello>"#;
179 let hello = parse_server_hello(xml).unwrap();
180 assert_eq!(hello.session_id, 1);
181 assert_eq!(hello.capabilities.len(), 1);
182 }
183
184 #[test]
185 fn test_parse_hello_missing_session_id() {
186 let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
187 <capabilities>
188 <capability>urn:ietf:params:netconf:base:1.0</capability>
189 </capabilities>
190</hello>"#;
191 let result = parse_server_hello(xml);
192 assert!(result.is_err());
193 }
194
195 #[test]
196 fn test_parse_hello_no_capabilities() {
197 let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
198 <capabilities/>
199 <session-id>1</session-id>
200</hello>"#;
201 let result = parse_server_hello(xml);
202 assert!(result.is_err());
203 }
204
205 #[test]
206 fn test_build_client_hello() {
207 let hello = build_client_hello();
208 assert!(hello.contains("urn:ietf:params:netconf:base:1.0"));
209 let expected_caps = capabilities::client_capabilities();
211 if expected_caps.contains(&capabilities::BASE_1_1) {
212 assert!(hello.contains("urn:ietf:params:netconf:base:1.1"));
213 }
214 assert!(hello.contains("<hello"));
215 assert!(hello.contains("</hello>"));
216 assert!(hello.contains("<capability>"));
217 assert!(hello.contains("</capability>"));
218 }
219
220 #[test]
221 fn test_build_client_hello_is_valid_xml() {
222 let hello = build_client_hello();
223 let mut reader = Reader::from_str(&hello);
225 reader.config_mut().trim_text(true);
226 let mut caps = Vec::new();
227 let mut in_cap = false;
228 loop {
229 match reader.read_event() {
230 Ok(Event::Start(e)) if e.local_name().as_ref() == b"capability" => in_cap = true,
231 Ok(Event::Text(e)) if in_cap => {
232 caps.push(
233 e.xml_content()
234 .map_err(|e| e.to_string())
235 .unwrap()
236 .to_string(),
237 );
238 }
239 Ok(Event::End(e)) if e.local_name().as_ref() == b"capability" => in_cap = false,
240 Ok(Event::Eof) => break,
241 Err(e) => panic!("invalid XML: {e}"),
242 _ => {}
243 }
244 }
245 assert!(!caps.is_empty());
246 assert!(caps.contains(&"urn:ietf:params:netconf:base:1.0".to_string()));
247 let expected_caps = capabilities::client_capabilities();
248 if expected_caps.contains(&capabilities::BASE_1_1) {
249 assert!(caps.contains(&"urn:ietf:params:netconf:base:1.1".to_string()));
250 }
251 }
252
253 #[tokio::test]
254 async fn test_exchange_negotiates_chunked() {
255 use tokio::io::AsyncWriteExt;
256 use tokio::io::duplex;
257
258 let (mut client, mut server) = duplex(4096);
259
260 let server_hello = r#"<?xml version="1.0" encoding="UTF-8"?>
261<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
262 <capabilities>
263 <capability>urn:ietf:params:netconf:base:1.0</capability>
264 <capability>urn:ietf:params:netconf:base:1.1</capability>
265 </capabilities>
266 <session-id>99</session-id>
267</hello>"#;
268
269 tokio::spawn(async move {
270 use tokio::io::AsyncReadExt;
271 server
272 .write_all(format!("{server_hello}]]>]]>").as_bytes())
273 .await
274 .unwrap();
275 let mut buf = vec![0u8; 4096];
276 let _ = server.read(&mut buf).await;
277 });
278
279 let (hello, mode) = exchange(&mut client, None).await.unwrap();
280 assert_eq!(hello.session_id, 99);
281 let client_caps = capabilities::client_capabilities();
283 if client_caps.contains(&capabilities::BASE_1_1) {
284 assert_eq!(mode, FramingMode::Chunked);
285 } else {
286 assert_eq!(mode, FramingMode::EndOfMessage);
287 }
288 }
289
290 #[tokio::test]
291 async fn test_exchange_negotiates_eom() {
292 use tokio::io::AsyncWriteExt;
293 use tokio::io::duplex;
294
295 let (mut client, mut server) = duplex(4096);
296
297 let server_hello = r#"<?xml version="1.0" encoding="UTF-8"?>
298<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
299 <capabilities>
300 <capability>urn:ietf:params:netconf:base:1.0</capability>
301 </capabilities>
302 <session-id>10</session-id>
303</hello>"#;
304
305 tokio::spawn(async move {
306 use tokio::io::AsyncReadExt;
307 server
308 .write_all(format!("{server_hello}]]>]]>").as_bytes())
309 .await
310 .unwrap();
311 let mut buf = vec![0u8; 4096];
312 let _ = server.read(&mut buf).await;
313 });
314
315 let (hello, mode) = exchange(&mut client, None).await.unwrap();
316 assert_eq!(hello.session_id, 10);
317 assert_eq!(mode, FramingMode::EndOfMessage);
318 }
319}