1use std::io::Write;
2use std::convert::{TryInto, TryFrom};
3
4use crate::errors::tool::ToolError;
5use crate::spec_util::validate_tag_path;
6
7use super::tag_iterator_util::EBMLSize::{self, Known, Unknown};
8
9use super::tools::{Vint, is_vint};
10use super::specs::{EbmlSpecification, EbmlTag, TagDataType, Master};
11
12use super::errors::tag_writer::TagWriterError;
13
14pub struct WriteOptions
18{
19 size_byte_length: Option<usize>,
20 unknown_sized_element: bool,
21}
22
23impl WriteOptions {
24 pub fn set_size_byte_count(len: usize) -> Self {
34 assert!(len > 0 && len < 9, "Size byte count for written vints must be within 1-8 (inclusive)");
35 Self {
36 size_byte_length: Some(len),
37 unknown_sized_element: false
38 }
39 }
40
41 pub fn is_unknown_sized_element() -> Self {
47 Self {
48 size_byte_length: None,
49 unknown_sized_element: true
50 }
51 }
52}
53
54pub struct TagWriter<W: Write>
60{
61 dest: W,
62 open_tags: Vec<(u64, EBMLSize, usize)>,
63 working_buffer: Vec<u8>,
64}
65
66impl<W: Write> TagWriter<W>
67{
68 pub fn new(dest: W) -> Self {
74 TagWriter {
75 dest,
76 open_tags: Vec::new(),
77 working_buffer: Vec::new(),
78 }
79 }
80
81 pub fn into_inner(mut self) -> Result<W, TagWriterError> {
87 self.flush()?;
88 Ok(self.dest)
89 }
90
91 pub fn get_mut(&mut self) -> &mut W {
95 &mut self.dest
96 }
97
98 pub fn get_ref(&self) -> &W {
102 &self.dest
103 }
104
105 fn start_tag(&mut self, id: u64, size_length: usize) {
106 self.open_tags.push((id, Known(self.working_buffer.len()), size_length));
107 }
108
109 fn start_unknown_size_tag(&mut self, id: u64) {
110 self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
111 self.working_buffer.extend_from_slice(&(u64::MAX >> 7).to_be_bytes());
112 self.open_tags.push((id, Unknown, 0));
113 }
114
115 fn end_tag(&mut self, id: u64) -> Result<(), TagWriterError> {
116 match self.open_tags.pop() {
117 Some(open_tag) => {
118 if open_tag.0 == id {
119 if let Known(start) = open_tag.1 {
120 let size: u64 = self.working_buffer.len()
121 .checked_sub(start).expect("overflow subtracting tag size from working buffer length")
122 .try_into().expect("couldn't convert usize to u64");
123
124 match open_tag.2 {
125 1 => { let size_vint = size.as_vint_with_length::<1>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
126 2 => { let size_vint = size.as_vint_with_length::<2>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
127 3 => { let size_vint = size.as_vint_with_length::<3>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
128 4 => { let size_vint = size.as_vint_with_length::<4>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
129 5 => { let size_vint = size.as_vint_with_length::<5>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
130 6 => { let size_vint = size.as_vint_with_length::<6>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
131 7 => { let size_vint = size.as_vint_with_length::<7>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
132 8 => { let size_vint = size.as_vint_with_length::<8>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
133 _ => { let size_vint = size.as_vint().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
134 };
135 }
136 Ok(())
137 } else {
138 Err(TagWriterError::UnexpectedClosingTag { tag_id: id, expected_id: Some(open_tag.0) })
139 }
140 },
141 None => Err(TagWriterError::UnexpectedClosingTag { tag_id: id, expected_id: None })
142 }
143 }
144
145 fn private_flush(&mut self) -> Result<(), TagWriterError> {
146 self.dest.write_all(self.working_buffer.drain(..).as_slice()).map_err(|source| TagWriterError::WriteError { source })?;
147 self.dest.flush().map_err(|source| TagWriterError::WriteError { source })
148 }
149
150 fn write_unsigned_int_tag<const SIZE_LENGTH: usize>(&mut self, id: u64, data: &u64) -> Result<(), TagWriterError> {
151 self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
152 let data = *data;
153
154 u8::try_from(data).map(|n| {
155 if SIZE_LENGTH == 0 {
156 self.working_buffer.push(0x81); self.working_buffer.extend_from_slice(&n.to_be_bytes());
158 } else {
159 self.working_buffer.extend_from_slice(&1u8.as_vint_with_length::<SIZE_LENGTH>()?);
160 self.working_buffer.extend_from_slice(&n.to_be_bytes());
161 }
162 Ok(())
163 })
164 .or_else(|_| u16::try_from(data).map(|n| {
165 if SIZE_LENGTH == 0 {
166 self.working_buffer.push(0x82); self.working_buffer.extend_from_slice(&n.to_be_bytes());
168 } else {
169 self.working_buffer.extend_from_slice(&2u8.as_vint_with_length::<SIZE_LENGTH>()?);
170 self.working_buffer.extend_from_slice(&n.to_be_bytes());
171 }
172 Ok(())
173 }))
174 .or_else(|_| u32::try_from(data).map(|n| {
175 if SIZE_LENGTH == 0 {
176 self.working_buffer.push(0x84); self.working_buffer.extend_from_slice(&n.to_be_bytes());
178 } else {
179 self.working_buffer.extend_from_slice(&4u8.as_vint_with_length::<SIZE_LENGTH>()?);
180 self.working_buffer.extend_from_slice(&n.to_be_bytes());
181 }
182 Ok(())
183 }))
184 .unwrap_or_else(|_| {
185 if SIZE_LENGTH == 0 {
186 self.working_buffer.push(0x88); self.working_buffer.extend_from_slice(&data.to_be_bytes());
188 } else {
189 self.working_buffer.extend_from_slice(&8u8.as_vint_with_length::<SIZE_LENGTH>()?);
190 self.working_buffer.extend_from_slice(&data.to_be_bytes());
191 }
192 Ok(())
193 }).map_err(|err: ToolError| TagWriterError::TagSizeError(err.to_string()))
194 }
195
196 fn write_signed_int_tag<const SIZE_LENGTH: usize>(&mut self, id: u64, data: &i64) -> Result<(), TagWriterError> {
197 self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
198 let data = *data;
199 i8::try_from(data).map(|n| {
200 if SIZE_LENGTH == 0 {
201 self.working_buffer.push(0x81); self.working_buffer.extend_from_slice(&n.to_be_bytes());
203 } else {
204 self.working_buffer.extend_from_slice(&1u8.as_vint_with_length::<SIZE_LENGTH>()?);
205 self.working_buffer.extend_from_slice(&n.to_be_bytes());
206 }
207 Ok(())
208 })
209 .or_else(|_| i16::try_from(data).map(|n| {
210 if SIZE_LENGTH == 0 {
211 self.working_buffer.push(0x82); self.working_buffer.extend_from_slice(&n.to_be_bytes());
213 } else {
214 self.working_buffer.extend_from_slice(&2u8.as_vint_with_length::<SIZE_LENGTH>()?);
215 self.working_buffer.extend_from_slice(&n.to_be_bytes());
216 }
217 Ok(())
218 }))
219 .or_else(|_| i32::try_from(data).map(|n| {
220 if SIZE_LENGTH == 0 {
221 self.working_buffer.push(0x84); self.working_buffer.extend_from_slice(&n.to_be_bytes());
223 } else {
224 self.working_buffer.extend_from_slice(&4u8.as_vint_with_length::<SIZE_LENGTH>()?);
225 self.working_buffer.extend_from_slice(&n.to_be_bytes());
226 }
227 Ok(())
228 }))
229 .unwrap_or_else(|_| {
230 if SIZE_LENGTH == 0 {
231 self.working_buffer.push(0x88); self.working_buffer.extend_from_slice(&data.to_be_bytes());
233 } else {
234 self.working_buffer.extend_from_slice(&8u8.as_vint_with_length::<SIZE_LENGTH>()?);
235 self.working_buffer.extend_from_slice(&data.to_be_bytes());
236 }
237 Ok(())
238 }).map_err(|err: ToolError| TagWriterError::TagSizeError(err.to_string()))
239 }
240
241 fn write_utf8_tag<const SIZE_LENGTH: usize>(&mut self, id: u64, data: &str) -> Result<(), TagWriterError> {
242 self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
243
244 let slice: &[u8] = data.as_bytes();
245 let size: u64 = slice.len().try_into().expect("couldn't convert usize to u64");
246 if SIZE_LENGTH == 0 {
247 let size_vint = size.as_vint().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?;
248 self.working_buffer.extend_from_slice(&size_vint);
249 } else {
250 let size_vint = size.as_vint_with_length::<SIZE_LENGTH>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?;
251 self.working_buffer.extend_from_slice(&size_vint);
252 };
253
254 self.working_buffer.extend_from_slice(slice);
255 Ok(())
256 }
257
258 fn write_binary_tag<const SIZE_LENGTH: usize>(&mut self, id: u64, data: &[u8]) -> Result<(), TagWriterError> {
259 self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
260
261 let size: u64 = data.len().try_into().expect("couldn't convert usize to u64");
262 if SIZE_LENGTH == 0 {
263 let size_vint = size.as_vint().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?;
264 self.working_buffer.extend_from_slice(&size_vint);
265 } else {
266 let size_vint = size.as_vint_with_length::<SIZE_LENGTH>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?;
267 self.working_buffer.extend_from_slice(&size_vint);
268 }
269
270 self.working_buffer.extend_from_slice(data);
271 Ok(())
272 }
273
274 fn write_float_tag<const SIZE_LENGTH: usize>(&mut self, id: u64, data: &f64) -> Result<(), TagWriterError> {
275 self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
276 if SIZE_LENGTH == 0 {
277 self.working_buffer.push(0x88); } else {
279 let size_vint = 8u8.as_vint_with_length::<SIZE_LENGTH>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?;
280 self.working_buffer.extend_from_slice(&size_vint);
281 }
282 self.working_buffer.extend_from_slice(&data.to_be_bytes());
283 Ok(())
284 }
285
286 pub fn write<TSpec: EbmlSpecification<TSpec> + EbmlTag<TSpec> + Clone>(&mut self, tag: &TSpec) -> Result<(), TagWriterError> {
319 self.write_advanced(tag, WriteOptions { size_byte_length: None, unknown_sized_element: false })
320 }
321
322 pub fn write_advanced<TSpec: EbmlSpecification<TSpec> + EbmlTag<TSpec> + Clone>(&mut self, tag: &TSpec, options: WriteOptions) -> Result<(), TagWriterError> {
336 let tag_id = tag.get_id();
337 let tag_type = TSpec::get_tag_data_type(tag_id);
338
339 if options.unknown_sized_element {
340 match tag_type {
341 Some(TagDataType::Master) => {},
342 _ => {
343 return Err(TagWriterError::TagSizeError(format!("Cannot write an unknown size for tag of type {tag_type:?}")))
344 }
345 };
346 self.start_unknown_size_tag(tag_id);
347 } else {
348 let should_validate = tag_type.is_some() && (!matches!(tag_type, Some(TagDataType::Master)) || !matches!(tag.as_master().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was master, but could not get tag!", tag_id)), Master::End));
349 if should_validate && !validate_tag_path::<TSpec>(tag_id, self.open_tags.iter().copied()) {
350 return Err(TagWriterError::UnexpectedTag { tag_id, current_path: self.open_tags.iter().map(|t| t.0).collect() });
351 }
352
353 match options.size_byte_length {
354 Some(1) => self.write_explicit_sized::<TSpec, 1>(tag, tag_id, tag_type)?,
355 Some(2) => self.write_explicit_sized::<TSpec, 2>(tag, tag_id, tag_type)?,
356 Some(3) => self.write_explicit_sized::<TSpec, 3>(tag, tag_id, tag_type)?,
357 Some(4) => self.write_explicit_sized::<TSpec, 4>(tag, tag_id, tag_type)?,
358 Some(5) => self.write_explicit_sized::<TSpec, 5>(tag, tag_id, tag_type)?,
359 Some(6) => self.write_explicit_sized::<TSpec, 6>(tag, tag_id, tag_type)?,
360 Some(7) => self.write_explicit_sized::<TSpec, 7>(tag, tag_id, tag_type)?,
361 Some(8) => self.write_explicit_sized::<TSpec, 8>(tag, tag_id, tag_type)?,
362 _ => self.write_explicit_sized::<TSpec, 0>(tag, tag_id, tag_type)?,
363 }
364 }
365
366 Ok(())
367 }
368
369 fn write_explicit_sized<TSpec: EbmlSpecification<TSpec> + EbmlTag<TSpec> + Clone, const SIZE_LENGTH: usize>(&mut self, tag: &TSpec, tag_id: u64, tag_type: Option<TagDataType>) -> Result<(), TagWriterError> {
370 assert!(SIZE_LENGTH < 9, "Vint length must be less than 9 bytes");
371 match tag_type {
372 Some(TagDataType::UnsignedInt) => {
373 let val = tag.as_unsigned_int().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was unsigned int, but could not get tag!", tag_id));
374 self.write_unsigned_int_tag::<SIZE_LENGTH>(tag_id, val)?
375 },
376 Some(TagDataType::Integer) => {
377 let val = tag.as_signed_int().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was integer, but could not get tag!", tag_id));
378 self.write_signed_int_tag::<SIZE_LENGTH>(tag_id, val)?
379 },
380 Some(TagDataType::Utf8) => {
381 let val = tag.as_utf8().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was utf8, but could not get tag!", tag_id));
382 self.write_utf8_tag::<SIZE_LENGTH>(tag_id, val)?
383 },
384 Some(TagDataType::Binary) => {
385 let val = tag.as_binary().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was binary, but could not get tag!", tag_id));
386 self.write_binary_tag::<SIZE_LENGTH>(tag_id, val)?
387 },
388 Some(TagDataType::Float) => {
389 let val = tag.as_float().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was float, but could not get tag!", tag_id));
390 self.write_float_tag::<SIZE_LENGTH>(tag_id, val)?
391 },
392 Some(TagDataType::Master) => {
393 let position = tag.as_master().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was master, but could not get tag!", tag_id));
394
395 match position {
396 Master::Start => self.start_tag(tag_id, SIZE_LENGTH),
397 Master::End => self.end_tag(tag_id)?,
398 Master::Full(children) => {
399 self.start_tag(tag_id, SIZE_LENGTH);
400 for child in children {
401 self.write(child)?;
402 }
403 self.end_tag(tag_id)?;
404 }
405 }
406 },
407 None => { if !is_vint(tag_id) {
409 return Err(TagWriterError::TagIdError(tag_id));
410 } else {
411 let val = tag.as_binary().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was raw tag, but could not get binary data!", tag_id));
412 self.write_binary_tag::<SIZE_LENGTH>(tag_id, val)?
413 }
414 }
415 }
416
417 if !self.open_tags.iter().any(|t| matches!(t.1, Known(_))) {
418 self.private_flush()
419 } else {
420 Ok(())
421 }
422 }
423
424 #[deprecated(since="0.6.0", note="Please use 'write_advanced' with WriteOptions obtained using 'is_unknown_sized_element' instead")]
436 pub fn write_unknown_size<TSpec: EbmlSpecification<TSpec> + EbmlTag<TSpec> + Clone>(&mut self, tag: &TSpec) -> Result<(), TagWriterError> {
437 let tag_id = tag.get_id();
438 let tag_type = TSpec::get_tag_data_type(tag_id);
439 match tag_type {
440 Some(TagDataType::Master) => {},
441 _ => {
442 return Err(TagWriterError::TagSizeError(format!("Cannot write an unknown size for tag of type {tag_type:?}")))
443 }
444 };
445 self.start_unknown_size_tag(tag_id);
446 Ok(())
447 }
448
449 pub fn write_raw(&mut self, tag_id: u64, data: &[u8]) -> Result<(), TagWriterError> {
473 self.write_binary_tag::<0>(tag_id, data)?;
474
475 if !self.open_tags.iter().any(|t| matches!(t.1, Known(_))) {
476 self.private_flush()
477 } else {
478 Ok(())
479 }
480 }
481
482 pub fn flush(&mut self) -> Result<(), TagWriterError> {
492 while let Some(id) = self.open_tags.last().map(|t| t.0) {
493 self.end_tag(id)?;
494 }
495 self.private_flush()
496 }
497
498 }
500
501#[cfg(test)]
502mod tests {
503 use std::io::Cursor;
504
505 use super::super::tools::Vint;
506 use super::TagWriter;
507
508 #[test]
509 fn write_ebml_tag() {
510 let mut dest = Cursor::new(Vec::new());
511 let mut writer = TagWriter::new(&mut dest);
512 writer.write_raw(0x1a45dfa3, &[]).expect("Error writing tag");
513
514 let zero_size = 0u64.as_vint().expect("Error converting [0] to vint")[0];
515 assert_eq!(vec![0x1a, 0x45, 0xdf, 0xa3, zero_size], dest.get_ref().to_vec());
516 }
517}