1use std::io::Cursor;
2
3use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event};
4use quick_xml::{Reader, Writer};
5use tokio::io::{AsyncRead, AsyncWrite};
6
7use crate::capabilities;
8use crate::codec::{self, FramingMode};
9
10#[derive(Debug)]
11pub struct ServerHello {
12 pub session_id: u32,
13 pub capabilities: Vec<String>,
14}
15
16pub async fn exchange<S: AsyncRead + AsyncWrite + Unpin>(
17 stream: &mut S,
18 max_message_size: Option<usize>,
19) -> crate::Result<(ServerHello, FramingMode)> {
20 let server_hello_xml = codec::read_eom_message(stream, max_message_size).await?;
21 let server_hello = parse_server_hello(&server_hello_xml)?;
22
23 let client_hello_xml = build_client_hello();
24 codec::write_eom_message(stream, &client_hello_xml).await?;
25
26 let framing_mode = if server_hello
27 .capabilities
28 .iter()
29 .any(|c| c == capabilities::BASE_1_1)
30 {
31 FramingMode::Chunked
32 } else {
33 FramingMode::EndOfMessage
34 };
35 Ok((server_hello, framing_mode))
36}
37
38fn build_client_hello() -> String {
41 let mut writer = Writer::new_with_indent(Cursor::new(Vec::new()), b' ', 2);
42
43 writer
44 .write_event(Event::Decl(BytesDecl::new("1.0", Some("UTF-8"), None)))
45 .unwrap();
46
47 let mut hello = BytesStart::new("hello");
48 hello.push_attribute(("xmlns", "urn:ietf:params:xml:ns:netconf:base:1.0"));
49 writer.write_event(Event::Start(hello)).unwrap();
50
51 writer
52 .write_event(Event::Start(BytesStart::new("capabilities")))
53 .unwrap();
54
55 for cap in capabilities::client_capabilities() {
56 writer
57 .create_element("capability")
58 .write_text_content(BytesText::new(cap))
59 .unwrap();
60 }
61
62 writer
63 .write_event(Event::End(BytesEnd::new("capabilities")))
64 .unwrap();
65 writer
66 .write_event(Event::End(BytesEnd::new("hello")))
67 .unwrap();
68 String::from_utf8(writer.into_inner().into_inner()).unwrap()
69}
70
71pub(crate) fn parse_server_hello(xml: &str) -> crate::Result<ServerHello> {
72 let mut reader = Reader::from_str(xml);
73 reader.config_mut().trim_text(true);
74
75 let mut capabilities = Vec::new();
76 let mut session_id: Option<u32> = None;
77
78 let mut in_capability = false;
80 let mut in_session_id = false;
81
82 loop {
83 match reader.read_event() {
84 Ok(Event::Start(e)) | Ok(Event::Empty(e)) => {
85 let local_name = e.local_name();
86 match local_name.as_ref() {
87 b"capability" => in_capability = true,
88 b"session-id" => in_session_id = true,
89 _ => {}
90 }
91 }
92 Ok(Event::Text(e)) => {
93 let text = e
94 .xml_content()
95 .map_err(|err| {
96 crate::error::Error::HelloFailed(format!("XML unescape error: {err}"))
97 })?
98 .to_string();
99
100 if in_capability {
101 if !text.is_empty() {
102 capabilities.push(text);
103 }
104 } else if in_session_id {
105 session_id = Some(text.parse::<u32>().map_err(|e| {
106 crate::Error::HelloFailed(format!("invalid session-id '{text}': {e}"))
107 })?);
108 }
109 }
110 Ok(Event::End(e)) => {
111 let local_name = e.local_name();
112 match local_name.as_ref() {
113 b"capability" => in_capability = false,
114 b"session-id" => in_session_id = false,
115 _ => {}
116 }
117 }
118 Ok(Event::Eof) => break,
119 Err(e) => return Err(crate::Error::HelloFailed(format!("XML parse error: {e}"))),
120 _ => {}
121 }
122 }
123
124 let session_id =
125 session_id.ok_or_else(|| crate::Error::HelloFailed("missing session-id".into()))?;
126
127 if capabilities.is_empty() {
128 return Err(crate::Error::HelloFailed("no capabilities found".into()));
129 }
130 Ok(ServerHello {
131 session_id,
132 capabilities,
133 })
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_parse_server_hello() {
142 let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
143<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
144 <capabilities>
145 <capability>urn:ietf:params:netconf:base:1.0</capability>
146 <capability>urn:ietf:params:netconf:base:1.1</capability>
147 <capability>urn:ietf:params:netconf:capability:candidate:1.0</capability>
148 </capabilities>
149 <session-id>42</session-id>
150</hello>"#;
151 let hello = parse_server_hello(xml).unwrap();
152 assert_eq!(hello.session_id, 42);
153 assert_eq!(hello.capabilities.len(), 3);
154 assert!(
155 hello
156 .capabilities
157 .contains(&"urn:ietf:params:netconf:base:1.0".to_string())
158 );
159 assert!(
160 hello
161 .capabilities
162 .contains(&"urn:ietf:params:netconf:base:1.1".to_string())
163 );
164 }
165
166 #[test]
167 fn test_parse_hello_1_0_only() {
168 let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
169 <capabilities>
170 <capability>urn:ietf:params:netconf:base:1.0</capability>
171 </capabilities>
172 <session-id>1</session-id>
173</hello>"#;
174 let hello = parse_server_hello(xml).unwrap();
175 assert_eq!(hello.session_id, 1);
176 assert_eq!(hello.capabilities.len(), 1);
177 }
178
179 #[test]
180 fn test_parse_hello_missing_session_id() {
181 let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
182 <capabilities>
183 <capability>urn:ietf:params:netconf:base:1.0</capability>
184 </capabilities>
185</hello>"#;
186 let result = parse_server_hello(xml);
187 assert!(result.is_err());
188 }
189
190 #[test]
191 fn test_parse_hello_no_capabilities() {
192 let xml = r#"<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
193 <capabilities/>
194 <session-id>1</session-id>
195</hello>"#;
196 let result = parse_server_hello(xml);
197 assert!(result.is_err());
198 }
199
200 #[test]
201 fn test_build_client_hello() {
202 let hello = build_client_hello();
203 assert!(hello.contains("urn:ietf:params:netconf:base:1.0"));
204 assert!(hello.contains("urn:ietf:params:netconf:base:1.1"));
205 assert!(hello.contains("<hello"));
206 assert!(hello.contains("</hello>"));
207 assert!(hello.contains("<capability>"));
208 assert!(hello.contains("</capability>"));
209 }
210
211 #[test]
212 fn test_build_client_hello_is_valid_xml() {
213 let hello = build_client_hello();
214 let mut reader = Reader::from_str(&hello);
216 reader.config_mut().trim_text(true);
217 let mut caps = Vec::new();
218 let mut in_cap = false;
219 loop {
220 match reader.read_event() {
221 Ok(Event::Start(e)) if e.local_name().as_ref() == b"capability" => in_cap = true,
222 Ok(Event::Text(e)) if in_cap => {
223 caps.push(
224 e.xml_content()
225 .map_err(|e| e.to_string())
226 .unwrap()
227 .to_string(),
228 );
229 }
230 Ok(Event::End(e)) if e.local_name().as_ref() == b"capability" => in_cap = false,
231 Ok(Event::Eof) => break,
232 Err(e) => panic!("invalid XML: {e}"),
233 _ => {}
234 }
235 }
236 assert!(!caps.is_empty());
237 assert!(caps.contains(&"urn:ietf:params:netconf:base:1.0".to_string()));
238 assert!(caps.contains(&"urn:ietf:params:netconf:base:1.1".to_string()));
239 }
240
241 #[tokio::test]
242 async fn test_exchange_negotiates_chunked() {
243 use tokio::io::AsyncWriteExt;
244 use tokio::io::duplex;
245
246 let (mut client, mut server) = duplex(4096);
247
248 let server_hello = r#"<?xml version="1.0" encoding="UTF-8"?>
249<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
250 <capabilities>
251 <capability>urn:ietf:params:netconf:base:1.0</capability>
252 <capability>urn:ietf:params:netconf:base:1.1</capability>
253 </capabilities>
254 <session-id>99</session-id>
255</hello>"#;
256
257 tokio::spawn(async move {
258 use tokio::io::AsyncReadExt;
259 server
260 .write_all(format!("{server_hello}]]>]]>").as_bytes())
261 .await
262 .unwrap();
263 let mut buf = vec![0u8; 4096];
264 let _ = server.read(&mut buf).await;
265 });
266
267 let (hello, mode) = exchange(&mut client, None).await.unwrap();
268 assert_eq!(hello.session_id, 99);
269 assert_eq!(mode, FramingMode::Chunked);
270 }
271
272 #[tokio::test]
273 async fn test_exchange_negotiates_eom() {
274 use tokio::io::AsyncWriteExt;
275 use tokio::io::duplex;
276
277 let (mut client, mut server) = duplex(4096);
278
279 let server_hello = r#"<?xml version="1.0" encoding="UTF-8"?>
280<hello xmlns="urn:ietf:params:xml:ns:netconf:base:1.0">
281 <capabilities>
282 <capability>urn:ietf:params:netconf:base:1.0</capability>
283 </capabilities>
284 <session-id>10</session-id>
285</hello>"#;
286
287 tokio::spawn(async move {
288 use tokio::io::AsyncReadExt;
289 server
290 .write_all(format!("{server_hello}]]>]]>").as_bytes())
291 .await
292 .unwrap();
293 let mut buf = vec![0u8; 4096];
294 let _ = server.read(&mut buf).await;
295 });
296
297 let (hello, mode) = exchange(&mut client, None).await.unwrap();
298 assert_eq!(hello.session_id, 10);
299 assert_eq!(mode, FramingMode::EndOfMessage);
300 }
301}