1use std::{convert::TryInto, io::Write};
2
3use crate::{entry, extension, write::util::CountBytes, State, Version};
4
5#[derive(Debug, Copy, Clone)]
7pub enum Extensions {
8 All,
10 Given {
17 tree_cache: bool,
19 end_of_index_entry: bool,
21 },
22 None,
24}
25
26impl Default for Extensions {
27 fn default() -> Self {
28 Extensions::All
29 }
30}
31
32impl Extensions {
33 pub fn should_write(&self, signature: extension::Signature) -> Option<extension::Signature> {
35 match self {
36 Extensions::None => None,
37 Extensions::All => Some(signature),
38 Extensions::Given {
39 tree_cache,
40 end_of_index_entry,
41 } => match signature {
42 extension::tree::SIGNATURE => tree_cache,
43 extension::end_of_index_entry::SIGNATURE => end_of_index_entry,
44 _ => &false,
45 }
46 .then(|| signature),
47 }
48 }
49}
50
51#[derive(Debug, Default, Clone, Copy)]
55pub struct Options {
56 pub extensions: Extensions,
58}
59
60impl State {
61 pub fn write_to(&self, out: impl std::io::Write, Options { extensions }: Options) -> std::io::Result<Version> {
63 let version = self.detect_required_version();
64
65 let mut write = CountBytes::new(out);
66 let num_entries: u32 = self
67 .entries()
68 .len()
69 .try_into()
70 .expect("definitely not 4billion entries");
71 let removed_entries: u32 = self
72 .entries()
73 .iter()
74 .filter(|e| e.flags.contains(entry::Flags::REMOVE))
75 .count()
76 .try_into()
77 .expect("definitely not too many entries");
78
79 let offset_to_entries = header(&mut write, version, num_entries - removed_entries)?;
80 let offset_to_extensions = entries(&mut write, self, offset_to_entries)?;
81 let (extension_toc, out) = self.write_extensions(write, offset_to_extensions, extensions)?;
82
83 if num_entries > 0
84 && extensions
85 .should_write(extension::end_of_index_entry::SIGNATURE)
86 .is_some()
87 && !extension_toc.is_empty()
88 {
89 extension::end_of_index_entry::write_to(out, self.object_hash, offset_to_extensions, extension_toc)?
90 }
91
92 Ok(version)
93 }
94
95 fn write_extensions<T>(
96 &self,
97 mut write: CountBytes<T>,
98 offset_to_extensions: u32,
99 extensions: Extensions,
100 ) -> std::io::Result<(Vec<(extension::Signature, u32)>, T)>
101 where
102 T: std::io::Write,
103 {
104 type WriteExtFn<'a> = &'a dyn Fn(&mut dyn std::io::Write) -> Option<std::io::Result<extension::Signature>>;
105 let extensions: &[WriteExtFn<'_>] = &[
106 &|write| {
107 extensions
108 .should_write(extension::tree::SIGNATURE)
109 .and_then(|signature| self.tree().map(|tree| tree.write_to(write).map(|_| signature)))
110 },
111 &|write| {
112 self.is_sparse()
113 .then(|| extension::sparse::write_to(write).map(|_| extension::sparse::SIGNATURE))
114 },
115 ];
116
117 let mut offset_to_previous_ext = offset_to_extensions;
118 let mut out = Vec::with_capacity(5);
119 for write_ext in extensions {
120 if let Some(signature) = write_ext(&mut write).transpose()? {
121 let offset_past_ext = write.count;
122 let ext_size = offset_past_ext - offset_to_previous_ext - (extension::MIN_SIZE as u32);
123 offset_to_previous_ext = offset_past_ext;
124 out.push((signature, ext_size));
125 }
126 }
127 Ok((out, write.inner))
128 }
129}
130
131impl State {
132 fn detect_required_version(&self) -> Version {
133 self.entries
134 .iter()
135 .find_map(|e| e.flags.contains(entry::Flags::EXTENDED).then_some(Version::V3))
136 .unwrap_or(Version::V2)
137 }
138}
139
140fn header<T: std::io::Write>(
141 out: &mut CountBytes<T>,
142 version: Version,
143 num_entries: u32,
144) -> Result<u32, std::io::Error> {
145 let version = match version {
146 Version::V2 => 2_u32.to_be_bytes(),
147 Version::V3 => 3_u32.to_be_bytes(),
148 Version::V4 => 4_u32.to_be_bytes(),
149 };
150
151 out.write_all(crate::decode::header::SIGNATURE)?;
152 out.write_all(&version)?;
153 out.write_all(&num_entries.to_be_bytes())?;
154
155 Ok(out.count)
156}
157
158fn entries<T: std::io::Write>(out: &mut CountBytes<T>, state: &State, header_size: u32) -> Result<u32, std::io::Error> {
159 for entry in state.entries() {
160 if entry.flags.contains(entry::Flags::REMOVE) {
161 continue;
162 }
163 entry.write_to(&mut *out, state)?;
164 match (out.count - header_size) % 8 {
165 0 => {}
166 n => {
167 let eight_null_bytes = [0u8; 8];
168 out.write_all(&eight_null_bytes[n as usize..])?;
169 }
170 };
171 }
172
173 Ok(out.count)
174}
175
176mod util {
177 use std::convert::TryFrom;
178
179 pub struct CountBytes<T> {
180 pub count: u32,
181 pub inner: T,
182 }
183
184 impl<T> CountBytes<T>
185 where
186 T: std::io::Write,
187 {
188 pub fn new(inner: T) -> Self {
189 CountBytes { inner, count: 0 }
190 }
191 }
192
193 impl<T> std::io::Write for CountBytes<T>
194 where
195 T: std::io::Write,
196 {
197 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
198 let written = self.inner.write(buf)?;
199 self.count = self
200 .count
201 .checked_add(u32::try_from(written).expect("we don't write 4GB buffers"))
202 .ok_or_else(|| {
203 std::io::Error::new(
204 std::io::ErrorKind::Other,
205 "Cannot write indices larger than 4 gigabytes",
206 )
207 })?;
208 Ok(written)
209 }
210
211 fn flush(&mut self) -> std::io::Result<()> {
212 self.inner.flush()
213 }
214 }
215}