libwild/
version_script.rs1use crate::error::Result;
8use crate::hash::PassThroughHasher;
9use crate::hash::PreHashed;
10use crate::input_data::VersionScriptData;
11use crate::linker_script::skip_comments_and_whitespace;
12use crate::symbol::UnversionedSymbolName;
13use anyhow::anyhow;
14use std::collections::HashSet;
15use winnow::BStr;
16use winnow::Parser;
17use winnow::error::ContextError;
18use winnow::error::FromExternalError;
19use winnow::token::take_until;
20use winnow::token::take_while;
21
22#[derive(Default)]
24pub(crate) struct VersionScript<'data> {
25 globals: MatchRules<'data>,
27 locals: MatchRules<'data>,
28 versions: Vec<Version<'data>>,
29}
30
31pub(crate) struct Version<'data> {
32 pub(crate) name: &'data [u8],
33 pub(crate) parent_index: Option<u16>,
34 symbols: MatchRules<'data>,
35}
36
37#[derive(Default)]
38struct MatchRules<'data> {
39 matches_all: bool,
40 exact: HashSet<PreHashed<UnversionedSymbolName<'data>>, PassThroughHasher>,
41 prefixes: Vec<&'data [u8]>,
42}
43
44impl<'data> MatchRules<'data> {
45 fn push(&mut self, pattern: SymbolMatcher<'data>) {
46 match pattern {
47 SymbolMatcher::All => self.matches_all = true,
48 SymbolMatcher::Prefix(prefix) => self.prefixes.push(prefix),
49 SymbolMatcher::Exact(exact) => {
50 self.exact.insert(UnversionedSymbolName::prehashed(exact));
51 }
52 }
53 }
54
55 fn matches(&self, name: &PreHashed<UnversionedSymbolName>) -> bool {
56 self.matches_all
57 || self.exact.contains(name)
58 || self
59 .prefixes
60 .iter()
61 .any(|prefix| name.bytes().starts_with(prefix))
62 }
63
64 fn merge(&mut self, other: &MatchRules<'data>) {
65 if other.matches_all {
66 self.matches_all = true;
67 }
68
69 if self.matches_all {
70 self.exact.clear();
71 self.prefixes.clear();
72 return;
73 }
74
75 self.exact.extend(&other.exact);
76 self.prefixes.extend(&other.prefixes);
77 }
78}
79
80#[derive(Debug, PartialEq, Eq, Clone, Copy)]
81enum SymbolMatcher<'data> {
82 All,
83 Prefix(&'data [u8]),
84 Exact(&'data [u8]),
85}
86
87fn parse_version_script<'input>(input: &mut &'input BStr) -> winnow::Result<VersionScript<'input>> {
88 let mut version_names: Vec<&[u8]> = Vec::new();
90
91 skip_comments_and_whitespace(input)?;
92
93 if input.starts_with(b"{") {
95 let script = parse_version_section(input)?;
96
97 ";".parse_next(input)?;
98
99 skip_comments_and_whitespace(input)?;
100
101 return Ok(script);
102 }
103
104 let mut version_script = VersionScript::default();
105
106 version_names.push(b"");
108 version_script.versions.push(Version {
109 name: b"",
110 symbols: MatchRules::default(),
111 parent_index: None,
112 });
113
114 while !input.is_empty() {
115 let name = parse_token(input)?;
116
117 skip_comments_and_whitespace(input)?;
118
119 let version = parse_version_section(input)?;
120
121 let parent_name = take_until(0.., b';').parse_next(input)?;
122
123 let parent_index = if parent_name.is_empty() {
124 None
125 } else {
126 Some(
128 version_names
129 .iter()
130 .position(|v| v == &parent_name)
131 .ok_or_else(|| {
132 ContextError::from_external_error(
133 input,
134 VersionScriptError::UnknownParentVersion,
135 )
136 })? as u16,
137 )
138 };
139
140 ";".parse_next(input)?;
141
142 skip_comments_and_whitespace(input)?;
143
144 version_script.globals.merge(&version.globals);
145 version_script.locals.merge(&version.locals);
146
147 let mut version_symbols = MatchRules::default();
148 version_symbols.merge(&version.globals);
149 version_symbols.merge(&version.locals);
150
151 version_names.push(name);
152
153 version_script.versions.push(Version {
154 name,
155 parent_index,
156 symbols: version_symbols,
157 });
158 }
159
160 Ok(version_script)
161}
162
163impl<'data> VersionScript<'data> {
164 #[tracing::instrument(skip_all, name = "Parse version script")]
165 pub(crate) fn parse(data: VersionScriptData<'data>) -> Result<VersionScript<'data>> {
166 parse_version_script
167 .parse(BStr::new(data.raw))
168 .map_err(|err| anyhow!("Failed to parse version script:\n{err}",))
169 }
170
171 pub(crate) fn is_local(&self, name: &PreHashed<UnversionedSymbolName>) -> bool {
172 if self.globals.matches(name) {
173 return false;
174 }
175 self.locals.matches(name)
176 }
177
178 pub(crate) fn version_count(&self) -> u16 {
180 self.versions.len() as u16
181 }
182
183 pub(crate) fn parent_count(&self) -> u16 {
184 self.versions
185 .iter()
186 .filter(|v| v.parent_index.is_some())
187 .count() as u16
188 }
189
190 pub(crate) fn version_iter(&self) -> impl Iterator<Item = &Version> {
191 self.versions.iter()
192 }
193
194 pub(crate) fn version_for_symbol(
195 &self,
196 name: &PreHashed<UnversionedSymbolName>,
197 ) -> Option<u16> {
198 self.versions.iter().enumerate().find_map(|(number, ver)| {
199 ver.is_present(name)
200 .then(|| number as u16 + object::elf::VER_NDX_GLOBAL)
201 })
202 }
203}
204
205enum VersionRuleSection {
206 Global,
207 Local,
208}
209
210fn parse_version_section<'data>(input: &mut &'data BStr) -> winnow::Result<VersionScript<'data>> {
211 let mut section = None;
212
213 let mut out = VersionScript::default();
214
215 '{'.parse_next(input)?;
216
217 loop {
218 skip_comments_and_whitespace(input)?;
219
220 if input.starts_with(b"}") {
221 '}'.parse_next(input)?;
222 skip_comments_and_whitespace(input)?;
223 break;
224 }
225
226 if input.starts_with(b"global:") {
227 "global:".parse_next(input)?;
228 section = Some(VersionRuleSection::Global);
229 } else if input.starts_with(b"local:") {
230 "local:".parse_next(input)?;
231 section = Some(VersionRuleSection::Local);
232 } else {
233 let matcher = parse_matcher(input)?;
234
235 match section {
236 Some(VersionRuleSection::Global) | None => {
237 out.globals.push(matcher);
238 }
239 Some(VersionRuleSection::Local) => {
240 out.locals.push(matcher);
241 }
242 }
243 }
244 }
245
246 Ok(out)
247}
248
249impl Version<'_> {
250 fn is_present(&self, name: &PreHashed<UnversionedSymbolName>) -> bool {
251 self.symbols.matches(name)
252 }
253}
254
255fn parse_matcher<'data>(input: &mut &'data BStr) -> winnow::Result<SymbolMatcher<'data>> {
256 let token = take_until(1.., b';').parse_next(input)?;
257
258 skip_comments_and_whitespace(input)?;
259
260 if input.starts_with(b";") {
261 ";".parse_next(input)?;
262 }
263
264 if token == b"*" {
265 return Ok(SymbolMatcher::All);
266 }
267
268 if let Some(prefix) = token.strip_suffix(b"*") {
269 if prefix.contains(&b'*') {
270 return Err(ContextError::new());
271 }
272 return Ok(SymbolMatcher::Prefix(prefix));
273 }
274
275 if token.contains(&b'*') {
276 return Err(ContextError::new());
277 }
278
279 Ok(SymbolMatcher::Exact(token))
280}
281
282fn parse_token<'input>(input: &mut &'input BStr) -> winnow::Result<&'input [u8]> {
283 take_while(1.., |b| !b" (){}\n\t".contains(&b)).parse_next(input)
284}
285
286#[derive(Debug)]
287enum VersionScriptError {
288 UnknownParentVersion,
289}
290
291impl std::error::Error for VersionScriptError {}
292
293impl std::fmt::Display for VersionScriptError {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 write!(f, "Unknown parent version")
296 }
297}
298
299impl std::fmt::Debug for Version<'_> {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 f.debug_struct("Version")
302 .field("name", &String::from_utf8_lossy(self.name))
303 .field("parent_index", &self.parent_index)
304 .field("symbols", &self.symbols)
305 .finish()
306 }
307}
308
309impl std::fmt::Debug for MatchRules<'_> {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 f.debug_struct("MatchRules")
312 .field("matches_all", &self.matches_all)
313 .field("exact", &self.exact)
314 .field("prefixes", &self.prefixes)
315 .finish()
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use itertools::Itertools;
323 use itertools::assert_equal;
324
325 #[test]
326 fn test_parse_simple_version_script() {
327 let data = VersionScriptData {
328 raw: br#"
329 # Comment starting with a hash
330 {global:
331 /* Single-line comment */
332 foo; /* Trailing comment */
333 bar*;
334 local:
335 /* Multi-line
336 comment */
337 *;
338 };"#,
339 };
340 let script = VersionScript::parse(data).unwrap();
341 assert_equal(
342 script
343 .globals
344 .exact
345 .iter()
346 .map(|s| std::str::from_utf8(s.bytes()).unwrap()),
347 ["foo"],
348 );
349 assert_equal(
350 script
351 .globals
352 .prefixes
353 .iter()
354 .map(|s| std::str::from_utf8(s).unwrap()),
355 ["bar"],
356 );
357 assert!(script.locals.matches_all);
358 }
359
360 #[test]
361 fn test_parse_version_script() {
362 let data = VersionScriptData {
363 raw: br#"
364 VERS_1.1 {
365 global:
366 foo1;
367 local:
368 old*;
369 };
370
371 VERS_1.2 {
372 foo2;
373 } VERS_1.1;
374 "#,
375 };
376 let script = VersionScript::parse(data).unwrap();
377 assert_eq!(script.versions.len(), 3);
378 assert_equal(
379 script
380 .globals
381 .exact
382 .iter()
383 .map(|s| std::str::from_utf8(s.bytes()).unwrap())
384 .sorted(),
385 ["foo1", "foo2"],
386 );
387 assert_equal(
388 script
389 .locals
390 .prefixes
391 .iter()
392 .map(|s| std::str::from_utf8(s).unwrap()),
393 ["old"],
394 );
395
396 let version = &script.versions[1];
397 assert_eq!(version.name, b"VERS_1.1");
398 assert_eq!(version.parent_index, None);
399 assert_equal(
400 version
401 .symbols
402 .exact
403 .iter()
404 .map(|s| std::str::from_utf8(s.bytes()).unwrap()),
405 ["foo1"],
406 );
407 assert_equal(
408 version
409 .symbols
410 .prefixes
411 .iter()
412 .map(|s| std::str::from_utf8(s).unwrap()),
413 ["old"],
414 );
415
416 let version = &script.versions[2];
417 assert_eq!(version.name, b"VERS_1.2");
418 assert_eq!(version.parent_index, Some(1));
419 assert_equal(
420 version
421 .symbols
422 .exact
423 .iter()
424 .map(|s| std::str::from_utf8(s.bytes()).unwrap()),
425 ["foo2"],
426 );
427 }
428
429 #[test]
430 fn single_line_version_script() {
431 let data = VersionScriptData {
432 raw: br#"VERSION42 { global: *; };"#,
433 };
434 let script = VersionScript::parse(data).unwrap();
435 assert!(script.globals.matches_all);
436 }
437
438 #[test]
439 fn invalid_version_scripts() {
440 #[track_caller]
441 fn assert_invalid(src: &str) {
442 let data = VersionScriptData {
443 raw: src.as_bytes(),
444 };
445 assert!(VersionScript::parse(data).is_err());
446 }
447
448 assert_invalid("{}");
450 assert_invalid("{*};");
451 assert_invalid("{foo};");
452
453 assert_invalid("{foo;");
455 assert_invalid("VER1 {foo;}; VER2 {bar;} VER1");
456
457 assert_invalid("VER2 {bar;} VER1;");
459 }
460}